Skip to main content

pf_model/
diff.rs

1// SPDX-License-Identifier: MIT
2//! Typed weight-diff payloads for the four supported diff kinds.
3
4use serde::{Deserialize, Serialize};
5use std::collections::BTreeMap;
6
7/// Discriminator tag — useful for API consumers that just want the kind.
8#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "kebab-case")]
10pub enum DiffKind {
11    /// Low-rank adapters (LoRA).
12    Lora,
13    /// IA³ per-head scaling vectors.
14    Ia3,
15    /// Dense full-finetune delta.
16    Full,
17    /// In-place test-time training trace.
18    InPlaceTtt,
19}
20
21/// One LoRA adapter for one matrix in one layer.
22#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
23pub struct LoraAdapter {
24    /// Layer index.
25    pub layer_id: u32,
26    /// Which matrix this adapter targets (e.g. `"q_proj"`, `"v_proj"`).
27    pub matrix: String,
28    /// Adapter rank (= shared inner dim of A and B).
29    pub rank: u32,
30    /// Input dimension (= columns of A, rows of the original matrix).
31    pub in_dim: u32,
32    /// Output dimension (= rows of B, rows of the original matrix).
33    pub out_dim: u32,
34    /// `A` matrix, shape `[rank, in_dim]`, row-major.
35    pub a: Vec<f32>,
36    /// `B` matrix, shape `[out_dim, rank]`, row-major.
37    pub b: Vec<f32>,
38}
39
40impl LoraAdapter {
41    /// Verify the declared dimensions match the supplied vectors. Cheap;
42    /// always called by `store_diff` before sealing.
43    pub fn validate(&self) -> pf_core::Result<()> {
44        let a_expected = (self.rank as usize) * (self.in_dim as usize);
45        let b_expected = (self.out_dim as usize) * (self.rank as usize);
46        if self.a.len() != a_expected {
47            return Err(pf_core::Error::Integrity(format!(
48                "LoraAdapter L{}/{}: a.len {} ≠ rank·in_dim {}",
49                self.layer_id,
50                self.matrix,
51                self.a.len(),
52                a_expected
53            )));
54        }
55        if self.b.len() != b_expected {
56            return Err(pf_core::Error::Integrity(format!(
57                "LoraAdapter L{}/{}: b.len {} ≠ out_dim·rank {}",
58                self.layer_id,
59                self.matrix,
60                self.b.len(),
61                b_expected
62            )));
63        }
64        Ok(())
65    }
66}
67
68/// LoRA diff: a list of per-matrix adapters.
69#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
70pub struct LoraDelta {
71    /// Adapters sorted by `(layer_id, matrix)` for deterministic digests.
72    pub adapters: Vec<LoraAdapter>,
73}
74
75impl LoraDelta {
76    /// Sort adapters into canonical order so the diff's serialized digest
77    /// is invariant w.r.t. caller iteration order.
78    pub fn canonicalize(&mut self) {
79        self.adapters
80            .sort_by(|a, b| (a.layer_id, &a.matrix).cmp(&(b.layer_id, &b.matrix)));
81    }
82}
83
84/// IA³ diff: per-layer per-matrix scaling vector.
85///
86/// Outer key: layer id, encoded as a base-10 string (e.g. `"0"`, `"31"`).
87/// Stored as a `String` because JSON object keys are always strings; using
88/// `String` here keeps the wire format trivially round-trippable.
89/// Inner key: matrix name.
90/// Value: scaling vector (length = head_dim).
91#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
92pub struct IA3Delta {
93    pub scaling: BTreeMap<String, BTreeMap<String, Vec<f32>>>,
94}
95
96/// Full-finetune diff: dense per-parameter delta tensors.
97#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
98pub struct FullDelta {
99    /// Map from canonical parameter name to its dense delta.
100    pub params: BTreeMap<String, Vec<f32>>,
101}
102
103/// One in-place TTT step.
104#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
105pub struct TttStep {
106    /// Step counter (0-based).
107    pub step_id: u32,
108    /// Per-parameter delta applied at this step.
109    pub deltas: BTreeMap<String, Vec<f32>>,
110}
111
112/// In-place TTT diff: an ordered trace of training steps.
113#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
114pub struct InPlaceTttDelta {
115    /// Steps in causal order.
116    pub steps: Vec<TttStep>,
117}
118
119/// Top-level diff payload. Wire format `model.diff.v1`.
120#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
121#[serde(tag = "kind")]
122pub enum ModelDiff {
123    /// LoRA / low-rank adapters.
124    #[serde(rename = "lora")]
125    Lora(LoraDelta),
126    /// IA³ scaling vectors.
127    #[serde(rename = "ia3")]
128    Ia3(IA3Delta),
129    /// Dense full-finetune delta.
130    #[serde(rename = "full")]
131    Full(FullDelta),
132    /// In-place test-time training trace.
133    #[serde(rename = "in-place-ttt")]
134    InPlaceTtt(InPlaceTttDelta),
135}
136
137impl ModelDiff {
138    /// Discriminator.
139    #[must_use]
140    pub fn kind(&self) -> DiffKind {
141        match self {
142            Self::Lora(_) => DiffKind::Lora,
143            Self::Ia3(_) => DiffKind::Ia3,
144            Self::Full(_) => DiffKind::Full,
145            Self::InPlaceTtt(_) => DiffKind::InPlaceTtt,
146        }
147    }
148
149    /// Validate internal invariants and put the diff into canonical order.
150    pub fn validate_and_canonicalize(&mut self) -> pf_core::Result<()> {
151        match self {
152            Self::Lora(d) => {
153                d.canonicalize();
154                for a in &d.adapters {
155                    a.validate()?;
156                }
157            }
158            Self::Ia3(_) | Self::Full(_) => {
159                // BTreeMap is already canonical.
160            }
161            Self::InPlaceTtt(d) => {
162                d.steps.sort_by_key(|s| s.step_id);
163            }
164        }
165        Ok(())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn lora_validate_catches_dim_mismatch() {
175        let bad = LoraAdapter {
176            layer_id: 0,
177            matrix: "q_proj".into(),
178            rank: 4,
179            in_dim: 8,
180            out_dim: 8,
181            a: vec![0.0; 4 * 8],
182            b: vec![0.0; 5], // wrong
183        };
184        assert!(bad.validate().is_err());
185    }
186
187    #[test]
188    fn lora_canonicalize_orders_by_layer_then_matrix() {
189        let mut d = LoraDelta {
190            adapters: vec![
191                LoraAdapter {
192                    layer_id: 1,
193                    matrix: "v_proj".into(),
194                    rank: 2,
195                    in_dim: 4,
196                    out_dim: 4,
197                    a: vec![0.0; 8],
198                    b: vec![0.0; 8],
199                },
200                LoraAdapter {
201                    layer_id: 0,
202                    matrix: "v_proj".into(),
203                    rank: 2,
204                    in_dim: 4,
205                    out_dim: 4,
206                    a: vec![0.0; 8],
207                    b: vec![0.0; 8],
208                },
209                LoraAdapter {
210                    layer_id: 0,
211                    matrix: "q_proj".into(),
212                    rank: 2,
213                    in_dim: 4,
214                    out_dim: 4,
215                    a: vec![0.0; 8],
216                    b: vec![0.0; 8],
217                },
218            ],
219        };
220        d.canonicalize();
221        assert_eq!(d.adapters[0].layer_id, 0);
222        assert_eq!(d.adapters[0].matrix, "q_proj");
223        assert_eq!(d.adapters[1].layer_id, 0);
224        assert_eq!(d.adapters[1].matrix, "v_proj");
225        assert_eq!(d.adapters[2].layer_id, 1);
226    }
227
228    #[test]
229    fn kind_discriminator_matches_variant() {
230        let lora = ModelDiff::Lora(LoraDelta { adapters: vec![] });
231        assert_eq!(lora.kind(), DiffKind::Lora);
232        let ia3 = ModelDiff::Ia3(IA3Delta {
233            scaling: BTreeMap::new(),
234        });
235        assert_eq!(ia3.kind(), DiffKind::Ia3);
236        let full = ModelDiff::Full(FullDelta {
237            params: BTreeMap::new(),
238        });
239        assert_eq!(full.kind(), DiffKind::Full);
240        let ttt = ModelDiff::InPlaceTtt(InPlaceTttDelta { steps: vec![] });
241        assert_eq!(ttt.kind(), DiffKind::InPlaceTtt);
242    }
243
244    #[test]
245    fn ttt_canonicalize_sorts_by_step_id() {
246        let mut d = ModelDiff::InPlaceTtt(InPlaceTttDelta {
247            steps: vec![
248                TttStep {
249                    step_id: 5,
250                    deltas: BTreeMap::new(),
251                },
252                TttStep {
253                    step_id: 1,
254                    deltas: BTreeMap::new(),
255                },
256                TttStep {
257                    step_id: 3,
258                    deltas: BTreeMap::new(),
259                },
260            ],
261        });
262        d.validate_and_canonicalize().unwrap();
263        if let ModelDiff::InPlaceTtt(t) = d {
264            let ids: Vec<_> = t.steps.iter().map(|s| s.step_id).collect();
265            assert_eq!(ids, vec![1, 3, 5]);
266        } else {
267            panic!("variant changed");
268        }
269    }
270}