Skip to main content

entrenar/lora/adapter/
io.rs

1//! LoRA adapter I/O convenience functions
2
3use super::error::AdapterError;
4use super::lora_adapter::LoRAAdapter;
5use super::peft_export::PeftAdapterBundle;
6use crate::lora::{LoRAConfig, LoRALayer};
7use crate::Tensor;
8use std::path::Path;
9
10/// Adapter serialization format
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum AdapterFormat {
13    /// Entrenar's native JSON format (single-layer)
14    EntrenarJson,
15    /// HuggingFace PEFT format (adapter_config.json + adapter_model.safetensors)
16    Peft,
17}
18
19/// Save LoRA adapter to file (Entrenar JSON format)
20///
21/// # Arguments
22/// * `layer` - LoRALayer to save
23/// * `rank` - LoRA rank
24/// * `alpha` - LoRA alpha parameter
25/// * `path` - File path to save to
26pub fn save_adapter<P: AsRef<Path>>(
27    layer: &LoRALayer,
28    rank: usize,
29    alpha: f32,
30    path: P,
31) -> Result<(), AdapterError> {
32    let adapter = LoRAAdapter::from_layer(layer, rank, alpha);
33    adapter.save(path)
34}
35
36/// Load LoRA adapter from file (Entrenar JSON format)
37///
38/// # Arguments
39/// * `base_weight` - Frozen base weight to apply adapter to
40/// * `path` - File path to load from
41pub fn load_adapter<P: AsRef<Path>>(
42    base_weight: Tensor,
43    path: P,
44) -> Result<LoRALayer, AdapterError> {
45    let adapter = LoRAAdapter::load(path)?;
46    adapter.to_layer(base_weight)
47}
48
49/// Save LoRA adapters in PEFT-compatible format
50///
51/// # Arguments
52/// * `adapters` - Layer path to LoRA layer mappings
53/// * `config` - LoRA configuration
54/// * `base_model` - Optional base model name for adapter_config.json
55/// * `output_dir` - Output directory (will contain adapter_config.json + adapter_model.safetensors)
56pub fn save_adapter_peft<P: AsRef<Path>>(
57    adapters: &[(&str, &LoRALayer)],
58    config: &LoRAConfig,
59    base_model: Option<&str>,
60    output_dir: P,
61) -> Result<(), AdapterError> {
62    let mut bundle = PeftAdapterBundle::new(config.clone());
63    if let Some(name) = base_model {
64        bundle = bundle.with_base_model(name);
65    }
66    for (path, layer) in adapters {
67        bundle.add_adapter(*path, layer);
68    }
69    bundle.save_peft(output_dir)
70}
71
72/// Load LoRA adapter from PEFT-compatible format
73///
74/// Reads `adapter_config.json` and `adapter_model.safetensors` from the given directory
75/// and returns the adapter configuration along with tensor name → weight data.
76pub fn load_adapter_peft<P: AsRef<Path>>(
77    dir: P,
78) -> Result<(super::peft_config::PeftAdapterConfig, Vec<(String, Vec<f32>)>), AdapterError> {
79    let dir = dir.as_ref();
80
81    // Read adapter_config.json
82    let config_path = dir.join("adapter_config.json");
83    let config_str = std::fs::read_to_string(&config_path)?;
84    let config = super::peft_config::PeftAdapterConfig::from_json(&config_str)
85        .map_err(|e| AdapterError::PeftFormatError(format!("Invalid adapter_config.json: {e}")))?;
86
87    // Read adapter_model.safetensors
88    let model_path = dir.join("adapter_model.safetensors");
89    let model_data = std::fs::read(&model_path)?;
90    let tensors = safetensors::SafeTensors::deserialize(&model_data).map_err(|e| {
91        AdapterError::SafeTensors(format!("Failed to load adapter_model.safetensors: {e}"))
92    })?;
93
94    let mut weights = Vec::new();
95    for name in tensors.names() {
96        let tensor = tensors.tensor(name).map_err(|e| {
97            AdapterError::SafeTensors(format!("Failed to read tensor '{name}': {e}"))
98        })?;
99        let data: Vec<f32> = bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec();
100        weights.push((name.to_string(), data));
101    }
102
103    Ok((config, weights))
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use tempfile::TempDir;
110
111    #[test]
112    fn test_adapter_format_eq() {
113        assert_eq!(AdapterFormat::EntrenarJson, AdapterFormat::EntrenarJson);
114        assert_eq!(AdapterFormat::Peft, AdapterFormat::Peft);
115        assert_ne!(AdapterFormat::EntrenarJson, AdapterFormat::Peft);
116    }
117
118    #[test]
119    fn test_save_load_peft_roundtrip() {
120        let config = LoRAConfig::new(4, 8.0).target_qv_projections();
121
122        let base = Tensor::zeros(8 * 16, false);
123        let layer = LoRALayer::new(base, 8, 16, 4, 8.0);
124
125        let tmp = TempDir::new().expect("temp file creation should succeed");
126        save_adapter_peft(
127            &[("model.layers.0.self_attn.q_proj", &layer)],
128            &config,
129            Some("test/model"),
130            tmp.path(),
131        )
132        .expect("operation should succeed");
133
134        let (loaded_config, weights) = load_adapter_peft(tmp.path()).expect("load should succeed");
135        assert_eq!(loaded_config.r, 4);
136        assert_eq!(loaded_config.lora_alpha, 8.0);
137        assert_eq!(weights.len(), 2); // lora_A + lora_B
138    }
139}