batuta/oracle/cookbook/
recipes_rlhf_efficiency.rs1use super::Recipe;
6
7pub fn register_rlhf_efficiency_recipes(cookbook: &mut super::Cookbook) {
8 cookbook.add(
10 Recipe::new("rlhf-quantization", "Quantization for Efficient Alignment")
11 .with_problem("Apply quantization techniques throughout the alignment pipeline")
12 .with_components(vec!["entrenar", "aprender", "realizar"])
13 .with_tags(vec!["rlhf", "quantization", "4bit", "8bit", "efficient", "memory"])
14 .with_code(
15 r#"use entrenar::prelude::*;
16use entrenar::quantization::*;
17
18// === QLoRA for SFT ===
19let model = Model::load_quantized("llama-7b.q4_k.gguf")?;
20let model = model.with_qlora(QLoraConfig {
21 lora: LoraConfig { r: 64, alpha: 16, ..Default::default() },
22 nf4: true,
23 double_quant: true,
24 ..Default::default()
25})?;
26
27// SFT on 24GB GPU
28let trainer = SftTrainer::new(model, SftConfig::default());
29trainer.train(&dataset, 3)?;
30
31// === Quantized Reward Model ===
32let reward_model = RewardModel::load("reward-model.apr")?;
33let quantized_rm = reward_model.quantize(Quantization::Int8)?;
34// 2x faster inference, minimal accuracy loss
35
36// === INT8 PPO Training ===
37let config = PpoConfig {
38 // Use 8-bit Adam optimizer (saves 75% optimizer memory)
39 optimizer: Optimizer::Adam8bit {
40 lr: 1e-6,
41 betas: (0.9, 0.999),
42 },
43 // Mixed precision training
44 mixed_precision: MixedPrecision::Bf16,
45 // Gradient checkpointing
46 gradient_checkpointing: true,
47 ..Default::default()
48};
49
50// === Post-Training Quantization ===
51let rlhf_model = Model::load("rlhf-model.apr")?;
52
53// GPTQ quantization (4-bit, minimal quality loss)
54let gptq_model = rlhf_model.quantize_gptq(GptqConfig {
55 bits: 4,
56 group_size: 128,
57 calibration_data: &calibration_samples,
58})?;
59gptq_model.save("rlhf-model.q4.gguf")?;
60
61// AWQ quantization (activation-aware)
62let awq_model = rlhf_model.quantize_awq(AwqConfig {
63 bits: 4,
64 group_size: 128,
65})?;
66"#,
67 )
68 .with_related(vec!["training-qlora", "rlhf-sft", "rlhf-ppo"])
69 .with_test_code(
70 r"#[cfg(test)]
71mod tests {
72 #[test]
73 fn test_quantization_bits_valid() {
74 let bits = 4;
75 assert!(bits == 4 || bits == 8);
76}
77
78 #[test]
79 fn test_group_size_positive() {
80 let group_size = 128;
81 assert!(group_size > 0);
82}
83
84 #[test]
85 fn test_mixed_precision_flag() {
86 let use_bf16 = true;
87 assert!(use_bf16);
88}
89}",
90 ),
91 );
92
93 cookbook.add(
95 Recipe::new("rlhf-peft", "PEFT Adapters for Alignment")
96 .with_problem("Use parameter-efficient methods beyond LoRA for alignment")
97 .with_components(vec!["entrenar", "aprender"])
98 .with_tags(vec!["rlhf", "peft", "lora", "adapters", "efficient", "fine-tuning"])
99 .with_code(
100 r#"use entrenar::prelude::*;
101use entrenar::peft::*;
102
103// === LoRA (Low-Rank Adaptation) ===
104let lora = LoraConfig {
105 r: 16,
106 alpha: 32,
107 dropout: 0.1,
108 target_modules: vec!["q_proj", "v_proj", "k_proj", "o_proj"],
109 ..Default::default()
110};
111
112// === DoRA (Weight-Decomposed LoRA) ===
113// Decomposes weights into magnitude and direction
114let dora = DoraConfig {
115 r: 16,
116 alpha: 32,
117 use_dora: true, // Enable magnitude learning
118 ..Default::default()
119};
120
121// === AdaLoRA (Adaptive LoRA) ===
122// Dynamically allocates rank budget
123let adalora = AdaLoraConfig {
124 init_r: 12,
125 target_r: 8,
126 beta1: 0.85,
127 beta2: 0.85,
128 ..Default::default()
129};
130
131// === IA3 (Infused Adapter by Inhibiting and Amplifying) ===
132// Even more efficient than LoRA
133let ia3 = Ia3Config {
134 target_modules: vec!["k_proj", "v_proj", "down_proj"],
135 feedforward_modules: vec!["down_proj"],
136};
137
138// === Prefix Tuning ===
139let prefix = PrefixTuningConfig {
140 num_virtual_tokens: 20,
141 encoder_hidden_size: 512,
142};
143
144// === Apply adapter to model ===
145let model = Model::load("llama-7b.apr")?;
146let model = model.with_adapter(AdapterConfig::LoRA(lora))?;
147
148// Train with any alignment method
149let trainer = DpoTrainer::new(model, reference, DpoConfig::default());
150trainer.train(&dataset, 1)?;
151
152// Save only adapter weights (small file)
153model.save_adapter("dpo-adapter.lora")?;
154
155// Merge adapters for inference
156let merged = model.merge_adapter()?;
157merged.save("dpo-merged.apr")?;
158"#,
159 )
160 .with_related(vec!["training-lora", "training-qlora", "rlhf-sft"])
161 .with_test_code(
162 r#"#[cfg(test)]
163mod tests {
164 #[test]
165 fn test_lora_rank_positive() {
166 let rank = 16;
167 assert!(rank > 0);
168}
169
170 #[test]
171 fn test_target_modules_non_empty() {
172 let target_modules = vec!["q_proj", "v_proj"];
173 assert!(!target_modules.is_empty());
174}
175
176 #[test]
177 fn test_prefix_tuning_virtual_tokens_positive() {
178 let num_virtual_tokens = 20;
179 assert!(num_virtual_tokens > 0);
180}
181}"#,
182 ),
183 );
184}