Skip to main content

entrenar/efficiency/paradigm/
fine_tune.rs

1//! Fine-tuning methods for model adaptation.
2
3use serde::{Deserialize, Serialize};
4
5/// Fine-tuning methods
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub enum FineTuneMethod {
8    /// Low-Rank Adaptation
9    LoRA {
10        /// Rank of the low-rank matrices
11        rank: u32,
12        /// Scaling factor (alpha)
13        alpha: f32,
14    },
15    /// Quantized LoRA (4-bit base weights)
16    QLoRA {
17        /// Rank of the low-rank matrices
18        rank: u32,
19        /// Quantization bits (typically 4)
20        bits: u8,
21    },
22    /// Adapter layers
23    Adapter,
24    /// Prefix tuning
25    Prefix,
26    /// IA3 (Infused Adapter by Inhibiting and Amplifying Inner Activations)
27    IA3,
28    /// Full fine-tuning (all parameters)
29    Full,
30}
31
32impl FineTuneMethod {
33    /// Create LoRA with default alpha = rank
34    pub fn lora(rank: u32) -> Self {
35        Self::LoRA { rank, alpha: rank as f32 }
36    }
37
38    /// Create QLoRA with 4-bit quantization
39    pub fn qlora(rank: u32) -> Self {
40        Self::QLoRA { rank, bits: 4 }
41    }
42
43    /// Get the memory reduction factor compared to full fine-tuning
44    pub fn memory_reduction_factor(&self) -> f64 {
45        match self {
46            Self::LoRA { rank, .. } => {
47                // LoRA typically uses ~0.1-1% of parameters
48                // Higher rank = more parameters
49                100.0 / f64::from(*rank).max(1.0)
50            }
51            Self::QLoRA { rank, bits } => {
52                // QLoRA: quantized base + low-rank adapters
53                // 4-bit = 8x compression on base, plus LoRA overhead
54                let base_compression = 32.0 / f64::from(*bits);
55                let lora_overhead = f64::from(*rank) * 0.01;
56                base_compression / (1.0 + lora_overhead)
57            }
58            Self::Adapter => 10.0, // ~10% of full
59            Self::Prefix => 20.0,  // ~5% of full
60            Self::IA3 => 50.0,     // ~2% of full
61            Self::Full => 1.0,     // No reduction
62        }
63    }
64
65    /// Get typical trainable parameter percentage
66    pub fn trainable_params_percent(&self) -> f64 {
67        match self {
68            Self::LoRA { rank, .. } => 0.1 * (f64::from(*rank) / 8.0).min(2.0),
69            Self::QLoRA { rank, .. } => 0.1 * (f64::from(*rank) / 8.0).min(2.0),
70            Self::Adapter => 5.0,
71            Self::Prefix => 1.0,
72            Self::IA3 => 0.01,
73            Self::Full => 100.0,
74        }
75    }
76
77    /// Memory multiplier relative to inference-only model size during training.
78    pub fn memory_multiplier(&self) -> f64 {
79        match self {
80            Self::Full => 4.0,
81            Self::LoRA { .. } => 1.2,
82            Self::QLoRA { .. } => 1.1,
83            Self::Adapter => 1.5,
84            Self::Prefix => 1.3,
85            Self::IA3 => 1.1,
86        }
87    }
88
89    /// Training speedup factor relative to training from scratch.
90    pub fn training_speedup(&self) -> f64 {
91        match self {
92            Self::Full => 2.0,
93            Self::LoRA { rank, .. } => 5.0 / (1.0 + (f64::from(*rank) / 64.0)),
94            Self::QLoRA { .. } => 6.0,
95            Self::Adapter => 4.0,
96            Self::Prefix => 5.0,
97            Self::IA3 => 8.0,
98        }
99    }
100
101    /// Expected quality retention as a fraction of full fine-tuning quality.
102    pub fn quality_retention(&self) -> f64 {
103        match self {
104            Self::Full => 1.0,
105            Self::LoRA { rank, .. } => 0.95 + (f64::from(*rank) / 256.0).min(0.05),
106            Self::QLoRA { .. } => 0.93,
107            Self::Adapter => 0.92,
108            Self::Prefix => 0.88,
109            Self::IA3 => 0.90,
110        }
111    }
112
113    /// Recommended batch size multiplier due to lower memory usage.
114    pub fn batch_size_multiplier(&self) -> f64 {
115        match self {
116            Self::Full => 1.0,
117            Self::LoRA { .. } => 2.0,
118            Self::QLoRA { .. } => 4.0,
119            Self::Adapter => 1.5,
120            Self::Prefix => 1.8,
121            Self::IA3 => 3.0,
122        }
123    }
124}
125
126impl std::fmt::Display for FineTuneMethod {
127    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
128        match self {
129            Self::LoRA { rank, alpha } => write!(f, "LoRA(r={rank}, α={alpha})"),
130            Self::QLoRA { rank, bits } => write!(f, "QLoRA(r={rank}, {bits}-bit)"),
131            Self::Adapter => write!(f, "Adapter"),
132            Self::Prefix => write!(f, "Prefix"),
133            Self::IA3 => write!(f, "IA³"),
134            Self::Full => write!(f, "Full"),
135        }
136    }
137}