Skip to main content

batuta/oracle/cookbook/
recipes_rlhf_efficiency.rs

1//! RLHF efficiency recipes: Quantization, PEFT Adapters
2//!
3//! Extracted from `register_rlhf_recipes` for TDG compliance (Refs #22).
4
5use super::Recipe;
6
7pub fn register_rlhf_efficiency_recipes(cookbook: &mut super::Cookbook) {
8    // Quantization for Alignment
9    cookbook.add(
10        Recipe::new("rlhf-quantization", "Quantization for Efficient Alignment")
11            .with_problem("Apply quantization techniques throughout the alignment pipeline")
12            .with_components(vec!["entrenar", "aprender", "realizar"])
13            .with_tags(vec!["rlhf", "quantization", "4bit", "8bit", "efficient", "memory"])
14            .with_code(
15                r#"use entrenar::prelude::*;
16use entrenar::quantization::*;
17
18// === QLoRA for SFT ===
19let model = Model::load_quantized("llama-7b.q4_k.gguf")?;
20let model = model.with_qlora(QLoraConfig {
21    lora: LoraConfig { r: 64, alpha: 16, ..Default::default() },
22    nf4: true,
23    double_quant: true,
24    ..Default::default()
25})?;
26
27// SFT on 24GB GPU
28let trainer = SftTrainer::new(model, SftConfig::default());
29trainer.train(&dataset, 3)?;
30
31// === Quantized Reward Model ===
32let reward_model = RewardModel::load("reward-model.apr")?;
33let quantized_rm = reward_model.quantize(Quantization::Int8)?;
34// 2x faster inference, minimal accuracy loss
35
36// === INT8 PPO Training ===
37let config = PpoConfig {
38    // Use 8-bit Adam optimizer (saves 75% optimizer memory)
39    optimizer: Optimizer::Adam8bit {
40        lr: 1e-6,
41        betas: (0.9, 0.999),
42    },
43    // Mixed precision training
44    mixed_precision: MixedPrecision::Bf16,
45    // Gradient checkpointing
46    gradient_checkpointing: true,
47    ..Default::default()
48};
49
50// === Post-Training Quantization ===
51let rlhf_model = Model::load("rlhf-model.apr")?;
52
53// GPTQ quantization (4-bit, minimal quality loss)
54let gptq_model = rlhf_model.quantize_gptq(GptqConfig {
55    bits: 4,
56    group_size: 128,
57    calibration_data: &calibration_samples,
58})?;
59gptq_model.save("rlhf-model.q4.gguf")?;
60
61// AWQ quantization (activation-aware)
62let awq_model = rlhf_model.quantize_awq(AwqConfig {
63    bits: 4,
64    group_size: 128,
65})?;
66"#,
67            )
68            .with_related(vec!["training-qlora", "rlhf-sft", "rlhf-ppo"])
69            .with_test_code(
70                r"#[cfg(test)]
71mod tests {
72    #[test]
73    fn test_quantization_bits_valid() {
74    let bits = 4;
75    assert!(bits == 4 || bits == 8);
76}
77
78    #[test]
79    fn test_group_size_positive() {
80    let group_size = 128;
81    assert!(group_size > 0);
82}
83
84    #[test]
85    fn test_mixed_precision_flag() {
86    let use_bf16 = true;
87    assert!(use_bf16);
88}
89}",
90            ),
91    );
92
93    // PEFT Adapters
94    cookbook.add(
95        Recipe::new("rlhf-peft", "PEFT Adapters for Alignment")
96            .with_problem("Use parameter-efficient methods beyond LoRA for alignment")
97            .with_components(vec!["entrenar", "aprender"])
98            .with_tags(vec!["rlhf", "peft", "lora", "adapters", "efficient", "fine-tuning"])
99            .with_code(
100                r#"use entrenar::prelude::*;
101use entrenar::peft::*;
102
103// === LoRA (Low-Rank Adaptation) ===
104let lora = LoraConfig {
105    r: 16,
106    alpha: 32,
107    dropout: 0.1,
108    target_modules: vec!["q_proj", "v_proj", "k_proj", "o_proj"],
109    ..Default::default()
110};
111
112// === DoRA (Weight-Decomposed LoRA) ===
113// Decomposes weights into magnitude and direction
114let dora = DoraConfig {
115    r: 16,
116    alpha: 32,
117    use_dora: true,              // Enable magnitude learning
118    ..Default::default()
119};
120
121// === AdaLoRA (Adaptive LoRA) ===
122// Dynamically allocates rank budget
123let adalora = AdaLoraConfig {
124    init_r: 12,
125    target_r: 8,
126    beta1: 0.85,
127    beta2: 0.85,
128    ..Default::default()
129};
130
131// === IA3 (Infused Adapter by Inhibiting and Amplifying) ===
132// Even more efficient than LoRA
133let ia3 = Ia3Config {
134    target_modules: vec!["k_proj", "v_proj", "down_proj"],
135    feedforward_modules: vec!["down_proj"],
136};
137
138// === Prefix Tuning ===
139let prefix = PrefixTuningConfig {
140    num_virtual_tokens: 20,
141    encoder_hidden_size: 512,
142};
143
144// === Apply adapter to model ===
145let model = Model::load("llama-7b.apr")?;
146let model = model.with_adapter(AdapterConfig::LoRA(lora))?;
147
148// Train with any alignment method
149let trainer = DpoTrainer::new(model, reference, DpoConfig::default());
150trainer.train(&dataset, 1)?;
151
152// Save only adapter weights (small file)
153model.save_adapter("dpo-adapter.lora")?;
154
155// Merge adapters for inference
156let merged = model.merge_adapter()?;
157merged.save("dpo-merged.apr")?;
158"#,
159            )
160            .with_related(vec!["training-lora", "training-qlora", "rlhf-sft"])
161            .with_test_code(
162                r#"#[cfg(test)]
163mod tests {
164    #[test]
165    fn test_lora_rank_positive() {
166    let rank = 16;
167    assert!(rank > 0);
168}
169
170    #[test]
171    fn test_target_modules_non_empty() {
172    let target_modules = vec!["q_proj", "v_proj"];
173    assert!(!target_modules.is_empty());
174}
175
176    #[test]
177    fn test_prefix_tuning_virtual_tokens_positive() {
178    let num_virtual_tokens = 20;
179    assert!(num_virtual_tokens > 0);
180}
181}"#,
182            ),
183    );
184}