use indexmap::IndexMap;
use super::config::LoraConfig;
use super::error::{LoraError, LoraResult};
use super::layer::LoraLayer;
#[derive(Debug, Clone)]
pub struct LayerStats {
pub name: String,
pub d: usize,
pub k: usize,
pub rank: usize,
pub trainable_params: usize,
pub total_params: usize,
pub compression_ratio: f64,
pub merged: bool,
}
#[derive(Debug, Clone)]
pub struct LoraAdapterSummary {
pub layers: Vec<LayerStats>,
pub total_trainable: usize,
pub total_params: usize,
}
pub struct LoraAdapter {
config: LoraConfig,
layers: IndexMap<String, LoraLayer>,
}
impl LoraAdapter {
pub fn new(config: LoraConfig) -> Self {
Self {
config,
layers: IndexMap::new(),
}
}
pub fn add_layer(&mut self, name: &str, base_weight: Vec<Vec<f64>>) -> LoraResult<()> {
let layer = LoraLayer::new(base_weight, self.config.clone())?;
self.layers.insert(name.to_string(), layer);
Ok(())
}
pub fn forward(&mut self, name: &str, input: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
let layer = self
.layers
.get_mut(name)
.ok_or_else(|| LoraError::DimensionMismatch {
expected: format!("layer '{name}' exists"),
got: "not found".into(),
})?;
layer.forward(input)
}
pub fn merge_all(&mut self) -> LoraResult<()> {
for layer in self.layers.values_mut() {
if !layer.merged {
layer.merge()?;
}
}
Ok(())
}
pub fn unmerge_all(&mut self) -> LoraResult<()> {
for layer in self.layers.values_mut() {
if layer.merged {
layer.unmerge()?;
}
}
Ok(())
}
pub fn total_trainable_params(&self) -> usize {
self.layers.values().map(|l| l.trainable_params()).sum()
}
pub fn summary(&self) -> LoraAdapterSummary {
let mut layers = Vec::with_capacity(self.layers.len());
for (name, layer) in &self.layers {
let d = layer.base_weight.len();
let k = layer.base_weight[0].len();
layers.push(LayerStats {
name: name.clone(),
d,
k,
rank: layer.config.rank,
trainable_params: layer.trainable_params(),
total_params: layer.total_params(),
compression_ratio: layer.compression_ratio(),
merged: layer.merged,
});
}
let total_trainable = layers.iter().map(|s| s.trainable_params).sum();
let total_params = layers.iter().map(|s| s.total_params).sum();
LoraAdapterSummary {
layers,
total_trainable,
total_params,
}
}
}