Skip to main content

batuta/oracle/cookbook/
recipes_rlhf_alignment.rs

1//! RLHF core alignment recipes: SFT, Reward Model, DPO, DPO Variants
2//!
3//! Extracted from `register_rlhf_recipes` for TDG compliance (Refs #22).
4
5use super::Recipe;
6
7pub fn register_rlhf_alignment_recipes(cookbook: &mut super::Cookbook) {
8    // Supervised Fine-Tuning (SFT)
9    cookbook.add(
10        Recipe::new("rlhf-sft", "Supervised Fine-Tuning (SFT)")
11            .with_problem("Fine-tune a base model on instruction-following data")
12            .with_components(vec!["entrenar", "aprender", "alimentar"])
13            .with_tags(vec!["rlhf", "sft", "instruction-tuning", "fine-tuning", "alignment"])
14            .with_code(
15                r#"use entrenar::prelude::*;
16use entrenar::sft::*;
17
18// Load base model
19let model = Model::load("llama-7b.apr")?;
20
21// Load instruction dataset (prompt, response pairs)
22let dataset = SftDataset::load("alpaca_data.json")?
23    .with_prompt_template(
24        "Below is an instruction. Write a response.\n\n\
25         ### Instruction:\n{instruction}\n\n\
26         ### Response:\n"
27    );
28
29// Configure SFT trainer
30let config = SftConfig {
31    learning_rate: 2e-5,
32    batch_size: 4,
33    gradient_accumulation: 8,
34    max_seq_length: 2048,
35    warmup_ratio: 0.03,
36    weight_decay: 0.0,
37    ..Default::default()
38};
39
40// Train with LoRA for efficiency
41let model = model.with_lora(LoraConfig::default())?;
42let trainer = SftTrainer::new(model, config);
43
44trainer.train(&dataset, 3)?;  // 3 epochs
45model.save("sft-model.apr")?;
46
47// Evaluation
48let eval_loss = trainer.evaluate(&eval_dataset)?;
49println!("Eval loss: {:.4}", eval_loss);
50"#,
51            )
52            .with_related(vec!["rlhf-reward-model", "rlhf-dpo", "training-lora"])
53            .with_test_code(
54                r#"#[cfg(test)]
55mod tests {
56    #[test]
57    fn test_sft_config_defaults() {
58    let epochs = 3;
59    let batch_size = 4;
60    assert!(epochs > 0 && batch_size > 0);
61}
62
63    #[test]
64    fn test_prompt_template_has_placeholder() {
65    let template = "Instruction: {instruction} Response:";
66    assert!(template.contains("{instruction}"));
67}
68
69    #[test]
70    fn test_warmup_ratio_in_range() {
71    let warmup_ratio = 0.03_f64;
72    assert!(warmup_ratio >= 0.0 && warmup_ratio <= 1.0);
73}
74}"#,
75            ),
76    );
77
78    // Reward Modeling
79    cookbook.add(
80        Recipe::new("rlhf-reward-model", "Reward Model Training")
81            .with_problem("Train a reward model from human preference data")
82            .with_components(vec!["entrenar", "aprender"])
83            .with_tags(vec!["rlhf", "reward-model", "preferences", "ranking", "alignment"])
84            .with_code(
85                r#"use entrenar::prelude::*;
86use entrenar::reward::*;
87
88// Load SFT model as base
89let model = Model::load("sft-model.apr")?;
90
91// Convert to reward model (adds value head)
92let reward_model = RewardModel::from_base(model)?;
93
94// Load preference dataset (chosen vs rejected)
95let dataset = PreferenceDataset::load("preferences.json")?;
96// Each sample: { prompt, chosen, rejected }
97
98// Configure reward model training
99let config = RewardConfig {
100    learning_rate: 1e-5,
101    batch_size: 4,
102    max_length: 512,
103    // Use margin ranking loss
104    loss_fn: RewardLoss::MarginRanking { margin: 0.0 },
105    ..Default::default()
106};
107
108let trainer = RewardTrainer::new(reward_model, config);
109trainer.train(&dataset, 1)?;
110
111// Evaluate accuracy (chosen > rejected)
112let accuracy = trainer.evaluate(&eval_dataset)?;
113println!("Preference accuracy: {:.2}%", accuracy * 100.0);
114
115reward_model.save("reward-model.apr")?;
116
117// Inference: score a response
118let score = reward_model.score(&prompt, &response)?;
119println!("Reward score: {:.3}", score);
120"#,
121            )
122            .with_related(vec!["rlhf-sft", "rlhf-ppo", "rlhf-dpo"])
123            .with_test_code(
124                r"#[cfg(test)]
125mod tests {
126    #[test]
127    fn test_preference_accuracy_in_range() {
128    let accuracy = 0.72_f64;
129    assert!(accuracy >= 0.0 && accuracy <= 1.0);
130}
131
132    #[test]
133    fn test_margin_non_negative() {
134    let margin = 0.0_f64;
135    assert!(margin >= 0.0);
136}
137
138    #[test]
139    fn test_reward_score_ordering() {
140    let chosen_reward = 1.5_f64;
141    let rejected_reward = -0.3_f64;
142    assert!(chosen_reward > rejected_reward);
143}
144}",
145            ),
146    );
147
148    // Direct Preference Optimization (DPO)
149    cookbook.add(
150        Recipe::new("rlhf-dpo", "Direct Preference Optimization (DPO)")
151            .with_problem("Align models directly from preferences without reward modeling")
152            .with_components(vec!["entrenar", "aprender"])
153            .with_tags(vec!["rlhf", "dpo", "alignment", "preferences", "efficient"])
154            .with_code(
155                r#"use entrenar::prelude::*;
156use entrenar::dpo::*;
157
158// Load SFT model (policy) and reference model
159let policy = Model::load("sft-model.apr")?;
160let reference = Model::load("sft-model.apr")?;  // Frozen copy
161
162// Load preference dataset
163let dataset = PreferenceDataset::load("preferences.json")?;
164
165// Configure DPO
166let config = DpoConfig {
167    beta: 0.1,                    // KL penalty coefficient
168    learning_rate: 5e-7,
169    batch_size: 4,
170    gradient_accumulation: 4,
171    max_length: 512,
172    max_prompt_length: 256,
173    label_smoothing: 0.0,
174    ..Default::default()
175};
176
177// DPO loss: -log σ(β * (log π(y_w|x) - log π(y_l|x)))
178let trainer = DpoTrainer::new(policy, reference, config);
179trainer.train(&dataset, 1)?;
180
181// Evaluate
182let metrics = trainer.evaluate(&eval_dataset)?;
183println!("Accuracy: {:.2}%", metrics.accuracy * 100.0);
184println!("Chosen reward: {:.3}", metrics.chosen_reward);
185println!("Rejected reward: {:.3}", metrics.rejected_reward);
186
187policy.save("dpo-model.apr")?;
188"#,
189            )
190            .with_related(vec!["rlhf-sft", "rlhf-ipo", "rlhf-kto"])
191            .with_test_code(
192                r"#[cfg(test)]
193mod tests {
194    #[test]
195    fn test_beta_positive() {
196    let beta = 0.1_f64;
197    assert!(beta > 0.0);
198}
199
200    #[test]
201    fn test_max_length_exceeds_prompt_length() {
202    let max_prompt_length = 256;
203    let max_length = 512;
204    assert!(max_length > max_prompt_length);
205}
206
207    #[test]
208    fn test_label_smoothing_in_range() {
209    let label_smoothing = 0.0_f64;
210    assert!(label_smoothing >= 0.0 && label_smoothing <= 1.0);
211}
212}",
213            ),
214    );
215
216    // DPO Variants (IPO, KTO, ORPO)
217    cookbook.add(
218        Recipe::new("rlhf-dpo-variants", "DPO Variants: IPO, KTO, ORPO")
219            .with_problem("Use improved DPO variants for better alignment")
220            .with_components(vec!["entrenar", "aprender"])
221            .with_tags(vec!["rlhf", "dpo", "ipo", "kto", "orpo", "alignment"])
222            .with_code(
223                r#"use entrenar::prelude::*;
224use entrenar::dpo::*;
225
226// === IPO (Identity Preference Optimization) ===
227// Addresses DPO's overfitting with identity mapping
228let ipo_config = IpoConfig {
229    tau: 0.1,                     // Temperature parameter
230    learning_rate: 5e-7,
231    ..Default::default()
232};
233let trainer = IpoTrainer::new(policy.clone(), reference.clone(), ipo_config);
234
235// === KTO (Kahneman-Tversky Optimization) ===
236// Works with unpaired data (no need for chosen/rejected pairs)
237let kto_dataset = KtoDataset::load("ratings.json")?;
238// Each sample: { prompt, response, is_desirable: bool }
239
240let kto_config = KtoConfig {
241    beta: 0.1,
242    desirable_weight: 1.0,
243    undesirable_weight: 1.0,
244    ..Default::default()
245};
246let trainer = KtoTrainer::new(policy.clone(), reference.clone(), kto_config);
247trainer.train(&kto_dataset, 1)?;
248
249// === ORPO (Odds Ratio Preference Optimization) ===
250// No reference model needed - uses odds ratio
251let orpo_config = OrpoConfig {
252    beta: 0.1,
253    learning_rate: 8e-6,
254    ..Default::default()
255};
256// ORPO combines SFT and preference learning
257let trainer = OrpoTrainer::new(policy.clone(), orpo_config);
258trainer.train(&dataset, 1)?;
259
260// === SimPO (Simple Preference Optimization) ===
261// Length-normalized, reference-free
262let simpo_config = SimpoConfig {
263    beta: 2.5,
264    gamma: 0.5,                   // Target margin
265    ..Default::default()
266};
267let trainer = SimpoTrainer::new(policy, simpo_config);
268"#,
269            )
270            .with_related(vec!["rlhf-dpo", "rlhf-sft"])
271            .with_test_code(
272                r"#[cfg(test)]
273mod tests {
274    #[test]
275    fn test_ipo_tau_positive() {
276    let tau = 0.1_f64;
277    assert!(tau > 0.0);
278}
279
280    #[test]
281    fn test_kto_weights_non_negative() {
282    let desirable_weight = 1.0_f64;
283    let undesirable_weight = 1.0_f64;
284    assert!(desirable_weight >= 0.0 && undesirable_weight >= 0.0);
285}
286
287    #[test]
288    fn test_orpo_beta_in_valid_range() {
289    let beta = 0.1_f64;
290    assert!(beta > 0.0 && beta < 1.0);
291}
292}",
293            ),
294    );
295}