Skip to main content

peft_rs/
io.rs

1//! I/O utilities for saving and loading adapter weights and configurations.
2//!
3//! This module provides functionality for:
4//! - Saving adapter weights to safetensors format
5//! - Loading adapter weights from safetensors format
6//! - Saving adapter configurations to JSON
7//! - Loading adapter configurations from JSON
8
9use std::collections::HashMap;
10use std::fs;
11use std::path::Path;
12
13use candle_core::{Device, Tensor};
14use serde::{de::DeserializeOwned, Serialize};
15
16use crate::error::{PeftError, Result};
17
18/// Trait for adapters that can be saved and loaded.
19pub trait SaveLoad {
20    /// Get all adapter tensors as a map of name -> tensor.
21    ///
22    /// # Errors
23    ///
24    /// Returns an error if tensor retrieval fails.
25    fn state_dict(&self) -> Result<HashMap<String, Tensor>>;
26
27    /// Load adapter tensors from a state dict.
28    ///
29    /// # Errors
30    ///
31    /// Returns an error if tensor loading fails.
32    fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()>;
33}
34
35/// Save adapter weights to a safetensors file.
36///
37/// # Arguments
38/// * `adapter` - The adapter implementing `SaveLoad` trait
39/// * `path` - Path to save the safetensors file
40///
41/// # Errors
42/// Returns an error if:
43/// - Failed to get state dict from adapter
44/// - Failed to serialize tensors to safetensors format
45/// - Failed to write file to disk
46pub fn save_adapter_weights<P: AsRef<Path>>(adapter: &dyn SaveLoad, path: P) -> Result<()> {
47    let state_dict = adapter.state_dict()?;
48
49    // Convert HashMap to Vec for safetensors
50    let tensors: Vec<(&str, Tensor)> = state_dict
51        .iter()
52        .map(|(name, tensor)| (name.as_str(), tensor.clone()))
53        .collect();
54
55    // Use candle's built-in safetensors serialization
56    safetensors::tensor::serialize_to_file(tensors, None, path.as_ref())
57        .map_err(|e| PeftError::Io(format!("Failed to save safetensors: {e}")))?;
58
59    Ok(())
60}
61
62/// Load adapter weights from a safetensors file.
63///
64/// # Arguments
65/// * `adapter` - The adapter to load weights into
66/// * `path` - Path to the safetensors file
67/// * `device` - Device to load tensors on
68///
69/// # Errors
70/// Returns an error if:
71/// - Failed to read file from disk
72/// - Failed to parse safetensors format
73/// - Failed to load tensors into adapter
74pub fn load_adapter_weights<P: AsRef<Path>>(
75    adapter: &mut dyn SaveLoad,
76    path: P,
77    device: &Device,
78) -> Result<()> {
79    // Use candle's built-in safetensors loading
80    let tensors = candle_core::safetensors::load(path.as_ref(), device)?;
81
82    // Load into adapter
83    adapter.load_state_dict(tensors)?;
84
85    Ok(())
86}
87
88/// Save adapter configuration to a JSON file.
89///
90/// # Arguments
91/// * `config` - The configuration to save
92/// * `path` - Path to save the JSON file
93///
94/// # Errors
95/// Returns an error if serialization or file writing fails
96pub fn save_adapter_config<T: Serialize, P: AsRef<Path>>(config: &T, path: P) -> Result<()> {
97    let json = serde_json::to_string_pretty(config)
98        .map_err(|e| PeftError::Io(format!("Failed to serialize config: {e}")))?;
99
100    fs::write(path, json)
101        .map_err(|e| PeftError::Io(format!("Failed to write config file: {e}")))?;
102
103    Ok(())
104}
105
106/// Load adapter configuration from a JSON file.
107///
108/// # Arguments
109/// * `path` - Path to the JSON file
110///
111/// # Errors
112/// Returns an error if file reading or deserialization fails
113pub fn load_adapter_config<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> Result<T> {
114    let json = fs::read_to_string(path)
115        .map_err(|e| PeftError::Io(format!("Failed to read config file: {e}")))?;
116
117    let config = serde_json::from_str(&json)
118        .map_err(|e| PeftError::Io(format!("Failed to parse config: {e}")))?;
119
120    Ok(config)
121}
122
123/// Default filename for adapter weights in `HuggingFace` PEFT format.
124pub const ADAPTER_WEIGHTS_FILENAME: &str = "adapter_model.safetensors";
125
126/// Default filename for adapter config in `HuggingFace` PEFT format.
127pub const ADAPTER_CONFIG_FILENAME: &str = "adapter_config.json";
128
129/// Save adapter weights and configuration to a directory in `HuggingFace` PEFT format.
130///
131/// Creates the directory if it doesn't exist. Saves:
132/// - `adapter_model.safetensors` - Adapter weights
133/// - `adapter_config.json` - Adapter configuration
134///
135/// # Arguments
136/// * `adapter` - The adapter implementing `SaveLoad` trait
137/// * `config` - The adapter configuration
138/// * `dir` - Directory path to save to
139///
140/// # Errors
141/// Returns an error if:
142/// - Failed to create directory
143/// - Failed to save weights or config
144///
145/// # Example
146/// ```rust,ignore
147/// use peft_rs::{save_pretrained, LoraConfig, LoraLayer};
148///
149/// let adapter = LoraLayer::new_with_zeros(768, 768, config, &device)?;
150/// save_pretrained(&adapter, &config, "path/to/adapter")?;
151/// ```
152pub fn save_pretrained<T: Serialize, P: AsRef<Path>>(
153    adapter: &dyn SaveLoad,
154    config: &T,
155    dir: P,
156) -> Result<()> {
157    let dir = dir.as_ref();
158
159    // Create directory if it doesn't exist
160    if !dir.exists() {
161        fs::create_dir_all(dir)
162            .map_err(|e| PeftError::Io(format!("Failed to create directory: {e}")))?;
163    }
164
165    // Save weights
166    let weights_path = dir.join(ADAPTER_WEIGHTS_FILENAME);
167    save_adapter_weights(adapter, &weights_path)?;
168
169    // Save config
170    let config_path = dir.join(ADAPTER_CONFIG_FILENAME);
171    save_adapter_config(config, &config_path)?;
172
173    Ok(())
174}
175
176/// Load adapter weights and configuration from a directory in `HuggingFace` PEFT format.
177///
178/// Expects:
179/// - `adapter_model.safetensors` - Adapter weights
180/// - `adapter_config.json` - Adapter configuration
181///
182/// # Arguments
183/// * `adapter` - The adapter to load weights into
184/// * `dir` - Directory path to load from
185/// * `device` - Device to load tensors on
186///
187/// # Returns
188/// The loaded adapter configuration
189///
190/// # Errors
191/// Returns an error if:
192/// - Directory doesn't exist
193/// - Failed to load weights or config
194///
195/// # Example
196/// ```rust,ignore
197/// use peft_rs::{load_pretrained, LoraConfig, LoraLayer};
198///
199/// let mut adapter = LoraLayer::new_with_zeros(768, 768, LoraConfig::default(), &device)?;
200/// let config: LoraConfig = load_pretrained(&mut adapter, "path/to/adapter", &device)?;
201/// ```
202pub fn load_pretrained<T: DeserializeOwned, P: AsRef<Path>>(
203    adapter: &mut dyn SaveLoad,
204    dir: P,
205    device: &Device,
206) -> Result<T> {
207    let dir = dir.as_ref();
208
209    if !dir.exists() {
210        return Err(PeftError::Io(format!(
211            "Directory does not exist: {}",
212            dir.display()
213        )));
214    }
215
216    // Load weights
217    let weights_path = dir.join(ADAPTER_WEIGHTS_FILENAME);
218    load_adapter_weights(adapter, &weights_path, device)?;
219
220    // Load config
221    let config_path = dir.join(ADAPTER_CONFIG_FILENAME);
222    let config: T = load_adapter_config(&config_path)?;
223
224    Ok(config)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use candle_core::Device;
231    use std::collections::HashMap;
232    use tempfile::TempDir;
233
234    struct MockAdapter {
235        weights: HashMap<String, Tensor>,
236    }
237
238    impl SaveLoad for MockAdapter {
239        fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
240            Ok(self.weights.clone())
241        }
242
243        fn load_state_dict(&mut self, state_dict: HashMap<String, Tensor>) -> Result<()> {
244            self.weights = state_dict;
245            Ok(())
246        }
247    }
248
249    #[test]
250    #[allow(clippy::similar_names)]
251    fn test_save_load_adapter_weights() -> anyhow::Result<()> {
252        let device = Device::Cpu;
253        let temp_dir = TempDir::new()?;
254        let weights_path = temp_dir.path().join("adapter.safetensors");
255
256        // Create mock adapter with some weights
257        let mut weights = HashMap::new();
258        weights.insert(
259            "lora_a".to_string(),
260            Tensor::randn(0f32, 1f32, (64, 8), &device)?,
261        );
262        weights.insert(
263            "lora_b".to_string(),
264            Tensor::randn(0f32, 1f32, (8, 64), &device)?,
265        );
266
267        let adapter = MockAdapter {
268            weights: weights.clone(),
269        };
270
271        // Save weights
272        save_adapter_weights(&adapter, &weights_path)?;
273        assert!(weights_path.exists());
274
275        // Load weights into new adapter
276        let mut loaded_adapter = MockAdapter {
277            weights: HashMap::new(),
278        };
279        load_adapter_weights(&mut loaded_adapter, &weights_path, &device)?;
280
281        // Verify loaded weights exist and have correct properties
282        assert_eq!(loaded_adapter.weights.len(), 2);
283        assert!(loaded_adapter.weights.contains_key("lora_a"));
284        assert!(loaded_adapter.weights.contains_key("lora_b"));
285
286        // Verify shapes match
287        assert_eq!(
288            loaded_adapter.weights["lora_a"].dims(),
289            weights["lora_a"].dims()
290        );
291        assert_eq!(
292            loaded_adapter.weights["lora_b"].dims(),
293            weights["lora_b"].dims()
294        );
295
296        // Verify tensor values are preserved (compare sum as a simple check)
297        let original_a_sum = weights["lora_a"].sum_all()?.to_scalar::<f32>()?;
298        let loaded_a_sum = loaded_adapter.weights["lora_a"]
299            .sum_all()?
300            .to_scalar::<f32>()?;
301        assert!(
302            (original_a_sum - loaded_a_sum).abs() < 1e-5,
303            "lora_a sum mismatch: {original_a_sum} vs {loaded_a_sum}"
304        );
305
306        let original_b_sum = weights["lora_b"].sum_all()?.to_scalar::<f32>()?;
307        let loaded_b_sum = loaded_adapter.weights["lora_b"]
308            .sum_all()?
309            .to_scalar::<f32>()?;
310        assert!(
311            (original_b_sum - loaded_b_sum).abs() < 1e-5,
312            "lora_b sum mismatch: {original_b_sum} vs {loaded_b_sum}"
313        );
314
315        Ok(())
316    }
317
318    #[test]
319    fn test_save_load_config() -> anyhow::Result<()> {
320        use serde::{Deserialize, Serialize};
321
322        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
323        struct TestConfig {
324            r: usize,
325            alpha: usize,
326            dropout: f64,
327        }
328
329        let temp_dir = TempDir::new()?;
330        let config_path = temp_dir.path().join("config.json");
331
332        let config = TestConfig {
333            r: 8,
334            alpha: 16,
335            dropout: 0.1,
336        };
337
338        // Save config
339        save_adapter_config(&config, &config_path)?;
340        assert!(config_path.exists());
341
342        // Load config
343        let loaded_config: TestConfig = load_adapter_config(&config_path)?;
344        assert_eq!(config, loaded_config);
345
346        Ok(())
347    }
348
349    #[test]
350    fn test_save_load_pretrained() -> anyhow::Result<()> {
351        use serde::{Deserialize, Serialize};
352
353        #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
354        struct TestConfig {
355            r: usize,
356            alpha: usize,
357        }
358
359        let device = Device::Cpu;
360        let temp_dir = TempDir::new()?;
361
362        // Create mock adapter with weights
363        let mut weights = HashMap::new();
364        weights.insert(
365            "lora_a".to_string(),
366            Tensor::randn(0f32, 1f32, (64, 8), &device)?,
367        );
368        weights.insert(
369            "lora_b".to_string(),
370            Tensor::randn(0f32, 1f32, (8, 64), &device)?,
371        );
372
373        let adapter = MockAdapter {
374            weights: weights.clone(),
375        };
376
377        let config = TestConfig { r: 8, alpha: 16 };
378
379        // Save pretrained
380        save_pretrained(&adapter, &config, temp_dir.path())?;
381
382        // Verify files exist
383        assert!(temp_dir.path().join(ADAPTER_WEIGHTS_FILENAME).exists());
384        assert!(temp_dir.path().join(ADAPTER_CONFIG_FILENAME).exists());
385
386        // Load pretrained
387        let mut loaded_adapter = MockAdapter {
388            weights: HashMap::new(),
389        };
390        let loaded_config: TestConfig =
391            load_pretrained(&mut loaded_adapter, temp_dir.path(), &device)?;
392
393        // Verify config
394        assert_eq!(config, loaded_config);
395
396        // Verify weights
397        assert_eq!(loaded_adapter.weights.len(), 2);
398        assert!(loaded_adapter.weights.contains_key("lora_a"));
399        assert!(loaded_adapter.weights.contains_key("lora_b"));
400
401        Ok(())
402    }
403}