use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum DiffKind {
Lora,
Ia3,
Full,
InPlaceTtt,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct LoraAdapter {
pub layer_id: u32,
pub matrix: String,
pub rank: u32,
pub in_dim: u32,
pub out_dim: u32,
pub a: Vec<f32>,
pub b: Vec<f32>,
}
impl LoraAdapter {
pub fn validate(&self) -> pf_core::Result<()> {
let a_expected = (self.rank as usize) * (self.in_dim as usize);
let b_expected = (self.out_dim as usize) * (self.rank as usize);
if self.a.len() != a_expected {
return Err(pf_core::Error::Integrity(format!(
"LoraAdapter L{}/{}: a.len {} ≠ rank·in_dim {}",
self.layer_id,
self.matrix,
self.a.len(),
a_expected
)));
}
if self.b.len() != b_expected {
return Err(pf_core::Error::Integrity(format!(
"LoraAdapter L{}/{}: b.len {} ≠ out_dim·rank {}",
self.layer_id,
self.matrix,
self.b.len(),
b_expected
)));
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct LoraDelta {
pub adapters: Vec<LoraAdapter>,
}
impl LoraDelta {
pub fn canonicalize(&mut self) {
self.adapters
.sort_by(|a, b| (a.layer_id, &a.matrix).cmp(&(b.layer_id, &b.matrix)));
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct IA3Delta {
pub scaling: BTreeMap<String, BTreeMap<String, Vec<f32>>>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FullDelta {
pub params: BTreeMap<String, Vec<f32>>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct TttStep {
pub step_id: u32,
pub deltas: BTreeMap<String, Vec<f32>>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct InPlaceTttDelta {
pub steps: Vec<TttStep>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum ModelDiff {
#[serde(rename = "lora")]
Lora(LoraDelta),
#[serde(rename = "ia3")]
Ia3(IA3Delta),
#[serde(rename = "full")]
Full(FullDelta),
#[serde(rename = "in-place-ttt")]
InPlaceTtt(InPlaceTttDelta),
}
impl ModelDiff {
#[must_use]
pub fn kind(&self) -> DiffKind {
match self {
Self::Lora(_) => DiffKind::Lora,
Self::Ia3(_) => DiffKind::Ia3,
Self::Full(_) => DiffKind::Full,
Self::InPlaceTtt(_) => DiffKind::InPlaceTtt,
}
}
pub fn validate_and_canonicalize(&mut self) -> pf_core::Result<()> {
match self {
Self::Lora(d) => {
d.canonicalize();
for a in &d.adapters {
a.validate()?;
}
}
Self::Ia3(_) | Self::Full(_) => {
}
Self::InPlaceTtt(d) => {
d.steps.sort_by_key(|s| s.step_id);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lora_validate_catches_dim_mismatch() {
let bad = LoraAdapter {
layer_id: 0,
matrix: "q_proj".into(),
rank: 4,
in_dim: 8,
out_dim: 8,
a: vec![0.0; 4 * 8],
b: vec![0.0; 5], };
assert!(bad.validate().is_err());
}
#[test]
fn lora_canonicalize_orders_by_layer_then_matrix() {
let mut d = LoraDelta {
adapters: vec![
LoraAdapter {
layer_id: 1,
matrix: "v_proj".into(),
rank: 2,
in_dim: 4,
out_dim: 4,
a: vec![0.0; 8],
b: vec![0.0; 8],
},
LoraAdapter {
layer_id: 0,
matrix: "v_proj".into(),
rank: 2,
in_dim: 4,
out_dim: 4,
a: vec![0.0; 8],
b: vec![0.0; 8],
},
LoraAdapter {
layer_id: 0,
matrix: "q_proj".into(),
rank: 2,
in_dim: 4,
out_dim: 4,
a: vec![0.0; 8],
b: vec![0.0; 8],
},
],
};
d.canonicalize();
assert_eq!(d.adapters[0].layer_id, 0);
assert_eq!(d.adapters[0].matrix, "q_proj");
assert_eq!(d.adapters[1].layer_id, 0);
assert_eq!(d.adapters[1].matrix, "v_proj");
assert_eq!(d.adapters[2].layer_id, 1);
}
#[test]
fn kind_discriminator_matches_variant() {
let lora = ModelDiff::Lora(LoraDelta { adapters: vec![] });
assert_eq!(lora.kind(), DiffKind::Lora);
let ia3 = ModelDiff::Ia3(IA3Delta {
scaling: BTreeMap::new(),
});
assert_eq!(ia3.kind(), DiffKind::Ia3);
let full = ModelDiff::Full(FullDelta {
params: BTreeMap::new(),
});
assert_eq!(full.kind(), DiffKind::Full);
let ttt = ModelDiff::InPlaceTtt(InPlaceTttDelta { steps: vec![] });
assert_eq!(ttt.kind(), DiffKind::InPlaceTtt);
}
#[test]
fn ttt_canonicalize_sorts_by_step_id() {
let mut d = ModelDiff::InPlaceTtt(InPlaceTttDelta {
steps: vec![
TttStep {
step_id: 5,
deltas: BTreeMap::new(),
},
TttStep {
step_id: 1,
deltas: BTreeMap::new(),
},
TttStep {
step_id: 3,
deltas: BTreeMap::new(),
},
],
});
d.validate_and_canonicalize().unwrap();
if let ModelDiff::InPlaceTtt(t) = d {
let ids: Vec<_> = t.steps.iter().map(|s| s.step_id).collect();
assert_eq!(ids, vec![1, 3, 5]);
} else {
panic!("variant changed");
}
}
}