1use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "kebab-case")]
10pub enum DiffKind {
11 Lora,
13 Ia3,
15 Full,
17 InPlaceTtt,
19}
20
21#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
23pub struct LoraAdapter {
24 pub layer_id: u32,
26 pub matrix: String,
28 pub rank: u32,
30 pub in_dim: u32,
32 pub out_dim: u32,
34 pub a: Vec<f32>,
36 pub b: Vec<f32>,
38}
39
40impl LoraAdapter {
41 pub fn validate(&self) -> pf_core::Result<()> {
44 let a_expected = (self.rank as usize) * (self.in_dim as usize);
45 let b_expected = (self.out_dim as usize) * (self.rank as usize);
46 if self.a.len() != a_expected {
47 return Err(pf_core::Error::Integrity(format!(
48 "LoraAdapter L{}/{}: a.len {} ≠ rank·in_dim {}",
49 self.layer_id,
50 self.matrix,
51 self.a.len(),
52 a_expected
53 )));
54 }
55 if self.b.len() != b_expected {
56 return Err(pf_core::Error::Integrity(format!(
57 "LoraAdapter L{}/{}: b.len {} ≠ out_dim·rank {}",
58 self.layer_id,
59 self.matrix,
60 self.b.len(),
61 b_expected
62 )));
63 }
64 Ok(())
65 }
66}
67
68#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70pub struct LoraDelta {
71 pub adapters: Vec<LoraAdapter>,
73}
74
75impl LoraDelta {
76 pub fn canonicalize(&mut self) {
79 self.adapters
80 .sort_by(|a, b| (a.layer_id, &a.matrix).cmp(&(b.layer_id, &b.matrix)));
81 }
82}
83
84#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
92pub struct IA3Delta {
93 pub scaling: BTreeMap<String, BTreeMap<String, Vec<f32>>>,
94}
95
96#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
98pub struct FullDelta {
99 pub params: BTreeMap<String, Vec<f32>>,
101}
102
103#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
105pub struct TttStep {
106 pub step_id: u32,
108 pub deltas: BTreeMap<String, Vec<f32>>,
110}
111
112#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
114pub struct InPlaceTttDelta {
115 pub steps: Vec<TttStep>,
117}
118
119#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
121#[serde(tag = "kind")]
122pub enum ModelDiff {
123 #[serde(rename = "lora")]
125 Lora(LoraDelta),
126 #[serde(rename = "ia3")]
128 Ia3(IA3Delta),
129 #[serde(rename = "full")]
131 Full(FullDelta),
132 #[serde(rename = "in-place-ttt")]
134 InPlaceTtt(InPlaceTttDelta),
135}
136
137impl ModelDiff {
138 #[must_use]
140 pub fn kind(&self) -> DiffKind {
141 match self {
142 Self::Lora(_) => DiffKind::Lora,
143 Self::Ia3(_) => DiffKind::Ia3,
144 Self::Full(_) => DiffKind::Full,
145 Self::InPlaceTtt(_) => DiffKind::InPlaceTtt,
146 }
147 }
148
149 pub fn validate_and_canonicalize(&mut self) -> pf_core::Result<()> {
151 match self {
152 Self::Lora(d) => {
153 d.canonicalize();
154 for a in &d.adapters {
155 a.validate()?;
156 }
157 }
158 Self::Ia3(_) | Self::Full(_) => {
159 }
161 Self::InPlaceTtt(d) => {
162 d.steps.sort_by_key(|s| s.step_id);
163 }
164 }
165 Ok(())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn lora_validate_catches_dim_mismatch() {
175 let bad = LoraAdapter {
176 layer_id: 0,
177 matrix: "q_proj".into(),
178 rank: 4,
179 in_dim: 8,
180 out_dim: 8,
181 a: vec![0.0; 4 * 8],
182 b: vec![0.0; 5], };
184 assert!(bad.validate().is_err());
185 }
186
187 #[test]
188 fn lora_canonicalize_orders_by_layer_then_matrix() {
189 let mut d = LoraDelta {
190 adapters: vec![
191 LoraAdapter {
192 layer_id: 1,
193 matrix: "v_proj".into(),
194 rank: 2,
195 in_dim: 4,
196 out_dim: 4,
197 a: vec![0.0; 8],
198 b: vec![0.0; 8],
199 },
200 LoraAdapter {
201 layer_id: 0,
202 matrix: "v_proj".into(),
203 rank: 2,
204 in_dim: 4,
205 out_dim: 4,
206 a: vec![0.0; 8],
207 b: vec![0.0; 8],
208 },
209 LoraAdapter {
210 layer_id: 0,
211 matrix: "q_proj".into(),
212 rank: 2,
213 in_dim: 4,
214 out_dim: 4,
215 a: vec![0.0; 8],
216 b: vec![0.0; 8],
217 },
218 ],
219 };
220 d.canonicalize();
221 assert_eq!(d.adapters[0].layer_id, 0);
222 assert_eq!(d.adapters[0].matrix, "q_proj");
223 assert_eq!(d.adapters[1].layer_id, 0);
224 assert_eq!(d.adapters[1].matrix, "v_proj");
225 assert_eq!(d.adapters[2].layer_id, 1);
226 }
227
228 #[test]
229 fn kind_discriminator_matches_variant() {
230 let lora = ModelDiff::Lora(LoraDelta { adapters: vec![] });
231 assert_eq!(lora.kind(), DiffKind::Lora);
232 let ia3 = ModelDiff::Ia3(IA3Delta {
233 scaling: BTreeMap::new(),
234 });
235 assert_eq!(ia3.kind(), DiffKind::Ia3);
236 let full = ModelDiff::Full(FullDelta {
237 params: BTreeMap::new(),
238 });
239 assert_eq!(full.kind(), DiffKind::Full);
240 let ttt = ModelDiff::InPlaceTtt(InPlaceTttDelta { steps: vec![] });
241 assert_eq!(ttt.kind(), DiffKind::InPlaceTtt);
242 }
243
244 #[test]
245 fn ttt_canonicalize_sorts_by_step_id() {
246 let mut d = ModelDiff::InPlaceTtt(InPlaceTttDelta {
247 steps: vec![
248 TttStep {
249 step_id: 5,
250 deltas: BTreeMap::new(),
251 },
252 TttStep {
253 step_id: 1,
254 deltas: BTreeMap::new(),
255 },
256 TttStep {
257 step_id: 3,
258 deltas: BTreeMap::new(),
259 },
260 ],
261 });
262 d.validate_and_canonicalize().unwrap();
263 if let ModelDiff::InPlaceTtt(t) = d {
264 let ids: Vec<_> = t.steps.iter().map(|s| s.step_id).collect();
265 assert_eq!(ids, vec![1, 3, 5]);
266 } else {
267 panic!("variant changed");
268 }
269 }
270}