1#[cfg(feature = "peft")]
7use std::collections::HashMap;
8#[cfg(feature = "peft")]
9use std::path::Path;
10
11use candle_core::Device;
12use candle_nn::VarMap;
13
14#[cfg(feature = "qlora")]
15use crate::config::QuantizationSettings;
16use crate::config::{AdapterType, AxolotlConfig, LoraSettings};
17use crate::error::{AxolotlError, Result};
18
19#[cfg(feature = "peft")]
21pub use peft_rs::{
22 Adapter, AdapterConfig, LoraConfig as PeftLoraConfig, LoraLayer, Mergeable, PeftModel,
23 SaveLoad, Trainable,
24};
25
26#[cfg(feature = "qlora")]
27pub use qlora_rs::{QLoraConfig, QLoraLayer, QuantizationConfig, QuantizedLinear, QuantizedTensor};
28
29pub struct AdapterWrapper {
31 pub adapter_type: AdapterType,
33 pub quantized: bool,
35 pub trainable_params: VarMap,
37 pub device: Device,
39}
40
41impl AdapterWrapper {
42 pub fn new(config: &AxolotlConfig, device: &Device) -> Result<Self> {
51 let trainable_params = VarMap::new();
52
53 match config.adapter {
54 AdapterType::None => Ok(Self {
55 adapter_type: AdapterType::None,
56 quantized: false,
57 trainable_params,
58 device: device.clone(),
59 }),
60 AdapterType::Lora => {
61 tracing::info!(
62 "Creating LoRA adapter with r={}, alpha={}",
63 config.lora.r,
64 config.lora.alpha
65 );
66 Ok(Self {
67 adapter_type: AdapterType::Lora,
68 quantized: false,
69 trainable_params,
70 device: device.clone(),
71 })
72 }
73 AdapterType::Qlora => {
74 if config.quantization.is_none() {
75 return Err(AxolotlError::Config(
76 "QLoRA requires quantization settings".into(),
77 ));
78 }
79 tracing::info!(
80 "Creating QLoRA adapter with r={}, alpha={}, quantization enabled",
81 config.lora.r,
82 config.lora.alpha
83 );
84 Ok(Self {
85 adapter_type: AdapterType::Qlora,
86 quantized: true,
87 trainable_params,
88 device: device.clone(),
89 })
90 }
91 }
92 }
93
94 #[cfg(feature = "peft")]
96 pub fn to_peft_lora_config(settings: &LoraSettings) -> PeftLoraConfig {
97 PeftLoraConfig {
98 r: settings.r,
99 alpha: settings.alpha,
100 dropout: settings.dropout,
101 target_modules: settings.target_modules.clone(),
102 ..Default::default()
103 }
104 }
105
106 #[cfg(feature = "qlora")]
108 pub fn to_qlora_config(
109 lora: &LoraSettings,
110 quant: &QuantizationSettings,
111 ) -> Result<QLoraConfig> {
112 let quant_config = QuantizationConfig {
113 block_size: quant.block_size,
114 double_quant: quant.double_quant,
115 ..Default::default()
116 };
117
118 let lora_config = Self::to_peft_lora_config(lora);
119
120 Ok(QLoraConfig {
121 lora: lora_config,
122 quantization: quant_config,
123 target_modules: lora.target_modules.clone(),
124 cache_dequantized: false, })
126 }
127
128 pub fn trainable_param_count(&self) -> usize {
130 self.trainable_params
131 .all_vars()
132 .iter()
133 .map(|v| v.elem_count())
134 .sum()
135 }
136
137 #[cfg(feature = "peft")]
139 pub fn wrap_linear(
140 &self,
141 in_features: usize,
142 out_features: usize,
143 lora_config: &PeftLoraConfig,
144 vb: candle_nn::VarBuilder,
145 ) -> Result<LoraLayer> {
146 LoraLayer::new(in_features, out_features, lora_config.clone(), vb)
147 .map_err(|e| AxolotlError::Model(format!("Failed to create LoRA layer: {}", e)))
148 }
149
150 #[cfg(feature = "peft")]
160 pub fn save_adapter<P: AsRef<Path>>(
161 &self,
162 path: P,
163 lora_config: &PeftLoraConfig,
164 layers: &HashMap<String, LoraLayer>,
165 ) -> Result<()> {
166 use candle_core::Tensor;
167 use crate::adapters::SaveLoad;
169
170 let dir = path.as_ref();
171 std::fs::create_dir_all(dir)?;
172
173 let mut all_tensors: Vec<(String, Tensor)> = Vec::new();
175
176 for (name, layer) in layers {
177 let state = layer.state_dict().map_err(|e| {
178 AxolotlError::Model(format!("Failed to get state dict for {}: {}", name, e))
179 })?;
180
181 for (key, tensor) in state {
182 all_tensors.push((format!("{}.{}", name, key), tensor));
183 }
184 }
185
186 let weights_path = dir.join("adapter_model.safetensors");
188 let tensors_ref: Vec<(&str, Tensor)> = all_tensors
189 .iter()
190 .map(|(name, tensor)| (name.as_str(), tensor.clone()))
191 .collect();
192
193 safetensors::tensor::serialize_to_file(tensors_ref, &None, &weights_path).map_err(|e| {
194 AxolotlError::Checkpoint(format!("Failed to save adapter weights: {}", e))
195 })?;
196
197 let config_path = dir.join("adapter_config.json");
199 let config_json = serde_json::to_string_pretty(lora_config).map_err(|e| {
200 AxolotlError::Checkpoint(format!("Failed to serialize adapter config: {}", e))
201 })?;
202 std::fs::write(&config_path, config_json)?;
203
204 tracing::info!("Saved adapter with {} layers to {:?}", layers.len(), dir);
205 Ok(())
206 }
207
208 #[cfg(feature = "peft")]
219 pub fn load_adapter<P: AsRef<Path>>(
220 &self,
221 path: P,
222 ) -> Result<(PeftLoraConfig, HashMap<String, candle_core::Tensor>)> {
223 let dir = path.as_ref();
224
225 let config_path = dir.join("adapter_config.json");
227 let config_json = std::fs::read_to_string(&config_path).map_err(|e| {
228 AxolotlError::Checkpoint(format!("Failed to read adapter config: {}", e))
229 })?;
230 let config: PeftLoraConfig = serde_json::from_str(&config_json).map_err(|e| {
231 AxolotlError::Checkpoint(format!("Failed to parse adapter config: {}", e))
232 })?;
233
234 let weights_path = dir.join("adapter_model.safetensors");
236 let tensors = candle_core::safetensors::load(&weights_path, &self.device).map_err(|e| {
237 AxolotlError::Checkpoint(format!("Failed to load adapter weights: {}", e))
238 })?;
239
240 tracing::info!(
241 "Loaded adapter with {} tensors from {:?}",
242 tensors.len(),
243 dir
244 );
245 Ok((config, tensors))
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct AdapterApplicationConfig {
252 pub target_modules: Vec<String>,
254 pub r: usize,
256 pub alpha: usize,
258 pub dropout: f32,
260}
261
262impl From<&LoraSettings> for AdapterApplicationConfig {
263 fn from(settings: &LoraSettings) -> Self {
264 Self {
265 target_modules: settings.target_modules.clone(),
266 r: settings.r,
267 alpha: settings.alpha,
268 dropout: settings.dropout as f32,
269 }
270 }
271}
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_adapter_wrapper_creation() {
278 let mut config = AxolotlConfig::from_preset("llama2-7b").unwrap();
279 config.adapter = AdapterType::Lora;
281 config.quantization = None;
282 let device = Device::Cpu;
283
284 let wrapper = AdapterWrapper::new(&config, &device).unwrap();
285 assert_eq!(wrapper.adapter_type, AdapterType::Lora);
286 assert!(!wrapper.quantized);
287 }
288
289 #[test]
290 fn test_adapter_wrapper_creation_qlora() {
291 let config = AxolotlConfig::from_preset("llama2-7b").unwrap();
292 let device = Device::Cpu;
293
294 let wrapper = AdapterWrapper::new(&config, &device).unwrap();
295 assert_eq!(wrapper.adapter_type, AdapterType::Qlora);
296 assert!(wrapper.quantized);
297 }
298
299 #[test]
300 fn test_adapter_application_config_from_lora_settings() {
301 let lora_settings = LoraSettings {
302 r: 16,
303 alpha: 32,
304 dropout: 0.1,
305 target_modules: vec!["q_proj".into(), "v_proj".into()],
306 };
307
308 let app_config: AdapterApplicationConfig = (&lora_settings).into();
309 assert_eq!(app_config.r, 16);
310 assert_eq!(app_config.alpha, 32);
311 assert!((app_config.dropout - 0.1).abs() < 0.001);
312 }
313
314 #[cfg(feature = "peft")]
315 #[test]
316 fn test_to_peft_lora_config() {
317 let lora_settings = LoraSettings {
318 r: 8,
319 alpha: 16,
320 dropout: 0.05,
321 target_modules: vec!["q_proj".into()],
322 };
323
324 let peft_config = AdapterWrapper::to_peft_lora_config(&lora_settings);
325 assert_eq!(peft_config.r, 8);
326 assert_eq!(peft_config.alpha, 16);
327 assert!((peft_config.dropout - 0.05).abs() < 0.001);
328 }
329
330 #[cfg(feature = "peft")]
331 #[test]
332 fn test_adapter_save_and_load() {
333 use candle_nn::VarBuilder;
334 use tempfile::TempDir;
335
336 let temp_dir = TempDir::new().unwrap();
337 let device = Device::Cpu;
338
339 let config = AxolotlConfig::from_preset("llama2-7b").unwrap();
341 let wrapper = AdapterWrapper::new(&config, &device).unwrap();
342
343 let lora_config = PeftLoraConfig {
345 r: 8,
346 alpha: 16,
347 dropout: 0.0,
348 target_modules: vec!["q_proj".into()],
349 ..Default::default()
350 };
351
352 let vb = VarBuilder::zeros(candle_core::DType::F32, &device);
353 let layer = LoraLayer::new(768, 768, lora_config.clone(), vb).unwrap();
354
355 let mut layers = HashMap::new();
357 layers.insert("model.layers.0.self_attn.q_proj".to_string(), layer);
358
359 wrapper
360 .save_adapter(temp_dir.path(), &lora_config, &layers)
361 .unwrap();
362
363 assert!(temp_dir.path().join("adapter_model.safetensors").exists());
365 assert!(temp_dir.path().join("adapter_config.json").exists());
366
367 let (loaded_config, loaded_tensors) = wrapper.load_adapter(temp_dir.path()).unwrap();
369 assert_eq!(loaded_config.r, 8);
370 assert_eq!(loaded_config.alpha, 16);
371 assert!(!loaded_tensors.is_empty());
372 }
373}