metal_candle/training/
adapter.rs

1//! `LoRA` adapter for applying low-rank adaptation to model layers.
2//!
3//! This module provides functionality to inject `LoRA` layers into existing
4//! transformer models, enabling efficient fine-tuning with a small number
5//! of trainable parameters.
6
7use super::lora::{LoRAConfig, LoRALayer};
8use crate::error::Result;
9use candle_core::{Device, Tensor};
10use std::collections::HashMap;
11
12/// Target modules for `LoRA` adaptation.
13///
14/// Specifies which layers in the model should have `LoRA` applied.
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum TargetModule {
17    /// Query projection in attention
18    QProj,
19    /// Key projection in attention
20    KProj,
21    /// Value projection in attention
22    VProj,
23    /// Output projection in attention
24    OProj,
25    /// Gate projection in MLP
26    GateProj,
27    /// Up projection in MLP
28    UpProj,
29    /// Down projection in MLP
30    DownProj,
31}
32
33impl TargetModule {
34    /// Returns the canonical name for this module.
35    #[must_use]
36    pub const fn name(&self) -> &'static str {
37        match self {
38            Self::QProj => "q_proj",
39            Self::KProj => "k_proj",
40            Self::VProj => "v_proj",
41            Self::OProj => "o_proj",
42            Self::GateProj => "gate_proj",
43            Self::UpProj => "up_proj",
44            Self::DownProj => "down_proj",
45        }
46    }
47
48    /// Parses a module name into a `TargetModule`.
49    #[must_use]
50    pub fn from_name(name: &str) -> Option<Self> {
51        match name {
52            "q_proj" => Some(Self::QProj),
53            "k_proj" => Some(Self::KProj),
54            "v_proj" => Some(Self::VProj),
55            "o_proj" => Some(Self::OProj),
56            "gate_proj" => Some(Self::GateProj),
57            "up_proj" => Some(Self::UpProj),
58            "down_proj" => Some(Self::DownProj),
59            _ => None,
60        }
61    }
62}
63
64/// Configuration for `LoRA` adapter.
65///
66/// Specifies which layers to apply `LoRA` to and the `LoRA` hyperparameters.
67///
68/// # Examples
69///
70/// ```
71/// use metal_candle::training::{LoRAAdapterConfig, TargetModule};
72///
73/// // Apply LoRA to Q and V projections only (common choice)
74/// let config = LoRAAdapterConfig {
75///     rank: 8,
76///     alpha: 16.0,
77///     dropout: 0.0,
78///     target_modules: vec![TargetModule::QProj, TargetModule::VProj],
79/// };
80/// ```
81#[derive(Debug, Clone)]
82pub struct LoRAAdapterConfig {
83    /// Rank of the low-rank decomposition
84    pub rank: usize,
85
86    /// Scaling factor for `LoRA` updates
87    pub alpha: f32,
88
89    /// Dropout probability
90    pub dropout: f32,
91
92    /// Which modules to apply `LoRA` to
93    pub target_modules: Vec<TargetModule>,
94}
95
96impl Default for LoRAAdapterConfig {
97    fn default() -> Self {
98        Self {
99            rank: 8,
100            alpha: 16.0,
101            dropout: 0.0,
102            // By default, apply LoRA to Q and V projections (most common)
103            target_modules: vec![TargetModule::QProj, TargetModule::VProj],
104        }
105    }
106}
107
108impl LoRAAdapterConfig {
109    /// Creates a `LoRAConfig` from this adapter configuration.
110    #[must_use]
111    pub const fn to_lora_config(&self) -> LoRAConfig {
112        LoRAConfig {
113            rank: self.rank,
114            alpha: self.alpha,
115            dropout: self.dropout,
116        }
117    }
118
119    /// Checks if a module is targeted for `LoRA`.
120    #[must_use]
121    pub fn is_target(&self, module: &TargetModule) -> bool {
122        self.target_modules.contains(module)
123    }
124}
125
126/// `LoRA` adapter for a transformer model.
127///
128/// Manages `LoRA` layers applied to specific modules in the model.
129/// Each `LoRA` layer adds a trainable low-rank update to a frozen linear layer.
130///
131/// # Architecture
132///
133/// For a frozen linear layer with weight W:
134/// ```text
135/// output = (W + ΔW) @ input
136///        = W @ input + ΔW @ input
137///        = frozen_output + lora_output
138/// ```
139///
140/// # Examples
141///
142/// ```no_run
143/// use metal_candle::training::{LoRAAdapter, LoRAAdapterConfig, TargetModule};
144/// use candle_core::Device;
145///
146/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
147/// let device = Device::Cpu;
148/// let config = LoRAAdapterConfig::default();
149///
150/// // Create adapter for a model with hidden_size=768
151/// let adapter = LoRAAdapter::new(768, 768, 32, &config, &device)?;
152///
153/// // Get number of trainable parameters
154/// println!("Trainable params: {}", adapter.num_trainable_parameters());
155/// # Ok(())
156/// # }
157/// ```
158#[derive(Debug)]
159pub struct LoRAAdapter {
160    /// `LoRA` layers indexed by (`layer_idx`, `module_name`)
161    layers: HashMap<String, LoRALayer>,
162
163    /// Adapter configuration
164    config: LoRAAdapterConfig,
165
166    /// Number of transformer layers
167    num_layers: usize,
168}
169
170impl LoRAAdapter {
171    /// Creates a new `LoRA` adapter.
172    ///
173    /// # Arguments
174    ///
175    /// * `hidden_size` - Model hidden dimension
176    /// * `intermediate_size` - MLP intermediate dimension (for MLP modules)
177    /// * `num_layers` - Number of transformer layers in the model
178    /// * `config` - Adapter configuration
179    /// * `device` - Device to place tensors on
180    ///
181    /// # Errors
182    ///
183    /// Returns an error if `LoRA` layer creation fails.
184    pub fn new(
185        hidden_size: usize,
186        intermediate_size: usize,
187        num_layers: usize,
188        config: &LoRAAdapterConfig,
189        device: &Device,
190    ) -> Result<Self> {
191        let lora_config = config.to_lora_config();
192        let mut layers = HashMap::new();
193
194        // Create LoRA layers for each target module in each transformer layer
195        for layer_idx in 0..num_layers {
196            for target in &config.target_modules {
197                let (in_features, out_features) = match target {
198                    TargetModule::QProj
199                    | TargetModule::KProj
200                    | TargetModule::VProj
201                    | TargetModule::OProj => (hidden_size, hidden_size),
202                    TargetModule::GateProj | TargetModule::UpProj => {
203                        (hidden_size, intermediate_size)
204                    }
205                    TargetModule::DownProj => (intermediate_size, hidden_size),
206                };
207
208                let lora_layer = LoRALayer::new(in_features, out_features, &lora_config, device)?;
209
210                let key = format!("layers.{}.{}", layer_idx, target.name());
211                layers.insert(key, lora_layer);
212            }
213        }
214
215        Ok(Self {
216            layers,
217            config: config.clone(),
218            num_layers,
219        })
220    }
221
222    /// Applies `LoRA` to a layer's output.
223    ///
224    /// # Arguments
225    ///
226    /// * `layer_idx` - Index of the transformer layer
227    /// * `module` - Which module (`q_proj`, `v_proj`, etc.)
228    /// * `input` - Input to the linear layer (before frozen projection)
229    ///
230    /// # Returns
231    ///
232    /// The `LoRA` delta to add to the frozen layer output.
233    /// Returns `None` if this layer/module doesn't have `LoRA` applied.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if the forward pass fails.
238    pub fn forward(
239        &self,
240        layer_idx: usize,
241        module: &TargetModule,
242        input: &Tensor,
243    ) -> Result<Option<Tensor>> {
244        let key = format!("layers.{}.{}", layer_idx, module.name());
245
246        if let Some(lora_layer) = self.layers.get(&key) {
247            let delta = lora_layer.forward(input)?;
248            Ok(Some(delta))
249        } else {
250            Ok(None)
251        }
252    }
253
254    /// Returns the total number of trainable parameters.
255    ///
256    /// This is the sum of parameters in all `LoRA` layers.
257    #[must_use]
258    pub fn num_trainable_parameters(&self) -> usize {
259        self.layers.values().map(LoRALayer::num_parameters).sum()
260    }
261
262    /// Returns the number of frozen (non-trainable) parameters.
263    ///
264    /// This would be all model parameters minus the `LoRA` parameters.
265    /// Note: This requires knowing the model's total parameter count.
266    #[must_use]
267    pub fn num_frozen_parameters(&self, total_model_params: usize) -> usize {
268        total_model_params.saturating_sub(self.num_trainable_parameters())
269    }
270
271    /// Returns the adapter configuration.
272    #[must_use]
273    pub const fn config(&self) -> &LoRAAdapterConfig {
274        &self.config
275    }
276
277    /// Returns the number of transformer layers.
278    #[must_use]
279    pub const fn num_layers(&self) -> usize {
280        self.num_layers
281    }
282
283    /// Returns an iterator over all `LoRA` layers.
284    pub fn layers(&self) -> impl Iterator<Item = (&String, &LoRALayer)> {
285        self.layers.iter()
286    }
287
288    /// Gets a specific `LoRA` layer by key.
289    #[must_use]
290    pub fn get_layer(&self, layer_idx: usize, module: &TargetModule) -> Option<&LoRALayer> {
291        let key = format!("layers.{}.{}", layer_idx, module.name());
292        self.layers.get(&key)
293    }
294
295    /// Merges `LoRA` weights back into the base model weights.
296    ///
297    /// Computes: `W_new = W_base + (B @ A) * scaling`
298    ///
299    /// This is useful for inference after training, as it eliminates
300    /// the overhead of separate `LoRA` computation.
301    ///
302    /// # Arguments
303    ///
304    /// * `base_weight` - The frozen base weight matrix (`out_features`, `in_features`)
305    /// * `layer_idx` - Index of the transformer layer
306    /// * `module` - Which module to merge
307    ///
308    /// # Returns
309    ///
310    /// The merged weight matrix, or the original if no `LoRA` is applied.
311    ///
312    /// # Errors
313    ///
314    /// Returns an error if tensor operations fail.
315    pub fn merge_weights(
316        &self,
317        base_weight: &Tensor,
318        layer_idx: usize,
319        module: &TargetModule,
320    ) -> Result<Tensor> {
321        let key = format!("layers.{}.{}", layer_idx, module.name());
322
323        if let Some(lora_layer) = self.layers.get(&key) {
324            // Compute ΔW = B @ A * scaling
325            // NOTE: LoRA matrices are stored in transposed form for optimization:
326            // - lora_a is stored as (in_features, rank) instead of (rank, in_features)
327            // - lora_b is stored as (rank, out_features) instead of (out_features, rank)
328            let lora_a = lora_layer.lora_a_tensor();
329            let lora_b = lora_layer.lora_b_tensor();
330
331            // We need: B_std @ A_std where
332            // - A_std: (rank, in_features) = lora_a^T
333            // - B_std: (out_features, rank) = lora_b^T
334            // Therefore: B_std @ A_std = lora_b^T @ lora_a^T = (lora_a @ lora_b)^T
335
336            // Step 1: lora_a @ lora_b
337            // (in_features, rank) @ (rank, out_features) = (in_features, out_features)
338            let temp = lora_a.matmul(lora_b)?;
339
340            // Step 2: Transpose to get (out_features, in_features)
341            let delta_w = temp.t()?;
342
343            // Scale by alpha/rank
344            let scaling = lora_layer.config().scaling();
345            let scaled_delta = (delta_w * f64::from(scaling))?;
346
347            // W_new = W_base + scaled_delta
348            let merged = base_weight.add(&scaled_delta)?;
349            Ok(merged)
350        } else {
351            // No LoRA for this layer/module, return base weight unchanged
352            Ok(base_weight.clone())
353        }
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_target_module_name() {
363        assert_eq!(TargetModule::QProj.name(), "q_proj");
364        assert_eq!(TargetModule::VProj.name(), "v_proj");
365        assert_eq!(TargetModule::GateProj.name(), "gate_proj");
366    }
367
368    #[test]
369    fn test_target_module_from_name() {
370        assert_eq!(TargetModule::from_name("q_proj"), Some(TargetModule::QProj));
371        assert_eq!(TargetModule::from_name("v_proj"), Some(TargetModule::VProj));
372        assert_eq!(TargetModule::from_name("invalid"), None);
373    }
374
375    #[test]
376    fn test_lora_adapter_config_default() {
377        let config = LoRAAdapterConfig::default();
378        assert_eq!(config.rank, 8);
379        assert!((f64::from(config.alpha) - 16.0).abs() < 1e-7);
380        assert_eq!(config.target_modules.len(), 2);
381        assert!(config.is_target(&TargetModule::QProj));
382        assert!(config.is_target(&TargetModule::VProj));
383        assert!(!config.is_target(&TargetModule::KProj));
384    }
385
386    #[test]
387    fn test_lora_adapter_creation() {
388        let device = Device::Cpu;
389        let config = LoRAAdapterConfig::default();
390
391        let adapter = LoRAAdapter::new(768, 2048, 4, &config, &device);
392        assert!(adapter.is_ok());
393
394        let adapter = adapter.unwrap();
395        assert_eq!(adapter.num_layers(), 4);
396
397        // Should have LoRA for 2 modules * 4 layers = 8 LoRA layers
398        assert_eq!(adapter.layers.len(), 8);
399    }
400
401    #[test]
402    fn test_lora_adapter_trainable_parameters() {
403        let device = Device::Cpu;
404        let config = LoRAAdapterConfig {
405            rank: 8,
406            target_modules: vec![TargetModule::QProj, TargetModule::VProj],
407            ..Default::default()
408        };
409
410        let adapter = LoRAAdapter::new(768, 2048, 4, &config, &device).unwrap();
411
412        // Each LoRA layer: rank * (in_features + out_features)
413        // q_proj, v_proj: 8 * (768 + 768) = 12,288 params each
414        // Total: 2 modules * 4 layers * 12,288 = 98,304 params
415        assert_eq!(adapter.num_trainable_parameters(), 98_304);
416    }
417
418    #[test]
419    fn test_lora_adapter_forward() {
420        let device = Device::Cpu;
421        let config = LoRAAdapterConfig::default();
422
423        let adapter = LoRAAdapter::new(768, 2048, 2, &config, &device).unwrap();
424
425        // Create input tensor
426        let input = Tensor::randn(0f32, 1f32, (2, 16, 768), &device).unwrap();
427
428        // Forward through layer 0, q_proj (should have LoRA)
429        let output = adapter.forward(0, &TargetModule::QProj, &input);
430        assert!(output.is_ok());
431        assert!(output.unwrap().is_some());
432
433        // Forward through layer 0, k_proj (should NOT have LoRA by default)
434        let output = adapter.forward(0, &TargetModule::KProj, &input);
435        assert!(output.is_ok());
436        assert!(output.unwrap().is_none());
437    }
438
439    #[test]
440    fn test_lora_adapter_get_layer() {
441        let device = Device::Cpu;
442        let config = LoRAAdapterConfig::default();
443
444        let adapter = LoRAAdapter::new(768, 2048, 2, &config, &device).unwrap();
445
446        // Should find q_proj in layer 0
447        assert!(adapter.get_layer(0, &TargetModule::QProj).is_some());
448
449        // Should not find k_proj (not in target modules)
450        assert!(adapter.get_layer(0, &TargetModule::KProj).is_none());
451
452        // Should not find q_proj in layer 5 (only 2 layers)
453        assert!(adapter.get_layer(5, &TargetModule::QProj).is_none());
454    }
455
456    #[test]
457    fn test_lora_adapter_merge_weights() {
458        let device = Device::Cpu;
459        let config = LoRAAdapterConfig::default();
460
461        let adapter = LoRAAdapter::new(768, 2048, 1, &config, &device).unwrap();
462
463        // Create a base weight matrix: (out_features=768, in_features=768)
464        let base_weight = Tensor::zeros((768, 768), candle_core::DType::F32, &device).unwrap();
465
466        // Merge with q_proj (should have LoRA)
467        let merged = adapter.merge_weights(&base_weight, 0, &TargetModule::QProj);
468        assert!(merged.is_ok());
469
470        let merged = merged.unwrap();
471        assert_eq!(merged.dims(), &[768, 768]);
472
473        // Merge with k_proj (no LoRA, should return base_weight)
474        let merged = adapter.merge_weights(&base_weight, 0, &TargetModule::KProj);
475        assert!(merged.is_ok());
476    }
477}