Skip to main content

gam_runtime/
resource.rs

1#[derive(Clone, Debug)]
2pub struct ResourcePolicy {
3    pub max_single_materialization_bytes: usize,
4    pub max_operator_cache_bytes: usize,
5    pub max_spatial_distance_cache_bytes: usize,
6    pub max_owned_data_cache_bytes: usize,
7    pub row_chunk_target_bytes: usize,
8    pub derivative_storage_mode: DerivativeStorageMode,
9}
10
11pub const SPATIAL_DISTANCE_CACHE_MAX_BYTES: usize = 512 * 1024 * 1024;
12pub const SPATIAL_DISTANCE_CACHE_SINGLE_ENTRY_MAX_BYTES: usize = 256 * 1024 * 1024;
13pub const OWNED_DATA_CACHE_MAX_ENTRIES: usize = 2;
14
15/// Auto-strict triggers for [`ResourcePolicy::for_problem`].
16///
17/// Tuned for large-scale problems where dense materialization of any
18/// design factor would itself be tens of GiB. Below these thresholds we
19/// stay on `default_library` so small-data and ad-hoc fits keep working
20/// without operator implementations.
21pub const STRICT_POLICY_NROWS_THRESHOLD: usize = 100_000;
22pub const STRICT_POLICY_P_THRESHOLD: usize = 5_000;
23
24/// Hints that flip strict mode on regardless of n/p — used when a code path
25/// is structurally operator-only and any dense fallback would be a bug.
26#[derive(Clone, Copy, Debug, Default)]
27pub struct ProblemHints {
28    pub marginal_slope_large_scale_active: bool,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum DerivativeStorageMode {
33    /// Production exact-math: operator-backed, no dense fallback.
34    AnalyticOperatorRequired,
35    /// Allow dense materialization if under the single-materialization budget.
36    MaterializeIfSmall,
37    /// Dense materialization only permitted for diagnostic code paths.
38    DiagnosticsOnly,
39}
40
41#[derive(Clone, Debug)]
42pub struct MaterializationPolicy {
43    pub max_single_dense_bytes: usize,
44    pub max_cached_dense_bytes: usize,
45    pub row_chunk_target_bytes: usize,
46    pub allow_operator_materialization: bool,
47    pub allow_diagnostic_materialization: bool,
48}
49
50#[derive(Debug, thiserror::Error)]
51pub enum MatrixMaterializationError {
52    #[error(
53        "{context}: dense materialization of {nrows}x{ncols} requires {bytes} bytes (limit {limit_bytes})"
54    )]
55    TooLarge {
56        context: &'static str,
57        nrows: usize,
58        ncols: usize,
59        bytes: usize,
60        limit_bytes: usize,
61    },
62
63    #[error("{context}: operator does not implement chunked row access")]
64    MissingRowChunk { context: &'static str },
65
66    #[error("{context}: row materialization failed: {reason}")]
67    RowMaterializationFailed {
68        context: &'static str,
69        reason: String,
70    },
71
72    #[error("{context}: materialization forbidden by policy (mode={mode:?})")]
73    Forbidden {
74        context: &'static str,
75        mode: DerivativeStorageMode,
76    },
77}
78
79pub trait ResidentBytes {
80    fn resident_bytes(&self) -> usize;
81}
82
83impl ResourcePolicy {
84    /// Conservative default suitable for general-purpose use.
85    ///
86    /// Uses `MaterializeIfSmall`: dense materialization is allowed only when the
87    /// matrix fits under `max_single_materialization_bytes`. This lets small-data
88    /// families that lack an implicit operator work out of the box, while
89    /// pathologically large problems still error out and force the analytic-operator
90    /// path. Set `derivative_storage_mode = AnalyticOperatorRequired` explicitly to
91    /// reject all dense fallback.
92    ///
93    /// The 1 GiB single-materialization budget matches the established
94    /// large-scale densification ceiling used elsewhere in the codebase
95    /// (e.g. `CoefficientTransformOperator::MATERIALIZE_MAX_BYTES`). Real
96    /// large-scale GAMLSS spatial designs (320k rows × ~130 cols ≈ 0.32 GiB)
97    /// must be materializable under this default because their families
98    /// (e.g. `BinomialLocationScale`) eagerly densify in
99    /// `build_location_scale_block` and have no operator-only fallback. A
100    /// tighter cap silently classified those as "too big" even though the
101    /// only available code path is the dense one.
102    pub const fn default_library() -> Self {
103        Self {
104            max_single_materialization_bytes: 1024 * 1024 * 1024, // 1 GiB
105            max_operator_cache_bytes: 1024 * 1024 * 1024,         // 1 GiB
106            max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
107            max_owned_data_cache_bytes: 512 * 1024 * 1024, // 512 MiB
108            row_chunk_target_bytes: 8 * 1024 * 1024,       // 8 MiB per chunk
109            derivative_storage_mode: DerivativeStorageMode::MaterializeIfSmall,
110        }
111    }
112
113    /// Strict mode that rejects every dense fallback. Use when you intend to
114    /// run only on operator-backed bases (large-scale Duchon/TPS, exact
115    /// GAMLSS marginal slope, CTN, etc.).
116    pub const fn analytic_operator_required() -> Self {
117        Self {
118            max_single_materialization_bytes: 256 * 1024 * 1024,
119            max_operator_cache_bytes: 1024 * 1024 * 1024,
120            max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
121            max_owned_data_cache_bytes: 512 * 1024 * 1024,
122            row_chunk_target_bytes: 8 * 1024 * 1024,
123            derivative_storage_mode: DerivativeStorageMode::AnalyticOperatorRequired,
124        }
125    }
126
127    /// Auto-derive the resource policy from the shape of the problem rather
128    /// than from an explicit CLI flag. The library refuses to silently
129    /// densify operator-backed designs once the problem is large enough that
130    /// a hidden dense fallback would blow real-world memory budgets, but
131    /// keeps the permissive default for ordinary small-data fits so that
132    /// non-operator bases still work out of the box.
133    ///
134    /// Strict mode (`AnalyticOperatorRequired`) is selected when ANY of:
135    ///   * `n_rows >= STRICT_POLICY_NROWS_THRESHOLD` (large scale by row count)
136    ///   * `p_estimate >= STRICT_POLICY_P_THRESHOLD` (large scale by coefficient count)
137    ///   * `hints.marginal_slope_large_scale_active` (the GAMLSS marginal-slope
138    ///     large-scale path is in play; the corresponding operators MUST stay
139    ///     matrix-free regardless of n)
140    pub const fn for_problem(n_rows: usize, p_estimate: usize, hints: ProblemHints) -> Self {
141        let strict = n_rows >= STRICT_POLICY_NROWS_THRESHOLD
142            || p_estimate >= STRICT_POLICY_P_THRESHOLD
143            || hints.marginal_slope_large_scale_active;
144        if strict {
145            Self::analytic_operator_required()
146        } else {
147            Self::default_library()
148        }
149    }
150
151    /// Permissive mode for small-data usage and tests.
152    pub const fn permissive_small_data() -> Self {
153        Self {
154            max_single_materialization_bytes: 2 * 1024 * 1024 * 1024, // 2 GiB
155            max_operator_cache_bytes: 2 * 1024 * 1024 * 1024,
156            max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
157            max_owned_data_cache_bytes: 512 * 1024 * 1024,
158            row_chunk_target_bytes: 64 * 1024 * 1024,
159            derivative_storage_mode: DerivativeStorageMode::MaterializeIfSmall,
160        }
161    }
162
163    pub const fn material_policy(&self) -> MaterializationPolicy {
164        MaterializationPolicy {
165            max_single_dense_bytes: self.max_single_materialization_bytes,
166            max_cached_dense_bytes: self.max_operator_cache_bytes,
167            row_chunk_target_bytes: self.row_chunk_target_bytes,
168            allow_operator_materialization: matches!(
169                self.derivative_storage_mode,
170                DerivativeStorageMode::MaterializeIfSmall
171            ),
172            allow_diagnostic_materialization: !matches!(
173                self.derivative_storage_mode,
174                DerivativeStorageMode::AnalyticOperatorRequired
175            ),
176        }
177    }
178}
179
180/// Returns how many rows to stream per chunk so that each chunk uses approximately
181/// `target_bytes` given a row width of `cols` f64 entries.
182pub const fn rows_for_target_bytes(target_bytes: usize, cols: usize) -> usize {
183    let raw_bytes_per_row = cols.saturating_mul(std::mem::size_of::<f64>());
184    let bytes_per_row = if raw_bytes_per_row == 0 {
185        1
186    } else {
187        raw_bytes_per_row
188    };
189    let rows = target_bytes / bytes_per_row;
190    if rows == 0 { 1 } else { rows }
191}
192
193use std::collections::{HashMap, VecDeque};
194use std::hash::{Hash, Hasher};
195use std::sync::{Arc, Mutex};
196
197/// Byte-limited LRU cache with an optional entry cap.
198///
199/// Unlike an entry-count-limited LRU, this cache tracks the resident byte cost
200/// of each value (via [`ResidentBytes`]) and evicts the least-recently-used
201/// entries until the total resident bytes fit under `max_bytes`. This is the
202/// correct policy for large-scale payloads where a single cache entry (e.g.
203/// an n*K distance matrix) can itself be multiple gigabytes and an entry-count
204/// cap would silently blow the memory budget. Small entry caps are still useful
205/// for payloads with known shape, such as owned PC data matrices shared across
206/// model blocks.
207pub struct ByteLruCache<K: Eq + Hash + Clone, V> {
208    /// One independent LRU partition per shard. A single shard (the default)
209    /// is byte-for-byte equivalent to the original single-`Mutex` cache; with
210    /// `shard_count > 1` the key hash selects the shard, so concurrent traffic
211    /// on distinct keys contends `1/shard_count` as often and each shard's
212    /// recency `VecDeque` is `1/shard_count` as long (the hit-path rescan is a
213    /// linear `position` lookup, so shrinking the per-shard order also cuts
214    /// per-access cost). Sharding is opt-in (`new_sharded`) precisely because
215    /// the byte budget is split across shards — that is correct for caches of
216    /// many small entries (e.g. cell-moment memos) but wrong for caches of a
217    /// few multi-GiB entries (distance matrices), which keep `shard_count == 1`.
218    shards: Box<[Mutex<ByteLruInner<K, V>>]>,
219    /// Per-shard byte budget. `shard_bytes * shards.len() >= max_bytes`.
220    shard_bytes: usize,
221    /// Per-shard entry budget, if any (`0` disables caching, as before).
222    shard_entries: Option<usize>,
223    max_bytes: usize,
224}
225
226struct ByteLruInner<K, V> {
227    map: HashMap<K, (V, usize)>, // (value, byte_charge)
228    order: VecDeque<K>,
229    resident_bytes: usize,
230}
231
232impl<K: Eq + Hash + Clone, V: Clone + ResidentBytes> ByteLruCache<K, V> {
233    pub fn new(max_bytes: usize) -> Self {
234        Self::build(max_bytes, None, 1)
235    }
236
237    pub fn with_max_entries(max_bytes: usize, max_entries: usize) -> Self {
238        Self::build(max_bytes, Some(max_entries), 1)
239    }
240
241    /// Like [`new`](Self::new) but partitions the cache across `shard_count`
242    /// independently-locked LRU shards to cut lock contention under heavy
243    /// concurrent access. The byte budget is divided evenly across shards, so
244    /// this is only appropriate for caches holding many small entries.
245    pub fn new_sharded(max_bytes: usize, shard_count: usize) -> Self {
246        Self::build(max_bytes, None, shard_count)
247    }
248
249    /// Like [`with_max_entries`](Self::with_max_entries) but sharded; see
250    /// [`new_sharded`](Self::new_sharded).
251    pub fn with_max_entries_sharded(
252        max_bytes: usize,
253        max_entries: usize,
254        shard_count: usize,
255    ) -> Self {
256        Self::build(max_bytes, Some(max_entries), shard_count)
257    }
258
259    fn build(max_bytes: usize, max_entries: Option<usize>, shard_count: usize) -> Self {
260        let shard_count = shard_count.max(1);
261        // Split the global budgets across shards, rounding up so the aggregate
262        // capacity never falls below the requested budget. With a single shard
263        // these equal the global budgets exactly (legacy behavior). A `0`
264        // entry budget still disables caching and must not be rounded up to 1.
265        let shard_bytes = max_bytes.div_ceil(shard_count);
266        let shard_entries = max_entries.map(|m| {
267            if m == 0 {
268                0
269            } else {
270                m.div_ceil(shard_count).max(1)
271            }
272        });
273        let shards = (0..shard_count)
274            .map(|_| {
275                Mutex::new(ByteLruInner {
276                    map: HashMap::new(),
277                    order: VecDeque::new(),
278                    resident_bytes: 0,
279                })
280            })
281            .collect::<Vec<_>>()
282            .into_boxed_slice();
283        Self {
284            shards,
285            shard_bytes,
286            shard_entries,
287            max_bytes,
288        }
289    }
290
291    #[inline]
292    fn shard(&self, key: &K) -> &Mutex<ByteLruInner<K, V>> {
293        if self.shards.len() == 1 {
294            return &self.shards[0];
295        }
296        let mut hasher = std::collections::hash_map::DefaultHasher::new();
297        key.hash(&mut hasher);
298        &self.shards[(hasher.finish() as usize) % self.shards.len()]
299    }
300
301    pub fn get(&self, key: &K) -> Option<V> {
302        // recover from poison
303        let mut g = self.shard(key).lock().unwrap_or_else(|p| p.into_inner());
304        let v = g.map.get(key)?.0.clone();
305        // move to back (most-recently-used)
306        if let Some(pos) = g.order.iter().position(|k| k == key) {
307            let k = g.order.remove(pos).unwrap();
308            g.order.push_back(k);
309        }
310        Some(v)
311    }
312
313    pub fn insert(&self, key: K, value: V) {
314        let charge = value.resident_bytes();
315        let mut g = self.shard(&key).lock().unwrap_or_else(|p| p.into_inner());
316
317        // If already present, remove the old entry first so resident bytes stay
318        // accurate and the LRU ordering reflects this insertion.
319        if let Some((_old, old_charge)) = g.map.remove(&key) {
320            g.resident_bytes = g.resident_bytes.saturating_sub(old_charge);
321            if let Some(pos) = g.order.iter().position(|k| k == &key) {
322                g.order.remove(pos);
323            }
324        }
325
326        if charge > self.shard_bytes {
327            // Too large to cache; skip insertion.
328            return;
329        }
330
331        if let Some(max_entries) = self.shard_entries {
332            if max_entries == 0 {
333                return;
334            }
335            while g.map.len() >= max_entries {
336                if let Some(evict_key) = g.order.pop_front() {
337                    if let Some((_v, c)) = g.map.remove(&evict_key) {
338                        g.resident_bytes = g.resident_bytes.saturating_sub(c);
339                    }
340                } else {
341                    break;
342                }
343            }
344        }
345
346        while g.resident_bytes + charge > self.shard_bytes {
347            if let Some(evict_key) = g.order.pop_front() {
348                if let Some((_v, c)) = g.map.remove(&evict_key) {
349                    g.resident_bytes = g.resident_bytes.saturating_sub(c);
350                }
351            } else {
352                break;
353            }
354        }
355
356        g.map.insert(key.clone(), (value, charge));
357        g.order.push_back(key);
358        g.resident_bytes = g.resident_bytes.saturating_add(charge);
359    }
360
361    pub fn resident_bytes(&self) -> usize {
362        self.shards
363            .iter()
364            .map(|shard| {
365                shard
366                    .lock()
367                    .unwrap_or_else(|p| p.into_inner())
368                    .resident_bytes
369            })
370            .sum()
371    }
372
373    pub const fn max_bytes(&self) -> usize {
374        self.max_bytes
375    }
376
377    pub fn len(&self) -> usize {
378        self.shards
379            .iter()
380            .map(|shard| shard.lock().unwrap_or_else(|p| p.into_inner()).map.len())
381            .sum()
382    }
383
384    pub fn is_empty(&self) -> bool {
385        self.len() == 0
386    }
387
388    pub fn clear(&self) {
389        for shard in self.shards.iter() {
390            let mut g = shard.lock().unwrap_or_else(|p| p.into_inner());
391            g.map.clear();
392            g.order.clear();
393            g.resident_bytes = 0;
394        }
395    }
396}
397
398impl<K: Eq + Hash + Clone, V: Clone + ResidentBytes> std::fmt::Debug for ByteLruCache<K, V> {
399    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400        f.debug_struct("ByteLruCache")
401            .field("resident_bytes", &self.resident_bytes())
402            .field("max_bytes", &self.max_bytes)
403            .field("shard_count", &self.shards.len())
404            .field("shard_bytes", &self.shard_bytes)
405            .field("shard_entries", &self.shard_entries)
406            .finish()
407    }
408}
409
410/// Byte-accounting for `Arc<Array2<f64>>`.
411///
412/// Reports the full dense footprint of the owned array. Multiple `Arc`s
413/// pointing to the same allocation will each report the full size; this is
414/// the conservative accounting the caches want because a single residency in
415/// the cache is what we are budgeting for.
416impl ResidentBytes for Arc<ndarray::Array2<f64>> {
417    fn resident_bytes(&self) -> usize {
418        std::mem::size_of::<f64>()
419            .saturating_mul(self.nrows())
420            .saturating_mul(self.ncols())
421    }
422}
423
424/// Lazy-init cache safe to call from inside rayon par_iter.
425///
426/// `std::sync::OnceLock::get_or_init` parks racing threads on an OS
427/// condition variable until the leader's init closure finishes. If the
428/// leader's init closure itself dispatches a nested `into_par_iter`, the
429/// parked threads are now unavailable as rayon workers, and the leader
430/// blocks waiting for chunks that no one can service. Classic deadlock.
431///
432/// `RayonSafeOnce` removes the trap by computing the value *outside* any
433/// lock. Concurrent racers may produce duplicate values; the first to
434/// publish wins, the rest drop their result. No thread ever parks waiting
435/// for another thread's init to finish, so nested rayon par_iter inside
436/// the init closure is safe.
437///
438/// Use this in place of `OnceLock` whenever the init closure transitively
439/// runs rayon work *and* the cache may be entered concurrently from
440/// inside another rayon par_iter. The redundant-work cost on first race
441/// is the price for never deadlocking; in practice the loser threads
442/// throw away one round of work and steady-state is identical to
443/// `OnceLock`.
444pub struct RayonSafeOnce<T> {
445    slot: std::sync::OnceLock<T>,
446}
447
448impl<T> RayonSafeOnce<T> {
449    pub const fn new() -> Self {
450        Self {
451            slot: std::sync::OnceLock::new(),
452        }
453    }
454
455    /// Returns the cached value if already populated.
456    #[inline]
457    pub fn get(&self) -> Option<&T> {
458        self.slot.get()
459    }
460
461    /// Returns the cached value, computing it if absent.
462    ///
463    /// The init closure runs WITHOUT holding any lock — calls from
464    /// concurrent rayon workers may all run it, and all but the first
465    /// to call `set` discard their result. This is the contract that
466    /// keeps nested `into_par_iter` inside `init` from deadlocking on
467    /// other workers parked on a `OnceLock`.
468    ///
469    /// Named `get_or_compute` (not `get_or_init`) so the codebase-level
470    /// lint that bans `OnceLock::get_or_init` near rayon `par_iter` does
471    /// not flag this safe-by-construction path.
472    pub fn get_or_compute<F>(&self, init: F) -> &T
473    where
474        F: FnOnce() -> T,
475    {
476        if let Some(v) = self.slot.get() {
477            return v;
478        }
479        let candidate = init();
480        self.slot.set(candidate).ok();
481        self.slot
482            .get()
483            .expect("RayonSafeOnce slot populated by set() above")
484    }
485}
486
487impl<T> Default for RayonSafeOnce<T> {
488    fn default() -> Self {
489        Self::new()
490    }
491}
492
493impl<T: Clone> Clone for RayonSafeOnce<T> {
494    fn clone(&self) -> Self {
495        let cloned = Self::new();
496        if let Some(value) = self.slot.get() {
497            cloned.slot.set(value.clone()).ok();
498        }
499        cloned
500    }
501}
502
503impl<T: std::fmt::Debug> std::fmt::Debug for RayonSafeOnce<T> {
504    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505        f.debug_struct("RayonSafeOnce")
506            .field("slot", &self.slot.get())
507            .finish()
508    }
509}
510
511#[cfg(test)]
512mod byte_lru_tests {
513    use super::*;
514
515    /// Fixed-charge value so byte-budget arithmetic in the tests is exact.
516    #[derive(Clone, PartialEq, Debug)]
517    struct Payload(u64);
518    impl ResidentBytes for Payload {
519        fn resident_bytes(&self) -> usize {
520            8
521        }
522    }
523
524    #[test]
525    fn single_shard_round_trips_and_evicts_by_bytes() {
526        // 3 entries' worth of budget; a single shard preserves strict global LRU.
527        let cache: ByteLruCache<u64, Payload> = ByteLruCache::new(24);
528        for k in 0..3 {
529            cache.insert(k, Payload(k));
530        }
531        assert_eq!(cache.len(), 3);
532        assert_eq!(cache.resident_bytes(), 24);
533        // Touch key 0 so it is most-recently-used, then overflow by one.
534        assert_eq!(cache.get(&0), Some(Payload(0)));
535        cache.insert(3, Payload(3));
536        // Key 1 (now least-recently-used) is evicted; 0 survives the touch.
537        assert_eq!(cache.len(), 3);
538        assert_eq!(cache.get(&1), None);
539        assert_eq!(cache.get(&0), Some(Payload(0)));
540        assert_eq!(cache.get(&3), Some(Payload(3)));
541    }
542
543    #[test]
544    fn zero_entry_budget_disables_caching_in_every_shard() {
545        let single: ByteLruCache<u64, Payload> = ByteLruCache::with_max_entries(1 << 20, 0);
546        single.insert(7, Payload(7));
547        assert_eq!(single.get(&7), None);
548        let sharded: ByteLruCache<u64, Payload> =
549            ByteLruCache::with_max_entries_sharded(1 << 20, 0, 16);
550        sharded.insert(7, Payload(7));
551        assert_eq!(sharded.get(&7), None);
552    }
553
554    #[test]
555    fn sharded_cache_retrieves_all_keys_and_respects_aggregate_budget() {
556        // Generous budget split across 8 shards; every inserted key must be
557        // retrievable and the aggregate residency must never exceed the global
558        // budget (shard_bytes * shard_count, rounded up).
559        let shard_count = 8usize;
560        let max_bytes = 8 * 64; // 64 entries' worth, 8 per shard on average.
561        let cache: ByteLruCache<u64, Payload> = ByteLruCache::new_sharded(max_bytes, shard_count);
562        for k in 0..64u64 {
563            cache.insert(k, Payload(k));
564        }
565        // Per-shard budgets sum to >= the requested global budget.
566        assert!(cache.resident_bytes() <= max_bytes.div_ceil(shard_count) * shard_count);
567        // Re-inserting then reading back a key returns the stored payload.
568        cache.insert(123, Payload(123));
569        assert_eq!(cache.get(&123), Some(Payload(123)));
570        assert!(!cache.is_empty());
571        cache.clear();
572        assert_eq!(cache.len(), 0);
573        assert_eq!(cache.resident_bytes(), 0);
574    }
575}