metal_candle/training/adapter.rs
1//! `LoRA` adapter for applying low-rank adaptation to model layers.
2//!
3//! This module provides functionality to inject `LoRA` layers into existing
4//! transformer models, enabling efficient fine-tuning with a small number
5//! of trainable parameters.
6
7use super::lora::{LoRAConfig, LoRALayer};
8use crate::error::Result;
9use candle_core::{Device, Tensor};
10use std::collections::HashMap;
11
12/// Target modules for `LoRA` adaptation.
13///
14/// Specifies which layers in the model should have `LoRA` applied.
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub enum TargetModule {
17 /// Query projection in attention
18 QProj,
19 /// Key projection in attention
20 KProj,
21 /// Value projection in attention
22 VProj,
23 /// Output projection in attention
24 OProj,
25 /// Gate projection in MLP
26 GateProj,
27 /// Up projection in MLP
28 UpProj,
29 /// Down projection in MLP
30 DownProj,
31}
32
33impl TargetModule {
34 /// Returns the canonical name for this module.
35 #[must_use]
36 pub const fn name(&self) -> &'static str {
37 match self {
38 Self::QProj => "q_proj",
39 Self::KProj => "k_proj",
40 Self::VProj => "v_proj",
41 Self::OProj => "o_proj",
42 Self::GateProj => "gate_proj",
43 Self::UpProj => "up_proj",
44 Self::DownProj => "down_proj",
45 }
46 }
47
48 /// Parses a module name into a `TargetModule`.
49 #[must_use]
50 pub fn from_name(name: &str) -> Option<Self> {
51 match name {
52 "q_proj" => Some(Self::QProj),
53 "k_proj" => Some(Self::KProj),
54 "v_proj" => Some(Self::VProj),
55 "o_proj" => Some(Self::OProj),
56 "gate_proj" => Some(Self::GateProj),
57 "up_proj" => Some(Self::UpProj),
58 "down_proj" => Some(Self::DownProj),
59 _ => None,
60 }
61 }
62}
63
64/// Configuration for `LoRA` adapter.
65///
66/// Specifies which layers to apply `LoRA` to and the `LoRA` hyperparameters.
67///
68/// # Examples
69///
70/// ```
71/// use metal_candle::training::{LoRAAdapterConfig, TargetModule};
72///
73/// // Apply LoRA to Q and V projections only (common choice)
74/// let config = LoRAAdapterConfig {
75/// rank: 8,
76/// alpha: 16.0,
77/// dropout: 0.0,
78/// target_modules: vec![TargetModule::QProj, TargetModule::VProj],
79/// };
80/// ```
81#[derive(Debug, Clone)]
82pub struct LoRAAdapterConfig {
83 /// Rank of the low-rank decomposition
84 pub rank: usize,
85
86 /// Scaling factor for `LoRA` updates
87 pub alpha: f32,
88
89 /// Dropout probability
90 pub dropout: f32,
91
92 /// Which modules to apply `LoRA` to
93 pub target_modules: Vec<TargetModule>,
94}
95
96impl Default for LoRAAdapterConfig {
97 fn default() -> Self {
98 Self {
99 rank: 8,
100 alpha: 16.0,
101 dropout: 0.0,
102 // By default, apply LoRA to Q and V projections (most common)
103 target_modules: vec![TargetModule::QProj, TargetModule::VProj],
104 }
105 }
106}
107
108impl LoRAAdapterConfig {
109 /// Creates a `LoRAConfig` from this adapter configuration.
110 #[must_use]
111 pub const fn to_lora_config(&self) -> LoRAConfig {
112 LoRAConfig {
113 rank: self.rank,
114 alpha: self.alpha,
115 dropout: self.dropout,
116 }
117 }
118
119 /// Checks if a module is targeted for `LoRA`.
120 #[must_use]
121 pub fn is_target(&self, module: &TargetModule) -> bool {
122 self.target_modules.contains(module)
123 }
124}
125
126/// `LoRA` adapter for a transformer model.
127///
128/// Manages `LoRA` layers applied to specific modules in the model.
129/// Each `LoRA` layer adds a trainable low-rank update to a frozen linear layer.
130///
131/// # Architecture
132///
133/// For a frozen linear layer with weight W:
134/// ```text
135/// output = (W + ΔW) @ input
136/// = W @ input + ΔW @ input
137/// = frozen_output + lora_output
138/// ```
139///
140/// # Examples
141///
142/// ```no_run
143/// use metal_candle::training::{LoRAAdapter, LoRAAdapterConfig, TargetModule};
144/// use candle_core::Device;
145///
146/// # fn example() -> Result<(), Box<dyn std::error::Error>> {
147/// let device = Device::Cpu;
148/// let config = LoRAAdapterConfig::default();
149///
150/// // Create adapter for a model with hidden_size=768
151/// let adapter = LoRAAdapter::new(768, 768, 32, &config, &device)?;
152///
153/// // Get number of trainable parameters
154/// println!("Trainable params: {}", adapter.num_trainable_parameters());
155/// # Ok(())
156/// # }
157/// ```
158#[derive(Debug)]
159pub struct LoRAAdapter {
160 /// `LoRA` layers indexed by (`layer_idx`, `module_name`)
161 layers: HashMap<String, LoRALayer>,
162
163 /// Adapter configuration
164 config: LoRAAdapterConfig,
165
166 /// Number of transformer layers
167 num_layers: usize,
168}
169
170impl LoRAAdapter {
171 /// Creates a new `LoRA` adapter.
172 ///
173 /// # Arguments
174 ///
175 /// * `hidden_size` - Model hidden dimension
176 /// * `intermediate_size` - MLP intermediate dimension (for MLP modules)
177 /// * `num_layers` - Number of transformer layers in the model
178 /// * `config` - Adapter configuration
179 /// * `device` - Device to place tensors on
180 ///
181 /// # Errors
182 ///
183 /// Returns an error if `LoRA` layer creation fails.
184 pub fn new(
185 hidden_size: usize,
186 intermediate_size: usize,
187 num_layers: usize,
188 config: &LoRAAdapterConfig,
189 device: &Device,
190 ) -> Result<Self> {
191 let lora_config = config.to_lora_config();
192 let mut layers = HashMap::new();
193
194 // Create LoRA layers for each target module in each transformer layer
195 for layer_idx in 0..num_layers {
196 for target in &config.target_modules {
197 let (in_features, out_features) = match target {
198 TargetModule::QProj
199 | TargetModule::KProj
200 | TargetModule::VProj
201 | TargetModule::OProj => (hidden_size, hidden_size),
202 TargetModule::GateProj | TargetModule::UpProj => {
203 (hidden_size, intermediate_size)
204 }
205 TargetModule::DownProj => (intermediate_size, hidden_size),
206 };
207
208 let lora_layer = LoRALayer::new(in_features, out_features, &lora_config, device)?;
209
210 let key = format!("layers.{}.{}", layer_idx, target.name());
211 layers.insert(key, lora_layer);
212 }
213 }
214
215 Ok(Self {
216 layers,
217 config: config.clone(),
218 num_layers,
219 })
220 }
221
222 /// Applies `LoRA` to a layer's output.
223 ///
224 /// # Arguments
225 ///
226 /// * `layer_idx` - Index of the transformer layer
227 /// * `module` - Which module (`q_proj`, `v_proj`, etc.)
228 /// * `input` - Input to the linear layer (before frozen projection)
229 ///
230 /// # Returns
231 ///
232 /// The `LoRA` delta to add to the frozen layer output.
233 /// Returns `None` if this layer/module doesn't have `LoRA` applied.
234 ///
235 /// # Errors
236 ///
237 /// Returns an error if the forward pass fails.
238 pub fn forward(
239 &self,
240 layer_idx: usize,
241 module: &TargetModule,
242 input: &Tensor,
243 ) -> Result<Option<Tensor>> {
244 let key = format!("layers.{}.{}", layer_idx, module.name());
245
246 if let Some(lora_layer) = self.layers.get(&key) {
247 let delta = lora_layer.forward(input)?;
248 Ok(Some(delta))
249 } else {
250 Ok(None)
251 }
252 }
253
254 /// Returns the total number of trainable parameters.
255 ///
256 /// This is the sum of parameters in all `LoRA` layers.
257 #[must_use]
258 pub fn num_trainable_parameters(&self) -> usize {
259 self.layers.values().map(LoRALayer::num_parameters).sum()
260 }
261
262 /// Returns the number of frozen (non-trainable) parameters.
263 ///
264 /// This would be all model parameters minus the `LoRA` parameters.
265 /// Note: This requires knowing the model's total parameter count.
266 #[must_use]
267 pub fn num_frozen_parameters(&self, total_model_params: usize) -> usize {
268 total_model_params.saturating_sub(self.num_trainable_parameters())
269 }
270
271 /// Returns the adapter configuration.
272 #[must_use]
273 pub const fn config(&self) -> &LoRAAdapterConfig {
274 &self.config
275 }
276
277 /// Returns the number of transformer layers.
278 #[must_use]
279 pub const fn num_layers(&self) -> usize {
280 self.num_layers
281 }
282
283 /// Returns an iterator over all `LoRA` layers.
284 pub fn layers(&self) -> impl Iterator<Item = (&String, &LoRALayer)> {
285 self.layers.iter()
286 }
287
288 /// Gets a specific `LoRA` layer by key.
289 #[must_use]
290 pub fn get_layer(&self, layer_idx: usize, module: &TargetModule) -> Option<&LoRALayer> {
291 let key = format!("layers.{}.{}", layer_idx, module.name());
292 self.layers.get(&key)
293 }
294
295 /// Merges `LoRA` weights back into the base model weights.
296 ///
297 /// Computes: `W_new = W_base + (B @ A) * scaling`
298 ///
299 /// This is useful for inference after training, as it eliminates
300 /// the overhead of separate `LoRA` computation.
301 ///
302 /// # Arguments
303 ///
304 /// * `base_weight` - The frozen base weight matrix (`out_features`, `in_features`)
305 /// * `layer_idx` - Index of the transformer layer
306 /// * `module` - Which module to merge
307 ///
308 /// # Returns
309 ///
310 /// The merged weight matrix, or the original if no `LoRA` is applied.
311 ///
312 /// # Errors
313 ///
314 /// Returns an error if tensor operations fail.
315 pub fn merge_weights(
316 &self,
317 base_weight: &Tensor,
318 layer_idx: usize,
319 module: &TargetModule,
320 ) -> Result<Tensor> {
321 let key = format!("layers.{}.{}", layer_idx, module.name());
322
323 if let Some(lora_layer) = self.layers.get(&key) {
324 // Compute ΔW = B @ A * scaling
325 // NOTE: LoRA matrices are stored in transposed form for optimization:
326 // - lora_a is stored as (in_features, rank) instead of (rank, in_features)
327 // - lora_b is stored as (rank, out_features) instead of (out_features, rank)
328 let lora_a = lora_layer.lora_a_tensor();
329 let lora_b = lora_layer.lora_b_tensor();
330
331 // We need: B_std @ A_std where
332 // - A_std: (rank, in_features) = lora_a^T
333 // - B_std: (out_features, rank) = lora_b^T
334 // Therefore: B_std @ A_std = lora_b^T @ lora_a^T = (lora_a @ lora_b)^T
335
336 // Step 1: lora_a @ lora_b
337 // (in_features, rank) @ (rank, out_features) = (in_features, out_features)
338 let temp = lora_a.matmul(lora_b)?;
339
340 // Step 2: Transpose to get (out_features, in_features)
341 let delta_w = temp.t()?;
342
343 // Scale by alpha/rank
344 let scaling = lora_layer.config().scaling();
345 let scaled_delta = (delta_w * f64::from(scaling))?;
346
347 // W_new = W_base + scaled_delta
348 let merged = base_weight.add(&scaled_delta)?;
349 Ok(merged)
350 } else {
351 // No LoRA for this layer/module, return base weight unchanged
352 Ok(base_weight.clone())
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_target_module_name() {
363 assert_eq!(TargetModule::QProj.name(), "q_proj");
364 assert_eq!(TargetModule::VProj.name(), "v_proj");
365 assert_eq!(TargetModule::GateProj.name(), "gate_proj");
366 }
367
368 #[test]
369 fn test_target_module_from_name() {
370 assert_eq!(TargetModule::from_name("q_proj"), Some(TargetModule::QProj));
371 assert_eq!(TargetModule::from_name("v_proj"), Some(TargetModule::VProj));
372 assert_eq!(TargetModule::from_name("invalid"), None);
373 }
374
375 #[test]
376 fn test_lora_adapter_config_default() {
377 let config = LoRAAdapterConfig::default();
378 assert_eq!(config.rank, 8);
379 assert!((f64::from(config.alpha) - 16.0).abs() < 1e-7);
380 assert_eq!(config.target_modules.len(), 2);
381 assert!(config.is_target(&TargetModule::QProj));
382 assert!(config.is_target(&TargetModule::VProj));
383 assert!(!config.is_target(&TargetModule::KProj));
384 }
385
386 #[test]
387 fn test_lora_adapter_creation() {
388 let device = Device::Cpu;
389 let config = LoRAAdapterConfig::default();
390
391 let adapter = LoRAAdapter::new(768, 2048, 4, &config, &device);
392 assert!(adapter.is_ok());
393
394 let adapter = adapter.unwrap();
395 assert_eq!(adapter.num_layers(), 4);
396
397 // Should have LoRA for 2 modules * 4 layers = 8 LoRA layers
398 assert_eq!(adapter.layers.len(), 8);
399 }
400
401 #[test]
402 fn test_lora_adapter_trainable_parameters() {
403 let device = Device::Cpu;
404 let config = LoRAAdapterConfig {
405 rank: 8,
406 target_modules: vec![TargetModule::QProj, TargetModule::VProj],
407 ..Default::default()
408 };
409
410 let adapter = LoRAAdapter::new(768, 2048, 4, &config, &device).unwrap();
411
412 // Each LoRA layer: rank * (in_features + out_features)
413 // q_proj, v_proj: 8 * (768 + 768) = 12,288 params each
414 // Total: 2 modules * 4 layers * 12,288 = 98,304 params
415 assert_eq!(adapter.num_trainable_parameters(), 98_304);
416 }
417
418 #[test]
419 fn test_lora_adapter_forward() {
420 let device = Device::Cpu;
421 let config = LoRAAdapterConfig::default();
422
423 let adapter = LoRAAdapter::new(768, 2048, 2, &config, &device).unwrap();
424
425 // Create input tensor
426 let input = Tensor::randn(0f32, 1f32, (2, 16, 768), &device).unwrap();
427
428 // Forward through layer 0, q_proj (should have LoRA)
429 let output = adapter.forward(0, &TargetModule::QProj, &input);
430 assert!(output.is_ok());
431 assert!(output.unwrap().is_some());
432
433 // Forward through layer 0, k_proj (should NOT have LoRA by default)
434 let output = adapter.forward(0, &TargetModule::KProj, &input);
435 assert!(output.is_ok());
436 assert!(output.unwrap().is_none());
437 }
438
439 #[test]
440 fn test_lora_adapter_get_layer() {
441 let device = Device::Cpu;
442 let config = LoRAAdapterConfig::default();
443
444 let adapter = LoRAAdapter::new(768, 2048, 2, &config, &device).unwrap();
445
446 // Should find q_proj in layer 0
447 assert!(adapter.get_layer(0, &TargetModule::QProj).is_some());
448
449 // Should not find k_proj (not in target modules)
450 assert!(adapter.get_layer(0, &TargetModule::KProj).is_none());
451
452 // Should not find q_proj in layer 5 (only 2 layers)
453 assert!(adapter.get_layer(5, &TargetModule::QProj).is_none());
454 }
455
456 #[test]
457 fn test_lora_adapter_merge_weights() {
458 let device = Device::Cpu;
459 let config = LoRAAdapterConfig::default();
460
461 let adapter = LoRAAdapter::new(768, 2048, 1, &config, &device).unwrap();
462
463 // Create a base weight matrix: (out_features=768, in_features=768)
464 let base_weight = Tensor::zeros((768, 768), candle_core::DType::F32, &device).unwrap();
465
466 // Merge with q_proj (should have LoRA)
467 let merged = adapter.merge_weights(&base_weight, 0, &TargetModule::QProj);
468 assert!(merged.is_ok());
469
470 let merged = merged.unwrap();
471 assert_eq!(merged.dims(), &[768, 768]);
472
473 // Merge with k_proj (no LoRA, should return base_weight)
474 let merged = adapter.merge_weights(&base_weight, 0, &TargetModule::KProj);
475 assert!(merged.is_ok());
476 }
477}