Skip to main content

pf_model/
serialize.rs

1// SPDX-License-Identifier: MIT
2//! Round-trip every [`ModelDiff`] variant through a [`pf_core::cas::BlobStore`].
3//!
4//! Wire format `model.diff.v1`: a single JSON blob containing a tagged
5//! [`ModelDiff`]. We rely on `serde_tagged` semantics from the enum's
6//! `#[serde(tag = "kind")]` for the discriminator. Validation +
7//! canonicalization happen on store; layout-version check happens on load.
8
9use crate::diff::ModelDiff;
10use pf_core::cas::BlobStore;
11use pf_core::digest::Digest256;
12use serde::{Deserialize, Serialize};
13
14const LAYOUT: &str = "model.diff.v1";
15
16#[derive(Serialize, Deserialize)]
17struct Envelope {
18    layout: String,
19    diff: ModelDiff,
20}
21
22/// Validate, canonicalize, and persist a [`ModelDiff`] into `blobs`. Returns
23/// the digest of the resulting blob.
24pub fn store_diff(blobs: &dyn BlobStore, mut diff: ModelDiff) -> pf_core::Result<Digest256> {
25    diff.validate_and_canonicalize()?;
26    let env = Envelope {
27        layout: LAYOUT.into(),
28        diff,
29    };
30    blobs.put(&serde_json::to_vec(&env)?)
31}
32
33/// Load a [`ModelDiff`] previously written by [`store_diff`].
34pub fn load_diff(blobs: &dyn BlobStore, digest: &Digest256) -> pf_core::Result<ModelDiff> {
35    let bytes = blobs.get(digest)?;
36    let env: Envelope = serde_json::from_slice(&bytes)?;
37    if env.layout != LAYOUT {
38        return Err(pf_core::Error::Integrity(format!(
39            "expected layout {LAYOUT}, got {}",
40            env.layout
41        )));
42    }
43    Ok(env.diff)
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49    use crate::diff::{FullDelta, IA3Delta, InPlaceTttDelta, LoraAdapter, LoraDelta, TttStep};
50    use pf_core::cas::MemBlobStore;
51    use std::collections::BTreeMap;
52
53    #[test]
54    fn lora_round_trip() {
55        let blobs = MemBlobStore::new();
56        let d = ModelDiff::Lora(LoraDelta {
57            adapters: vec![LoraAdapter {
58                layer_id: 0,
59                matrix: "q_proj".into(),
60                rank: 2,
61                in_dim: 4,
62                out_dim: 4,
63                a: vec![1.0; 8],
64                b: vec![2.0; 8],
65            }],
66        });
67        let cid = store_diff(&blobs, d.clone()).unwrap();
68        let back = load_diff(&blobs, &cid).unwrap();
69        assert_eq!(back, d);
70    }
71
72    #[test]
73    fn ia3_round_trip() {
74        let blobs = MemBlobStore::new();
75        let mut s = BTreeMap::new();
76        let mut inner = BTreeMap::new();
77        inner.insert("k_proj".to_owned(), vec![0.5_f32, 1.5_f32]);
78        s.insert("0".to_owned(), inner);
79        let d = ModelDiff::Ia3(IA3Delta { scaling: s });
80        let cid = store_diff(&blobs, d.clone()).unwrap();
81        let back = load_diff(&blobs, &cid).unwrap();
82        assert_eq!(back, d);
83    }
84
85    #[test]
86    fn full_round_trip() {
87        let blobs = MemBlobStore::new();
88        let mut p = BTreeMap::new();
89        p.insert("layer_0/q_proj".to_owned(), vec![0.1_f32, 0.2_f32, 0.3_f32]);
90        let d = ModelDiff::Full(FullDelta { params: p });
91        let cid = store_diff(&blobs, d.clone()).unwrap();
92        let back = load_diff(&blobs, &cid).unwrap();
93        assert_eq!(back, d);
94    }
95
96    #[test]
97    fn ttt_round_trip_in_canonical_order() {
98        let blobs = MemBlobStore::new();
99        let mut step_a = TttStep {
100            step_id: 2,
101            deltas: BTreeMap::new(),
102        };
103        step_a.deltas.insert("x".into(), vec![1.0]);
104        let step_b = TttStep {
105            step_id: 1,
106            deltas: BTreeMap::new(),
107        };
108        let d = ModelDiff::InPlaceTtt(InPlaceTttDelta {
109            steps: vec![step_a, step_b],
110        });
111        let cid = store_diff(&blobs, d).unwrap();
112        let back = load_diff(&blobs, &cid).unwrap();
113        if let ModelDiff::InPlaceTtt(t) = back {
114            assert_eq!(t.steps[0].step_id, 1, "canonicalize sorted by step_id");
115            assert_eq!(t.steps[1].step_id, 2);
116        } else {
117            panic!("variant changed");
118        }
119    }
120
121    #[test]
122    fn rejects_wrong_layout_on_load() {
123        let blobs = MemBlobStore::new();
124        let bogus = serde_json::json!({
125            "layout": "model.diff.v9",
126            "diff": { "kind": "lora", "adapters": [] }
127        });
128        let cid = blobs.put(&serde_json::to_vec(&bogus).unwrap()).unwrap();
129        let err = load_diff(&blobs, &cid).unwrap_err();
130        assert!(matches!(err, pf_core::Error::Integrity(_)));
131    }
132
133    #[test]
134    fn rejects_lora_with_wrong_dims_on_store() {
135        let blobs = MemBlobStore::new();
136        let d = ModelDiff::Lora(LoraDelta {
137            adapters: vec![LoraAdapter {
138                layer_id: 0,
139                matrix: "q".into(),
140                rank: 2,
141                in_dim: 4,
142                out_dim: 4,
143                a: vec![0.0; 8],
144                b: vec![0.0; 7], // wrong
145            }],
146        });
147        assert!(store_diff(&blobs, d).is_err());
148    }
149}