1use super::error::AdapterError;
6use super::metadata::AdapterMetadata;
7use crate::lora::LoRALayer;
8use crate::Tensor;
9use serde::{Deserialize, Serialize};
10use std::fs::File;
11use std::io::{BufReader, BufWriter};
12use std::path::Path;
13
14#[derive(Serialize, Deserialize, Debug, Clone)]
19pub struct LoRAAdapter {
20 version: String,
22 rank: usize,
24 alpha: f32,
26 d_out: usize,
28 d_in: usize,
30 scale: f32,
32 lora_a: Vec<f32>,
34 lora_b: Vec<f32>,
36}
37
38impl LoRAAdapter {
39 const VERSION: &'static str = "1.0";
41
42 pub fn from_layer(layer: &LoRALayer, rank: usize, alpha: f32) -> Self {
49 Self {
50 version: Self::VERSION.to_string(),
51 rank,
52 alpha,
53 d_out: layer.d_out(),
54 d_in: layer.d_in(),
55 scale: layer.scale(),
56 lora_a: layer.lora_a().data().to_vec(),
57 lora_b: layer.lora_b().data().to_vec(),
58 }
59 }
60
61 pub fn to_layer(&self, base_weight: Tensor) -> Result<LoRALayer, AdapterError> {
69 if base_weight.len() != self.d_out * self.d_in {
71 return Err(AdapterError::DimensionMismatch {
72 expected: format!("{}x{} = {}", self.d_out, self.d_in, self.d_out * self.d_in),
73 actual: base_weight.len().to_string(),
74 });
75 }
76
77 if self.lora_a.len() != self.rank * self.d_in {
78 return Err(AdapterError::Validation(format!(
79 "LoRA A size mismatch: expected {} (rank {} * d_in {}), got {}",
80 self.rank * self.d_in,
81 self.rank,
82 self.d_in,
83 self.lora_a.len()
84 )));
85 }
86
87 if self.lora_b.len() != self.d_out * self.rank {
88 return Err(AdapterError::Validation(format!(
89 "LoRA B size mismatch: expected {} (d_out {} * rank {}), got {}",
90 self.d_out * self.rank,
91 self.d_out,
92 self.rank,
93 self.lora_b.len()
94 )));
95 }
96
97 let mut layer = LoRALayer::new(base_weight, self.d_out, self.d_in, self.rank, self.alpha);
99
100 *layer.lora_a_mut().data_mut() = ndarray::arr1(&self.lora_a);
102 *layer.lora_b_mut().data_mut() = ndarray::arr1(&self.lora_b);
103
104 Ok(layer)
105 }
106
107 pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), AdapterError> {
112 let file = File::create(path)?;
113 let writer = BufWriter::new(file);
114 serde_json::to_writer_pretty(writer, self)?;
115 Ok(())
116 }
117
118 pub fn load<P: AsRef<Path>>(path: P) -> Result<Self, AdapterError> {
123 let file = File::open(path)?;
124 let reader = BufReader::new(file);
125 let adapter: LoRAAdapter = serde_json::from_reader(reader)?;
126
127 if adapter.version != Self::VERSION {
129 return Err(AdapterError::Validation(format!(
130 "Unsupported adapter version: {} (expected {})",
131 adapter.version,
132 Self::VERSION
133 )));
134 }
135
136 Ok(adapter)
137 }
138
139 pub fn metadata(&self) -> AdapterMetadata {
141 AdapterMetadata {
142 version: self.version.clone(),
143 rank: self.rank,
144 alpha: self.alpha,
145 d_out: self.d_out,
146 d_in: self.d_in,
147 scale: self.scale,
148 num_params: self.lora_a.len() + self.lora_b.len(),
149 }
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use tempfile::NamedTempFile;
157
158 fn make_test_adapter() -> LoRAAdapter {
159 LoRAAdapter {
160 version: "1.0".to_string(),
161 rank: 4,
162 alpha: 8.0,
163 d_out: 8,
164 d_in: 16,
165 scale: 2.0,
166 lora_a: vec![0.1; 4 * 16], lora_b: vec![0.2; 8 * 4], }
169 }
170
171 #[test]
172 fn test_adapter_from_layer() {
173 let base_weight = Tensor::zeros(8 * 16, false);
174 let layer = LoRALayer::new(base_weight, 8, 16, 4, 8.0);
175 let adapter = LoRAAdapter::from_layer(&layer, 4, 8.0);
176 assert_eq!(adapter.rank, 4);
177 assert_eq!(adapter.alpha, 8.0);
178 assert_eq!(adapter.d_out, 8);
179 assert_eq!(adapter.d_in, 16);
180 }
181
182 #[test]
183 fn test_adapter_to_layer_valid() {
184 let adapter = make_test_adapter();
185 let base_weight = Tensor::zeros(8 * 16, false);
186 let layer = adapter.to_layer(base_weight).expect("operation should succeed");
187 assert_eq!(layer.d_out(), 8);
188 assert_eq!(layer.d_in(), 16);
189 }
190
191 #[test]
192 fn test_adapter_to_layer_dimension_mismatch() {
193 let adapter = make_test_adapter();
194 let base_weight = Tensor::zeros(100, false); let result = adapter.to_layer(base_weight);
196 assert!(result.is_err());
197 match result {
198 Err(AdapterError::DimensionMismatch { .. }) => {}
199 _ => panic!("Expected DimensionMismatch error"),
200 }
201 }
202
203 #[test]
204 fn test_adapter_to_layer_lora_a_mismatch() {
205 let mut adapter = make_test_adapter();
206 adapter.lora_a = vec![0.1; 10]; let base_weight = Tensor::zeros(8 * 16, false);
208 let result = adapter.to_layer(base_weight);
209 assert!(result.is_err());
210 match result {
211 Err(AdapterError::Validation(msg)) => {
212 assert!(msg.contains("LoRA A size mismatch"));
213 }
214 _ => panic!("Expected Validation error"),
215 }
216 }
217
218 #[test]
219 fn test_adapter_to_layer_lora_b_mismatch() {
220 let mut adapter = make_test_adapter();
221 adapter.lora_b = vec![0.2; 10]; let base_weight = Tensor::zeros(8 * 16, false);
223 let result = adapter.to_layer(base_weight);
224 assert!(result.is_err());
225 match result {
226 Err(AdapterError::Validation(msg)) => {
227 assert!(msg.contains("LoRA B size mismatch"));
228 }
229 _ => panic!("Expected Validation error"),
230 }
231 }
232
233 #[test]
234 fn test_adapter_save_load_roundtrip() {
235 let adapter = make_test_adapter();
236 let file = NamedTempFile::new().expect("temp file creation should succeed");
237
238 adapter.save(file.path()).expect("save should succeed");
239 let loaded = LoRAAdapter::load(file.path()).expect("load should succeed");
240
241 assert_eq!(adapter.rank, loaded.rank);
242 assert_eq!(adapter.alpha, loaded.alpha);
243 assert_eq!(adapter.d_out, loaded.d_out);
244 assert_eq!(adapter.d_in, loaded.d_in);
245 assert_eq!(adapter.lora_a.len(), loaded.lora_a.len());
246 assert_eq!(adapter.lora_b.len(), loaded.lora_b.len());
247 }
248
249 #[test]
250 fn test_adapter_load_invalid_version() {
251 let mut adapter = make_test_adapter();
252 adapter.version = "0.0".to_string();
253 let file = NamedTempFile::new().expect("temp file creation should succeed");
254 adapter.save(file.path()).expect("save should succeed");
255
256 let result = LoRAAdapter::load(file.path());
257 assert!(result.is_err());
258 match result {
259 Err(AdapterError::Validation(msg)) => {
260 assert!(msg.contains("Unsupported adapter version"));
261 }
262 _ => panic!("Expected Validation error"),
263 }
264 }
265
266 #[test]
267 fn test_adapter_load_nonexistent_file() {
268 let result = LoRAAdapter::load("/nonexistent/path/adapter.json");
269 assert!(result.is_err());
270 }
271
272 #[test]
273 fn test_adapter_save_invalid_path() {
274 let adapter = make_test_adapter();
275 let result = adapter.save("/nonexistent/dir/adapter.json");
276 assert!(result.is_err());
277 }
278
279 #[test]
280 fn test_adapter_metadata() {
281 let adapter = make_test_adapter();
282 let meta = adapter.metadata();
283 assert_eq!(meta.rank, 4);
284 assert_eq!(meta.alpha, 8.0);
285 assert_eq!(meta.d_out, 8);
286 assert_eq!(meta.d_in, 16);
287 assert_eq!(meta.num_params, 4 * 16 + 8 * 4);
288 }
289
290 #[test]
291 fn test_adapter_clone() {
292 let adapter = make_test_adapter();
293 let cloned = adapter.clone();
294 assert_eq!(adapter.rank, cloned.rank);
295 assert_eq!(adapter.lora_a.len(), cloned.lora_a.len());
296 }
297
298 #[test]
299 fn test_adapter_debug() {
300 let adapter = make_test_adapter();
301 let debug = format!("{adapter:?}");
302 assert!(debug.contains("LoRAAdapter"));
303 }
304}