Skip to main content

entrenar/lora/adapter/
lora_adapter.rs

1//! LoRA adapter serialization and deserialization
2//!
3//! Contains the main LoRAAdapter struct for saving and loading adapters.
4
5use super::error::AdapterError;
6use super::metadata::AdapterMetadata;
7use crate::lora::LoRALayer;
8use crate::Tensor;
9use serde::{Deserialize, Serialize};
10use std::fs::File;
11use std::io::{BufReader, BufWriter};
12use std::path::Path;
13
14/// Serializable LoRA adapter format
15///
16/// Contains all information needed to reconstruct a LoRA adapter
17/// (excluding the base weight, which remains frozen and separate)
18#[derive(Serialize, Deserialize, Debug, Clone)]
19pub struct LoRAAdapter {
20    /// Format version for future compatibility
21    version: String,
22    /// LoRA rank
23    rank: usize,
24    /// LoRA alpha parameter
25    alpha: f32,
26    /// Output dimension
27    d_out: usize,
28    /// Input dimension
29    d_in: usize,
30    /// Computed scale factor (alpha/rank)
31    scale: f32,
32    /// LoRA A matrix weights [rank * d_in]
33    lora_a: Vec<f32>,
34    /// LoRA B matrix weights [d_out * rank]
35    lora_b: Vec<f32>,
36}
37
38impl LoRAAdapter {
39    /// Current adapter format version
40    const VERSION: &'static str = "1.0";
41
42    /// Create adapter from LoRALayer
43    ///
44    /// # Arguments
45    /// * `layer` - LoRALayer to extract adapter from
46    /// * `rank` - LoRA rank
47    /// * `alpha` - LoRA alpha parameter
48    pub fn from_layer(layer: &LoRALayer, rank: usize, alpha: f32) -> Self {
49        Self {
50            version: Self::VERSION.to_string(),
51            rank,
52            alpha,
53            d_out: layer.d_out(),
54            d_in: layer.d_in(),
55            scale: layer.scale(),
56            lora_a: layer.lora_a().data().to_vec(),
57            lora_b: layer.lora_b().data().to_vec(),
58        }
59    }
60
61    /// Load adapter and apply to base weight
62    ///
63    /// # Arguments
64    /// * `base_weight` - Frozen base weight tensor [d_out * d_in]
65    ///
66    /// # Returns
67    /// LoRALayer with loaded adapter weights
68    pub fn to_layer(&self, base_weight: Tensor) -> Result<LoRALayer, AdapterError> {
69        // Validate dimensions
70        if base_weight.len() != self.d_out * self.d_in {
71            return Err(AdapterError::DimensionMismatch {
72                expected: format!("{}x{} = {}", self.d_out, self.d_in, self.d_out * self.d_in),
73                actual: base_weight.len().to_string(),
74            });
75        }
76
77        if self.lora_a.len() != self.rank * self.d_in {
78            return Err(AdapterError::Validation(format!(
79                "LoRA A size mismatch: expected {} (rank {} * d_in {}), got {}",
80                self.rank * self.d_in,
81                self.rank,
82                self.d_in,
83                self.lora_a.len()
84            )));
85        }
86
87        if self.lora_b.len() != self.d_out * self.rank {
88            return Err(AdapterError::Validation(format!(
89                "LoRA B size mismatch: expected {} (d_out {} * rank {}), got {}",
90                self.d_out * self.rank,
91                self.d_out,
92                self.rank,
93                self.lora_b.len()
94            )));
95        }
96
97        // Create layer with loaded weights
98        let mut layer = LoRALayer::new(base_weight, self.d_out, self.d_in, self.rank, self.alpha);
99
100        // Replace LoRA weights with loaded values
101        *layer.lora_a_mut().data_mut() = ndarray::arr1(&self.lora_a);
102        *layer.lora_b_mut().data_mut() = ndarray::arr1(&self.lora_b);
103
104        Ok(layer)
105    }
106
107    /// Save adapter to JSON file
108    ///
109    /// # Arguments
110    /// * `path` - File path to save to
111    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), AdapterError> {
112        let file = File::create(path)?;
113        let writer = BufWriter::new(file);
114        serde_json::to_writer_pretty(writer, self)?;
115        Ok(())
116    }
117
118    /// Load adapter from JSON file
119    ///
120    /// # Arguments
121    /// * `path` - File path to load from
122    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, AdapterError> {
123        let file = File::open(path)?;
124        let reader = BufReader::new(file);
125        let adapter: LoRAAdapter = serde_json::from_reader(reader)?;
126
127        // Validate version
128        if adapter.version != Self::VERSION {
129            return Err(AdapterError::Validation(format!(
130                "Unsupported adapter version: {} (expected {})",
131                adapter.version,
132                Self::VERSION
133            )));
134        }
135
136        Ok(adapter)
137    }
138
139    /// Get adapter metadata
140    pub fn metadata(&self) -> AdapterMetadata {
141        AdapterMetadata {
142            version: self.version.clone(),
143            rank: self.rank,
144            alpha: self.alpha,
145            d_out: self.d_out,
146            d_in: self.d_in,
147            scale: self.scale,
148            num_params: self.lora_a.len() + self.lora_b.len(),
149        }
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use tempfile::NamedTempFile;
157
158    fn make_test_adapter() -> LoRAAdapter {
159        LoRAAdapter {
160            version: "1.0".to_string(),
161            rank: 4,
162            alpha: 8.0,
163            d_out: 8,
164            d_in: 16,
165            scale: 2.0,
166            lora_a: vec![0.1; 4 * 16], // rank * d_in
167            lora_b: vec![0.2; 8 * 4],  // d_out * rank
168        }
169    }
170
171    #[test]
172    fn test_adapter_from_layer() {
173        let base_weight = Tensor::zeros(8 * 16, false);
174        let layer = LoRALayer::new(base_weight, 8, 16, 4, 8.0);
175        let adapter = LoRAAdapter::from_layer(&layer, 4, 8.0);
176        assert_eq!(adapter.rank, 4);
177        assert_eq!(adapter.alpha, 8.0);
178        assert_eq!(adapter.d_out, 8);
179        assert_eq!(adapter.d_in, 16);
180    }
181
182    #[test]
183    fn test_adapter_to_layer_valid() {
184        let adapter = make_test_adapter();
185        let base_weight = Tensor::zeros(8 * 16, false);
186        let layer = adapter.to_layer(base_weight).expect("operation should succeed");
187        assert_eq!(layer.d_out(), 8);
188        assert_eq!(layer.d_in(), 16);
189    }
190
191    #[test]
192    fn test_adapter_to_layer_dimension_mismatch() {
193        let adapter = make_test_adapter();
194        let base_weight = Tensor::zeros(100, false); // Wrong size
195        let result = adapter.to_layer(base_weight);
196        assert!(result.is_err());
197        match result {
198            Err(AdapterError::DimensionMismatch { .. }) => {}
199            _ => panic!("Expected DimensionMismatch error"),
200        }
201    }
202
203    #[test]
204    fn test_adapter_to_layer_lora_a_mismatch() {
205        let mut adapter = make_test_adapter();
206        adapter.lora_a = vec![0.1; 10]; // Wrong size
207        let base_weight = Tensor::zeros(8 * 16, false);
208        let result = adapter.to_layer(base_weight);
209        assert!(result.is_err());
210        match result {
211            Err(AdapterError::Validation(msg)) => {
212                assert!(msg.contains("LoRA A size mismatch"));
213            }
214            _ => panic!("Expected Validation error"),
215        }
216    }
217
218    #[test]
219    fn test_adapter_to_layer_lora_b_mismatch() {
220        let mut adapter = make_test_adapter();
221        adapter.lora_b = vec![0.2; 10]; // Wrong size
222        let base_weight = Tensor::zeros(8 * 16, false);
223        let result = adapter.to_layer(base_weight);
224        assert!(result.is_err());
225        match result {
226            Err(AdapterError::Validation(msg)) => {
227                assert!(msg.contains("LoRA B size mismatch"));
228            }
229            _ => panic!("Expected Validation error"),
230        }
231    }
232
233    #[test]
234    fn test_adapter_save_load_roundtrip() {
235        let adapter = make_test_adapter();
236        let file = NamedTempFile::new().expect("temp file creation should succeed");
237
238        adapter.save(file.path()).expect("save should succeed");
239        let loaded = LoRAAdapter::load(file.path()).expect("load should succeed");
240
241        assert_eq!(adapter.rank, loaded.rank);
242        assert_eq!(adapter.alpha, loaded.alpha);
243        assert_eq!(adapter.d_out, loaded.d_out);
244        assert_eq!(adapter.d_in, loaded.d_in);
245        assert_eq!(adapter.lora_a.len(), loaded.lora_a.len());
246        assert_eq!(adapter.lora_b.len(), loaded.lora_b.len());
247    }
248
249    #[test]
250    fn test_adapter_load_invalid_version() {
251        let mut adapter = make_test_adapter();
252        adapter.version = "0.0".to_string();
253        let file = NamedTempFile::new().expect("temp file creation should succeed");
254        adapter.save(file.path()).expect("save should succeed");
255
256        let result = LoRAAdapter::load(file.path());
257        assert!(result.is_err());
258        match result {
259            Err(AdapterError::Validation(msg)) => {
260                assert!(msg.contains("Unsupported adapter version"));
261            }
262            _ => panic!("Expected Validation error"),
263        }
264    }
265
266    #[test]
267    fn test_adapter_load_nonexistent_file() {
268        let result = LoRAAdapter::load("/nonexistent/path/adapter.json");
269        assert!(result.is_err());
270    }
271
272    #[test]
273    fn test_adapter_save_invalid_path() {
274        let adapter = make_test_adapter();
275        let result = adapter.save("/nonexistent/dir/adapter.json");
276        assert!(result.is_err());
277    }
278
279    #[test]
280    fn test_adapter_metadata() {
281        let adapter = make_test_adapter();
282        let meta = adapter.metadata();
283        assert_eq!(meta.rank, 4);
284        assert_eq!(meta.alpha, 8.0);
285        assert_eq!(meta.d_out, 8);
286        assert_eq!(meta.d_in, 16);
287        assert_eq!(meta.num_params, 4 * 16 + 8 * 4);
288    }
289
290    #[test]
291    fn test_adapter_clone() {
292        let adapter = make_test_adapter();
293        let cloned = adapter.clone();
294        assert_eq!(adapter.rank, cloned.rank);
295        assert_eq!(adapter.lora_a.len(), cloned.lora_a.len());
296    }
297
298    #[test]
299    fn test_adapter_debug() {
300        let adapter = make_test_adapter();
301        let debug = format!("{adapter:?}");
302        assert!(debug.contains("LoRAAdapter"));
303    }
304}