use crate::terms::smooth::TermCollectionSpec;
use crate::warm_start::store::{EntryKind, StoreOptions, WarmStartStore};
use crate::warm_start::{Fingerprint, Fingerprinter};
use std::collections::HashMap;
use std::time::Duration;
const WARM_STATE_SCHEMA: u32 = 1;
const LRU_CAPACITY: usize = 8192;
const DISK_BUDGET_BYTES: u64 = 512 * 1024 * 1024;
const DISK_TTL_SECS: u64 = 14 * 24 * 60 * 60;
#[derive(Debug, Clone, PartialEq)]
pub struct RowWarmState {
pub latent_coords: Vec<f64>,
pub active_set: Vec<u32>,
pub last_inner_iters: u32,
}
impl RowWarmState {
pub fn serialize(&self) -> Vec<u8> {
let mut out = Vec::with_capacity(
4 + 8 + self.latent_coords.len() * 8 + 8 + self.active_set.len() * 4 + 4,
);
out.extend_from_slice(&WARM_STATE_SCHEMA.to_le_bytes());
out.extend_from_slice(&(self.latent_coords.len() as u64).to_le_bytes());
for &c in &self.latent_coords {
let v = if c == 0.0 { 0.0 } else { c };
out.extend_from_slice(&v.to_bits().to_le_bytes());
}
out.extend_from_slice(&(self.active_set.len() as u64).to_le_bytes());
for &a in &self.active_set {
out.extend_from_slice(&a.to_le_bytes());
}
out.extend_from_slice(&self.last_inner_iters.to_le_bytes());
out
}
pub fn deserialize(bytes: &[u8]) -> Option<Self> {
let mut off = 0usize;
let take = |off: &mut usize, n: usize| -> Option<&[u8]> {
let end = off.checked_add(n)?;
if end > bytes.len() {
return None;
}
let s = &bytes[*off..end];
*off = end;
Some(s)
};
let schema = u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?);
if schema != WARM_STATE_SCHEMA {
return None;
}
let n_coords = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?) as usize;
let mut latent_coords = Vec::with_capacity(n_coords);
for _ in 0..n_coords {
let bits = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?);
latent_coords.push(f64::from_bits(bits));
}
let n_active = u64::from_le_bytes(take(&mut off, 8)?.try_into().ok()?) as usize;
let mut active_set = Vec::with_capacity(n_active);
for _ in 0..n_active {
active_set.push(u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?));
}
let last_inner_iters = u32::from_le_bytes(take(&mut off, 4)?.try_into().ok()?);
if off != bytes.len() {
return None;
}
Some(Self {
latent_coords,
active_set,
last_inner_iters,
})
}
}
pub trait RowWarmCache {
fn get(&mut self, row_id: u64) -> Option<RowWarmState>;
fn put(&mut self, row_id: u64, state: &RowWarmState);
}
struct LruEntry {
row_id: u64,
state: RowWarmState,
stamp: u64,
}
pub struct DiskRowWarmCache {
structural_hash: u64,
lru: HashMap<u64, LruEntry>,
stamp: u64,
store: Option<WarmStartStore>,
}
impl DiskRowWarmCache {
pub fn new(spec: &TermCollectionSpec) -> Self {
let mut fp = Fingerprinter::new();
fp.write_str("sae-corpus-row-warm-state-v1");
spec.write_structural_shape_hash(&mut fp);
let structural_hash = fingerprint_to_u64(&fp.finalize());
let store = Self::open_store();
Self {
structural_hash,
lru: HashMap::new(),
stamp: 0,
store,
}
}
fn open_store() -> Option<WarmStartStore> {
let root = std::env::temp_dir()
.join("gam")
.join("sae_corpus_warm")
.join("v1");
WarmStartStore::open(
root,
StoreOptions {
size_budget_bytes: DISK_BUDGET_BYTES,
ttl: Duration::from_secs(DISK_TTL_SECS),
},
)
.ok()
}
fn row_fingerprint(&self, row_id: u64) -> Fingerprint {
let mut fp = Fingerprinter::new();
fp.write_str("sae-corpus-row-warm-state-key-v1");
fp.write_u64(self.structural_hash);
fp.write_u64(row_id);
fp.finalize()
}
#[inline]
fn lru_key(&self, row_id: u64) -> u64 {
fingerprint_to_u64(&self.row_fingerprint(row_id))
}
fn evict_if_full(&mut self) {
if self.lru.len() <= LRU_CAPACITY {
return;
}
if let Some((&victim, _)) = self.lru.iter().min_by_key(|(_, e)| e.stamp) {
self.lru.remove(&victim);
}
}
}
impl RowWarmCache for DiskRowWarmCache {
fn get(&mut self, row_id: u64) -> Option<RowWarmState> {
let key = self.lru_key(row_id);
if let Some(entry) = self.lru.get_mut(&key) {
if entry.row_id == row_id {
self.stamp += 1;
entry.stamp = self.stamp;
return Some(entry.state.clone());
}
}
let store = self.store.as_ref()?;
let fp = self.row_fingerprint(row_id);
let cached = store.lookup(&fp).ok().flatten()?;
let state = RowWarmState::deserialize(&cached.payload)?;
self.stamp += 1;
self.lru.insert(
key,
LruEntry {
row_id,
state: state.clone(),
stamp: self.stamp,
},
);
self.evict_if_full();
Some(state)
}
fn put(&mut self, row_id: u64, state: &RowWarmState) {
let key = self.lru_key(row_id);
self.stamp += 1;
self.lru.insert(
key,
LruEntry {
row_id,
state: state.clone(),
stamp: self.stamp,
},
);
self.evict_if_full();
if let Some(store) = self.store.as_ref() {
let payload = state.serialize();
let fp = self.row_fingerprint(row_id);
store
.save(
&fp,
&payload,
None,
Some(u64::from(state.last_inner_iters)),
EntryKind::Final,
)
.ok();
}
}
}
fn fingerprint_to_u64(fp: &Fingerprint) -> u64 {
let bytes = fp.as_bytes();
let mut acc = 0u64;
for &b in bytes.iter().take(8) {
acc = acc.wrapping_shl(8) ^ u64::from(b);
}
crate::linalg::utils::splitmix64_hash(acc)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_state() -> RowWarmState {
RowWarmState {
latent_coords: vec![1.0, -2.5, 0.0, 3.125],
active_set: vec![0, 4, 9, 17],
last_inner_iters: 3,
}
}
#[test]
fn serialize_round_trips() {
let s = sample_state();
let bytes = s.serialize();
let back = RowWarmState::deserialize(&bytes).expect("decode");
assert_eq!(s, back);
}
#[test]
fn serialize_is_bit_deterministic() {
let a = RowWarmState {
latent_coords: vec![-0.0, 1.0],
active_set: vec![2],
last_inner_iters: 1,
};
let b = RowWarmState {
latent_coords: vec![0.0, 1.0],
active_set: vec![2],
last_inner_iters: 1,
};
assert_eq!(a.serialize(), b.serialize());
assert_eq!(a.serialize(), a.serialize());
}
#[test]
fn deserialize_rejects_wrong_schema() {
let mut bytes = sample_state().serialize();
bytes[0] ^= 0xFF;
assert!(RowWarmState::deserialize(&bytes).is_none());
}
#[test]
fn deserialize_rejects_trailing_garbage() {
let mut bytes = sample_state().serialize();
bytes.push(0u8);
assert!(RowWarmState::deserialize(&bytes).is_none());
}
#[test]
fn deserialize_rejects_truncation() {
let bytes = sample_state().serialize();
assert!(RowWarmState::deserialize(&bytes[..bytes.len() - 2]).is_none());
}
}