1use crate::diff::ModelDiff;
10use pf_core::cas::BlobStore;
11use pf_core::digest::Digest256;
12use serde::{Deserialize, Serialize};
13
14const LAYOUT: &str = "model.diff.v1";
15
16#[derive(Serialize, Deserialize)]
17struct Envelope {
18 layout: String,
19 diff: ModelDiff,
20}
21
22pub fn store_diff(blobs: &dyn BlobStore, mut diff: ModelDiff) -> pf_core::Result<Digest256> {
25 diff.validate_and_canonicalize()?;
26 let env = Envelope {
27 layout: LAYOUT.into(),
28 diff,
29 };
30 blobs.put(&serde_json::to_vec(&env)?)
31}
32
33pub fn load_diff(blobs: &dyn BlobStore, digest: &Digest256) -> pf_core::Result<ModelDiff> {
35 let bytes = blobs.get(digest)?;
36 let env: Envelope = serde_json::from_slice(&bytes)?;
37 if env.layout != LAYOUT {
38 return Err(pf_core::Error::Integrity(format!(
39 "expected layout {LAYOUT}, got {}",
40 env.layout
41 )));
42 }
43 Ok(env.diff)
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49 use crate::diff::{FullDelta, IA3Delta, InPlaceTttDelta, LoraAdapter, LoraDelta, TttStep};
50 use pf_core::cas::MemBlobStore;
51 use std::collections::BTreeMap;
52
53 #[test]
54 fn lora_round_trip() {
55 let blobs = MemBlobStore::new();
56 let d = ModelDiff::Lora(LoraDelta {
57 adapters: vec![LoraAdapter {
58 layer_id: 0,
59 matrix: "q_proj".into(),
60 rank: 2,
61 in_dim: 4,
62 out_dim: 4,
63 a: vec![1.0; 8],
64 b: vec![2.0; 8],
65 }],
66 });
67 let cid = store_diff(&blobs, d.clone()).unwrap();
68 let back = load_diff(&blobs, &cid).unwrap();
69 assert_eq!(back, d);
70 }
71
72 #[test]
73 fn ia3_round_trip() {
74 let blobs = MemBlobStore::new();
75 let mut s = BTreeMap::new();
76 let mut inner = BTreeMap::new();
77 inner.insert("k_proj".to_owned(), vec![0.5_f32, 1.5_f32]);
78 s.insert("0".to_owned(), inner);
79 let d = ModelDiff::Ia3(IA3Delta { scaling: s });
80 let cid = store_diff(&blobs, d.clone()).unwrap();
81 let back = load_diff(&blobs, &cid).unwrap();
82 assert_eq!(back, d);
83 }
84
85 #[test]
86 fn full_round_trip() {
87 let blobs = MemBlobStore::new();
88 let mut p = BTreeMap::new();
89 p.insert("layer_0/q_proj".to_owned(), vec![0.1_f32, 0.2_f32, 0.3_f32]);
90 let d = ModelDiff::Full(FullDelta { params: p });
91 let cid = store_diff(&blobs, d.clone()).unwrap();
92 let back = load_diff(&blobs, &cid).unwrap();
93 assert_eq!(back, d);
94 }
95
96 #[test]
97 fn ttt_round_trip_in_canonical_order() {
98 let blobs = MemBlobStore::new();
99 let mut step_a = TttStep {
100 step_id: 2,
101 deltas: BTreeMap::new(),
102 };
103 step_a.deltas.insert("x".into(), vec![1.0]);
104 let step_b = TttStep {
105 step_id: 1,
106 deltas: BTreeMap::new(),
107 };
108 let d = ModelDiff::InPlaceTtt(InPlaceTttDelta {
109 steps: vec![step_a, step_b],
110 });
111 let cid = store_diff(&blobs, d).unwrap();
112 let back = load_diff(&blobs, &cid).unwrap();
113 if let ModelDiff::InPlaceTtt(t) = back {
114 assert_eq!(t.steps[0].step_id, 1, "canonicalize sorted by step_id");
115 assert_eq!(t.steps[1].step_id, 2);
116 } else {
117 panic!("variant changed");
118 }
119 }
120
121 #[test]
122 fn rejects_wrong_layout_on_load() {
123 let blobs = MemBlobStore::new();
124 let bogus = serde_json::json!({
125 "layout": "model.diff.v9",
126 "diff": { "kind": "lora", "adapters": [] }
127 });
128 let cid = blobs.put(&serde_json::to_vec(&bogus).unwrap()).unwrap();
129 let err = load_diff(&blobs, &cid).unwrap_err();
130 assert!(matches!(err, pf_core::Error::Integrity(_)));
131 }
132
133 #[test]
134 fn rejects_lora_with_wrong_dims_on_store() {
135 let blobs = MemBlobStore::new();
136 let d = ModelDiff::Lora(LoraDelta {
137 adapters: vec![LoraAdapter {
138 layer_id: 0,
139 matrix: "q".into(),
140 rank: 2,
141 in_dim: 4,
142 out_dim: 4,
143 a: vec![0.0; 8],
144 b: vec![0.0; 7], }],
146 });
147 assert!(store_diff(&blobs, d).is_err());
148 }
149}