Skip to main content

axolotl_rs/adapters/
mod.rs

1//! Adapter integration layer.
2//!
3//! This module provides unified access to PEFT adapters (LoRA, QLoRA, etc.)
4//! using either the real peft-rs/qlora-rs crates or mock implementations.
5
6#[cfg(feature = "peft")]
7use std::collections::HashMap;
8#[cfg(feature = "peft")]
9use std::path::Path;
10
11use candle_core::Device;
12use candle_nn::VarMap;
13
14#[cfg(feature = "qlora")]
15use crate::config::QuantizationSettings;
16use crate::config::{AdapterType, AxolotlConfig, LoraSettings};
17use crate::error::{AxolotlError, Result};
18
19// Re-export based on features
20#[cfg(feature = "peft")]
21pub use peft_rs::{
22    Adapter, AdapterConfig, LoraConfig as PeftLoraConfig, LoraLayer, Mergeable, PeftModel,
23    SaveLoad, Trainable,
24};
25
26#[cfg(feature = "qlora")]
27pub use qlora_rs::{QLoraConfig, QLoraLayer, QuantizationConfig, QuantizedLinear, QuantizedTensor};
28
29/// Unified adapter wrapper that works with both real and mock implementations.
30pub struct AdapterWrapper {
31    /// The type of adapter being used
32    pub adapter_type: AdapterType,
33    /// Whether quantization is enabled
34    pub quantized: bool,
35    /// Trainable parameters (LoRA weights)
36    pub trainable_params: VarMap,
37    /// Device where adapter is loaded
38    pub device: Device,
39}
40
41impl AdapterWrapper {
42    /// Create a new adapter based on configuration.
43    ///
44    /// # Arguments
45    /// * `config` - The axolotl configuration
46    /// * `device` - Device to create adapter on
47    ///
48    /// # Errors
49    /// Returns an error if the adapter cannot be created.
50    pub fn new(config: &AxolotlConfig, device: &Device) -> Result<Self> {
51        let trainable_params = VarMap::new();
52
53        match config.adapter {
54            AdapterType::None => Ok(Self {
55                adapter_type: AdapterType::None,
56                quantized: false,
57                trainable_params,
58                device: device.clone(),
59            }),
60            AdapterType::Lora => {
61                tracing::info!(
62                    "Creating LoRA adapter with r={}, alpha={}",
63                    config.lora.r,
64                    config.lora.alpha
65                );
66                Ok(Self {
67                    adapter_type: AdapterType::Lora,
68                    quantized: false,
69                    trainable_params,
70                    device: device.clone(),
71                })
72            }
73            AdapterType::Qlora => {
74                if config.quantization.is_none() {
75                    return Err(AxolotlError::Config(
76                        "QLoRA requires quantization settings".into(),
77                    ));
78                }
79                tracing::info!(
80                    "Creating QLoRA adapter with r={}, alpha={}, quantization enabled",
81                    config.lora.r,
82                    config.lora.alpha
83                );
84                Ok(Self {
85                    adapter_type: AdapterType::Qlora,
86                    quantized: true,
87                    trainable_params,
88                    device: device.clone(),
89                })
90            }
91        }
92    }
93
94    /// Convert axolotl LoRA settings to peft-rs config.
95    #[cfg(feature = "peft")]
96    pub fn to_peft_lora_config(settings: &LoraSettings) -> PeftLoraConfig {
97        PeftLoraConfig {
98            r: settings.r,
99            alpha: settings.alpha,
100            dropout: settings.dropout,
101            target_modules: settings.target_modules.clone(),
102            ..Default::default()
103        }
104    }
105
106    /// Convert axolotl quantization settings to qlora-rs config.
107    #[cfg(feature = "qlora")]
108    pub fn to_qlora_config(
109        lora: &LoraSettings,
110        quant: &QuantizationSettings,
111    ) -> Result<QLoraConfig> {
112        let quant_config = QuantizationConfig {
113            block_size: quant.block_size,
114            double_quant: quant.double_quant,
115            ..Default::default()
116        };
117
118        let lora_config = Self::to_peft_lora_config(lora);
119
120        Ok(QLoraConfig {
121            lora: lora_config,
122            quantization: quant_config,
123            target_modules: lora.target_modules.clone(),
124            cache_dequantized: false, // On-the-fly dequant for training
125        })
126    }
127
128    /// Get the number of trainable parameters.
129    pub fn trainable_param_count(&self) -> usize {
130        self.trainable_params
131            .all_vars()
132            .iter()
133            .map(|v| v.elem_count())
134            .sum()
135    }
136
137    /// Apply adapter to a linear layer, returning a wrapped layer.
138    #[cfg(feature = "peft")]
139    pub fn wrap_linear(
140        &self,
141        in_features: usize,
142        out_features: usize,
143        lora_config: &PeftLoraConfig,
144        vb: candle_nn::VarBuilder,
145    ) -> Result<LoraLayer> {
146        LoraLayer::new(in_features, out_features, lora_config.clone(), vb)
147            .map_err(|e| AxolotlError::Model(format!("Failed to create LoRA layer: {}", e)))
148    }
149
150    /// Save adapter weights to a directory.
151    ///
152    /// # Arguments
153    /// * `path` - Directory to save adapter files to
154    /// * `lora_config` - LoRA configuration to save
155    /// * `layers` - Map of layer names to LoraLayer instances
156    ///
157    /// # Errors
158    /// Returns error if saving fails.
159    #[cfg(feature = "peft")]
160    pub fn save_adapter<P: AsRef<Path>>(
161        &self,
162        path: P,
163        lora_config: &PeftLoraConfig,
164        layers: &HashMap<String, LoraLayer>,
165    ) -> Result<()> {
166        use candle_core::Tensor;
167        // Import SaveLoad trait to enable state_dict() method
168        use crate::adapters::SaveLoad;
169
170        let dir = path.as_ref();
171        std::fs::create_dir_all(dir)?;
172
173        // Collect all adapter weights into a single state dict
174        let mut all_tensors: Vec<(String, Tensor)> = Vec::new();
175
176        for (name, layer) in layers {
177            let state = layer.state_dict().map_err(|e| {
178                AxolotlError::Model(format!("Failed to get state dict for {}: {}", name, e))
179            })?;
180
181            for (key, tensor) in state {
182                all_tensors.push((format!("{}.{}", name, key), tensor));
183            }
184        }
185
186        // Save weights to safetensors
187        let weights_path = dir.join("adapter_model.safetensors");
188        let tensors_ref: Vec<(&str, Tensor)> = all_tensors
189            .iter()
190            .map(|(name, tensor)| (name.as_str(), tensor.clone()))
191            .collect();
192
193        safetensors::tensor::serialize_to_file(tensors_ref, &None, &weights_path).map_err(|e| {
194            AxolotlError::Checkpoint(format!("Failed to save adapter weights: {}", e))
195        })?;
196
197        // Save config to JSON
198        let config_path = dir.join("adapter_config.json");
199        let config_json = serde_json::to_string_pretty(lora_config).map_err(|e| {
200            AxolotlError::Checkpoint(format!("Failed to serialize adapter config: {}", e))
201        })?;
202        std::fs::write(&config_path, config_json)?;
203
204        tracing::info!("Saved adapter with {} layers to {:?}", layers.len(), dir);
205        Ok(())
206    }
207
208    /// Load adapter weights from a directory.
209    ///
210    /// # Arguments
211    /// * `path` - Directory containing adapter files
212    ///
213    /// # Returns
214    /// Tuple of (LoraConfig, HashMap of tensors)
215    ///
216    /// # Errors
217    /// Returns error if loading fails.
218    #[cfg(feature = "peft")]
219    pub fn load_adapter<P: AsRef<Path>>(
220        &self,
221        path: P,
222    ) -> Result<(PeftLoraConfig, HashMap<String, candle_core::Tensor>)> {
223        let dir = path.as_ref();
224
225        // Load config
226        let config_path = dir.join("adapter_config.json");
227        let config_json = std::fs::read_to_string(&config_path).map_err(|e| {
228            AxolotlError::Checkpoint(format!("Failed to read adapter config: {}", e))
229        })?;
230        let config: PeftLoraConfig = serde_json::from_str(&config_json).map_err(|e| {
231            AxolotlError::Checkpoint(format!("Failed to parse adapter config: {}", e))
232        })?;
233
234        // Load weights
235        let weights_path = dir.join("adapter_model.safetensors");
236        let tensors = candle_core::safetensors::load(&weights_path, &self.device).map_err(|e| {
237            AxolotlError::Checkpoint(format!("Failed to load adapter weights: {}", e))
238        })?;
239
240        tracing::info!(
241            "Loaded adapter with {} tensors from {:?}",
242            tensors.len(),
243            dir
244        );
245        Ok((config, tensors))
246    }
247}
248
249/// Configuration for applying adapters to a model.
250#[derive(Debug, Clone)]
251pub struct AdapterApplicationConfig {
252    /// Target module patterns (e.g., "q_proj", "v_proj")
253    pub target_modules: Vec<String>,
254    /// LoRA rank
255    pub r: usize,
256    /// LoRA alpha scaling
257    pub alpha: usize,
258    /// LoRA dropout
259    pub dropout: f32,
260}
261
262impl From<&LoraSettings> for AdapterApplicationConfig {
263    fn from(settings: &LoraSettings) -> Self {
264        Self {
265            target_modules: settings.target_modules.clone(),
266            r: settings.r,
267            alpha: settings.alpha,
268            dropout: settings.dropout as f32,
269        }
270    }
271}
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_adapter_wrapper_creation() {
278        let mut config = AxolotlConfig::from_preset("llama2-7b").unwrap();
279        // Override to use LoRA (default preset uses QLoRA)
280        config.adapter = AdapterType::Lora;
281        config.quantization = None;
282        let device = Device::Cpu;
283
284        let wrapper = AdapterWrapper::new(&config, &device).unwrap();
285        assert_eq!(wrapper.adapter_type, AdapterType::Lora);
286        assert!(!wrapper.quantized);
287    }
288
289    #[test]
290    fn test_adapter_wrapper_creation_qlora() {
291        let config = AxolotlConfig::from_preset("llama2-7b").unwrap();
292        let device = Device::Cpu;
293
294        let wrapper = AdapterWrapper::new(&config, &device).unwrap();
295        assert_eq!(wrapper.adapter_type, AdapterType::Qlora);
296        assert!(wrapper.quantized);
297    }
298
299    #[test]
300    fn test_adapter_application_config_from_lora_settings() {
301        let lora_settings = LoraSettings {
302            r: 16,
303            alpha: 32,
304            dropout: 0.1,
305            target_modules: vec!["q_proj".into(), "v_proj".into()],
306        };
307
308        let app_config: AdapterApplicationConfig = (&lora_settings).into();
309        assert_eq!(app_config.r, 16);
310        assert_eq!(app_config.alpha, 32);
311        assert!((app_config.dropout - 0.1).abs() < 0.001);
312    }
313
314    #[cfg(feature = "peft")]
315    #[test]
316    fn test_to_peft_lora_config() {
317        let lora_settings = LoraSettings {
318            r: 8,
319            alpha: 16,
320            dropout: 0.05,
321            target_modules: vec!["q_proj".into()],
322        };
323
324        let peft_config = AdapterWrapper::to_peft_lora_config(&lora_settings);
325        assert_eq!(peft_config.r, 8);
326        assert_eq!(peft_config.alpha, 16);
327        assert!((peft_config.dropout - 0.05).abs() < 0.001);
328    }
329
330    #[cfg(feature = "peft")]
331    #[test]
332    fn test_adapter_save_and_load() {
333        use candle_nn::VarBuilder;
334        use tempfile::TempDir;
335
336        let temp_dir = TempDir::new().unwrap();
337        let device = Device::Cpu;
338
339        // Create adapter wrapper
340        let config = AxolotlConfig::from_preset("llama2-7b").unwrap();
341        let wrapper = AdapterWrapper::new(&config, &device).unwrap();
342
343        // Create a test LoRA layer
344        let lora_config = PeftLoraConfig {
345            r: 8,
346            alpha: 16,
347            dropout: 0.0,
348            target_modules: vec!["q_proj".into()],
349            ..Default::default()
350        };
351
352        let vb = VarBuilder::zeros(candle_core::DType::F32, &device);
353        let layer = LoraLayer::new(768, 768, lora_config.clone(), vb).unwrap();
354
355        // Save adapter
356        let mut layers = HashMap::new();
357        layers.insert("model.layers.0.self_attn.q_proj".to_string(), layer);
358
359        wrapper
360            .save_adapter(temp_dir.path(), &lora_config, &layers)
361            .unwrap();
362
363        // Verify files exist
364        assert!(temp_dir.path().join("adapter_model.safetensors").exists());
365        assert!(temp_dir.path().join("adapter_config.json").exists());
366
367        // Load adapter
368        let (loaded_config, loaded_tensors) = wrapper.load_adapter(temp_dir.path()).unwrap();
369        assert_eq!(loaded_config.r, 8);
370        assert_eq!(loaded_config.alpha, 16);
371        assert!(!loaded_tensors.is_empty());
372    }
373}