entrenar/lora/adapter/
io.rs1use super::error::AdapterError;
4use super::lora_adapter::LoRAAdapter;
5use super::peft_export::PeftAdapterBundle;
6use crate::lora::{LoRAConfig, LoRALayer};
7use crate::Tensor;
8use std::path::Path;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum AdapterFormat {
13 EntrenarJson,
15 Peft,
17}
18
19pub fn save_adapter<P: AsRef<Path>>(
27 layer: &LoRALayer,
28 rank: usize,
29 alpha: f32,
30 path: P,
31) -> Result<(), AdapterError> {
32 let adapter = LoRAAdapter::from_layer(layer, rank, alpha);
33 adapter.save(path)
34}
35
36pub fn load_adapter<P: AsRef<Path>>(
42 base_weight: Tensor,
43 path: P,
44) -> Result<LoRALayer, AdapterError> {
45 let adapter = LoRAAdapter::load(path)?;
46 adapter.to_layer(base_weight)
47}
48
49pub fn save_adapter_peft<P: AsRef<Path>>(
57 adapters: &[(&str, &LoRALayer)],
58 config: &LoRAConfig,
59 base_model: Option<&str>,
60 output_dir: P,
61) -> Result<(), AdapterError> {
62 let mut bundle = PeftAdapterBundle::new(config.clone());
63 if let Some(name) = base_model {
64 bundle = bundle.with_base_model(name);
65 }
66 for (path, layer) in adapters {
67 bundle.add_adapter(*path, layer);
68 }
69 bundle.save_peft(output_dir)
70}
71
72pub fn load_adapter_peft<P: AsRef<Path>>(
77 dir: P,
78) -> Result<(super::peft_config::PeftAdapterConfig, Vec<(String, Vec<f32>)>), AdapterError> {
79 let dir = dir.as_ref();
80
81 let config_path = dir.join("adapter_config.json");
83 let config_str = std::fs::read_to_string(&config_path)?;
84 let config = super::peft_config::PeftAdapterConfig::from_json(&config_str)
85 .map_err(|e| AdapterError::PeftFormatError(format!("Invalid adapter_config.json: {e}")))?;
86
87 let model_path = dir.join("adapter_model.safetensors");
89 let model_data = std::fs::read(&model_path)?;
90 let tensors = safetensors::SafeTensors::deserialize(&model_data).map_err(|e| {
91 AdapterError::SafeTensors(format!("Failed to load adapter_model.safetensors: {e}"))
92 })?;
93
94 let mut weights = Vec::new();
95 for name in tensors.names() {
96 let tensor = tensors.tensor(name).map_err(|e| {
97 AdapterError::SafeTensors(format!("Failed to read tensor '{name}': {e}"))
98 })?;
99 let data: Vec<f32> = bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec();
100 weights.push((name.to_string(), data));
101 }
102
103 Ok((config, weights))
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use tempfile::TempDir;
110
111 #[test]
112 fn test_adapter_format_eq() {
113 assert_eq!(AdapterFormat::EntrenarJson, AdapterFormat::EntrenarJson);
114 assert_eq!(AdapterFormat::Peft, AdapterFormat::Peft);
115 assert_ne!(AdapterFormat::EntrenarJson, AdapterFormat::Peft);
116 }
117
118 #[test]
119 fn test_save_load_peft_roundtrip() {
120 let config = LoRAConfig::new(4, 8.0).target_qv_projections();
121
122 let base = Tensor::zeros(8 * 16, false);
123 let layer = LoRALayer::new(base, 8, 16, 4, 8.0);
124
125 let tmp = TempDir::new().expect("temp file creation should succeed");
126 save_adapter_peft(
127 &[("model.layers.0.self_attn.q_proj", &layer)],
128 &config,
129 Some("test/model"),
130 tmp.path(),
131 )
132 .expect("operation should succeed");
133
134 let (loaded_config, weights) = load_adapter_peft(tmp.path()).expect("load should succeed");
135 assert_eq!(loaded_config.r, 4);
136 assert_eq!(loaded_config.lora_alpha, 8.0);
137 assert_eq!(weights.len(), 2); }
139}