gam_sae/corpus/warm_state.rs
1//! Disk-backed per-row warm-state cache for the streaming SAE inner solve
2//! (#973).
3//!
4//! # The 3-vs-30 economics
5//!
6//! Each corpus row's SAE code is found by an inner solve (latent coords + an
7//! active set of dictionary atoms). Cold, that solve takes ~30 inner
8//! iterations. But across outer ρ passes (and across resumed runs) the *same
9//! row* is solved again and again, and its solution barely moves between
10//! neighbouring ρ values. Seeding the next solve from the previous one's latent
11//! coords + active set cuts it to ~3 iterations. Over a multi-million-row
12//! corpus that is the difference between a tractable and an intractable fit.
13//!
14//! This module persists that per-row warm start so the economics survive both
15//! the outer ρ loop *and* a process restart (SIGKILL-resume), keyed so that a
16//! warm start can never be applied to a structurally different model.
17//!
18//! # Keying
19//!
20//! The cache key is `(row_id, TermCollectionSpec structural hash)`. The
21//! structural hash is computed **the same way the persistent warm-start cache
22//! already does it** — via
23//! [`gam_terms::smooth::TermCollectionSpec::write_structural_shape_hash`]
24//! (#869), the topology-aware shape hash — so a sphere-vs-torus-vs-euclidean
25//! candidate on the same data gets a *distinct* per-row warm-start keyspace and
26//! the candidates never cross-seed each other with geometrically incompatible
27//! coords. We hash that shape into a [`gam_runtime::warm_start::Fingerprinter`] together
28//! with the `row_id`, matching the existing warm-start key derivation byte
29//! framing.
30//!
31//! # Storage
32//!
33//! The on-disk tier reuses [`gam_runtime::warm_start::WarmStartStore`] (tmp-file + fsync +
34//! rename writes, per-entry checksums, bounded size + TTL eviction) so we
35//! inherit its crash-safety and disk-budget guarantees for free. In front of
36//! it sits a bounded in-process LRU keyed by the same fingerprint, so the hot
37//! rows of the current batch never round-trip to disk. The serialized payload
38//! is **bit-deterministic**: a fixed-layout little-endian encoding (no
39//! `HashMap` iteration, no float formatting), so the same warm state always
40//! hashes/round-trips identically.
41
42use gam_terms::smooth::TermCollectionSpec;
43use gam_runtime::warm_start::store::{EntryKind, StoreOptions, WarmStartStore};
44use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
45use std::collections::HashMap;
46use std::time::Duration;
47
48/// On-disk payload schema tag. Bump on any layout change so stale entries are
49/// rejected rather than mis-decoded.
50const WARM_STATE_SCHEMA: u32 = 1;
51
52/// In-process LRU capacity (entries). Auto-derived; bounds resident warm-state
53/// memory regardless of corpus size. Not a CLI knob.
54const LRU_CAPACITY: usize = 8192;
55
56/// On-disk budget for the per-row warm-state tier.
57const DISK_BUDGET_BYTES: u64 = 512 * 1024 * 1024;
58/// Disk TTL: warm states older than this are evicted on save.
59const DISK_TTL_SECS: u64 = 14 * 24 * 60 * 60;
60
61/// A serialized inner-solve warm start for one corpus row.
62///
63/// `latent_coords` are the SAE latent coordinates; `active_set` is the indices
64/// of the dictionary atoms that were active at the previous solution. Together
65/// they let the next solve start ~3 iterations from convergence instead of ~30.
66#[derive(Debug, Clone, PartialEq)]
67pub struct RowWarmState {
68 pub latent_coords: Vec<f64>,
69 pub active_set: Vec<u32>,
70 /// Inner iteration count the previous solve converged in. Carried for
71 /// diagnostics / adaptive scheduling; does not affect the seed itself.
72 pub last_inner_iters: u32,
73}
74
75impl RowWarmState {
76 /// Bit-deterministic little-endian serialization.
77 ///
78 /// Layout: `schema(u32) | n_coords(u64) | coords[f64 LE]… | n_active(u64) |
79 /// active[u32 LE]… | last_inner_iters(u32)`. No map iteration, no float
80 /// formatting — the bytes are a pure function of the value, so two equal
81 /// `RowWarmState`s always serialize identically.
82 pub fn serialize(&self) -> Vec<u8> {
83 let mut out = Vec::with_capacity(
84 4 + 8 + self.latent_coords.len() * 8 + 8 + self.active_set.len() * 4 + 4,
85 );
86 out.extend_from_slice(&WARM_STATE_SCHEMA.to_le_bytes());
87 out.extend_from_slice(&(self.latent_coords.len() as u64).to_le_bytes());
88 for &c in &self.latent_coords {
89 // Normalize signed zero so two arithmetically-equal seeds serialize
90 // byte-identically (matches the Fingerprinter::write_f64 contract).
91 let v = if c == 0.0 { 0.0 } else { c };
92 out.extend_from_slice(&v.to_bits().to_le_bytes());
93 }
94 out.extend_from_slice(&(self.active_set.len() as u64).to_le_bytes());
95 for &a in &self.active_set {
96 out.extend_from_slice(&a.to_le_bytes());
97 }
98 out.extend_from_slice(&self.last_inner_iters.to_le_bytes());
99 out
100 }
101
102 /// Inverse of [`Self::serialize`]; returns `None` on schema mismatch,
103 /// truncation, or trailing garbage.
104 pub fn deserialize(bytes: &[u8]) -> Option<Self> {
105 let mut off = 0usize;
106 let take = |off: &mut usize, n: usize| -> Option<&[u8]> {
107 let end = off.checked_add(n)?;
108 if end > bytes.len() {
109 return None;
110 }
111 let s = &bytes[*off..end];
112 *off = end;
113 Some(s)
114 };
115 let schema = u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?);
116 if schema != WARM_STATE_SCHEMA {
117 return None;
118 }
119 let n_coords = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?) as usize;
120 let mut latent_coords = Vec::with_capacity(n_coords);
121 for _ in 0..n_coords {
122 let bits = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?);
123 latent_coords.push(f64::from_bits(bits));
124 }
125 let n_active = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?) as usize;
126 let mut active_set = Vec::with_capacity(n_active);
127 for _ in 0..n_active {
128 active_set.push(u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?));
129 }
130 let last_inner_iters = u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?);
131 if off != bytes.len() {
132 // Trailing bytes => corrupt / wrong-schema payload.
133 return None;
134 }
135 Some(Self {
136 latent_coords,
137 active_set,
138 last_inner_iters,
139 })
140 }
141}
142
143/// The per-row warm-state cache seam.
144///
145/// This is the second half of the driver seam ([`super::shard_reader::CorpusRowSource`]
146/// is the first). The streaming SAE term will, per row: `get` a seed, run the
147/// inner solve from it, then `put` the refined state back.
148pub trait RowWarmCache {
149 /// Fetch the warm state for `row_id` under this cache's bound topology, if
150 /// present (in-process LRU first, then disk).
151 fn get(&mut self, row_id: u64) -> Option<RowWarmState>;
152 /// Store / overwrite the warm state for `row_id`.
153 fn put(&mut self, row_id: u64, state: &RowWarmState);
154}
155
156/// A bounded in-process LRU node.
157struct LruEntry {
158 /// The canonical row id this entry belongs to; used to detect the (rare)
159 /// u64-key collision so a colliding lookup falls through to the disk tier
160 /// rather than returning a wrong-row seed.
161 row_id: u64,
162 state: RowWarmState,
163 /// Monotonic stamp for LRU ordering (highest = most recently used).
164 stamp: u64,
165}
166
167/// mmap/LRU-backed disk cache implementing [`RowWarmCache`].
168///
169/// Bound to one `TermCollectionSpec` structural hash at construction; every key
170/// folds that hash so two topologies cannot collide.
171pub struct DiskRowWarmCache {
172 /// Per-topology structural hash, folded into every row key.
173 structural_hash: u64,
174 /// Bounded in-process LRU over `Fingerprint`-equivalent row keys.
175 lru: HashMap<u64, LruEntry>,
176 stamp: u64,
177 /// On-disk tier (None if the cache directory is unwritable; the cache then
178 /// degrades to in-process-LRU-only without erroring).
179 store: Option<WarmStartStore>,
180}
181
182impl DiskRowWarmCache {
183 /// Construct a cache bound to `spec`'s structural topology, deriving the
184 /// per-topology structural hash the same way the existing warm-start cache
185 /// does (#869) — via `write_structural_shape_hash`.
186 pub fn new(spec: &TermCollectionSpec) -> Self {
187 let mut fp = Fingerprinter::new();
188 fp.write_str("sae-corpus-row-warm-state-v1");
189 spec.write_structural_shape_hash(&mut fp);
190 let structural_hash = fingerprint_to_u64(&fp.finalize());
191 let store = Self::open_store();
192 Self {
193 structural_hash,
194 lru: HashMap::new(),
195 stamp: 0,
196 store,
197 }
198 }
199
200 /// Anchor the disk tier under the platform temp directory, mirroring the
201 /// persistent-warm-start root resolution (which avoids the banned
202 /// `env::var` path that `dirs::cache_dir()` would take).
203 fn open_store() -> Option<WarmStartStore> {
204 let root = std::env::temp_dir()
205 .join("gam")
206 .join("sae_corpus_warm")
207 .join("v1");
208 WarmStartStore::open(
209 root,
210 StoreOptions {
211 size_budget_bytes: DISK_BUDGET_BYTES,
212 ttl: Duration::from_secs(DISK_TTL_SECS),
213 },
214 )
215 .ok()
216 }
217
218 /// Compose the full disk/LRU key for a row under this cache's topology.
219 ///
220 /// Folds the schema tag, the per-topology structural hash, and the row id
221 /// into one fingerprint, matching the existing warm-start key framing
222 /// (length-prefixed `write_*` calls on a `Fingerprinter`). The `u64`
223 /// reduction keys the in-process LRU; the full `Fingerprint` keys disk.
224 fn row_fingerprint(&self, row_id: u64) -> Fingerprint {
225 let mut fp = Fingerprinter::new();
226 fp.write_str("sae-corpus-row-warm-state-key-v1");
227 fp.write_u64(self.structural_hash);
228 fp.write_u64(row_id);
229 fp.finalize()
230 }
231
232 #[inline]
233 fn lru_key(&self, row_id: u64) -> u64 {
234 fingerprint_to_u64(&self.row_fingerprint(row_id))
235 }
236
237 /// Evict the least-recently-used LRU entry when over capacity.
238 fn evict_if_full(&mut self) {
239 if self.lru.len() <= LRU_CAPACITY {
240 return;
241 }
242 if let Some((&victim, _)) = self.lru.iter().min_by_key(|(_, e)| e.stamp) {
243 self.lru.remove(&victim);
244 }
245 }
246}
247
248impl RowWarmCache for DiskRowWarmCache {
249 fn get(&mut self, row_id: u64) -> Option<RowWarmState> {
250 let key = self.lru_key(row_id);
251 // Hot path: in-process LRU. Guard against the (rare) u64-key collision
252 // by checking the stored row_id; a collision falls through to the disk
253 // tier which always re-checks the full Fingerprint.
254 if let Some(entry) = self.lru.get_mut(&key) {
255 if entry.row_id == row_id {
256 self.stamp += 1;
257 entry.stamp = self.stamp;
258 return Some(entry.state.clone());
259 }
260 // Collision: drop through to disk.
261 }
262 // Cold path: disk tier. Decode, then promote into the LRU (overwriting
263 // any colliding entry; the evicted entry's row_id diverges from key so
264 // a subsequent get for the displaced row will also fall through to disk
265 // — correct, just slower).
266 let store = self.store.as_ref()?;
267 let fp = self.row_fingerprint(row_id);
268 let cached = store.lookup(&fp).ok().flatten()?;
269 let state = RowWarmState::deserialize(&cached.payload)?;
270 self.stamp += 1;
271 self.lru.insert(
272 key,
273 LruEntry {
274 row_id,
275 state: state.clone(),
276 stamp: self.stamp,
277 },
278 );
279 self.evict_if_full();
280 Some(state)
281 }
282
283 fn put(&mut self, row_id: u64, state: &RowWarmState) {
284 let key = self.lru_key(row_id);
285 self.stamp += 1;
286 self.lru.insert(
287 key,
288 LruEntry {
289 row_id,
290 state: state.clone(),
291 stamp: self.stamp,
292 },
293 );
294 self.evict_if_full();
295 // Write-through to disk so the seed survives a process restart. A
296 // failed disk write is non-fatal: the LRU still holds the seed for the
297 // current run. `iteration` carries the converged inner-iter count for
298 // disk-side diagnostics; `Final` kind marks a converged seed.
299 if let Some(store) = self.store.as_ref() {
300 let payload = state.serialize();
301 let fp = self.row_fingerprint(row_id);
302 store
303 .save(
304 &fp,
305 &payload,
306 None,
307 Some(u64::from(state.last_inner_iters)),
308 EntryKind::Final,
309 )
310 .ok();
311 }
312 }
313}
314
315/// Reduce a 32-byte [`Fingerprint`] to a `u64` LRU bucket key by folding its
316/// raw leading bytes. Collisions in the `u64` space are harmless: the
317/// disk tier always re-checks the full `Fingerprint`, and an LRU bucket
318/// collision only risks a spurious in-process miss (then a correct disk hit),
319/// never a wrong-row seed.
320fn fingerprint_to_u64(fp: &Fingerprint) -> u64 {
321 // Take the first 8 raw bytes of the 32-byte fingerprint and assemble them
322 // into a u64. Using raw bytes (not the hex-string ASCII representation)
323 // gives full 8-bit entropy per lane rather than the biased 4-bit range
324 // that hex ASCII digits occupy (0x30-0x39, 0x61-0x66).
325 let bytes = fp.as_bytes();
326 let mut acc = 0u64;
327 for &b in bytes.iter().take(8) {
328 acc = acc.wrapping_shl(8) ^ u64::from(b);
329 }
330 // Mix so adjacent row ids (which share a long key prefix) spread across
331 // buckets. Reuses the canonical splitmix64 finalizer.
332 gam_linalg::utils::splitmix64_hash(acc)
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 fn sample_state() -> RowWarmState {
340 RowWarmState {
341 latent_coords: vec![1.0, -2.5, 0.0, 3.125],
342 active_set: vec![0, 4, 9, 17],
343 last_inner_iters: 3,
344 }
345 }
346
347 #[test]
348 fn serialize_round_trips() {
349 let s = sample_state();
350 let bytes = s.serialize();
351 let back = RowWarmState::deserialize(&bytes).expect("decode");
352 assert_eq!(s, back);
353 }
354
355 #[test]
356 fn serialize_is_bit_deterministic() {
357 // -0.0 normalizes to +0.0 so two arithmetically-equal seeds match.
358 let a = RowWarmState {
359 latent_coords: vec![-0.0, 1.0],
360 active_set: vec![2],
361 last_inner_iters: 1,
362 };
363 let b = RowWarmState {
364 latent_coords: vec![0.0, 1.0],
365 active_set: vec![2],
366 last_inner_iters: 1,
367 };
368 assert_eq!(a.serialize(), b.serialize());
369 // Re-serializing yields identical bytes.
370 assert_eq!(a.serialize(), a.serialize());
371 }
372
373 #[test]
374 fn deserialize_rejects_wrong_schema() {
375 let mut bytes = sample_state().serialize();
376 bytes[0] ^= 0xFF;
377 assert!(RowWarmState::deserialize(&bytes).is_none());
378 }
379
380 #[test]
381 fn deserialize_rejects_trailing_garbage() {
382 let mut bytes = sample_state().serialize();
383 bytes.push(0u8);
384 assert!(RowWarmState::deserialize(&bytes).is_none());
385 }
386
387 #[test]
388 fn deserialize_rejects_truncation() {
389 let bytes = sample_state().serialize();
390 assert!(RowWarmState::deserialize(&bytes[..bytes.len() - 2]).is_none());
391 }
392}