Skip to main content

gam/
resource.rs

1#[derive(Clone, Debug)]
2pub struct ResourcePolicy {
3    pub max_single_materialization_bytes: usize,
4    pub max_operator_cache_bytes: usize,
5    pub max_spatial_distance_cache_bytes: usize,
6    pub max_owned_data_cache_bytes: usize,
7    pub row_chunk_target_bytes: usize,
8    pub derivative_storage_mode: DerivativeStorageMode,
9}
10
11pub const SPATIAL_DISTANCE_CACHE_MAX_BYTES: usize = 512 * 1024 * 1024;
12pub const SPATIAL_DISTANCE_CACHE_SINGLE_ENTRY_MAX_BYTES: usize = 256 * 1024 * 1024;
13pub const OWNED_DATA_CACHE_MAX_ENTRIES: usize = 2;
14
15/// Auto-strict triggers for [`ResourcePolicy::for_problem`].
16///
17/// Tuned for biobank-scale problems where dense materialization of any
18/// design factor would itself be tens of GiB. Below these thresholds we
19/// stay on `default_library` so small-data and ad-hoc fits keep working
20/// without operator implementations.
21pub const STRICT_POLICY_NROWS_THRESHOLD: usize = 100_000;
22pub const STRICT_POLICY_P_THRESHOLD: usize = 5_000;
23
24/// Hints that flip strict mode on regardless of n/p — used when a code path
25/// is structurally operator-only and any dense fallback would be a bug.
26#[derive(Clone, Copy, Debug, Default)]
27pub struct ProblemHints {
28    pub marginal_slope_biobank_active: bool,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum DerivativeStorageMode {
33    /// Production exact-math: operator-backed, no dense fallback.
34    AnalyticOperatorRequired,
35    /// Allow dense materialization if under the single-materialization budget.
36    MaterializeIfSmall,
37    /// Dense materialization only permitted for diagnostic code paths.
38    DiagnosticsOnly,
39}
40
41#[derive(Clone, Debug)]
42pub struct MaterializationPolicy {
43    pub max_single_dense_bytes: usize,
44    pub max_cached_dense_bytes: usize,
45    pub row_chunk_target_bytes: usize,
46    pub allow_operator_materialization: bool,
47    pub allow_diagnostic_materialization: bool,
48}
49
50#[derive(Debug, thiserror::Error)]
51pub enum MatrixMaterializationError {
52    #[error(
53        "{context}: dense materialization of {nrows}x{ncols} requires {bytes} bytes (limit {limit_bytes})"
54    )]
55    TooLarge {
56        context: &'static str,
57        nrows: usize,
58        ncols: usize,
59        bytes: usize,
60        limit_bytes: usize,
61    },
62
63    #[error("{context}: operator does not implement chunked row access")]
64    MissingRowChunk { context: &'static str },
65
66    #[error("{context}: materialization forbidden by policy (mode={mode:?})")]
67    Forbidden {
68        context: &'static str,
69        mode: DerivativeStorageMode,
70    },
71}
72
73pub trait ResidentBytes {
74    fn resident_bytes(&self) -> usize;
75}
76
77impl ResourcePolicy {
78    /// Conservative default suitable for general-purpose use.
79    ///
80    /// Uses `MaterializeIfSmall`: dense materialization is allowed only when the
81    /// matrix fits under `max_single_materialization_bytes`. This lets small-data
82    /// families that lack an implicit operator work out of the box, while
83    /// biobank-scale problems error out and force the analytic-operator path.
84    /// Set `derivative_storage_mode = AnalyticOperatorRequired` explicitly to
85    /// reject all dense fallback.
86    pub fn default_library() -> Self {
87        Self {
88            max_single_materialization_bytes: 256 * 1024 * 1024, // 256 MiB
89            max_operator_cache_bytes: 1024 * 1024 * 1024,        // 1 GiB
90            max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
91            max_owned_data_cache_bytes: 512 * 1024 * 1024, // 512 MiB
92            row_chunk_target_bytes: 8 * 1024 * 1024,       // 8 MiB per chunk
93            derivative_storage_mode: DerivativeStorageMode::MaterializeIfSmall,
94        }
95    }
96
97    /// Strict mode that rejects every dense fallback. Use when you intend to
98    /// run only on operator-backed bases (biobank-scale Duchon/TPS, exact
99    /// GAMLSS marginal slope, CTN, etc.).
100    pub fn analytic_operator_required() -> Self {
101        Self {
102            derivative_storage_mode: DerivativeStorageMode::AnalyticOperatorRequired,
103            ..Self::default_library()
104        }
105    }
106
107    /// Auto-derive the resource policy from the shape of the problem rather
108    /// than from an explicit CLI flag. The library refuses to silently
109    /// densify operator-backed designs once the problem is large enough that
110    /// a hidden dense fallback would blow real-world memory budgets, but
111    /// keeps the permissive default for ordinary small-data fits so that
112    /// non-operator bases still work out of the box.
113    ///
114    /// Strict mode (`AnalyticOperatorRequired`) is selected when ANY of:
115    ///   * `n_rows >= STRICT_POLICY_NROWS_THRESHOLD` (biobank scale by row count)
116    ///   * `p_estimate >= STRICT_POLICY_P_THRESHOLD` (biobank scale by coefficient count)
117    ///   * `hints.marginal_slope_biobank_active` (the GAMLSS marginal-slope
118    ///     biobank path is in play; the corresponding operators MUST stay
119    ///     matrix-free regardless of n)
120    pub fn for_problem(n_rows: usize, p_estimate: usize, hints: ProblemHints) -> Self {
121        let strict = n_rows >= STRICT_POLICY_NROWS_THRESHOLD
122            || p_estimate >= STRICT_POLICY_P_THRESHOLD
123            || hints.marginal_slope_biobank_active;
124        if strict {
125            Self::analytic_operator_required()
126        } else {
127            Self::default_library()
128        }
129    }
130
131    /// Permissive mode for small-data usage and tests.
132    pub fn permissive_small_data() -> Self {
133        Self {
134            max_single_materialization_bytes: 2 * 1024 * 1024 * 1024, // 2 GiB
135            max_operator_cache_bytes: 2 * 1024 * 1024 * 1024,
136            max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
137            max_owned_data_cache_bytes: 512 * 1024 * 1024,
138            row_chunk_target_bytes: 64 * 1024 * 1024,
139            derivative_storage_mode: DerivativeStorageMode::MaterializeIfSmall,
140        }
141    }
142
143    pub fn material_policy(&self) -> MaterializationPolicy {
144        MaterializationPolicy {
145            max_single_dense_bytes: self.max_single_materialization_bytes,
146            max_cached_dense_bytes: self.max_operator_cache_bytes,
147            row_chunk_target_bytes: self.row_chunk_target_bytes,
148            allow_operator_materialization: matches!(
149                self.derivative_storage_mode,
150                DerivativeStorageMode::MaterializeIfSmall
151            ),
152            allow_diagnostic_materialization: !matches!(
153                self.derivative_storage_mode,
154                DerivativeStorageMode::AnalyticOperatorRequired
155            ),
156        }
157    }
158}
159
160/// Returns how many rows to stream per chunk so that each chunk uses approximately
161/// `target_bytes` given a row width of `cols` f64 entries.
162pub fn rows_for_target_bytes(target_bytes: usize, cols: usize) -> usize {
163    let bytes_per_row = cols.saturating_mul(std::mem::size_of::<f64>()).max(1);
164    (target_bytes / bytes_per_row).max(1)
165}
166
167use std::collections::{HashMap, VecDeque};
168use std::hash::Hash;
169use std::sync::{Arc, Mutex};
170
171/// Byte-limited LRU cache with an optional entry cap.
172///
173/// Unlike an entry-count-limited LRU, this cache tracks the resident byte cost
174/// of each value (via [`ResidentBytes`]) and evicts the least-recently-used
175/// entries until the total resident bytes fit under `max_bytes`. This is the
176/// correct policy for biobank-scale payloads where a single cache entry (e.g.
177/// an n*K distance matrix) can itself be multiple gigabytes and an entry-count
178/// cap would silently blow the memory budget. Small entry caps are still useful
179/// for payloads with known shape, such as owned PC data matrices shared across
180/// model blocks.
181pub struct ByteLruCache<K: Eq + Hash + Clone, V> {
182    inner: Mutex<ByteLruInner<K, V>>,
183    max_bytes: usize,
184    max_entries: Option<usize>,
185}
186
187struct ByteLruInner<K, V> {
188    map: HashMap<K, (V, usize)>, // (value, byte_charge)
189    order: VecDeque<K>,
190    resident_bytes: usize,
191}
192
193impl<K: Eq + Hash + Clone, V: Clone + ResidentBytes> ByteLruCache<K, V> {
194    pub fn new(max_bytes: usize) -> Self {
195        Self {
196            inner: Mutex::new(ByteLruInner {
197                map: HashMap::new(),
198                order: VecDeque::new(),
199                resident_bytes: 0,
200            }),
201            max_bytes,
202            max_entries: None,
203        }
204    }
205
206    pub fn with_max_entries(max_bytes: usize, max_entries: usize) -> Self {
207        Self {
208            inner: Mutex::new(ByteLruInner {
209                map: HashMap::new(),
210                order: VecDeque::new(),
211                resident_bytes: 0,
212            }),
213            max_bytes,
214            max_entries: Some(max_entries),
215        }
216    }
217
218    pub fn get(&self, key: &K) -> Option<V> {
219        let mut g = self.inner.lock().unwrap();
220        let v = g.map.get(key)?.0.clone();
221        // move to back (most-recently-used)
222        if let Some(pos) = g.order.iter().position(|k| k == key) {
223            let k = g.order.remove(pos).unwrap();
224            g.order.push_back(k);
225        }
226        Some(v)
227    }
228
229    pub fn insert(&self, key: K, value: V) {
230        let charge = value.resident_bytes();
231        let mut g = self.inner.lock().unwrap();
232
233        // If already present, remove the old entry first so resident bytes stay
234        // accurate and the LRU ordering reflects this insertion.
235        if let Some((_old, old_charge)) = g.map.remove(&key) {
236            g.resident_bytes = g.resident_bytes.saturating_sub(old_charge);
237            if let Some(pos) = g.order.iter().position(|k| k == &key) {
238                g.order.remove(pos);
239            }
240        }
241
242        if charge > self.max_bytes {
243            // Too large to cache; skip insertion.
244            return;
245        }
246
247        if let Some(max_entries) = self.max_entries {
248            if max_entries == 0 {
249                return;
250            }
251            while g.map.len() >= max_entries {
252                if let Some(evict_key) = g.order.pop_front() {
253                    if let Some((_v, c)) = g.map.remove(&evict_key) {
254                        g.resident_bytes = g.resident_bytes.saturating_sub(c);
255                    }
256                } else {
257                    break;
258                }
259            }
260        }
261
262        while g.resident_bytes + charge > self.max_bytes {
263            if let Some(evict_key) = g.order.pop_front() {
264                if let Some((_v, c)) = g.map.remove(&evict_key) {
265                    g.resident_bytes = g.resident_bytes.saturating_sub(c);
266                }
267            } else {
268                break;
269            }
270        }
271
272        g.map.insert(key.clone(), (value, charge));
273        g.order.push_back(key);
274        g.resident_bytes = g.resident_bytes.saturating_add(charge);
275    }
276
277    pub fn resident_bytes(&self) -> usize {
278        self.inner.lock().unwrap().resident_bytes
279    }
280
281    pub fn max_bytes(&self) -> usize {
282        self.max_bytes
283    }
284
285    pub fn len(&self) -> usize {
286        self.inner.lock().unwrap().map.len()
287    }
288
289    pub fn clear(&self) {
290        let mut g = self.inner.lock().unwrap();
291        g.map.clear();
292        g.order.clear();
293        g.resident_bytes = 0;
294    }
295}
296
297impl<K: Eq + Hash + Clone, V: Clone + ResidentBytes> std::fmt::Debug for ByteLruCache<K, V> {
298    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
299        f.debug_struct("ByteLruCache")
300            .field("resident_bytes", &self.resident_bytes())
301            .field("max_bytes", &self.max_bytes)
302            .field("max_entries", &self.max_entries)
303            .finish()
304    }
305}
306
307/// Byte-accounting for `Arc<Array2<f64>>`.
308///
309/// Reports the full dense footprint of the owned array. Multiple `Arc`s
310/// pointing to the same allocation will each report the full size; this is
311/// the conservative accounting the caches want because a single residency in
312/// the cache is what we are budgeting for.
313impl ResidentBytes for Arc<ndarray::Array2<f64>> {
314    fn resident_bytes(&self) -> usize {
315        std::mem::size_of::<f64>()
316            .saturating_mul(self.nrows())
317            .saturating_mul(self.ncols())
318    }
319}
320
321/// Lazy-init cache safe to call from inside rayon par_iter.
322///
323/// `std::sync::OnceLock::get_or_init` parks racing threads on an OS
324/// condition variable until the leader's init closure finishes. If the
325/// leader's init closure itself dispatches a nested `into_par_iter`, the
326/// parked threads are now unavailable as rayon workers, and the leader
327/// blocks waiting for chunks that no one can service. Classic deadlock.
328///
329/// `RayonSafeOnce` removes the trap by computing the value *outside* any
330/// lock. Concurrent racers may produce duplicate values; the first to
331/// publish wins, the rest drop their result. No thread ever parks waiting
332/// for another thread's init to finish, so nested rayon par_iter inside
333/// the init closure is safe.
334///
335/// Use this in place of `OnceLock` whenever the init closure transitively
336/// runs rayon work *and* the cache may be entered concurrently from
337/// inside another rayon par_iter. The redundant-work cost on first race
338/// is the price for never deadlocking; in practice the loser threads
339/// throw away one round of work and steady-state is identical to
340/// `OnceLock`.
341pub struct RayonSafeOnce<T> {
342    slot: std::sync::OnceLock<T>,
343}
344
345impl<T> RayonSafeOnce<T> {
346    pub const fn new() -> Self {
347        Self {
348            slot: std::sync::OnceLock::new(),
349        }
350    }
351
352    /// Returns the cached value if already populated.
353    #[inline]
354    pub fn get(&self) -> Option<&T> {
355        self.slot.get()
356    }
357
358    /// Returns the cached value, computing it if absent.
359    ///
360    /// The init closure runs WITHOUT holding any lock — calls from
361    /// concurrent rayon workers may all run it, and all but the first
362    /// to call `set` discard their result. This is the contract that
363    /// keeps nested `into_par_iter` inside `init` from deadlocking on
364    /// other workers parked on a `OnceLock`.
365    pub fn get_or_init<F>(&self, init: F) -> &T
366    where
367        F: FnOnce() -> T,
368    {
369        if let Some(v) = self.slot.get() {
370            return v;
371        }
372        let candidate = init();
373        let _ = self.slot.set(candidate);
374        self.slot
375            .get()
376            .expect("RayonSafeOnce slot populated by set() above")
377    }
378
379    /// Fallible variant of `get_or_init`.
380    pub fn get_or_try_init<F, E>(&self, init: F) -> Result<&T, E>
381    where
382        F: FnOnce() -> Result<T, E>,
383    {
384        if let Some(v) = self.slot.get() {
385            return Ok(v);
386        }
387        let candidate = init()?;
388        let _ = self.slot.set(candidate);
389        Ok(self
390            .slot
391            .get()
392            .expect("RayonSafeOnce slot populated by set() above"))
393    }
394}
395
396impl<T> Default for RayonSafeOnce<T> {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402impl<T: Clone> Clone for RayonSafeOnce<T> {
403    fn clone(&self) -> Self {
404        let cloned = Self::new();
405        if let Some(value) = self.slot.get() {
406            let _ = cloned.slot.set(value.clone());
407        }
408        cloned
409    }
410}
411
412impl<T: std::fmt::Debug> std::fmt::Debug for RayonSafeOnce<T> {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        f.debug_struct("RayonSafeOnce")
415            .field("slot", &self.slot.get())
416            .finish()
417    }
418}