1use std::collections::hash_map::DefaultHasher;
14use std::collections::{HashMap, VecDeque};
15use std::hash::{Hash, Hasher};
16use std::path::PathBuf;
17use std::sync::{Arc, Mutex, OnceLock};
18
19use ndarray::{Array1, Array2, ArrayView2};
20
21use gam_terms::basis::{DuchonNullspaceOrder, MaternNu, RadialScalarKind};
22use crate::estimate::EstimationError;
23use crate::estimate::reml::DirectionalHyperParam;
24pub use gam_problem::LatentRetractionRegistry;
25use gam_terms::latent::{
26 AuxPriorFamily, AuxPriorStrength, LatentCoordValues, LatentIdMode, LatentManifold,
27};
28use gam_terms::smooth::{TermCollectionDesign, TermCollectionSpec};
29use gam_runtime::warm_start::{Fingerprint, Fingerprinter};
30
31const DEFAULT_LATENT_CACHE_CAPACITY: usize = 4;
32const DEFAULT_PERSISTENT_LATENT_CACHE_CAPACITY: usize = 16;
33const DEFAULT_PERSISTENT_LATENT_CACHE_BYTE_BUDGET: usize = 1024 * 1024 * 1024;
34
35static PERSISTENT_LATENT_DESIGN_CACHE: OnceLock<Mutex<PersistentLatentDesignCache>> =
36 OnceLock::new();
37
38#[derive(Clone, Debug)]
40pub(crate) struct LatentFingerprint {
41 pub(crate) hash: u64,
42 pub(crate) len: usize,
43}
44
45impl LatentFingerprint {
46 pub(crate) fn from_flat(flat: &[f64]) -> Self {
47 let mut hasher = DefaultHasher::new();
48 flat.len().hash(&mut hasher);
49 for &value in flat {
50 value.to_bits().hash(&mut hasher);
51 }
52 Self {
53 hash: hasher.finish(),
54 len: flat.len(),
55 }
56 }
57}
58
59pub type CacheDigest = Fingerprint;
60
61fn cache_digest_builder(namespace: &str) -> Fingerprinter {
68 let mut out = Fingerprinter::new();
69 out.write_str(namespace);
70 out
71}
72
73#[derive(Clone)]
74pub enum LatentBasisKind {
75 Matern {
77 centers: Array2<f64>,
78 length_scale: f64,
79 nu: MaternNu,
80 aniso_log_scales: Vec<f64>,
81 chunk_size: Option<usize>,
82 },
83 Duchon {
84 centers: Array2<f64>,
85 length_scale: Option<f64>,
86 power: f64,
87 nullspace_order: DuchonNullspaceOrder,
88 aniso_log_scales: Vec<f64>,
89 },
90 Sphere {
91 centers: Array2<f64>,
92 penalty_order: usize,
93 chunk_size: Option<usize>,
94 },
95 PeriodicBspline {
96 domain_start: f64,
97 period: f64,
98 degree: usize,
99 num_basis: usize,
100 chunk_size: Option<usize>,
101 },
102 TensorBspline {
103 knots: Vec<Array1<f64>>,
104 degrees: Vec<usize>,
105 chunk_size: Option<usize>,
106 },
107 Pca {
108 basis_matrix: Array2<f64>,
109 centered: bool,
110 center_mean_fingerprint: Option<u64>,
111 smooth_penalty: f64,
112 pca_basis_path: Option<PathBuf>,
113 chunk_size: usize,
114 },
115}
116
117impl LatentBasisKind {
118 fn centers(&self) -> Option<&Array2<f64>> {
119 match self {
120 Self::Matern { centers, .. }
121 | Self::Duchon { centers, .. }
122 | Self::Sphere { centers, .. } => Some(centers),
123 Self::PeriodicBspline { .. } | Self::TensorBspline { .. } => None,
124 Self::Pca { .. } => None,
125 }
126 }
127
128 fn streams_radial_cache(&self) -> bool {
129 matches!(
130 self,
131 Self::Matern {
132 chunk_size: Some(_),
133 ..
134 } | Self::Sphere {
135 chunk_size: Some(_),
136 ..
137 }
138 )
139 }
140
141 fn cache_digest(&self) -> CacheDigest {
142 let mut hasher = cache_digest_builder("latent-basis-v1");
143 match self {
144 Self::Matern {
145 centers,
146 length_scale,
147 nu,
148 aniso_log_scales,
149 chunk_size,
150 } => {
151 hasher.write_usize(0);
152 hasher.write_usize(centers.nrows());
153 hasher.write_usize(centers.ncols());
154 hasher.write_f64(*length_scale);
155 hasher.write_usize(matern_nu_signature(*nu));
156 hasher.write_f64_slice(aniso_log_scales);
157 hash_optional_usize(*chunk_size, &mut hasher);
158 hasher.write_f64_array2(centers);
159 }
160 Self::Duchon {
161 centers,
162 length_scale,
163 power,
164 nullspace_order,
165 aniso_log_scales,
166 } => {
167 hasher.write_usize(1);
168 hasher.write_usize(centers.nrows());
169 hasher.write_usize(centers.ncols());
170 hash_optional_f64(*length_scale, &mut hasher);
171 hasher.write_u64(power.to_bits());
172 hash_duchon_nullspace_order(*nullspace_order, &mut hasher);
173 hasher.write_f64_slice(aniso_log_scales);
174 hasher.write_f64_array2(centers);
175 }
176 Self::Sphere {
177 centers,
178 penalty_order,
179 chunk_size,
180 } => {
181 hasher.write_usize(2);
182 hasher.write_usize(centers.nrows());
183 hasher.write_usize(centers.ncols());
184 hasher.write_usize(*penalty_order);
185 hash_optional_usize(*chunk_size, &mut hasher);
186 hasher.write_f64_array2(centers);
187 }
188 Self::PeriodicBspline {
189 domain_start,
190 period,
191 degree,
192 num_basis,
193 chunk_size,
194 } => {
195 hasher.write_usize(3);
196 hasher.write_f64(*domain_start);
197 hasher.write_f64(*period);
198 hasher.write_usize(*degree);
199 hasher.write_usize(*num_basis);
200 hash_optional_usize(*chunk_size, &mut hasher);
201 }
202 Self::TensorBspline {
203 knots,
204 degrees,
205 chunk_size,
206 } => {
207 hasher.write_usize(4);
208 hasher.write_usize(degrees.len());
209 for °ree in degrees {
210 hasher.write_usize(degree);
211 }
212 hash_optional_usize(*chunk_size, &mut hasher);
213 hasher.write_usize(knots.len());
214 for axis_knots in knots {
215 hasher.write_f64_array1(axis_knots);
216 }
217 }
218 Self::Pca {
219 basis_matrix,
220 centered,
221 center_mean_fingerprint,
222 smooth_penalty,
223 pca_basis_path,
224 chunk_size,
225 } => {
226 hasher.write_usize(5);
227 hasher.write_u8(*centered as u8);
228 if let Some(fp) = center_mean_fingerprint {
229 hasher.write_u64(*fp);
230 }
231 hasher.write_u64(smooth_penalty.to_bits());
232 if let Some(path) = pca_basis_path {
233 hasher.write_u8(1);
234 hasher.write_bytes(path.to_string_lossy().as_bytes());
235 if let Ok(meta) = std::fs::metadata(path) {
236 hasher.write_u64(meta.len());
237 if let Ok(modified) = meta.modified()
238 && let Ok(elapsed) =
239 modified.duration_since(std::time::SystemTime::UNIX_EPOCH)
240 {
241 hasher.write_u64(elapsed.as_secs());
242 hasher.write_u64(elapsed.subsec_nanos() as u64);
243 }
244 }
245 } else {
246 hasher.write_u8(0);
247 }
248 hasher.write_usize(*chunk_size);
249 hasher.write_usize(basis_matrix.nrows());
250 hasher.write_usize(basis_matrix.ncols());
251 hasher.write_f64_array2(basis_matrix);
252 }
253 }
254 hasher.finalize()
255 }
256}
257
258pub fn pca_center_mean_fingerprint(mean: &Array1<f64>) -> u64 {
259 let mut hasher = Fingerprinter::new();
260 hasher.write_usize(mean.len());
261 for &value in mean.iter() {
262 hasher.write_f64(value);
263 }
264 hasher.finish_u64()
265}
266
267fn matern_nu_signature(nu: MaternNu) -> usize {
268 match nu {
269 MaternNu::Half => 0,
270 MaternNu::ThreeHalves => 1,
271 MaternNu::FiveHalves => 2,
272 MaternNu::SevenHalves => 3,
273 MaternNu::NineHalves => 4,
274 }
275}
276
277fn hash_duchon_nullspace_order(order: DuchonNullspaceOrder, hasher: &mut Fingerprinter) {
278 match order {
279 DuchonNullspaceOrder::Zero => {
280 hasher.write_usize(0);
281 }
282 DuchonNullspaceOrder::Linear => {
283 hasher.write_usize(1);
284 }
285 DuchonNullspaceOrder::Degree(degree) => {
286 hasher.write_usize(2);
287 hasher.write_usize(degree);
288 }
289 }
290}
291
292fn hash_optional_f64(value: Option<f64>, hasher: &mut Fingerprinter) {
293 match value {
294 Some(value) => {
295 hasher.write_bool(true);
296 hasher.write_f64(value);
297 }
298 None => {
299 hasher.write_bool(false);
300 }
301 }
302}
303
304fn hash_optional_usize(value: Option<usize>, hasher: &mut Fingerprinter) {
305 match value {
306 Some(value) => {
307 hasher.write_bool(true);
308 hasher.write_usize(value);
309 }
310 None => {
311 hasher.write_bool(false);
312 }
313 }
314}
315
316fn latent_metadata_cache_digest(latent: &LatentCoordValues) -> CacheDigest {
317 let mut hasher = cache_digest_builder("latent-cache-metadata-v1");
318 hasher.write_usize(latent.n_obs());
319 hasher.write_usize(latent.latent_dim());
320 hash_latent_manifold(latent.manifold(), &mut hasher);
321 hash_latent_id_mode(latent.id_mode(), &mut hasher);
322 hasher.finalize()
323}
324
325pub fn latent_design_context_cache_digest(
326 data: ArrayView2<'_, f64>,
327 spec: &TermCollectionSpec,
328 term_index: gam_problem::SmoothTermIdx,
329 analytic_rho_count: usize,
330 feature_cols: &[usize],
331) -> Result<CacheDigest, EstimationError> {
332 let mut hasher = cache_digest_builder("latent-design-context-v1");
333 hasher.write_usize(data.nrows());
334 hasher.write_usize(data.ncols());
335 for row in 0..data.nrows() {
336 for col in 0..data.ncols() {
337 hasher.write_f64(data[[row, col]]);
338 }
339 }
340 let spec_bytes = serde_json::to_vec(spec).map_err(|err| {
341 EstimationError::InvalidInput(format!(
342 "failed to serialize latent design cache context: {err}"
343 ))
344 })?;
345 hasher.write_usize(spec_bytes.len());
346 hasher.write_bytes(&spec_bytes);
347 hasher.write_usize(term_index.get());
348 hasher.write_usize(analytic_rho_count);
349 hasher.write_usize(feature_cols.len());
350 for &col in feature_cols {
351 hasher.write_usize(col);
352 }
353 Ok(hasher.finalize())
354}
355
356fn hash_latent_id_mode(id_mode: &LatentIdMode, hasher: &mut Fingerprinter) {
357 match id_mode {
358 LatentIdMode::AuxPrior {
359 u,
360 family,
361 strength,
362 } => {
363 hasher.write_usize(0);
364 hasher.write_f64_array2(u);
365 hash_aux_prior_family(*family, hasher);
366 hash_aux_prior_strength(*strength, hasher);
367 }
368 LatentIdMode::AuxPriorDimSelection {
369 u,
370 family,
371 strength,
372 init_log_precision,
373 } => {
374 hasher.write_usize(1);
375 hasher.write_f64_array2(u);
376 hash_aux_prior_family(*family, hasher);
377 hash_aux_prior_strength(*strength, hasher);
378 hash_optional_vector(init_log_precision.as_ref(), hasher);
379 }
380 LatentIdMode::DimSelection { init_log_precision } => {
381 hasher.write_usize(2);
382 hash_optional_vector(init_log_precision.as_ref(), hasher);
383 }
384 LatentIdMode::IsometryToReference {
385 reference,
386 strength,
387 } => {
388 hasher.write_usize(5);
389 hasher.write_f64_array2(reference);
390 hash_aux_prior_strength(*strength, hasher);
391 }
392 LatentIdMode::AuxOutcome {
393 head,
394 init_log_precision,
395 } => {
396 hasher.write_usize(4);
397 hash_behavioral_head(head, hasher);
398 hash_optional_vector(init_log_precision.as_ref(), hasher);
399 }
400 LatentIdMode::None => {
401 hasher.write_usize(3);
402 }
403 }
404}
405
406fn hash_behavioral_head(
407 head: &gam_terms::decoders::behavioral_head::BehavioralHead,
408 hasher: &mut Fingerprinter,
409) {
410 use gam_terms::decoders::behavioral_head::AuxOutcomeFamily;
411 match head.family() {
412 AuxOutcomeFamily::Binomial => hasher.write_usize(0),
413 AuxOutcomeFamily::Multinomial { n_classes } => {
414 hasher.write_usize(1);
415 hasher.write_usize(n_classes);
416 }
417 }
418 hasher.write_usize(head.n_obs());
419 hasher.write_f64(head.effective_labeled_count());
420}
421
422fn hash_aux_prior_family(family: AuxPriorFamily, hasher: &mut Fingerprinter) {
423 hasher.write_usize(match family {
424 AuxPriorFamily::Ridge => 0,
425 AuxPriorFamily::Linear => 1,
426 });
427}
428
429fn hash_aux_prior_strength(strength: AuxPriorStrength, hasher: &mut Fingerprinter) {
430 match strength {
431 AuxPriorStrength::Auto => {
432 hasher.write_usize(0);
433 }
434 AuxPriorStrength::Fixed(value) => {
435 hasher.write_usize(1);
436 hasher.write_f64(value);
437 }
438 }
439}
440
441fn hash_optional_vector(vector: Option<&Array1<f64>>, hasher: &mut Fingerprinter) {
442 match vector {
443 Some(vector) => {
444 hasher.write_bool(true);
445 hasher.write_f64_array1(vector);
446 }
447 None => {
448 hasher.write_bool(false);
449 }
450 }
451}
452
453fn hash_latent_manifold(manifold: &LatentManifold, hasher: &mut Fingerprinter) {
454 match manifold {
455 LatentManifold::Euclidean => {
456 hasher.write_usize(0);
457 }
458 LatentManifold::Circle { period } => {
459 hasher.write_usize(1);
460 hasher.write_f64(*period);
461 }
462 LatentManifold::Sphere { dim } => {
463 hasher.write_usize(2);
464 hasher.write_usize(*dim);
465 }
466 LatentManifold::Interval { lo, hi } => {
467 hasher.write_usize(3);
468 hasher.write_f64(*lo);
469 hasher.write_f64(*hi);
470 }
471 LatentManifold::Product(parts) => {
472 hasher.write_usize(4);
473 hasher.write_usize(parts.len());
474 for part in parts {
475 hash_latent_manifold(part, hasher);
476 }
477 }
478 LatentManifold::ProductWithMetric { manifolds, weights } => {
479 hasher.write_usize(5);
480 hasher.write_usize(manifolds.len());
481 for part in manifolds {
482 hash_latent_manifold(part, hasher);
483 }
484 hasher.write_f64_slice(weights);
485 }
486 }
487}
488
489#[derive(Clone)]
490pub(crate) struct RadialDistanceMatrices {
491 pub(crate) squared: Array2<f64>,
492 pub(crate) distance: Array2<f64>,
493}
494
495#[derive(Clone)]
496pub(crate) struct BasisDerivativeJets {
497 pub(crate) phi: Option<Array2<f64>>,
498 pub(crate) q: Option<Array2<f64>>,
499 pub(crate) t: Option<Array2<f64>>,
500 pub(crate) phi_r: Option<Array2<f64>>,
501 pub(crate) phi_rr: Option<Array2<f64>>,
502 pub(crate) operator_resident: bool,
503}
504
505impl BasisDerivativeJets {
506 fn empty() -> Self {
507 Self {
508 phi: None,
509 q: None,
510 t: None,
511 phi_r: None,
512 phi_rr: None,
513 operator_resident: false,
514 }
515 }
516}
517
518#[derive(Clone)]
519pub struct CachedDesign {
520 pub(crate) latent_id: u64,
521 pub(crate) fingerprint: LatentFingerprint,
522 basis_digest: CacheDigest,
523 latent_metadata_digest: CacheDigest,
524 design_context_digest: CacheDigest,
525 latent_bits: Arc<[u64]>,
526 cacheable: bool,
527 pub design: TermCollectionDesign,
528 pub hyper_dirs: Vec<DirectionalHyperParam>,
529 pub(crate) radial_distances: RadialDistanceMatrices,
530 pub(crate) basis_derivative_jets: BasisDerivativeJets,
531}
532
533pub struct ComputedLatentDesign {
534 pub design: TermCollectionDesign,
535 pub hyper_dirs: Vec<DirectionalHyperParam>,
536}
537
538pub struct LatentDesignLookup<'a> {
539 pub cached: &'a CachedDesign,
540 pub entry_id: u64,
541}
542
543#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
544struct PersistentLatentDesignKey {
545 latent_id: u64,
546 flat_hash: u64,
547 basis_digest: CacheDigest,
548 latent_metadata_digest: CacheDigest,
549 design_context_digest: CacheDigest,
550}
551
552struct PersistentLatentDesignEntry {
553 fingerprint: LatentFingerprint,
554 cached: Arc<CachedDesign>,
555 bytes: usize,
556}
557
558pub(crate) struct PersistentLatentDesignCache {
559 entries: HashMap<PersistentLatentDesignKey, PersistentLatentDesignEntry>,
560 lru: VecDeque<PersistentLatentDesignKey>,
561 capacity: usize,
562 byte_budget: usize,
563 cache_bytes: usize,
564}
565
566impl Default for PersistentLatentDesignCache {
567 fn default() -> Self {
568 Self::new(DEFAULT_PERSISTENT_LATENT_CACHE_CAPACITY)
569 }
570}
571
572impl PersistentLatentDesignCache {
573 pub(crate) fn new(capacity: usize) -> Self {
574 Self {
575 entries: HashMap::new(),
576 lru: VecDeque::new(),
577 capacity: capacity.max(1),
578 byte_budget: DEFAULT_PERSISTENT_LATENT_CACHE_BYTE_BUDGET,
579 cache_bytes: 0,
580 }
581 }
582
583 pub(crate) fn lookup(
584 &mut self,
585 latent: &LatentCoordValues,
586 basis_digest: CacheDigest,
587 latent_metadata_digest: CacheDigest,
588 design_context_digest: CacheDigest,
589 fingerprint: &LatentFingerprint,
590 ) -> Result<Option<Arc<CachedDesign>>, EstimationError> {
591 let key = PersistentLatentDesignKey {
592 latent_id: latent.latent_id(),
593 flat_hash: fingerprint.hash,
594 basis_digest,
595 latent_metadata_digest,
596 design_context_digest,
597 };
598 let Some(entry) = self.entries.get(&key) else {
599 return Ok(None);
600 };
601 let cached = entry.cached.clone();
602 let entry_fingerprint = entry.fingerprint.clone();
603 self.touch(key);
604 if entry_fingerprint.len != fingerprint.len {
605 return Ok(None);
606 }
607 if entry_fingerprint.hash == fingerprint.hash
608 && cached.cacheable
609 && cached.basis_digest == basis_digest
610 && cached.latent_metadata_digest == latent_metadata_digest
611 && cached.design_context_digest == design_context_digest
612 && latent_bits_match(latent, &cached.latent_bits)
613 {
614 return Ok(Some(cached));
615 }
616 Ok(None)
617 }
618
619 pub(crate) fn insert(&mut self, cached: Arc<CachedDesign>) {
620 if !cached.cacheable {
621 return;
622 }
623 let bytes = cached.resident_byte_count();
624 if bytes > self.byte_budget {
625 return;
626 }
627 let key = PersistentLatentDesignKey {
628 latent_id: cached.latent_id,
629 flat_hash: cached.fingerprint.hash,
630 basis_digest: cached.basis_digest,
631 latent_metadata_digest: cached.latent_metadata_digest,
632 design_context_digest: cached.design_context_digest,
633 };
634 let entry = PersistentLatentDesignEntry {
635 fingerprint: cached.fingerprint.clone(),
636 cached,
637 bytes,
638 };
639 if let Some(old) = self.entries.insert(key, entry) {
640 self.cache_bytes = self.cache_bytes.saturating_sub(old.bytes);
641 }
642 self.cache_bytes = self.cache_bytes.saturating_add(bytes);
643 self.touch(key);
644 self.evict_to_limits();
645 }
646
647 fn evict_to_limits(&mut self) {
648 while self.entries.len() > self.capacity || self.cache_bytes > self.byte_budget {
649 let Some(evicted) = self.lru.pop_front() else {
650 break;
651 };
652 if let Some(entry) = self.entries.remove(&evicted) {
653 self.cache_bytes = self.cache_bytes.saturating_sub(entry.bytes);
654 }
655 }
656 }
657
658 fn touch(&mut self, key: PersistentLatentDesignKey) {
659 if let Some(index) = self.lru.iter().position(|queued| *queued == key) {
660 self.lru.remove(index);
661 }
662 self.lru.push_back(key);
663 }
664}
665
666pub struct LatentDesignCache {
667 entries: Vec<LatentDesignCacheEntry>,
668 capacity: usize,
669 clock: u64,
670 iteration: u64,
671 next_entry_id: u64,
672}
673
674struct LatentDesignCacheEntry {
675 id: u64,
676 cached: Arc<CachedDesign>,
677 last_used: u64,
678 iteration: u64,
679}
680
681impl Default for LatentDesignCache {
682 fn default() -> Self {
683 Self::new(DEFAULT_LATENT_CACHE_CAPACITY)
684 }
685}
686
687impl LatentDesignCache {
688 pub(crate) fn new(capacity: usize) -> Self {
689 Self {
690 entries: Vec::new(),
691 capacity: capacity.max(1),
692 clock: 0,
693 iteration: 0,
694 next_entry_id: 0,
695 }
696 }
697
698 pub fn invalidate(&mut self) {
699 self.entries.clear();
700 }
701
702 pub fn invalidate_all(&mut self) {
703 self.entries.clear();
704 self.clock = self.clock.wrapping_add(1);
705 self.iteration = self.iteration.wrapping_add(1);
706 }
707
708 pub fn lookup_or_compute<F>(
709 &mut self,
710 latent: Arc<LatentCoordValues>,
711 basis_kind: LatentBasisKind,
712 design_context_digest: CacheDigest,
713 compute: F,
714 ) -> Result<LatentDesignLookup<'_>, EstimationError>
715 where
716 F: FnOnce() -> Result<ComputedLatentDesign, EstimationError>,
717 {
718 self.iteration = self.iteration.wrapping_add(1);
719 self.clock = self.clock.wrapping_add(1);
720 let flat = latent.as_flat();
721 let flat_slice = flat
722 .as_slice()
723 .expect("LatentCoordValues flat storage must be contiguous");
724 let fingerprint = LatentFingerprint::from_flat(flat_slice);
725 let basis_digest = basis_kind.cache_digest();
726 let latent_metadata_digest = latent_metadata_cache_digest(&latent);
727 let cacheable = flat_slice.iter().all(|value| value.is_finite());
728 if cacheable
729 && let Some(index) = self.find_entry(
730 &latent,
731 basis_digest,
732 latent_metadata_digest,
733 design_context_digest,
734 )
735 {
736 self.entries[index].last_used = self.clock;
737 return Ok(LatentDesignLookup {
738 cached: self.entries[index].cached.as_ref(),
739 entry_id: self.entries[index].id,
740 });
741 }
742 if cacheable
743 && let Some(cached) = lookup_persistent_latent_design(
744 &latent,
745 basis_digest,
746 latent_metadata_digest,
747 design_context_digest,
748 &fingerprint,
749 )?
750 {
751 let id = self.next_entry_id;
752 self.next_entry_id = self.next_entry_id.wrapping_add(1);
753 self.insert(cached, id);
754 return self.lookup_inserted(id);
755 }
756
757 let computed = compute()?;
758 let radial_distances = if basis_kind.streams_radial_cache() {
759 RadialDistanceMatrices {
760 squared: Array2::<f64>::zeros((0, 0)),
761 distance: Array2::<f64>::zeros((0, 0)),
762 }
763 } else {
764 match basis_kind.centers() {
765 Some(centers) => build_radial_distances(&latent, centers)?,
766 None => RadialDistanceMatrices {
767 squared: Array2::<f64>::zeros((0, 0)),
768 distance: Array2::<f64>::zeros((0, 0)),
769 },
770 }
771 };
772 let basis_derivative_jets = build_basis_derivative_jets(&basis_kind, &radial_distances)?;
773 let id = self.next_entry_id;
774 self.next_entry_id = self.next_entry_id.wrapping_add(1);
775 let entry = Arc::new(CachedDesign {
776 latent_id: latent.latent_id(),
777 fingerprint,
778 basis_digest,
779 latent_metadata_digest,
780 design_context_digest,
781 latent_bits: latent_bits(&latent),
782 cacheable,
783 design: computed.design,
784 hyper_dirs: computed.hyper_dirs,
785 radial_distances,
786 basis_derivative_jets,
787 });
788 if cacheable {
789 insert_persistent_latent_design(Arc::clone(&entry))?;
790 }
791 self.insert(entry, id);
792 self.lookup_inserted(id)
793 }
794
795 fn find_entry(
796 &mut self,
797 latent: &LatentCoordValues,
798 basis_digest: CacheDigest,
799 latent_metadata_digest: CacheDigest,
800 design_context_digest: CacheDigest,
801 ) -> Option<usize> {
802 self.entries.iter().position(|entry| {
803 entry.cached.cacheable
804 && entry.cached.basis_digest == basis_digest
805 && entry.cached.latent_metadata_digest == latent_metadata_digest
806 && entry.cached.design_context_digest == design_context_digest
807 && entry.cached.latent_id == latent.latent_id()
808 && latent_bits_match(latent, &entry.cached.latent_bits)
809 })
810 }
811
812 fn lookup_inserted(&self, id: u64) -> Result<LatentDesignLookup<'_>, EstimationError> {
813 let Some(index) = self.entries.iter().position(|entry| entry.id == id) else {
814 return Err(EstimationError::InvalidInput(
815 "inserted latent design cache entry missing".to_string(),
816 ));
817 };
818 Ok(LatentDesignLookup {
819 cached: self.entries[index].cached.as_ref(),
820 entry_id: self.entries[index].id,
821 })
822 }
823
824 fn insert(&mut self, cached: Arc<CachedDesign>, id: u64) {
825 self.entries.push(LatentDesignCacheEntry {
826 id,
827 cached,
828 last_used: self.clock,
829 iteration: self.iteration,
830 });
831 while self.entries.len() > self.capacity {
832 if let Some(evict_index) = self
833 .entries
834 .iter()
835 .enumerate()
836 .min_by_key(|(_, entry)| (entry.last_used, entry.iteration))
837 .map(|(index, _)| index)
838 {
839 self.entries.remove(evict_index);
840 } else {
841 break;
842 }
843 }
844 }
845}
846
847impl CachedDesign {
848 fn resident_byte_count(&self) -> usize {
849 self.resident_scalar_count()
850 .saturating_mul(std::mem::size_of::<f64>())
851 .saturating_add(
852 self.hyper_dirs
853 .iter()
854 .map(DirectionalHyperParam::resident_byte_count)
855 .sum::<usize>(),
856 )
857 }
858
859 fn resident_scalar_count(&self) -> usize {
860 let mut count = self
861 .design
862 .design
863 .nrows()
864 .saturating_mul(self.design.design.ncols());
865 count = count.saturating_add(
866 self.design
867 .coefficient_lower_bounds
868 .as_ref()
869 .map_or(0, |values| values.len()),
870 );
871 count = count.saturating_add(self.radial_distances.squared.len());
872 count = count.saturating_add(self.radial_distances.distance.len());
873 count = count.saturating_add(
874 self.basis_derivative_jets
875 .phi
876 .as_ref()
877 .map_or(0, |values| values.len()),
878 );
879 count = count.saturating_add(
880 self.basis_derivative_jets
881 .q
882 .as_ref()
883 .map_or(0, |values| values.len()),
884 );
885 count = count.saturating_add(
886 self.basis_derivative_jets
887 .t
888 .as_ref()
889 .map_or(0, |values| values.len()),
890 );
891 count = count.saturating_add(
892 self.basis_derivative_jets
893 .phi_r
894 .as_ref()
895 .map_or(0, |values| values.len()),
896 );
897 count = count.saturating_add(
898 self.basis_derivative_jets
899 .phi_rr
900 .as_ref()
901 .map_or(0, |values| values.len()),
902 );
903 count.saturating_add(usize::from(self.basis_derivative_jets.operator_resident))
904 }
905}
906
907fn lookup_persistent_latent_design(
908 latent: &LatentCoordValues,
909 basis_digest: CacheDigest,
910 latent_metadata_digest: CacheDigest,
911 design_context_digest: CacheDigest,
912 fingerprint: &LatentFingerprint,
913) -> Result<Option<Arc<CachedDesign>>, EstimationError> {
914 let cache = PERSISTENT_LATENT_DESIGN_CACHE
915 .get_or_init(|| Mutex::new(PersistentLatentDesignCache::default()));
916 let mut guard = cache.lock().map_err(|_| {
917 EstimationError::InvalidInput("persistent latent design cache mutex poisoned".to_string())
918 })?;
919 guard.lookup(
920 latent,
921 basis_digest,
922 latent_metadata_digest,
923 design_context_digest,
924 fingerprint,
925 )
926}
927
928fn insert_persistent_latent_design(cached: Arc<CachedDesign>) -> Result<(), EstimationError> {
929 let cache = PERSISTENT_LATENT_DESIGN_CACHE
930 .get_or_init(|| Mutex::new(PersistentLatentDesignCache::default()));
931 let mut guard = cache.lock().map_err(|_| {
932 EstimationError::InvalidInput("persistent latent design cache mutex poisoned".to_string())
933 })?;
934 guard.insert(cached);
935 Ok(())
936}
937
938fn latent_bits(latent: &LatentCoordValues) -> Arc<[u64]> {
939 latent
940 .as_flat()
941 .iter()
942 .map(|value| value.to_bits())
943 .collect::<Vec<_>>()
944 .into()
945}
946
947fn latent_bits_match(latent: &LatentCoordValues, cached_bits: &[u64]) -> bool {
948 latent.as_flat().len() == cached_bits.len()
949 && latent
950 .as_flat()
951 .iter()
952 .zip(cached_bits.iter())
953 .all(|(value, bits)| value.to_bits() == *bits)
954}
955
956fn build_radial_distances(
957 latent: &LatentCoordValues,
958 centers: &Array2<f64>,
959) -> Result<RadialDistanceMatrices, EstimationError> {
960 let t = latent.as_matrix();
961 if t.ncols() != centers.ncols() {
962 return Err(EstimationError::InvalidInput(format!(
963 "latent design cache center dimension mismatch: latent d={}, centers d={}",
964 t.ncols(),
965 centers.ncols()
966 )));
967 }
968 let mut squared = Array2::<f64>::zeros((t.nrows(), centers.nrows()));
969 let mut distance = Array2::<f64>::zeros((t.nrows(), centers.nrows()));
970 for row in 0..t.nrows() {
971 for center in 0..centers.nrows() {
972 let mut r2 = 0.0_f64;
973 for axis in 0..t.ncols() {
974 let delta = t[[row, axis]] - centers[[center, axis]];
975 r2 += delta * delta;
976 }
977 squared[[row, center]] = r2;
978 distance[[row, center]] = r2.sqrt();
979 }
980 }
981 Ok(RadialDistanceMatrices { squared, distance })
982}
983
984fn build_basis_derivative_jets(
985 basis_kind: &LatentBasisKind,
986 distances: &RadialDistanceMatrices,
987) -> Result<BasisDerivativeJets, EstimationError> {
988 match basis_kind {
989 LatentBasisKind::Matern {
990 length_scale,
991 nu,
992 chunk_size,
993 ..
994 } => {
995 if chunk_size.is_some() {
996 return Ok(BasisDerivativeJets {
997 operator_resident: true,
998 ..BasisDerivativeJets::empty()
999 });
1000 }
1001 let radial = RadialScalarKind::Matern {
1002 length_scale: *length_scale,
1003 nu: *nu,
1004 };
1005 let mut phi = Array2::<f64>::zeros(distances.distance.raw_dim());
1006 let mut q = Array2::<f64>::zeros(distances.distance.raw_dim());
1007 let mut t = Array2::<f64>::zeros(distances.distance.raw_dim());
1008 for row in 0..distances.distance.nrows() {
1009 for center in 0..distances.distance.ncols() {
1010 let (phi_value, q_value, t_value) = radial
1011 .eval_design_triplet(distances.distance[[row, center]])
1012 .map_err(EstimationError::from)?;
1013 phi[[row, center]] = phi_value;
1014 q[[row, center]] = q_value;
1015 t[[row, center]] = t_value;
1016 }
1017 }
1018 Ok(BasisDerivativeJets {
1019 phi: Some(phi),
1020 q: Some(q),
1021 t: Some(t),
1022 phi_r: None,
1023 phi_rr: None,
1024 operator_resident: false,
1025 })
1026 }
1027 LatentBasisKind::Duchon { .. } => Ok(BasisDerivativeJets {
1028 operator_resident: true,
1029 ..BasisDerivativeJets::empty()
1030 }),
1031 LatentBasisKind::Sphere { .. }
1032 | LatentBasisKind::PeriodicBspline { .. }
1033 | LatentBasisKind::Pca { .. }
1034 | LatentBasisKind::TensorBspline { .. } => Ok(BasisDerivativeJets {
1035 operator_resident: true,
1036 ..BasisDerivativeJets::empty()
1037 }),
1038 }
1039}