1use super::Recipe;
6
7pub fn register_rlhf_training_recipes(cookbook: &mut super::Cookbook) {
8 cookbook.add(
10 Recipe::new("rlhf-ppo", "PPO for RLHF")
11 .with_problem("Train language models with PPO using a reward model")
12 .with_components(vec!["entrenar", "aprender"])
13 .with_tags(vec!["rlhf", "ppo", "reinforcement-learning", "alignment"])
14 .with_code(
15 r#"use entrenar::prelude::*;
16use entrenar::ppo::*;
17
18// Load models
19let policy = Model::load("sft-model.apr")?;
20let reference = Model::load("sft-model.apr")?; // KL anchor
21let reward_model = RewardModel::load("reward-model.apr")?;
22
23// Configure PPO
24let config = PpoConfig {
25 // Learning
26 learning_rate: 1e-6,
27 batch_size: 64,
28 mini_batch_size: 8,
29 gradient_accumulation: 8,
30
31 // PPO hyperparameters
32 ppo_epochs: 4,
33 clip_range: 0.2,
34 clip_range_value: 0.2,
35 gamma: 1.0,
36 lam: 0.95, // GAE lambda
37
38 // KL control
39 kl_penalty: KlPenalty::Adaptive {
40 target: 6.0,
41 horizon: 10000,
42 },
43
44 // Generation
45 max_new_tokens: 128,
46 temperature: 0.7,
47 top_p: 0.9,
48
49 ..Default::default()
50};
51
52// Create PPO trainer
53let trainer = PpoTrainer::new(policy, reference, reward_model, config);
54
55// Training loop with prompts
56let prompts = PromptDataset::load("prompts.txt")?;
57for epoch in 0..10 {
58 let stats = trainer.step(&prompts)?;
59 println!("Epoch {}: reward={:.3}, kl={:.3}, loss={:.4}",
60 epoch, stats.mean_reward, stats.kl_div, stats.loss);
61}
62
63trainer.policy().save("rlhf-model.apr")?;
64"#,
65 )
66 .with_related(vec!["rlhf-reward-model", "rlhf-sft", "rlhf-stability"])
67 .with_test_code(
68 r"#[cfg(test)]
69mod tests {
70 #[test]
71 fn test_clip_range_valid() {
72 let clip_range = 0.2_f64;
73 assert!(clip_range > 0.0 && clip_range <= 1.0);
74}
75
76 #[test]
77 fn test_ppo_epochs_positive() {
78 let ppo_epochs = 4;
79 assert!(ppo_epochs > 0);
80}
81
82 #[test]
83 fn test_temperature_positive() {
84 let temperature = 0.7_f64;
85 assert!(temperature > 0.0);
86}
87}",
88 ),
89 );
90
91 cookbook.add(
93 Recipe::new("rlhf-stability", "RLHF Stability & Best Practices")
94 .with_problem("Ensure stable RLHF training and avoid common pitfalls")
95 .with_components(vec!["entrenar", "aprender", "trueno-viz"])
96 .with_tags(vec!["rlhf", "stability", "best-practices", "debugging", "monitoring"])
97 .with_code(
98 r#"use entrenar::prelude::*;
99use entrenar::ppo::*;
100
101// === Stability Techniques ===
102
103// 1. Reward normalization (running statistics)
104let config = PpoConfig {
105 reward_normalization: true,
106 reward_clip: 10.0, // Clip extreme rewards
107 ..Default::default()
108};
109
110// 2. Advantage normalization
111let config = PpoConfig {
112 advantage_normalization: true,
113 ..config
114};
115
116// 3. Value function clipping
117let config = PpoConfig {
118 clip_range_value: 0.2, // Clip value function updates
119 ..config
120};
121
122// 4. Adaptive KL penalty (InstructGPT style)
123let config = PpoConfig {
124 kl_penalty: KlPenalty::Adaptive {
125 target: 6.0, // Target KL divergence
126 horizon: 10000, // Adaptation horizon
127 },
128 ..config
129};
130
131// 5. Gradient clipping
132let config = PpoConfig {
133 max_grad_norm: 1.0,
134 ..config
135};
136
137// === Monitoring ===
138let callback = |stats: &PpoStats| {
139 // Check for reward hacking
140 if stats.mean_reward > 10.0 {
141 println!("Warning: Possible reward hacking!");
142 }
143
144 // Check KL divergence
145 if stats.kl_div > 15.0 {
146 println!("Warning: High KL divergence - policy drifting!");
147 }
148
149 // Check for mode collapse
150 if stats.response_entropy < 0.5 {
151 println!("Warning: Low entropy - possible mode collapse!");
152 }
153};
154
155// === Evaluation ===
156// Always evaluate on held-out prompts
157let eval_results = trainer.evaluate(&eval_prompts)?;
158println!("Win rate vs SFT: {:.2}%", eval_results.win_rate * 100.0);
159println!("Mean length: {:.1}", eval_results.mean_length);
160println!("Diversity: {:.3}", eval_results.diversity);
161"#,
162 )
163 .with_related(vec!["rlhf-ppo", "rlhf-evaluation"])
164 .with_test_code(
165 r"#[cfg(test)]
166mod tests {
167 #[test]
168 fn test_reward_clip_positive() {
169 let reward_clip = 10.0_f64;
170 assert!(reward_clip > 0.0);
171}
172
173 #[test]
174 fn test_grad_norm_positive() {
175 let max_grad_norm = 1.0_f64;
176 assert!(max_grad_norm > 0.0);
177}
178
179 #[test]
180 fn test_kl_target_positive() {
181 let kl_target = 6.0_f64;
182 assert!(kl_target > 0.0);
183}
184}",
185 ),
186 );
187
188 cookbook.add(
190 Recipe::new("rlhf-evaluation", "RLHF Model Evaluation")
191 .with_problem("Comprehensively evaluate aligned models")
192 .with_components(vec!["entrenar", "aprender", "trueno-viz"])
193 .with_tags(vec!["rlhf", "evaluation", "benchmarks", "metrics", "alignment"])
194 .with_code(
195 r#"use entrenar::prelude::*;
196use entrenar::eval::*;
197
198// Load models to compare
199let sft_model = Model::load("sft-model.apr")?;
200let rlhf_model = Model::load("rlhf-model.apr")?;
201
202// === Pairwise Evaluation ===
203let evaluator = PairwiseEvaluator::new(reward_model);
204let results = evaluator.compare(
205 &sft_model,
206 &rlhf_model,
207 &eval_prompts,
208)?;
209println!("RLHF win rate: {:.2}%", results.model_b_wins * 100.0);
210println!("Tie rate: {:.2}%", results.ties * 100.0);
211
212// === Safety Evaluation ===
213let safety_eval = SafetyEvaluator::new()
214 .add_detector(ToxicityDetector::new())
215 .add_detector(BiasDetector::new())
216 .add_detector(HarmfulContentDetector::new());
217
218let safety_results = safety_eval.evaluate(&rlhf_model, &safety_prompts)?;
219println!("Toxicity rate: {:.3}%", safety_results.toxicity_rate * 100.0);
220println!("Refusal rate: {:.2}%", safety_results.refusal_rate * 100.0);
221
222// === Helpfulness Benchmarks ===
223let benchmarks = vec![
224 ("MT-Bench", MtBench::new()),
225 ("AlpacaEval", AlpacaEval::new()),
226 ("HumanEval", HumanEval::new()),
227];
228
229for (name, bench) in benchmarks {
230 let score = bench.evaluate(&rlhf_model)?;
231 println!("{}: {:.2}", name, score);
232}
233
234// === Diversity Metrics ===
235let diversity = DiversityMetrics::compute(&rlhf_model, &prompts)?;
236println!("Distinct-1: {:.3}", diversity.distinct_1);
237println!("Distinct-2: {:.3}", diversity.distinct_2);
238println!("Self-BLEU: {:.3}", diversity.self_bleu);
239"#,
240 )
241 .with_related(vec!["rlhf-stability", "rlhf-ppo", "rlhf-dpo"])
242 .with_test_code(
243 r"#[cfg(test)]
244mod tests {
245 #[test]
246 fn test_win_rate_in_range() {
247 let win_rate = 0.65_f64;
248 assert!(win_rate >= 0.0 && win_rate <= 1.0);
249}
250
251 #[test]
252 fn test_distinct_n_in_range() {
253 let distinct_2 = 0.82_f64;
254 assert!(distinct_2 >= 0.0 && distinct_2 <= 1.0);
255}
256
257 #[test]
258 fn test_toxicity_rate_in_range() {
259 let toxicity = 0.03_f64;
260 assert!(toxicity >= 0.0 && toxicity <= 1.0);
261}
262}",
263 ),
264 );
265}