use super::error::AdapterError;
use super::lora_adapter::LoRAAdapter;
use super::peft_export::PeftAdapterBundle;
use crate::lora::{LoRAConfig, LoRALayer};
use crate::Tensor;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AdapterFormat {
EntrenarJson,
Peft,
}
pub fn save_adapter<P: AsRef<Path>>(
layer: &LoRALayer,
rank: usize,
alpha: f32,
path: P,
) -> Result<(), AdapterError> {
let adapter = LoRAAdapter::from_layer(layer, rank, alpha);
adapter.save(path)
}
pub fn load_adapter<P: AsRef<Path>>(
base_weight: Tensor,
path: P,
) -> Result<LoRALayer, AdapterError> {
let adapter = LoRAAdapter::load(path)?;
adapter.to_layer(base_weight)
}
pub fn save_adapter_peft<P: AsRef<Path>>(
adapters: &[(&str, &LoRALayer)],
config: &LoRAConfig,
base_model: Option<&str>,
output_dir: P,
) -> Result<(), AdapterError> {
let mut bundle = PeftAdapterBundle::new(config.clone());
if let Some(name) = base_model {
bundle = bundle.with_base_model(name);
}
for (path, layer) in adapters {
bundle.add_adapter(*path, layer);
}
bundle.save_peft(output_dir)
}
pub fn load_adapter_peft<P: AsRef<Path>>(
dir: P,
) -> Result<(super::peft_config::PeftAdapterConfig, Vec<(String, Vec<f32>)>), AdapterError> {
let dir = dir.as_ref();
let config_path = dir.join("adapter_config.json");
let config_str = std::fs::read_to_string(&config_path)?;
let config = super::peft_config::PeftAdapterConfig::from_json(&config_str)
.map_err(|e| AdapterError::PeftFormatError(format!("Invalid adapter_config.json: {e}")))?;
let model_path = dir.join("adapter_model.safetensors");
let model_data = std::fs::read(&model_path)?;
let tensors = safetensors::SafeTensors::deserialize(&model_data).map_err(|e| {
AdapterError::SafeTensors(format!("Failed to load adapter_model.safetensors: {e}"))
})?;
let mut weights = Vec::new();
for name in tensors.names() {
let tensor = tensors.tensor(name).map_err(|e| {
AdapterError::SafeTensors(format!("Failed to read tensor '{name}': {e}"))
})?;
let data: Vec<f32> = bytemuck::cast_slice::<u8, f32>(tensor.data()).to_vec();
weights.push((name.to_string(), data));
}
Ok((config, weights))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_adapter_format_eq() {
assert_eq!(AdapterFormat::EntrenarJson, AdapterFormat::EntrenarJson);
assert_eq!(AdapterFormat::Peft, AdapterFormat::Peft);
assert_ne!(AdapterFormat::EntrenarJson, AdapterFormat::Peft);
}
#[test]
fn test_save_load_peft_roundtrip() {
let config = LoRAConfig::new(4, 8.0).target_qv_projections();
let base = Tensor::zeros(8 * 16, false);
let layer = LoRALayer::new(base, 8, 16, 4, 8.0);
let tmp = TempDir::new().expect("temp file creation should succeed");
save_adapter_peft(
&[("model.layers.0.self_attn.q_proj", &layer)],
&config,
Some("test/model"),
tmp.path(),
)
.expect("operation should succeed");
let (loaded_config, weights) = load_adapter_peft(tmp.path()).expect("load should succeed");
assert_eq!(loaded_config.r, 4);
assert_eq!(loaded_config.lora_alpha, 8.0);
assert_eq!(weights.len(), 2); }
}