1use super::Recipe;
6
7pub fn register_rlhf_alignment_recipes(cookbook: &mut super::Cookbook) {
8 cookbook.add(
10 Recipe::new("rlhf-sft", "Supervised Fine-Tuning (SFT)")
11 .with_problem("Fine-tune a base model on instruction-following data")
12 .with_components(vec!["entrenar", "aprender", "alimentar"])
13 .with_tags(vec!["rlhf", "sft", "instruction-tuning", "fine-tuning", "alignment"])
14 .with_code(
15 r#"use entrenar::prelude::*;
16use entrenar::sft::*;
17
18// Load base model
19let model = Model::load("llama-7b.apr")?;
20
21// Load instruction dataset (prompt, response pairs)
22let dataset = SftDataset::load("alpaca_data.json")?
23 .with_prompt_template(
24 "Below is an instruction. Write a response.\n\n\
25 ### Instruction:\n{instruction}\n\n\
26 ### Response:\n"
27 );
28
29// Configure SFT trainer
30let config = SftConfig {
31 learning_rate: 2e-5,
32 batch_size: 4,
33 gradient_accumulation: 8,
34 max_seq_length: 2048,
35 warmup_ratio: 0.03,
36 weight_decay: 0.0,
37 ..Default::default()
38};
39
40// Train with LoRA for efficiency
41let model = model.with_lora(LoraConfig::default())?;
42let trainer = SftTrainer::new(model, config);
43
44trainer.train(&dataset, 3)?; // 3 epochs
45model.save("sft-model.apr")?;
46
47// Evaluation
48let eval_loss = trainer.evaluate(&eval_dataset)?;
49println!("Eval loss: {:.4}", eval_loss);
50"#,
51 )
52 .with_related(vec!["rlhf-reward-model", "rlhf-dpo", "training-lora"])
53 .with_test_code(
54 r#"#[cfg(test)]
55mod tests {
56 #[test]
57 fn test_sft_config_defaults() {
58 let epochs = 3;
59 let batch_size = 4;
60 assert!(epochs > 0 && batch_size > 0);
61}
62
63 #[test]
64 fn test_prompt_template_has_placeholder() {
65 let template = "Instruction: {instruction} Response:";
66 assert!(template.contains("{instruction}"));
67}
68
69 #[test]
70 fn test_warmup_ratio_in_range() {
71 let warmup_ratio = 0.03_f64;
72 assert!(warmup_ratio >= 0.0 && warmup_ratio <= 1.0);
73}
74}"#,
75 ),
76 );
77
78 cookbook.add(
80 Recipe::new("rlhf-reward-model", "Reward Model Training")
81 .with_problem("Train a reward model from human preference data")
82 .with_components(vec!["entrenar", "aprender"])
83 .with_tags(vec!["rlhf", "reward-model", "preferences", "ranking", "alignment"])
84 .with_code(
85 r#"use entrenar::prelude::*;
86use entrenar::reward::*;
87
88// Load SFT model as base
89let model = Model::load("sft-model.apr")?;
90
91// Convert to reward model (adds value head)
92let reward_model = RewardModel::from_base(model)?;
93
94// Load preference dataset (chosen vs rejected)
95let dataset = PreferenceDataset::load("preferences.json")?;
96// Each sample: { prompt, chosen, rejected }
97
98// Configure reward model training
99let config = RewardConfig {
100 learning_rate: 1e-5,
101 batch_size: 4,
102 max_length: 512,
103 // Use margin ranking loss
104 loss_fn: RewardLoss::MarginRanking { margin: 0.0 },
105 ..Default::default()
106};
107
108let trainer = RewardTrainer::new(reward_model, config);
109trainer.train(&dataset, 1)?;
110
111// Evaluate accuracy (chosen > rejected)
112let accuracy = trainer.evaluate(&eval_dataset)?;
113println!("Preference accuracy: {:.2}%", accuracy * 100.0);
114
115reward_model.save("reward-model.apr")?;
116
117// Inference: score a response
118let score = reward_model.score(&prompt, &response)?;
119println!("Reward score: {:.3}", score);
120"#,
121 )
122 .with_related(vec!["rlhf-sft", "rlhf-ppo", "rlhf-dpo"])
123 .with_test_code(
124 r"#[cfg(test)]
125mod tests {
126 #[test]
127 fn test_preference_accuracy_in_range() {
128 let accuracy = 0.72_f64;
129 assert!(accuracy >= 0.0 && accuracy <= 1.0);
130}
131
132 #[test]
133 fn test_margin_non_negative() {
134 let margin = 0.0_f64;
135 assert!(margin >= 0.0);
136}
137
138 #[test]
139 fn test_reward_score_ordering() {
140 let chosen_reward = 1.5_f64;
141 let rejected_reward = -0.3_f64;
142 assert!(chosen_reward > rejected_reward);
143}
144}",
145 ),
146 );
147
148 cookbook.add(
150 Recipe::new("rlhf-dpo", "Direct Preference Optimization (DPO)")
151 .with_problem("Align models directly from preferences without reward modeling")
152 .with_components(vec!["entrenar", "aprender"])
153 .with_tags(vec!["rlhf", "dpo", "alignment", "preferences", "efficient"])
154 .with_code(
155 r#"use entrenar::prelude::*;
156use entrenar::dpo::*;
157
158// Load SFT model (policy) and reference model
159let policy = Model::load("sft-model.apr")?;
160let reference = Model::load("sft-model.apr")?; // Frozen copy
161
162// Load preference dataset
163let dataset = PreferenceDataset::load("preferences.json")?;
164
165// Configure DPO
166let config = DpoConfig {
167 beta: 0.1, // KL penalty coefficient
168 learning_rate: 5e-7,
169 batch_size: 4,
170 gradient_accumulation: 4,
171 max_length: 512,
172 max_prompt_length: 256,
173 label_smoothing: 0.0,
174 ..Default::default()
175};
176
177// DPO loss: -log σ(β * (log π(y_w|x) - log π(y_l|x)))
178let trainer = DpoTrainer::new(policy, reference, config);
179trainer.train(&dataset, 1)?;
180
181// Evaluate
182let metrics = trainer.evaluate(&eval_dataset)?;
183println!("Accuracy: {:.2}%", metrics.accuracy * 100.0);
184println!("Chosen reward: {:.3}", metrics.chosen_reward);
185println!("Rejected reward: {:.3}", metrics.rejected_reward);
186
187policy.save("dpo-model.apr")?;
188"#,
189 )
190 .with_related(vec!["rlhf-sft", "rlhf-ipo", "rlhf-kto"])
191 .with_test_code(
192 r"#[cfg(test)]
193mod tests {
194 #[test]
195 fn test_beta_positive() {
196 let beta = 0.1_f64;
197 assert!(beta > 0.0);
198}
199
200 #[test]
201 fn test_max_length_exceeds_prompt_length() {
202 let max_prompt_length = 256;
203 let max_length = 512;
204 assert!(max_length > max_prompt_length);
205}
206
207 #[test]
208 fn test_label_smoothing_in_range() {
209 let label_smoothing = 0.0_f64;
210 assert!(label_smoothing >= 0.0 && label_smoothing <= 1.0);
211}
212}",
213 ),
214 );
215
216 cookbook.add(
218 Recipe::new("rlhf-dpo-variants", "DPO Variants: IPO, KTO, ORPO")
219 .with_problem("Use improved DPO variants for better alignment")
220 .with_components(vec!["entrenar", "aprender"])
221 .with_tags(vec!["rlhf", "dpo", "ipo", "kto", "orpo", "alignment"])
222 .with_code(
223 r#"use entrenar::prelude::*;
224use entrenar::dpo::*;
225
226// === IPO (Identity Preference Optimization) ===
227// Addresses DPO's overfitting with identity mapping
228let ipo_config = IpoConfig {
229 tau: 0.1, // Temperature parameter
230 learning_rate: 5e-7,
231 ..Default::default()
232};
233let trainer = IpoTrainer::new(policy.clone(), reference.clone(), ipo_config);
234
235// === KTO (Kahneman-Tversky Optimization) ===
236// Works with unpaired data (no need for chosen/rejected pairs)
237let kto_dataset = KtoDataset::load("ratings.json")?;
238// Each sample: { prompt, response, is_desirable: bool }
239
240let kto_config = KtoConfig {
241 beta: 0.1,
242 desirable_weight: 1.0,
243 undesirable_weight: 1.0,
244 ..Default::default()
245};
246let trainer = KtoTrainer::new(policy.clone(), reference.clone(), kto_config);
247trainer.train(&kto_dataset, 1)?;
248
249// === ORPO (Odds Ratio Preference Optimization) ===
250// No reference model needed - uses odds ratio
251let orpo_config = OrpoConfig {
252 beta: 0.1,
253 learning_rate: 8e-6,
254 ..Default::default()
255};
256// ORPO combines SFT and preference learning
257let trainer = OrpoTrainer::new(policy.clone(), orpo_config);
258trainer.train(&dataset, 1)?;
259
260// === SimPO (Simple Preference Optimization) ===
261// Length-normalized, reference-free
262let simpo_config = SimpoConfig {
263 beta: 2.5,
264 gamma: 0.5, // Target margin
265 ..Default::default()
266};
267let trainer = SimpoTrainer::new(policy, simpo_config);
268"#,
269 )
270 .with_related(vec!["rlhf-dpo", "rlhf-sft"])
271 .with_test_code(
272 r"#[cfg(test)]
273mod tests {
274 #[test]
275 fn test_ipo_tau_positive() {
276 let tau = 0.1_f64;
277 assert!(tau > 0.0);
278}
279
280 #[test]
281 fn test_kto_weights_non_negative() {
282 let desirable_weight = 1.0_f64;
283 let undesirable_weight = 1.0_f64;
284 assert!(desirable_weight >= 0.0 && undesirable_weight >= 0.0);
285}
286
287 #[test]
288 fn test_orpo_beta_in_valid_range() {
289 let beta = 0.1_f64;
290 assert!(beta > 0.0 && beta < 1.0);
291}
292}",
293 ),
294 );
295}