Skip to main content

gam_sae/corpus/
warm_state.rs

1//! Disk-backed per-row warm-state cache for the streaming SAE inner solve
2//! (#973).
3//!
4//! # The 3-vs-30 economics
5//!
6//! Each corpus row's SAE code is found by an inner solve (latent coords + an
7//! active set of dictionary atoms). Cold, that solve takes ~30 inner
8//! iterations. But across outer ρ passes (and across resumed runs) the *same
9//! row* is solved again and again, and its solution barely moves between
10//! neighbouring ρ values. Seeding the next solve from the previous one's latent
11//! coords + active set cuts it to ~3 iterations. Over a multi-million-row
12//! corpus that is the difference between a tractable and an intractable fit.
13//!
14//! This module persists that per-row warm start so the economics survive both
15//! the outer ρ loop *and* a process restart (SIGKILL-resume), keyed so that a
16//! warm start can never be applied to a structurally different model.
17//!
18//! # Keying
19//!
20//! The cache key is `(row_id, TermCollectionSpec structural hash)`. The
21//! structural hash is computed **the same way the persistent warm-start cache
22//! already does it** — via
23//! [`gam_terms::smooth::TermCollectionSpec::write_structural_shape_hash`]
24//! (#869), the topology-aware shape hash — so a sphere-vs-torus-vs-euclidean
25//! candidate on the same data gets a *distinct* per-row warm-start keyspace and
26//! the candidates never cross-seed each other with geometrically incompatible
27//! coords. We hash that shape into a [`gam_runtime::warm_start::Fingerprinter`] together
28//! with the `row_id`, matching the existing warm-start key derivation byte
29//! framing.
30//!
31//! # Storage
32//!
33//! The on-disk tier reuses [`gam_runtime::warm_start::WarmStartStore`] (tmp-file + fsync +
34//! rename writes, per-entry checksums, bounded size + TTL eviction) so we
35//! inherit its crash-safety and disk-budget guarantees for free. In front of
36//! it sits a bounded in-process LRU keyed by the same fingerprint, so the hot
37//! rows of the current batch never round-trip to disk. The serialized payload
38//! is **bit-deterministic**: a fixed-layout little-endian encoding (no
39//! `HashMap` iteration, no float formatting), so the same warm state always
40//! hashes/round-trips identically.
41
42use gam_terms::smooth::TermCollectionSpec;
43use gam_runtime::warm_start::store::{EntryKind, StoreOptions, WarmStartStore};
44use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
45use std::collections::HashMap;
46use std::time::Duration;
47
48/// On-disk payload schema tag. Bump on any layout change so stale entries are
49/// rejected rather than mis-decoded.
50const WARM_STATE_SCHEMA: u32 = 1;
51
52/// In-process LRU capacity (entries). Auto-derived; bounds resident warm-state
53/// memory regardless of corpus size. Not a CLI knob.
54const LRU_CAPACITY: usize = 8192;
55
56/// On-disk budget for the per-row warm-state tier.
57const DISK_BUDGET_BYTES: u64 = 512 * 1024 * 1024;
58/// Disk TTL: warm states older than this are evicted on save.
59const DISK_TTL_SECS: u64 = 14 * 24 * 60 * 60;
60
61/// A serialized inner-solve warm start for one corpus row.
62///
63/// `latent_coords` are the SAE latent coordinates; `active_set` is the indices
64/// of the dictionary atoms that were active at the previous solution. Together
65/// they let the next solve start ~3 iterations from convergence instead of ~30.
66#[derive(Debug, Clone, PartialEq)]
67pub struct RowWarmState {
68    pub latent_coords: Vec<f64>,
69    pub active_set: Vec<u32>,
70    /// Inner iteration count the previous solve converged in. Carried for
71    /// diagnostics / adaptive scheduling; does not affect the seed itself.
72    pub last_inner_iters: u32,
73}
74
75impl RowWarmState {
76    /// Bit-deterministic little-endian serialization.
77    ///
78    /// Layout: `schema(u32) | n_coords(u64) | coords[f64 LE]… | n_active(u64) |
79    /// active[u32 LE]… | last_inner_iters(u32)`. No map iteration, no float
80    /// formatting — the bytes are a pure function of the value, so two equal
81    /// `RowWarmState`s always serialize identically.
82    pub fn serialize(&self) -> Vec<u8> {
83        let mut out = Vec::with_capacity(
84            4 + 8 + self.latent_coords.len() * 8 + 8 + self.active_set.len() * 4 + 4,
85        );
86        out.extend_from_slice(&WARM_STATE_SCHEMA.to_le_bytes());
87        out.extend_from_slice(&(self.latent_coords.len() as u64).to_le_bytes());
88        for &c in &self.latent_coords {
89            // Normalize signed zero so two arithmetically-equal seeds serialize
90            // byte-identically (matches the Fingerprinter::write_f64 contract).
91            let v = if c == 0.0 { 0.0 } else { c };
92            out.extend_from_slice(&v.to_bits().to_le_bytes());
93        }
94        out.extend_from_slice(&(self.active_set.len() as u64).to_le_bytes());
95        for &a in &self.active_set {
96            out.extend_from_slice(&a.to_le_bytes());
97        }
98        out.extend_from_slice(&self.last_inner_iters.to_le_bytes());
99        out
100    }
101
102    /// Inverse of [`Self::serialize`]; returns `None` on schema mismatch,
103    /// truncation, or trailing garbage.
104    pub fn deserialize(bytes: &[u8]) -> Option<Self> {
105        let mut off = 0usize;
106        let take = |off: &mut usize, n: usize| -> Option<&[u8]> {
107            let end = off.checked_add(n)?;
108            if end > bytes.len() {
109                return None;
110            }
111            let s = &bytes[*off..end];
112            *off = end;
113            Some(s)
114        };
115        let schema = u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?);
116        if schema != WARM_STATE_SCHEMA {
117            return None;
118        }
119        let n_coords = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?) as usize;
120        let mut latent_coords = Vec::with_capacity(n_coords);
121        for _ in 0..n_coords {
122            let bits = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?);
123            latent_coords.push(f64::from_bits(bits));
124        }
125        let n_active = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?) as usize;
126        let mut active_set = Vec::with_capacity(n_active);
127        for _ in 0..n_active {
128            active_set.push(u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?));
129        }
130        let last_inner_iters = u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?);
131        if off != bytes.len() {
132            // Trailing bytes => corrupt / wrong-schema payload.
133            return None;
134        }
135        Some(Self {
136            latent_coords,
137            active_set,
138            last_inner_iters,
139        })
140    }
141}
142
143/// The per-row warm-state cache seam.
144///
145/// This is the second half of the driver seam ([`super::shard_reader::CorpusRowSource`]
146/// is the first). The streaming SAE term will, per row: `get` a seed, run the
147/// inner solve from it, then `put` the refined state back.
148pub trait RowWarmCache {
149    /// Fetch the warm state for `row_id` under this cache's bound topology, if
150    /// present (in-process LRU first, then disk).
151    fn get(&mut self, row_id: u64) -> Option<RowWarmState>;
152    /// Store / overwrite the warm state for `row_id`.
153    fn put(&mut self, row_id: u64, state: &RowWarmState);
154}
155
156/// A bounded in-process LRU node.
157struct LruEntry {
158    /// The canonical row id this entry belongs to; used to detect the (rare)
159    /// u64-key collision so a colliding lookup falls through to the disk tier
160    /// rather than returning a wrong-row seed.
161    row_id: u64,
162    state: RowWarmState,
163    /// Monotonic stamp for LRU ordering (highest = most recently used).
164    stamp: u64,
165}
166
167/// mmap/LRU-backed disk cache implementing [`RowWarmCache`].
168///
169/// Bound to one `TermCollectionSpec` structural hash at construction; every key
170/// folds that hash so two topologies cannot collide.
171pub struct DiskRowWarmCache {
172    /// Per-topology structural hash, folded into every row key.
173    structural_hash: u64,
174    /// Bounded in-process LRU over `Fingerprint`-equivalent row keys.
175    lru: HashMap<u64, LruEntry>,
176    stamp: u64,
177    /// On-disk tier (None if the cache directory is unwritable; the cache then
178    /// degrades to in-process-LRU-only without erroring).
179    store: Option<WarmStartStore>,
180}
181
182impl DiskRowWarmCache {
183    /// Construct a cache bound to `spec`'s structural topology, deriving the
184    /// per-topology structural hash the same way the existing warm-start cache
185    /// does (#869) — via `write_structural_shape_hash`.
186    pub fn new(spec: &TermCollectionSpec) -> Self {
187        let mut fp = Fingerprinter::new();
188        fp.write_str("sae-corpus-row-warm-state-v1");
189        spec.write_structural_shape_hash(&mut fp);
190        let structural_hash = fingerprint_to_u64(&fp.finalize());
191        let store = Self::open_store();
192        Self {
193            structural_hash,
194            lru: HashMap::new(),
195            stamp: 0,
196            store,
197        }
198    }
199
200    /// Anchor the disk tier under the platform temp directory, mirroring the
201    /// persistent-warm-start root resolution (which avoids the banned
202    /// `env::var` path that `dirs::cache_dir()` would take).
203    fn open_store() -> Option<WarmStartStore> {
204        let root = std::env::temp_dir()
205            .join("gam")
206            .join("sae_corpus_warm")
207            .join("v1");
208        WarmStartStore::open(
209            root,
210            StoreOptions {
211                size_budget_bytes: DISK_BUDGET_BYTES,
212                ttl: Duration::from_secs(DISK_TTL_SECS),
213            },
214        )
215        .ok()
216    }
217
218    /// Compose the full disk/LRU key for a row under this cache's topology.
219    ///
220    /// Folds the schema tag, the per-topology structural hash, and the row id
221    /// into one fingerprint, matching the existing warm-start key framing
222    /// (length-prefixed `write_*` calls on a `Fingerprinter`). The `u64`
223    /// reduction keys the in-process LRU; the full `Fingerprint` keys disk.
224    fn row_fingerprint(&self, row_id: u64) -> Fingerprint {
225        let mut fp = Fingerprinter::new();
226        fp.write_str("sae-corpus-row-warm-state-key-v1");
227        fp.write_u64(self.structural_hash);
228        fp.write_u64(row_id);
229        fp.finalize()
230    }
231
232    #[inline]
233    fn lru_key(&self, row_id: u64) -> u64 {
234        fingerprint_to_u64(&self.row_fingerprint(row_id))
235    }
236
237    /// Evict the least-recently-used LRU entry when over capacity.
238    fn evict_if_full(&mut self) {
239        if self.lru.len() <= LRU_CAPACITY {
240            return;
241        }
242        if let Some((&victim, _)) = self.lru.iter().min_by_key(|(_, e)| e.stamp) {
243            self.lru.remove(&victim);
244        }
245    }
246}
247
248impl RowWarmCache for DiskRowWarmCache {
249    fn get(&mut self, row_id: u64) -> Option<RowWarmState> {
250        let key = self.lru_key(row_id);
251        // Hot path: in-process LRU. Guard against the (rare) u64-key collision
252        // by checking the stored row_id; a collision falls through to the disk
253        // tier which always re-checks the full Fingerprint.
254        if let Some(entry) = self.lru.get_mut(&key) {
255            if entry.row_id == row_id {
256                self.stamp += 1;
257                entry.stamp = self.stamp;
258                return Some(entry.state.clone());
259            }
260            // Collision: drop through to disk.
261        }
262        // Cold path: disk tier. Decode, then promote into the LRU (overwriting
263        // any colliding entry; the evicted entry's row_id diverges from key so
264        // a subsequent get for the displaced row will also fall through to disk
265        // — correct, just slower).
266        let store = self.store.as_ref()?;
267        let fp = self.row_fingerprint(row_id);
268        let cached = store.lookup(&fp).ok().flatten()?;
269        let state = RowWarmState::deserialize(&cached.payload)?;
270        self.stamp += 1;
271        self.lru.insert(
272            key,
273            LruEntry {
274                row_id,
275                state: state.clone(),
276                stamp: self.stamp,
277            },
278        );
279        self.evict_if_full();
280        Some(state)
281    }
282
283    fn put(&mut self, row_id: u64, state: &RowWarmState) {
284        let key = self.lru_key(row_id);
285        self.stamp += 1;
286        self.lru.insert(
287            key,
288            LruEntry {
289                row_id,
290                state: state.clone(),
291                stamp: self.stamp,
292            },
293        );
294        self.evict_if_full();
295        // Write-through to disk so the seed survives a process restart. A
296        // failed disk write is non-fatal: the LRU still holds the seed for the
297        // current run. `iteration` carries the converged inner-iter count for
298        // disk-side diagnostics; `Final` kind marks a converged seed.
299        if let Some(store) = self.store.as_ref() {
300            let payload = state.serialize();
301            let fp = self.row_fingerprint(row_id);
302            store
303                .save(
304                    &fp,
305                    &payload,
306                    None,
307                    Some(u64::from(state.last_inner_iters)),
308                    EntryKind::Final,
309                )
310                .ok();
311        }
312    }
313}
314
315/// Reduce a 32-byte [`Fingerprint`] to a `u64` LRU bucket key by folding its
316/// raw leading bytes. Collisions in the `u64` space are harmless: the
317/// disk tier always re-checks the full `Fingerprint`, and an LRU bucket
318/// collision only risks a spurious in-process miss (then a correct disk hit),
319/// never a wrong-row seed.
320fn fingerprint_to_u64(fp: &Fingerprint) -> u64 {
321    // Take the first 8 raw bytes of the 32-byte fingerprint and assemble them
322    // into a u64. Using raw bytes (not the hex-string ASCII representation)
323    // gives full 8-bit entropy per lane rather than the biased 4-bit range
324    // that hex ASCII digits occupy (0x30-0x39, 0x61-0x66).
325    let bytes = fp.as_bytes();
326    let mut acc = 0u64;
327    for &b in bytes.iter().take(8) {
328        acc = acc.wrapping_shl(8) ^ u64::from(b);
329    }
330    // Mix so adjacent row ids (which share a long key prefix) spread across
331    // buckets. Reuses the canonical splitmix64 finalizer.
332    gam_linalg::utils::splitmix64_hash(acc)
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    fn sample_state() -> RowWarmState {
340        RowWarmState {
341            latent_coords: vec![1.0, -2.5, 0.0, 3.125],
342            active_set: vec![0, 4, 9, 17],
343            last_inner_iters: 3,
344        }
345    }
346
347    #[test]
348    fn serialize_round_trips() {
349        let s = sample_state();
350        let bytes = s.serialize();
351        let back = RowWarmState::deserialize(&bytes).expect("decode");
352        assert_eq!(s, back);
353    }
354
355    #[test]
356    fn serialize_is_bit_deterministic() {
357        // -0.0 normalizes to +0.0 so two arithmetically-equal seeds match.
358        let a = RowWarmState {
359            latent_coords: vec![-0.0, 1.0],
360            active_set: vec![2],
361            last_inner_iters: 1,
362        };
363        let b = RowWarmState {
364            latent_coords: vec![0.0, 1.0],
365            active_set: vec![2],
366            last_inner_iters: 1,
367        };
368        assert_eq!(a.serialize(), b.serialize());
369        // Re-serializing yields identical bytes.
370        assert_eq!(a.serialize(), a.serialize());
371    }
372
373    #[test]
374    fn deserialize_rejects_wrong_schema() {
375        let mut bytes = sample_state().serialize();
376        bytes[0] ^= 0xFF;
377        assert!(RowWarmState::deserialize(&bytes).is_none());
378    }
379
380    #[test]
381    fn deserialize_rejects_trailing_garbage() {
382        let mut bytes = sample_state().serialize();
383        bytes.push(0u8);
384        assert!(RowWarmState::deserialize(&bytes).is_none());
385    }
386
387    #[test]
388    fn deserialize_rejects_truncation() {
389        let bytes = sample_state().serialize();
390        assert!(RowWarmState::deserialize(&bytes[..bytes.len() - 2]).is_none());
391    }
392}