Skip to main content

entrenar/lora/adapter/
peft_export.rs

1//! PEFT-compatible adapter export (adapter_model.safetensors + adapter_config.json)
2//!
3//! Produces output compatible with `peft.PeftModel.from_pretrained()`.
4
5use super::error::AdapterError;
6use super::peft_config::PeftAdapterConfig;
7use crate::lora::LoRAConfig;
8use crate::lora::LoRALayer;
9use safetensors::tensor::{Dtype, TensorView};
10use std::collections::HashMap;
11use std::path::Path;
12
13/// A bundle of LoRA adapters keyed by layer path
14///
15/// Collects multiple LoRA layer adapters and exports them in PEFT format.
16pub struct PeftAdapterBundle {
17    /// Adapters keyed by layer path (e.g., "model.layers.0.self_attn.q_proj")
18    adapters: Vec<(String, AdapterWeights)>,
19    /// LoRA configuration
20    config: LoRAConfig,
21    /// Base model name (for adapter_config.json)
22    base_model: Option<String>,
23}
24
25/// Extracted adapter weights for a single layer
26struct AdapterWeights {
27    /// LoRA A matrix [rank, d_in]
28    lora_a: Vec<f32>,
29    /// LoRA B matrix [d_out, rank]
30    lora_b: Vec<f32>,
31    /// LoRA rank
32    rank: usize,
33    /// Input dimension
34    d_in: usize,
35    /// Output dimension
36    d_out: usize,
37}
38
39impl PeftAdapterBundle {
40    /// Create a new bundle with the given LoRA config
41    pub fn new(config: LoRAConfig) -> Self {
42        Self { adapters: Vec::new(), config, base_model: None }
43    }
44
45    /// Set the base model name
46    pub fn with_base_model(mut self, name: impl Into<String>) -> Self {
47        self.base_model = Some(name.into());
48        self
49    }
50
51    /// Add a LoRA layer adapter with its full layer path
52    ///
53    /// The layer path should follow the model's naming convention, e.g.:
54    /// `"model.layers.0.self_attn.q_proj"`
55    pub fn add_adapter(&mut self, layer_path: impl Into<String>, layer: &LoRALayer) {
56        let weights = AdapterWeights {
57            lora_a: layer.lora_a().data().to_vec(),
58            lora_b: layer.lora_b().data().to_vec(),
59            rank: layer.rank(),
60            d_in: layer.d_in(),
61            d_out: layer.d_out(),
62        };
63        self.adapters.push((layer_path.into(), weights));
64    }
65
66    /// Add raw LoRA weights (for GPU pipeline where LoRALayer isn't available)
67    pub fn add_raw_adapter(
68        &mut self,
69        layer_path: impl Into<String>,
70        lora_a: Vec<f32>,
71        lora_b: Vec<f32>,
72        rank: usize,
73        d_in: usize,
74        d_out: usize,
75    ) {
76        self.adapters
77            .push((layer_path.into(), AdapterWeights { lora_a, lora_b, rank, d_in, d_out }));
78    }
79
80    /// Save PEFT-compatible adapter to output directory
81    ///
82    /// Creates:
83    /// - `adapter_config.json` — PEFT configuration
84    /// - `adapter_model.safetensors` — adapter weights in PEFT naming convention
85    pub fn save_peft(&self, output_dir: impl AsRef<Path>) -> Result<(), AdapterError> {
86        let output_dir = output_dir.as_ref();
87        std::fs::create_dir_all(output_dir)?;
88
89        // Write adapter_config.json
90        let peft_config =
91            PeftAdapterConfig::from_lora_config(&self.config, self.base_model.as_deref());
92        let config_json = peft_config.to_json().map_err(AdapterError::Serialization)?;
93        std::fs::write(output_dir.join("adapter_config.json"), config_json)?;
94
95        // Build tensor data for safetensors
96        // PEFT naming convention: "base_model.model.{layer_path}.lora_A.weight" / "lora_B.weight"
97        let mut tensor_data: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::new();
98
99        for (layer_path, weights) in &self.adapters {
100            // LoRA A: [rank, d_in]
101            let a_name = format!("base_model.model.{layer_path}.lora_A.weight");
102            let a_bytes: Vec<u8> = bytemuck::cast_slice(&weights.lora_a).to_vec();
103            let a_shape = vec![weights.rank, weights.d_in];
104            tensor_data.push((a_name, a_bytes, a_shape));
105
106            // LoRA B: [d_out, rank]
107            let b_name = format!("base_model.model.{layer_path}.lora_B.weight");
108            let b_bytes: Vec<u8> = bytemuck::cast_slice(&weights.lora_b).to_vec();
109            let b_shape = vec![weights.d_out, weights.rank];
110            tensor_data.push((b_name, b_bytes, b_shape));
111        }
112
113        // Create TensorViews
114        let views: Vec<(&str, TensorView<'_>)> = tensor_data
115            .iter()
116            .map(|(name, bytes, shape)| {
117                let view = TensorView::new(Dtype::F32, shape.clone(), bytes)
118                    .expect("TensorView construction must not fail for valid F32 data");
119                (name.as_str(), view)
120            })
121            .collect();
122
123        // Metadata
124        let mut metadata = HashMap::new();
125        metadata.insert("format".to_string(), "pt".to_string());
126
127        let safetensor_bytes = safetensors::serialize(views, Some(metadata)).map_err(|e| {
128            AdapterError::SafeTensors(format!("SafeTensors serialization failed: {e}"))
129        })?;
130
131        std::fs::write(output_dir.join("adapter_model.safetensors"), safetensor_bytes)?;
132
133        Ok(())
134    }
135
136    /// Number of adapter layers in the bundle
137    pub fn len(&self) -> usize {
138        self.adapters.len()
139    }
140
141    /// Check if bundle is empty
142    pub fn is_empty(&self) -> bool {
143        self.adapters.is_empty()
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use crate::lora::LoRALayer;
151    use crate::Tensor;
152    use tempfile::TempDir;
153
154    fn make_test_layer(d_out: usize, d_in: usize, rank: usize) -> LoRALayer {
155        let base_weight = Tensor::zeros(d_out * d_in, false);
156        LoRALayer::new(base_weight, d_out, d_in, rank, 16.0)
157    }
158
159    #[test]
160    fn test_bundle_creation() {
161        let config = LoRAConfig::new(8, 16.0).target_qv_projections();
162        let bundle = PeftAdapterBundle::new(config);
163        assert!(bundle.is_empty());
164        assert_eq!(bundle.len(), 0);
165    }
166
167    #[test]
168    fn test_add_adapter() {
169        let config = LoRAConfig::new(8, 16.0).target_qv_projections();
170        let mut bundle = PeftAdapterBundle::new(config);
171
172        let layer = make_test_layer(64, 64, 8);
173        bundle.add_adapter("model.layers.0.self_attn.q_proj", &layer);
174
175        assert_eq!(bundle.len(), 1);
176        assert!(!bundle.is_empty());
177    }
178
179    #[test]
180    fn test_save_peft_creates_files() {
181        let config = LoRAConfig::new(4, 8.0).target_qv_projections();
182        let mut bundle = PeftAdapterBundle::new(config).with_base_model("meta-llama/Llama-2-7b");
183
184        let layer = make_test_layer(16, 16, 4);
185        bundle.add_adapter("model.layers.0.self_attn.q_proj", &layer);
186
187        let tmp = TempDir::new().expect("temp file creation should succeed");
188        bundle.save_peft(tmp.path()).expect("save should succeed");
189
190        // Verify files exist
191        assert!(tmp.path().join("adapter_config.json").exists());
192        assert!(tmp.path().join("adapter_model.safetensors").exists());
193    }
194
195    #[test]
196    fn test_save_peft_config_content() {
197        let config = LoRAConfig::new(16, 32.0).target_attention_projections();
198        let bundle = PeftAdapterBundle::new(config).with_base_model("test/model");
199
200        let tmp = TempDir::new().expect("temp file creation should succeed");
201        bundle.save_peft(tmp.path()).expect("save should succeed");
202
203        let json = std::fs::read_to_string(tmp.path().join("adapter_config.json"))
204            .expect("file read should succeed");
205        let parsed: PeftAdapterConfig =
206            serde_json::from_str(&json).expect("JSON deserialization should succeed");
207
208        assert_eq!(parsed.peft_type, "LORA");
209        assert_eq!(parsed.r, 16);
210        assert_eq!(parsed.lora_alpha, 32.0);
211        assert_eq!(parsed.base_model_name_or_path, Some("test/model".to_string()));
212    }
213
214    #[test]
215    fn test_save_peft_safetensors_content() {
216        let config = LoRAConfig::new(4, 8.0).target_qv_projections();
217        let mut bundle = PeftAdapterBundle::new(config);
218
219        let layer = make_test_layer(8, 16, 4);
220        bundle.add_adapter("model.layers.0.self_attn.q_proj", &layer);
221
222        let tmp = TempDir::new().expect("temp file creation should succeed");
223        bundle.save_peft(tmp.path()).expect("save should succeed");
224
225        // Load and verify safetensors
226        let data = std::fs::read(tmp.path().join("adapter_model.safetensors"))
227            .expect("file read should succeed");
228        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
229
230        let names = loaded.names();
231        assert!(names.contains(&"base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight"));
232        assert!(names.contains(&"base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight"));
233
234        // Check shapes
235        let lora_a = loaded
236            .tensor("base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight")
237            .expect("operation should succeed");
238        assert_eq!(lora_a.shape(), &[4, 16]); // [rank, d_in]
239
240        let lora_b = loaded
241            .tensor("base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight")
242            .expect("operation should succeed");
243        assert_eq!(lora_b.shape(), &[8, 4]); // [d_out, rank]
244    }
245
246    #[test]
247    fn test_save_peft_multiple_layers() {
248        let config = LoRAConfig::new(4, 8.0).target_qv_projections();
249        let mut bundle = PeftAdapterBundle::new(config);
250
251        for i in 0..3 {
252            let layer = make_test_layer(8, 8, 4);
253            bundle.add_adapter(format!("model.layers.{i}.self_attn.q_proj"), &layer);
254        }
255        assert_eq!(bundle.len(), 3);
256
257        let tmp = TempDir::new().expect("temp file creation should succeed");
258        bundle.save_peft(tmp.path()).expect("save should succeed");
259
260        let data = std::fs::read(tmp.path().join("adapter_model.safetensors"))
261            .expect("file read should succeed");
262        let loaded = safetensors::SafeTensors::deserialize(&data).expect("load should succeed");
263        // 3 layers * 2 matrices (A + B) = 6 tensors
264        assert_eq!(loaded.len(), 6);
265    }
266
267    #[test]
268    fn test_save_peft_empty_bundle() {
269        let config = LoRAConfig::new(4, 8.0);
270        let bundle = PeftAdapterBundle::new(config);
271
272        let tmp = TempDir::new().expect("temp file creation should succeed");
273        bundle.save_peft(tmp.path()).expect("save should succeed");
274
275        // Should still create both files
276        assert!(tmp.path().join("adapter_config.json").exists());
277        assert!(tmp.path().join("adapter_model.safetensors").exists());
278    }
279}