Skip to main content

batuta/oracle/cookbook/
recipes_rlhf_training.rs

1//! RLHF training infrastructure recipes: PPO, Stability, Evaluation
2//!
3//! Extracted from `register_rlhf_recipes` for TDG compliance (Refs #22).
4
5use super::Recipe;
6
7pub fn register_rlhf_training_recipes(cookbook: &mut super::Cookbook) {
8    // PPO for RLHF
9    cookbook.add(
10        Recipe::new("rlhf-ppo", "PPO for RLHF")
11            .with_problem("Train language models with PPO using a reward model")
12            .with_components(vec!["entrenar", "aprender"])
13            .with_tags(vec!["rlhf", "ppo", "reinforcement-learning", "alignment"])
14            .with_code(
15                r#"use entrenar::prelude::*;
16use entrenar::ppo::*;
17
18// Load models
19let policy = Model::load("sft-model.apr")?;
20let reference = Model::load("sft-model.apr")?;  // KL anchor
21let reward_model = RewardModel::load("reward-model.apr")?;
22
23// Configure PPO
24let config = PpoConfig {
25    // Learning
26    learning_rate: 1e-6,
27    batch_size: 64,
28    mini_batch_size: 8,
29    gradient_accumulation: 8,
30
31    // PPO hyperparameters
32    ppo_epochs: 4,
33    clip_range: 0.2,
34    clip_range_value: 0.2,
35    gamma: 1.0,
36    lam: 0.95,                    // GAE lambda
37
38    // KL control
39    kl_penalty: KlPenalty::Adaptive {
40        target: 6.0,
41        horizon: 10000,
42    },
43
44    // Generation
45    max_new_tokens: 128,
46    temperature: 0.7,
47    top_p: 0.9,
48
49    ..Default::default()
50};
51
52// Create PPO trainer
53let trainer = PpoTrainer::new(policy, reference, reward_model, config);
54
55// Training loop with prompts
56let prompts = PromptDataset::load("prompts.txt")?;
57for epoch in 0..10 {
58    let stats = trainer.step(&prompts)?;
59    println!("Epoch {}: reward={:.3}, kl={:.3}, loss={:.4}",
60        epoch, stats.mean_reward, stats.kl_div, stats.loss);
61}
62
63trainer.policy().save("rlhf-model.apr")?;
64"#,
65            )
66            .with_related(vec!["rlhf-reward-model", "rlhf-sft", "rlhf-stability"])
67            .with_test_code(
68                r"#[cfg(test)]
69mod tests {
70    #[test]
71    fn test_clip_range_valid() {
72    let clip_range = 0.2_f64;
73    assert!(clip_range > 0.0 && clip_range <= 1.0);
74}
75
76    #[test]
77    fn test_ppo_epochs_positive() {
78    let ppo_epochs = 4;
79    assert!(ppo_epochs > 0);
80}
81
82    #[test]
83    fn test_temperature_positive() {
84    let temperature = 0.7_f64;
85    assert!(temperature > 0.0);
86}
87}",
88            ),
89    );
90
91    // RLHF Stability & Best Practices
92    cookbook.add(
93        Recipe::new("rlhf-stability", "RLHF Stability & Best Practices")
94            .with_problem("Ensure stable RLHF training and avoid common pitfalls")
95            .with_components(vec!["entrenar", "aprender", "trueno-viz"])
96            .with_tags(vec!["rlhf", "stability", "best-practices", "debugging", "monitoring"])
97            .with_code(
98                r#"use entrenar::prelude::*;
99use entrenar::ppo::*;
100
101// === Stability Techniques ===
102
103// 1. Reward normalization (running statistics)
104let config = PpoConfig {
105    reward_normalization: true,
106    reward_clip: 10.0,           // Clip extreme rewards
107    ..Default::default()
108};
109
110// 2. Advantage normalization
111let config = PpoConfig {
112    advantage_normalization: true,
113    ..config
114};
115
116// 3. Value function clipping
117let config = PpoConfig {
118    clip_range_value: 0.2,       // Clip value function updates
119    ..config
120};
121
122// 4. Adaptive KL penalty (InstructGPT style)
123let config = PpoConfig {
124    kl_penalty: KlPenalty::Adaptive {
125        target: 6.0,             // Target KL divergence
126        horizon: 10000,          // Adaptation horizon
127    },
128    ..config
129};
130
131// 5. Gradient clipping
132let config = PpoConfig {
133    max_grad_norm: 1.0,
134    ..config
135};
136
137// === Monitoring ===
138let callback = |stats: &PpoStats| {
139    // Check for reward hacking
140    if stats.mean_reward > 10.0 {
141        println!("Warning: Possible reward hacking!");
142    }
143
144    // Check KL divergence
145    if stats.kl_div > 15.0 {
146        println!("Warning: High KL divergence - policy drifting!");
147    }
148
149    // Check for mode collapse
150    if stats.response_entropy < 0.5 {
151        println!("Warning: Low entropy - possible mode collapse!");
152    }
153};
154
155// === Evaluation ===
156// Always evaluate on held-out prompts
157let eval_results = trainer.evaluate(&eval_prompts)?;
158println!("Win rate vs SFT: {:.2}%", eval_results.win_rate * 100.0);
159println!("Mean length: {:.1}", eval_results.mean_length);
160println!("Diversity: {:.3}", eval_results.diversity);
161"#,
162            )
163            .with_related(vec!["rlhf-ppo", "rlhf-evaluation"])
164            .with_test_code(
165                r"#[cfg(test)]
166mod tests {
167    #[test]
168    fn test_reward_clip_positive() {
169    let reward_clip = 10.0_f64;
170    assert!(reward_clip > 0.0);
171}
172
173    #[test]
174    fn test_grad_norm_positive() {
175    let max_grad_norm = 1.0_f64;
176    assert!(max_grad_norm > 0.0);
177}
178
179    #[test]
180    fn test_kl_target_positive() {
181    let kl_target = 6.0_f64;
182    assert!(kl_target > 0.0);
183}
184}",
185            ),
186    );
187
188    // RLHF Evaluation
189    cookbook.add(
190        Recipe::new("rlhf-evaluation", "RLHF Model Evaluation")
191            .with_problem("Comprehensively evaluate aligned models")
192            .with_components(vec!["entrenar", "aprender", "trueno-viz"])
193            .with_tags(vec!["rlhf", "evaluation", "benchmarks", "metrics", "alignment"])
194            .with_code(
195                r#"use entrenar::prelude::*;
196use entrenar::eval::*;
197
198// Load models to compare
199let sft_model = Model::load("sft-model.apr")?;
200let rlhf_model = Model::load("rlhf-model.apr")?;
201
202// === Pairwise Evaluation ===
203let evaluator = PairwiseEvaluator::new(reward_model);
204let results = evaluator.compare(
205    &sft_model,
206    &rlhf_model,
207    &eval_prompts,
208)?;
209println!("RLHF win rate: {:.2}%", results.model_b_wins * 100.0);
210println!("Tie rate: {:.2}%", results.ties * 100.0);
211
212// === Safety Evaluation ===
213let safety_eval = SafetyEvaluator::new()
214    .add_detector(ToxicityDetector::new())
215    .add_detector(BiasDetector::new())
216    .add_detector(HarmfulContentDetector::new());
217
218let safety_results = safety_eval.evaluate(&rlhf_model, &safety_prompts)?;
219println!("Toxicity rate: {:.3}%", safety_results.toxicity_rate * 100.0);
220println!("Refusal rate: {:.2}%", safety_results.refusal_rate * 100.0);
221
222// === Helpfulness Benchmarks ===
223let benchmarks = vec![
224    ("MT-Bench", MtBench::new()),
225    ("AlpacaEval", AlpacaEval::new()),
226    ("HumanEval", HumanEval::new()),
227];
228
229for (name, bench) in benchmarks {
230    let score = bench.evaluate(&rlhf_model)?;
231    println!("{}: {:.2}", name, score);
232}
233
234// === Diversity Metrics ===
235let diversity = DiversityMetrics::compute(&rlhf_model, &prompts)?;
236println!("Distinct-1: {:.3}", diversity.distinct_1);
237println!("Distinct-2: {:.3}", diversity.distinct_2);
238println!("Self-BLEU: {:.3}", diversity.self_bleu);
239"#,
240            )
241            .with_related(vec!["rlhf-stability", "rlhf-ppo", "rlhf-dpo"])
242            .with_test_code(
243                r"#[cfg(test)]
244mod tests {
245    #[test]
246    fn test_win_rate_in_range() {
247    let win_rate = 0.65_f64;
248    assert!(win_rate >= 0.0 && win_rate <= 1.0);
249}
250
251    #[test]
252    fn test_distinct_n_in_range() {
253    let distinct_2 = 0.82_f64;
254    assert!(distinct_2 >= 0.0 && distinct_2 <= 1.0);
255}
256
257    #[test]
258    fn test_toxicity_rate_in_range() {
259    let toxicity = 0.03_f64;
260    assert!(toxicity >= 0.0 && toxicity <= 1.0);
261}
262}",
263            ),
264    );
265}