Skip to main content

gam_solve/
latent_cache.rs

1//! Persistent cache for latent-coordinate REML design evaluations.
2//!
3//! This follows the same invalidation shape as the design-revision pattern in
4//! `src/terms/smooth.rs`'s `SpatialLogKappa` path (around the
5//! `SpatialLogKappa` cache near line 12805) and the `EvalShared` rho-keyed
6//! cache in `src/solver/reml/mod.rs` (around line 3525).  REML's outer
7//! evaluator is reentrant for each theta: the rho component is already covered
8//! by `EvalShared`, while the design-moving component is fully determined by
9//! the latent fingerprint.  Together, `(rho, latent_fingerprint)` is sufficient
10//! to reuse the realized surface until the caller bumps the design revision or
11//! explicitly invalidates this cache.
12
13use std::collections::hash_map::DefaultHasher;
14use std::collections::{HashMap, VecDeque};
15use std::hash::{Hash, Hasher};
16use std::path::PathBuf;
17use std::sync::{Arc, Mutex, OnceLock};
18
19use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_terms::basis::{DuchonNullspaceOrder, MaternNu, RadialScalarKind};
22use crate::estimate::EstimationError;
23use crate::estimate::reml::DirectionalHyperParam;
24pub use gam_problem::LatentRetractionRegistry;
25use gam_terms::latent::{
26    AuxPriorFamily, AuxPriorStrength, LatentCoordValues, LatentIdMode, LatentManifold,
27};
28use gam_terms::smooth::{TermCollectionDesign, TermCollectionSpec};
29use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
30
31const DEFAULT_LATENT_CACHE_CAPACITY: usize = 4;
32const DEFAULT_PERSISTENT_LATENT_CACHE_CAPACITY: usize = 16;
33const DEFAULT_PERSISTENT_LATENT_CACHE_BYTE_BUDGET: usize = 1024 * 1024 * 1024;
34
35static PERSISTENT_LATENT_DESIGN_CACHE: OnceLock<Mutex<PersistentLatentDesignCache>> =
36    OnceLock::new();
37
38/// O(N) identity summary for a flat latent-coordinate vector.
39#[derive(Clone, Debug)]
40pub(crate) struct LatentFingerprint {
41    pub(crate) hash: u64,
42    pub(crate) len: usize,
43}
44
45impl LatentFingerprint {
46    pub(crate) fn from_flat(flat: &[f64]) -> Self {
47        let mut hasher = DefaultHasher::new();
48        flat.len().hash(&mut hasher);
49        for &value in flat {
50            value.to_bits().hash(&mut hasher);
51        }
52        Self {
53            hash: hasher.finish(),
54            len: flat.len(),
55        }
56    }
57}
58
59pub type CacheDigest = Fingerprint;
60
61/// Open a [`Fingerprinter`] pre-seeded with a length-prefixed namespace
62/// string so different cache-digest call sites cannot alias.
63///
64/// This is a thin convenience wrapper over [`Fingerprinter::write_str`]
65/// — it exists so the call sites read as `cache_digest_builder("…-v1")`
66/// instead of repeating the namespace-framing pattern at every callsite.
67fn cache_digest_builder(namespace: &str) -> Fingerprinter {
68    let mut out = Fingerprinter::new();
69    out.write_str(namespace);
70    out
71}
72
73#[derive(Clone)]
74pub enum LatentBasisKind {
75    // Basis/evaluator family for Phi(t); the per-row latent values live in LatentCoordValues.
76    Matern {
77        centers: Array2<f64>,
78        length_scale: f64,
79        nu: MaternNu,
80        aniso_log_scales: Vec<f64>,
81        chunk_size: Option<usize>,
82    },
83    Duchon {
84        centers: Array2<f64>,
85        length_scale: Option<f64>,
86        power: f64,
87        nullspace_order: DuchonNullspaceOrder,
88        aniso_log_scales: Vec<f64>,
89    },
90    Sphere {
91        centers: Array2<f64>,
92        penalty_order: usize,
93        chunk_size: Option<usize>,
94    },
95    PeriodicBspline {
96        domain_start: f64,
97        period: f64,
98        degree: usize,
99        num_basis: usize,
100        chunk_size: Option<usize>,
101    },
102    TensorBspline {
103        knots: Vec<Array1<f64>>,
104        degrees: Vec<usize>,
105        chunk_size: Option<usize>,
106    },
107    Pca {
108        basis_matrix: Array2<f64>,
109        centered: bool,
110        center_mean_fingerprint: Option<u64>,
111        smooth_penalty: f64,
112        pca_basis_path: Option<PathBuf>,
113        chunk_size: usize,
114    },
115}
116
117impl LatentBasisKind {
118    fn centers(&self) -> Option<&Array2<f64>> {
119        match self {
120            Self::Matern { centers, .. }
121            | Self::Duchon { centers, .. }
122            | Self::Sphere { centers, .. } => Some(centers),
123            Self::PeriodicBspline { .. } | Self::TensorBspline { .. } => None,
124            Self::Pca { .. } => None,
125        }
126    }
127
128    fn streams_radial_cache(&self) -> bool {
129        matches!(
130            self,
131            Self::Matern {
132                chunk_size: Some(_),
133                ..
134            } | Self::Sphere {
135                chunk_size: Some(_),
136                ..
137            }
138        )
139    }
140
141    fn cache_digest(&self) -> CacheDigest {
142        let mut hasher = cache_digest_builder("latent-basis-v1");
143        match self {
144            Self::Matern {
145                centers,
146                length_scale,
147                nu,
148                aniso_log_scales,
149                chunk_size,
150            } => {
151                hasher.write_usize(0);
152                hasher.write_usize(centers.nrows());
153                hasher.write_usize(centers.ncols());
154                hasher.write_f64(*length_scale);
155                hasher.write_usize(matern_nu_signature(*nu));
156                hasher.write_f64_slice(aniso_log_scales);
157                hash_optional_usize(*chunk_size, &mut hasher);
158                hasher.write_f64_array2(centers);
159            }
160            Self::Duchon {
161                centers,
162                length_scale,
163                power,
164                nullspace_order,
165                aniso_log_scales,
166            } => {
167                hasher.write_usize(1);
168                hasher.write_usize(centers.nrows());
169                hasher.write_usize(centers.ncols());
170                hash_optional_f64(*length_scale, &mut hasher);
171                hasher.write_u64(power.to_bits());
172                hash_duchon_nullspace_order(*nullspace_order, &mut hasher);
173                hasher.write_f64_slice(aniso_log_scales);
174                hasher.write_f64_array2(centers);
175            }
176            Self::Sphere {
177                centers,
178                penalty_order,
179                chunk_size,
180            } => {
181                hasher.write_usize(2);
182                hasher.write_usize(centers.nrows());
183                hasher.write_usize(centers.ncols());
184                hasher.write_usize(*penalty_order);
185                hash_optional_usize(*chunk_size, &mut hasher);
186                hasher.write_f64_array2(centers);
187            }
188            Self::PeriodicBspline {
189                domain_start,
190                period,
191                degree,
192                num_basis,
193                chunk_size,
194            } => {
195                hasher.write_usize(3);
196                hasher.write_f64(*domain_start);
197                hasher.write_f64(*period);
198                hasher.write_usize(*degree);
199                hasher.write_usize(*num_basis);
200                hash_optional_usize(*chunk_size, &mut hasher);
201            }
202            Self::TensorBspline {
203                knots,
204                degrees,
205                chunk_size,
206            } => {
207                hasher.write_usize(4);
208                hasher.write_usize(degrees.len());
209                for &degree in degrees {
210                    hasher.write_usize(degree);
211                }
212                hash_optional_usize(*chunk_size, &mut hasher);
213                hasher.write_usize(knots.len());
214                for axis_knots in knots {
215                    hasher.write_f64_array1(axis_knots);
216                }
217            }
218            Self::Pca {
219                basis_matrix,
220                centered,
221                center_mean_fingerprint,
222                smooth_penalty,
223                pca_basis_path,
224                chunk_size,
225            } => {
226                hasher.write_usize(5);
227                hasher.write_u8(*centered as u8);
228                if let Some(fp) = center_mean_fingerprint {
229                    hasher.write_u64(*fp);
230                }
231                hasher.write_u64(smooth_penalty.to_bits());
232                if let Some(path) = pca_basis_path {
233                    hasher.write_u8(1);
234                    hasher.write_bytes(path.to_string_lossy().as_bytes());
235                    if let Ok(meta) = std::fs::metadata(path) {
236                        hasher.write_u64(meta.len());
237                        if let Ok(modified) = meta.modified()
238                            && let Ok(elapsed) =
239                                modified.duration_since(std::time::SystemTime::UNIX_EPOCH)
240                        {
241                            hasher.write_u64(elapsed.as_secs());
242                            hasher.write_u64(elapsed.subsec_nanos() as u64);
243                        }
244                    }
245                } else {
246                    hasher.write_u8(0);
247                }
248                hasher.write_usize(*chunk_size);
249                hasher.write_usize(basis_matrix.nrows());
250                hasher.write_usize(basis_matrix.ncols());
251                hasher.write_f64_array2(basis_matrix);
252            }
253        }
254        hasher.finalize()
255    }
256}
257
258pub fn pca_center_mean_fingerprint(mean: &Array1<f64>) -> u64 {
259    let mut hasher = Fingerprinter::new();
260    hasher.write_usize(mean.len());
261    for &value in mean.iter() {
262        hasher.write_f64(value);
263    }
264    hasher.finish_u64()
265}
266
267fn matern_nu_signature(nu: MaternNu) -> usize {
268    match nu {
269        MaternNu::Half => 0,
270        MaternNu::ThreeHalves => 1,
271        MaternNu::FiveHalves => 2,
272        MaternNu::SevenHalves => 3,
273        MaternNu::NineHalves => 4,
274    }
275}
276
277fn hash_duchon_nullspace_order(order: DuchonNullspaceOrder, hasher: &mut Fingerprinter) {
278    match order {
279        DuchonNullspaceOrder::Zero => {
280            hasher.write_usize(0);
281        }
282        DuchonNullspaceOrder::Linear => {
283            hasher.write_usize(1);
284        }
285        DuchonNullspaceOrder::Degree(degree) => {
286            hasher.write_usize(2);
287            hasher.write_usize(degree);
288        }
289    }
290}
291
292fn hash_optional_f64(value: Option<f64>, hasher: &mut Fingerprinter) {
293    match value {
294        Some(value) => {
295            hasher.write_bool(true);
296            hasher.write_f64(value);
297        }
298        None => {
299            hasher.write_bool(false);
300        }
301    }
302}
303
304fn hash_optional_usize(value: Option<usize>, hasher: &mut Fingerprinter) {
305    match value {
306        Some(value) => {
307            hasher.write_bool(true);
308            hasher.write_usize(value);
309        }
310        None => {
311            hasher.write_bool(false);
312        }
313    }
314}
315
316fn latent_metadata_cache_digest(latent: &LatentCoordValues) -> CacheDigest {
317    let mut hasher = cache_digest_builder("latent-cache-metadata-v1");
318    hasher.write_usize(latent.n_obs());
319    hasher.write_usize(latent.latent_dim());
320    hash_latent_manifold(latent.manifold(), &mut hasher);
321    hash_latent_id_mode(latent.id_mode(), &mut hasher);
322    hasher.finalize()
323}
324
325pub fn latent_design_context_cache_digest(
326    data: ArrayView2<'_, f64>,
327    spec: &TermCollectionSpec,
328    term_index: gam_problem::SmoothTermIdx,
329    analytic_rho_count: usize,
330    feature_cols: &[usize],
331) -> Result<CacheDigest, EstimationError> {
332    let mut hasher = cache_digest_builder("latent-design-context-v1");
333    hasher.write_usize(data.nrows());
334    hasher.write_usize(data.ncols());
335    for row in 0..data.nrows() {
336        for col in 0..data.ncols() {
337            hasher.write_f64(data[[row, col]]);
338        }
339    }
340    let spec_bytes = serde_json::to_vec(spec).map_err(|err| {
341        EstimationError::InvalidInput(format!(
342            "failed to serialize latent design cache context: {err}"
343        ))
344    })?;
345    hasher.write_usize(spec_bytes.len());
346    hasher.write_bytes(&spec_bytes);
347    hasher.write_usize(term_index.get());
348    hasher.write_usize(analytic_rho_count);
349    hasher.write_usize(feature_cols.len());
350    for &col in feature_cols {
351        hasher.write_usize(col);
352    }
353    Ok(hasher.finalize())
354}
355
356fn hash_latent_id_mode(id_mode: &LatentIdMode, hasher: &mut Fingerprinter) {
357    match id_mode {
358        LatentIdMode::AuxPrior {
359            u,
360            family,
361            strength,
362        } => {
363            hasher.write_usize(0);
364            hasher.write_f64_array2(u);
365            hash_aux_prior_family(*family, hasher);
366            hash_aux_prior_strength(*strength, hasher);
367        }
368        LatentIdMode::AuxPriorDimSelection {
369            u,
370            family,
371            strength,
372            init_log_precision,
373        } => {
374            hasher.write_usize(1);
375            hasher.write_f64_array2(u);
376            hash_aux_prior_family(*family, hasher);
377            hash_aux_prior_strength(*strength, hasher);
378            hash_optional_vector(init_log_precision.as_ref(), hasher);
379        }
380        LatentIdMode::DimSelection { init_log_precision } => {
381            hasher.write_usize(2);
382            hash_optional_vector(init_log_precision.as_ref(), hasher);
383        }
384        LatentIdMode::IsometryToReference {
385            reference,
386            strength,
387        } => {
388            hasher.write_usize(5);
389            hasher.write_f64_array2(reference);
390            hash_aux_prior_strength(*strength, hasher);
391        }
392        LatentIdMode::AuxOutcome {
393            head,
394            init_log_precision,
395        } => {
396            hasher.write_usize(4);
397            hash_behavioral_head(head, hasher);
398            hash_optional_vector(init_log_precision.as_ref(), hasher);
399        }
400        LatentIdMode::None => {
401            hasher.write_usize(3);
402        }
403    }
404}
405
406fn hash_behavioral_head(
407    head: &gam_terms::decoders::behavioral_head::BehavioralHead,
408    hasher: &mut Fingerprinter,
409) {
410    use gam_terms::decoders::behavioral_head::AuxOutcomeFamily;
411    match head.family() {
412        AuxOutcomeFamily::Binomial => hasher.write_usize(0),
413        AuxOutcomeFamily::Multinomial { n_classes } => {
414            hasher.write_usize(1);
415            hasher.write_usize(n_classes);
416        }
417    }
418    hasher.write_usize(head.n_obs());
419    hasher.write_f64(head.effective_labeled_count());
420}
421
422fn hash_aux_prior_family(family: AuxPriorFamily, hasher: &mut Fingerprinter) {
423    hasher.write_usize(match family {
424        AuxPriorFamily::Ridge => 0,
425        AuxPriorFamily::Linear => 1,
426    });
427}
428
429fn hash_aux_prior_strength(strength: AuxPriorStrength, hasher: &mut Fingerprinter) {
430    match strength {
431        AuxPriorStrength::Auto => {
432            hasher.write_usize(0);
433        }
434        AuxPriorStrength::Fixed(value) => {
435            hasher.write_usize(1);
436            hasher.write_f64(value);
437        }
438    }
439}
440
441fn hash_optional_vector(vector: Option<&Array1<f64>>, hasher: &mut Fingerprinter) {
442    match vector {
443        Some(vector) => {
444            hasher.write_bool(true);
445            hasher.write_f64_array1(vector);
446        }
447        None => {
448            hasher.write_bool(false);
449        }
450    }
451}
452
453fn hash_latent_manifold(manifold: &LatentManifold, hasher: &mut Fingerprinter) {
454    match manifold {
455        LatentManifold::Euclidean => {
456            hasher.write_usize(0);
457        }
458        LatentManifold::Circle { period } => {
459            hasher.write_usize(1);
460            hasher.write_f64(*period);
461        }
462        LatentManifold::Sphere { dim } => {
463            hasher.write_usize(2);
464            hasher.write_usize(*dim);
465        }
466        LatentManifold::Interval { lo, hi } => {
467            hasher.write_usize(3);
468            hasher.write_f64(*lo);
469            hasher.write_f64(*hi);
470        }
471        LatentManifold::Product(parts) => {
472            hasher.write_usize(4);
473            hasher.write_usize(parts.len());
474            for part in parts {
475                hash_latent_manifold(part, hasher);
476            }
477        }
478        LatentManifold::ProductWithMetric { manifolds, weights } => {
479            hasher.write_usize(5);
480            hasher.write_usize(manifolds.len());
481            for part in manifolds {
482                hash_latent_manifold(part, hasher);
483            }
484            hasher.write_f64_slice(weights);
485        }
486    }
487}
488
489#[derive(Clone)]
490pub(crate) struct RadialDistanceMatrices {
491    pub(crate) squared: Array2<f64>,
492    pub(crate) distance: Array2<f64>,
493}
494
495#[derive(Clone)]
496pub(crate) struct BasisDerivativeJets {
497    pub(crate) phi: Option<Array2<f64>>,
498    pub(crate) q: Option<Array2<f64>>,
499    pub(crate) t: Option<Array2<f64>>,
500    pub(crate) phi_r: Option<Array2<f64>>,
501    pub(crate) phi_rr: Option<Array2<f64>>,
502    pub(crate) operator_resident: bool,
503}
504
505impl BasisDerivativeJets {
506    fn empty() -> Self {
507        Self {
508            phi: None,
509            q: None,
510            t: None,
511            phi_r: None,
512            phi_rr: None,
513            operator_resident: false,
514        }
515    }
516}
517
518#[derive(Clone)]
519pub struct CachedDesign {
520    pub(crate) latent_id: u64,
521    pub(crate) fingerprint: LatentFingerprint,
522    basis_digest: CacheDigest,
523    latent_metadata_digest: CacheDigest,
524    design_context_digest: CacheDigest,
525    latent_bits: Arc<[u64]>,
526    cacheable: bool,
527    pub design: TermCollectionDesign,
528    pub hyper_dirs: Vec<DirectionalHyperParam>,
529    pub(crate) radial_distances: RadialDistanceMatrices,
530    pub(crate) basis_derivative_jets: BasisDerivativeJets,
531}
532
533pub struct ComputedLatentDesign {
534    pub design: TermCollectionDesign,
535    pub hyper_dirs: Vec<DirectionalHyperParam>,
536}
537
538pub struct LatentDesignLookup<'a> {
539    pub cached: &'a CachedDesign,
540    pub entry_id: u64,
541}
542
543#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
544struct PersistentLatentDesignKey {
545    latent_id: u64,
546    flat_hash: u64,
547    basis_digest: CacheDigest,
548    latent_metadata_digest: CacheDigest,
549    design_context_digest: CacheDigest,
550}
551
552struct PersistentLatentDesignEntry {
553    fingerprint: LatentFingerprint,
554    cached: Arc<CachedDesign>,
555    bytes: usize,
556}
557
558pub(crate) struct PersistentLatentDesignCache {
559    entries: HashMap<PersistentLatentDesignKey, PersistentLatentDesignEntry>,
560    lru: VecDeque<PersistentLatentDesignKey>,
561    capacity: usize,
562    byte_budget: usize,
563    cache_bytes: usize,
564}
565
566impl Default for PersistentLatentDesignCache {
567    fn default() -> Self {
568        Self::new(DEFAULT_PERSISTENT_LATENT_CACHE_CAPACITY)
569    }
570}
571
572impl PersistentLatentDesignCache {
573    pub(crate) fn new(capacity: usize) -> Self {
574        Self {
575            entries: HashMap::new(),
576            lru: VecDeque::new(),
577            capacity: capacity.max(1),
578            byte_budget: DEFAULT_PERSISTENT_LATENT_CACHE_BYTE_BUDGET,
579            cache_bytes: 0,
580        }
581    }
582
583    pub(crate) fn lookup(
584        &mut self,
585        latent: &LatentCoordValues,
586        basis_digest: CacheDigest,
587        latent_metadata_digest: CacheDigest,
588        design_context_digest: CacheDigest,
589        fingerprint: &LatentFingerprint,
590    ) -> Result<Option<Arc<CachedDesign>>, EstimationError> {
591        let key = PersistentLatentDesignKey {
592            latent_id: latent.latent_id(),
593            flat_hash: fingerprint.hash,
594            basis_digest,
595            latent_metadata_digest,
596            design_context_digest,
597        };
598        let Some(entry) = self.entries.get(&key) else {
599            return Ok(None);
600        };
601        let cached = entry.cached.clone();
602        let entry_fingerprint = entry.fingerprint.clone();
603        self.touch(key);
604        if entry_fingerprint.len != fingerprint.len {
605            return Ok(None);
606        }
607        if entry_fingerprint.hash == fingerprint.hash
608            && cached.cacheable
609            && cached.basis_digest == basis_digest
610            && cached.latent_metadata_digest == latent_metadata_digest
611            && cached.design_context_digest == design_context_digest
612            && latent_bits_match(latent, &cached.latent_bits)
613        {
614            return Ok(Some(cached));
615        }
616        Ok(None)
617    }
618
619    pub(crate) fn insert(&mut self, cached: Arc<CachedDesign>) {
620        if !cached.cacheable {
621            return;
622        }
623        let bytes = cached.resident_byte_count();
624        if bytes > self.byte_budget {
625            return;
626        }
627        let key = PersistentLatentDesignKey {
628            latent_id: cached.latent_id,
629            flat_hash: cached.fingerprint.hash,
630            basis_digest: cached.basis_digest,
631            latent_metadata_digest: cached.latent_metadata_digest,
632            design_context_digest: cached.design_context_digest,
633        };
634        let entry = PersistentLatentDesignEntry {
635            fingerprint: cached.fingerprint.clone(),
636            cached,
637            bytes,
638        };
639        if let Some(old) = self.entries.insert(key, entry) {
640            self.cache_bytes = self.cache_bytes.saturating_sub(old.bytes);
641        }
642        self.cache_bytes = self.cache_bytes.saturating_add(bytes);
643        self.touch(key);
644        self.evict_to_limits();
645    }
646
647    fn evict_to_limits(&mut self) {
648        while self.entries.len() > self.capacity || self.cache_bytes > self.byte_budget {
649            let Some(evicted) = self.lru.pop_front() else {
650                break;
651            };
652            if let Some(entry) = self.entries.remove(&evicted) {
653                self.cache_bytes = self.cache_bytes.saturating_sub(entry.bytes);
654            }
655        }
656    }
657
658    fn touch(&mut self, key: PersistentLatentDesignKey) {
659        if let Some(index) = self.lru.iter().position(|queued| *queued == key) {
660            self.lru.remove(index);
661        }
662        self.lru.push_back(key);
663    }
664}
665
666pub struct LatentDesignCache {
667    entries: Vec<LatentDesignCacheEntry>,
668    capacity: usize,
669    clock: u64,
670    iteration: u64,
671    next_entry_id: u64,
672}
673
674struct LatentDesignCacheEntry {
675    id: u64,
676    cached: Arc<CachedDesign>,
677    last_used: u64,
678    iteration: u64,
679}
680
681impl Default for LatentDesignCache {
682    fn default() -> Self {
683        Self::new(DEFAULT_LATENT_CACHE_CAPACITY)
684    }
685}
686
687impl LatentDesignCache {
688    pub(crate) fn new(capacity: usize) -> Self {
689        Self {
690            entries: Vec::new(),
691            capacity: capacity.max(1),
692            clock: 0,
693            iteration: 0,
694            next_entry_id: 0,
695        }
696    }
697
698    pub fn invalidate(&mut self) {
699        self.entries.clear();
700    }
701
702    pub fn invalidate_all(&mut self) {
703        self.entries.clear();
704        self.clock = self.clock.wrapping_add(1);
705        self.iteration = self.iteration.wrapping_add(1);
706    }
707
708    pub fn lookup_or_compute<F>(
709        &mut self,
710        latent: Arc<LatentCoordValues>,
711        basis_kind: LatentBasisKind,
712        design_context_digest: CacheDigest,
713        compute: F,
714    ) -> Result<LatentDesignLookup<'_>, EstimationError>
715    where
716        F: FnOnce() -> Result<ComputedLatentDesign, EstimationError>,
717    {
718        self.iteration = self.iteration.wrapping_add(1);
719        self.clock = self.clock.wrapping_add(1);
720        let flat = latent.as_flat();
721        let flat_slice = flat
722            .as_slice()
723            .expect("LatentCoordValues flat storage must be contiguous");
724        let fingerprint = LatentFingerprint::from_flat(flat_slice);
725        let basis_digest = basis_kind.cache_digest();
726        let latent_metadata_digest = latent_metadata_cache_digest(&latent);
727        let cacheable = flat_slice.iter().all(|value| value.is_finite());
728        if cacheable
729            && let Some(index) = self.find_entry(
730                &latent,
731                basis_digest,
732                latent_metadata_digest,
733                design_context_digest,
734            )
735        {
736            self.entries[index].last_used = self.clock;
737            return Ok(LatentDesignLookup {
738                cached: self.entries[index].cached.as_ref(),
739                entry_id: self.entries[index].id,
740            });
741        }
742        if cacheable
743            && let Some(cached) = lookup_persistent_latent_design(
744                &latent,
745                basis_digest,
746                latent_metadata_digest,
747                design_context_digest,
748                &fingerprint,
749            )?
750        {
751            let id = self.next_entry_id;
752            self.next_entry_id = self.next_entry_id.wrapping_add(1);
753            self.insert(cached, id);
754            return self.lookup_inserted(id);
755        }
756
757        let computed = compute()?;
758        let radial_distances = if basis_kind.streams_radial_cache() {
759            RadialDistanceMatrices {
760                squared: Array2::<f64>::zeros((0, 0)),
761                distance: Array2::<f64>::zeros((0, 0)),
762            }
763        } else {
764            match basis_kind.centers() {
765                Some(centers) => build_radial_distances(&latent, centers)?,
766                None => RadialDistanceMatrices {
767                    squared: Array2::<f64>::zeros((0, 0)),
768                    distance: Array2::<f64>::zeros((0, 0)),
769                },
770            }
771        };
772        let basis_derivative_jets = build_basis_derivative_jets(&basis_kind, &radial_distances)?;
773        let id = self.next_entry_id;
774        self.next_entry_id = self.next_entry_id.wrapping_add(1);
775        let entry = Arc::new(CachedDesign {
776            latent_id: latent.latent_id(),
777            fingerprint,
778            basis_digest,
779            latent_metadata_digest,
780            design_context_digest,
781            latent_bits: latent_bits(&latent),
782            cacheable,
783            design: computed.design,
784            hyper_dirs: computed.hyper_dirs,
785            radial_distances,
786            basis_derivative_jets,
787        });
788        if cacheable {
789            insert_persistent_latent_design(Arc::clone(&entry))?;
790        }
791        self.insert(entry, id);
792        self.lookup_inserted(id)
793    }
794
795    fn find_entry(
796        &mut self,
797        latent: &LatentCoordValues,
798        basis_digest: CacheDigest,
799        latent_metadata_digest: CacheDigest,
800        design_context_digest: CacheDigest,
801    ) -> Option<usize> {
802        self.entries.iter().position(|entry| {
803            entry.cached.cacheable
804                && entry.cached.basis_digest == basis_digest
805                && entry.cached.latent_metadata_digest == latent_metadata_digest
806                && entry.cached.design_context_digest == design_context_digest
807                && entry.cached.latent_id == latent.latent_id()
808                && latent_bits_match(latent, &entry.cached.latent_bits)
809        })
810    }
811
812    fn lookup_inserted(&self, id: u64) -> Result<LatentDesignLookup<'_>, EstimationError> {
813        let Some(index) = self.entries.iter().position(|entry| entry.id == id) else {
814            return Err(EstimationError::InvalidInput(
815                "inserted latent design cache entry missing".to_string(),
816            ));
817        };
818        Ok(LatentDesignLookup {
819            cached: self.entries[index].cached.as_ref(),
820            entry_id: self.entries[index].id,
821        })
822    }
823
824    fn insert(&mut self, cached: Arc<CachedDesign>, id: u64) {
825        self.entries.push(LatentDesignCacheEntry {
826            id,
827            cached,
828            last_used: self.clock,
829            iteration: self.iteration,
830        });
831        while self.entries.len() > self.capacity {
832            if let Some(evict_index) = self
833                .entries
834                .iter()
835                .enumerate()
836                .min_by_key(|(_, entry)| (entry.last_used, entry.iteration))
837                .map(|(index, _)| index)
838            {
839                self.entries.remove(evict_index);
840            } else {
841                break;
842            }
843        }
844    }
845}
846
847impl CachedDesign {
848    fn resident_byte_count(&self) -> usize {
849        self.resident_scalar_count()
850            .saturating_mul(std::mem::size_of::<f64>())
851            .saturating_add(
852                self.hyper_dirs
853                    .iter()
854                    .map(DirectionalHyperParam::resident_byte_count)
855                    .sum::<usize>(),
856            )
857    }
858
859    fn resident_scalar_count(&self) -> usize {
860        let mut count = self
861            .design
862            .design
863            .nrows()
864            .saturating_mul(self.design.design.ncols());
865        count = count.saturating_add(
866            self.design
867                .coefficient_lower_bounds
868                .as_ref()
869                .map_or(0, |values| values.len()),
870        );
871        count = count.saturating_add(self.radial_distances.squared.len());
872        count = count.saturating_add(self.radial_distances.distance.len());
873        count = count.saturating_add(
874            self.basis_derivative_jets
875                .phi
876                .as_ref()
877                .map_or(0, |values| values.len()),
878        );
879        count = count.saturating_add(
880            self.basis_derivative_jets
881                .q
882                .as_ref()
883                .map_or(0, |values| values.len()),
884        );
885        count = count.saturating_add(
886            self.basis_derivative_jets
887                .t
888                .as_ref()
889                .map_or(0, |values| values.len()),
890        );
891        count = count.saturating_add(
892            self.basis_derivative_jets
893                .phi_r
894                .as_ref()
895                .map_or(0, |values| values.len()),
896        );
897        count = count.saturating_add(
898            self.basis_derivative_jets
899                .phi_rr
900                .as_ref()
901                .map_or(0, |values| values.len()),
902        );
903        count.saturating_add(usize::from(self.basis_derivative_jets.operator_resident))
904    }
905}
906
907fn lookup_persistent_latent_design(
908    latent: &LatentCoordValues,
909    basis_digest: CacheDigest,
910    latent_metadata_digest: CacheDigest,
911    design_context_digest: CacheDigest,
912    fingerprint: &LatentFingerprint,
913) -> Result<Option<Arc<CachedDesign>>, EstimationError> {
914    let cache = PERSISTENT_LATENT_DESIGN_CACHE
915        .get_or_init(|| Mutex::new(PersistentLatentDesignCache::default()));
916    let mut guard = cache.lock().map_err(|_| {
917        EstimationError::InvalidInput("persistent latent design cache mutex poisoned".to_string())
918    })?;
919    guard.lookup(
920        latent,
921        basis_digest,
922        latent_metadata_digest,
923        design_context_digest,
924        fingerprint,
925    )
926}
927
928fn insert_persistent_latent_design(cached: Arc<CachedDesign>) -> Result<(), EstimationError> {
929    let cache = PERSISTENT_LATENT_DESIGN_CACHE
930        .get_or_init(|| Mutex::new(PersistentLatentDesignCache::default()));
931    let mut guard = cache.lock().map_err(|_| {
932        EstimationError::InvalidInput("persistent latent design cache mutex poisoned".to_string())
933    })?;
934    guard.insert(cached);
935    Ok(())
936}
937
938fn latent_bits(latent: &LatentCoordValues) -> Arc<[u64]> {
939    latent
940        .as_flat()
941        .iter()
942        .map(|value| value.to_bits())
943        .collect::<Vec<_>>()
944        .into()
945}
946
947fn latent_bits_match(latent: &LatentCoordValues, cached_bits: &[u64]) -> bool {
948    latent.as_flat().len() == cached_bits.len()
949        && latent
950            .as_flat()
951            .iter()
952            .zip(cached_bits.iter())
953            .all(|(value, bits)| value.to_bits() == *bits)
954}
955
956fn build_radial_distances(
957    latent: &LatentCoordValues,
958    centers: &Array2<f64>,
959) -> Result<RadialDistanceMatrices, EstimationError> {
960    let t = latent.as_matrix();
961    if t.ncols() != centers.ncols() {
962        return Err(EstimationError::InvalidInput(format!(
963            "latent design cache center dimension mismatch: latent d={}, centers d={}",
964            t.ncols(),
965            centers.ncols()
966        )));
967    }
968    let mut squared = Array2::<f64>::zeros((t.nrows(), centers.nrows()));
969    let mut distance = Array2::<f64>::zeros((t.nrows(), centers.nrows()));
970    for row in 0..t.nrows() {
971        for center in 0..centers.nrows() {
972            let mut r2 = 0.0_f64;
973            for axis in 0..t.ncols() {
974                let delta = t[[row, axis]] - centers[[center, axis]];
975                r2 += delta * delta;
976            }
977            squared[[row, center]] = r2;
978            distance[[row, center]] = r2.sqrt();
979        }
980    }
981    Ok(RadialDistanceMatrices { squared, distance })
982}
983
984fn build_basis_derivative_jets(
985    basis_kind: &LatentBasisKind,
986    distances: &RadialDistanceMatrices,
987) -> Result<BasisDerivativeJets, EstimationError> {
988    match basis_kind {
989        LatentBasisKind::Matern {
990            length_scale,
991            nu,
992            chunk_size,
993            ..
994        } => {
995            if chunk_size.is_some() {
996                return Ok(BasisDerivativeJets {
997                    operator_resident: true,
998                    ..BasisDerivativeJets::empty()
999                });
1000            }
1001            let radial = RadialScalarKind::Matern {
1002                length_scale: *length_scale,
1003                nu: *nu,
1004            };
1005            let mut phi = Array2::<f64>::zeros(distances.distance.raw_dim());
1006            let mut q = Array2::<f64>::zeros(distances.distance.raw_dim());
1007            let mut t = Array2::<f64>::zeros(distances.distance.raw_dim());
1008            for row in 0..distances.distance.nrows() {
1009                for center in 0..distances.distance.ncols() {
1010                    let (phi_value, q_value, t_value) = radial
1011                        .eval_design_triplet(distances.distance[[row, center]])
1012                        .map_err(EstimationError::from)?;
1013                    phi[[row, center]] = phi_value;
1014                    q[[row, center]] = q_value;
1015                    t[[row, center]] = t_value;
1016                }
1017            }
1018            Ok(BasisDerivativeJets {
1019                phi: Some(phi),
1020                q: Some(q),
1021                t: Some(t),
1022                phi_r: None,
1023                phi_rr: None,
1024                operator_resident: false,
1025            })
1026        }
1027        LatentBasisKind::Duchon { .. } => Ok(BasisDerivativeJets {
1028            operator_resident: true,
1029            ..BasisDerivativeJets::empty()
1030        }),
1031        LatentBasisKind::Sphere { .. }
1032        | LatentBasisKind::PeriodicBspline { .. }
1033        | LatentBasisKind::Pca { .. }
1034        | LatentBasisKind::TensorBspline { .. } => Ok(BasisDerivativeJets {
1035            operator_resident: true,
1036            ..BasisDerivativeJets::empty()
1037        }),
1038    }
1039}