1#[derive(Clone, Debug)]
2pub struct ResourcePolicy {
3 pub max_single_materialization_bytes: usize,
4 pub max_operator_cache_bytes: usize,
5 pub max_spatial_distance_cache_bytes: usize,
6 pub max_owned_data_cache_bytes: usize,
7 pub row_chunk_target_bytes: usize,
8 pub derivative_storage_mode: DerivativeStorageMode,
9}
10
11pub const SPATIAL_DISTANCE_CACHE_MAX_BYTES: usize = 512 * 1024 * 1024;
12pub const SPATIAL_DISTANCE_CACHE_SINGLE_ENTRY_MAX_BYTES: usize = 256 * 1024 * 1024;
13pub const OWNED_DATA_CACHE_MAX_ENTRIES: usize = 2;
14
15pub const STRICT_POLICY_NROWS_THRESHOLD: usize = 100_000;
22pub const STRICT_POLICY_P_THRESHOLD: usize = 5_000;
23
24#[derive(Clone, Copy, Debug, Default)]
27pub struct ProblemHints {
28 pub marginal_slope_large_scale_active: bool,
29}
30
31#[derive(Clone, Copy, Debug, PartialEq, Eq)]
32pub enum DerivativeStorageMode {
33 AnalyticOperatorRequired,
35 MaterializeIfSmall,
37 DiagnosticsOnly,
39}
40
41#[derive(Clone, Debug)]
42pub struct MaterializationPolicy {
43 pub max_single_dense_bytes: usize,
44 pub max_cached_dense_bytes: usize,
45 pub row_chunk_target_bytes: usize,
46 pub allow_operator_materialization: bool,
47 pub allow_diagnostic_materialization: bool,
48}
49
50#[derive(Debug, thiserror::Error)]
51pub enum MatrixMaterializationError {
52 #[error(
53 "{context}: dense materialization of {nrows}x{ncols} requires {bytes} bytes (limit {limit_bytes})"
54 )]
55 TooLarge {
56 context: &'static str,
57 nrows: usize,
58 ncols: usize,
59 bytes: usize,
60 limit_bytes: usize,
61 },
62
63 #[error("{context}: operator does not implement chunked row access")]
64 MissingRowChunk { context: &'static str },
65
66 #[error("{context}: row materialization failed: {reason}")]
67 RowMaterializationFailed {
68 context: &'static str,
69 reason: String,
70 },
71
72 #[error("{context}: materialization forbidden by policy (mode={mode:?})")]
73 Forbidden {
74 context: &'static str,
75 mode: DerivativeStorageMode,
76 },
77}
78
79pub trait ResidentBytes {
80 fn resident_bytes(&self) -> usize;
81}
82
83impl ResourcePolicy {
84 pub const fn default_library() -> Self {
103 Self {
104 max_single_materialization_bytes: 1024 * 1024 * 1024, max_operator_cache_bytes: 1024 * 1024 * 1024, max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
107 max_owned_data_cache_bytes: 512 * 1024 * 1024, row_chunk_target_bytes: 8 * 1024 * 1024, derivative_storage_mode: DerivativeStorageMode::MaterializeIfSmall,
110 }
111 }
112
113 pub const fn analytic_operator_required() -> Self {
117 Self {
118 max_single_materialization_bytes: 256 * 1024 * 1024,
119 max_operator_cache_bytes: 1024 * 1024 * 1024,
120 max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
121 max_owned_data_cache_bytes: 512 * 1024 * 1024,
122 row_chunk_target_bytes: 8 * 1024 * 1024,
123 derivative_storage_mode: DerivativeStorageMode::AnalyticOperatorRequired,
124 }
125 }
126
127 pub const fn for_problem(n_rows: usize, p_estimate: usize, hints: ProblemHints) -> Self {
141 let strict = n_rows >= STRICT_POLICY_NROWS_THRESHOLD
142 || p_estimate >= STRICT_POLICY_P_THRESHOLD
143 || hints.marginal_slope_large_scale_active;
144 if strict {
145 Self::analytic_operator_required()
146 } else {
147 Self::default_library()
148 }
149 }
150
151 pub const fn permissive_small_data() -> Self {
153 Self {
154 max_single_materialization_bytes: 2 * 1024 * 1024 * 1024, max_operator_cache_bytes: 2 * 1024 * 1024 * 1024,
156 max_spatial_distance_cache_bytes: SPATIAL_DISTANCE_CACHE_MAX_BYTES,
157 max_owned_data_cache_bytes: 512 * 1024 * 1024,
158 row_chunk_target_bytes: 64 * 1024 * 1024,
159 derivative_storage_mode: DerivativeStorageMode::MaterializeIfSmall,
160 }
161 }
162
163 pub const fn material_policy(&self) -> MaterializationPolicy {
164 MaterializationPolicy {
165 max_single_dense_bytes: self.max_single_materialization_bytes,
166 max_cached_dense_bytes: self.max_operator_cache_bytes,
167 row_chunk_target_bytes: self.row_chunk_target_bytes,
168 allow_operator_materialization: matches!(
169 self.derivative_storage_mode,
170 DerivativeStorageMode::MaterializeIfSmall
171 ),
172 allow_diagnostic_materialization: !matches!(
173 self.derivative_storage_mode,
174 DerivativeStorageMode::AnalyticOperatorRequired
175 ),
176 }
177 }
178}
179
180pub const fn rows_for_target_bytes(target_bytes: usize, cols: usize) -> usize {
183 let raw_bytes_per_row = cols.saturating_mul(std::mem::size_of::<f64>());
184 let bytes_per_row = if raw_bytes_per_row == 0 {
185 1
186 } else {
187 raw_bytes_per_row
188 };
189 let rows = target_bytes / bytes_per_row;
190 if rows == 0 { 1 } else { rows }
191}
192
193use std::collections::{HashMap, VecDeque};
194use std::hash::{Hash, Hasher};
195use std::sync::{Arc, Mutex};
196
197pub struct ByteLruCache<K: Eq + Hash + Clone, V> {
208 shards: Box<[Mutex<ByteLruInner<K, V>>]>,
219 shard_bytes: usize,
221 shard_entries: Option<usize>,
223 max_bytes: usize,
224}
225
226struct ByteLruInner<K, V> {
227 map: HashMap<K, (V, usize)>, order: VecDeque<K>,
229 resident_bytes: usize,
230}
231
232impl<K: Eq + Hash + Clone, V: Clone + ResidentBytes> ByteLruCache<K, V> {
233 pub fn new(max_bytes: usize) -> Self {
234 Self::build(max_bytes, None, 1)
235 }
236
237 pub fn with_max_entries(max_bytes: usize, max_entries: usize) -> Self {
238 Self::build(max_bytes, Some(max_entries), 1)
239 }
240
241 pub fn new_sharded(max_bytes: usize, shard_count: usize) -> Self {
246 Self::build(max_bytes, None, shard_count)
247 }
248
249 pub fn with_max_entries_sharded(
252 max_bytes: usize,
253 max_entries: usize,
254 shard_count: usize,
255 ) -> Self {
256 Self::build(max_bytes, Some(max_entries), shard_count)
257 }
258
259 fn build(max_bytes: usize, max_entries: Option<usize>, shard_count: usize) -> Self {
260 let shard_count = shard_count.max(1);
261 let shard_bytes = max_bytes.div_ceil(shard_count);
266 let shard_entries = max_entries.map(|m| {
267 if m == 0 {
268 0
269 } else {
270 m.div_ceil(shard_count).max(1)
271 }
272 });
273 let shards = (0..shard_count)
274 .map(|_| {
275 Mutex::new(ByteLruInner {
276 map: HashMap::new(),
277 order: VecDeque::new(),
278 resident_bytes: 0,
279 })
280 })
281 .collect::<Vec<_>>()
282 .into_boxed_slice();
283 Self {
284 shards,
285 shard_bytes,
286 shard_entries,
287 max_bytes,
288 }
289 }
290
291 #[inline]
292 fn shard(&self, key: &K) -> &Mutex<ByteLruInner<K, V>> {
293 if self.shards.len() == 1 {
294 return &self.shards[0];
295 }
296 let mut hasher = std::collections::hash_map::DefaultHasher::new();
297 key.hash(&mut hasher);
298 &self.shards[(hasher.finish() as usize) % self.shards.len()]
299 }
300
301 pub fn get(&self, key: &K) -> Option<V> {
302 let mut g = self.shard(key).lock().unwrap_or_else(|p| p.into_inner());
304 let v = g.map.get(key)?.0.clone();
305 if let Some(pos) = g.order.iter().position(|k| k == key) {
307 let k = g.order.remove(pos).unwrap();
308 g.order.push_back(k);
309 }
310 Some(v)
311 }
312
313 pub fn insert(&self, key: K, value: V) {
314 let charge = value.resident_bytes();
315 let mut g = self.shard(&key).lock().unwrap_or_else(|p| p.into_inner());
316
317 if let Some((_old, old_charge)) = g.map.remove(&key) {
320 g.resident_bytes = g.resident_bytes.saturating_sub(old_charge);
321 if let Some(pos) = g.order.iter().position(|k| k == &key) {
322 g.order.remove(pos);
323 }
324 }
325
326 if charge > self.shard_bytes {
327 return;
329 }
330
331 if let Some(max_entries) = self.shard_entries {
332 if max_entries == 0 {
333 return;
334 }
335 while g.map.len() >= max_entries {
336 if let Some(evict_key) = g.order.pop_front() {
337 if let Some((_v, c)) = g.map.remove(&evict_key) {
338 g.resident_bytes = g.resident_bytes.saturating_sub(c);
339 }
340 } else {
341 break;
342 }
343 }
344 }
345
346 while g.resident_bytes + charge > self.shard_bytes {
347 if let Some(evict_key) = g.order.pop_front() {
348 if let Some((_v, c)) = g.map.remove(&evict_key) {
349 g.resident_bytes = g.resident_bytes.saturating_sub(c);
350 }
351 } else {
352 break;
353 }
354 }
355
356 g.map.insert(key.clone(), (value, charge));
357 g.order.push_back(key);
358 g.resident_bytes = g.resident_bytes.saturating_add(charge);
359 }
360
361 pub fn resident_bytes(&self) -> usize {
362 self.shards
363 .iter()
364 .map(|shard| {
365 shard
366 .lock()
367 .unwrap_or_else(|p| p.into_inner())
368 .resident_bytes
369 })
370 .sum()
371 }
372
373 pub const fn max_bytes(&self) -> usize {
374 self.max_bytes
375 }
376
377 pub fn len(&self) -> usize {
378 self.shards
379 .iter()
380 .map(|shard| shard.lock().unwrap_or_else(|p| p.into_inner()).map.len())
381 .sum()
382 }
383
384 pub fn is_empty(&self) -> bool {
385 self.len() == 0
386 }
387
388 pub fn clear(&self) {
389 for shard in self.shards.iter() {
390 let mut g = shard.lock().unwrap_or_else(|p| p.into_inner());
391 g.map.clear();
392 g.order.clear();
393 g.resident_bytes = 0;
394 }
395 }
396}
397
398impl<K: Eq + Hash + Clone, V: Clone + ResidentBytes> std::fmt::Debug for ByteLruCache<K, V> {
399 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
400 f.debug_struct("ByteLruCache")
401 .field("resident_bytes", &self.resident_bytes())
402 .field("max_bytes", &self.max_bytes)
403 .field("shard_count", &self.shards.len())
404 .field("shard_bytes", &self.shard_bytes)
405 .field("shard_entries", &self.shard_entries)
406 .finish()
407 }
408}
409
410impl ResidentBytes for Arc<ndarray::Array2<f64>> {
417 fn resident_bytes(&self) -> usize {
418 std::mem::size_of::<f64>()
419 .saturating_mul(self.nrows())
420 .saturating_mul(self.ncols())
421 }
422}
423
424pub struct RayonSafeOnce<T> {
445 slot: std::sync::OnceLock<T>,
446}
447
448impl<T> RayonSafeOnce<T> {
449 pub const fn new() -> Self {
450 Self {
451 slot: std::sync::OnceLock::new(),
452 }
453 }
454
455 #[inline]
457 pub fn get(&self) -> Option<&T> {
458 self.slot.get()
459 }
460
461 pub fn get_or_compute<F>(&self, init: F) -> &T
473 where
474 F: FnOnce() -> T,
475 {
476 if let Some(v) = self.slot.get() {
477 return v;
478 }
479 let candidate = init();
480 self.slot.set(candidate).ok();
481 self.slot
482 .get()
483 .expect("RayonSafeOnce slot populated by set() above")
484 }
485}
486
487impl<T> Default for RayonSafeOnce<T> {
488 fn default() -> Self {
489 Self::new()
490 }
491}
492
493impl<T: Clone> Clone for RayonSafeOnce<T> {
494 fn clone(&self) -> Self {
495 let cloned = Self::new();
496 if let Some(value) = self.slot.get() {
497 cloned.slot.set(value.clone()).ok();
498 }
499 cloned
500 }
501}
502
503impl<T: std::fmt::Debug> std::fmt::Debug for RayonSafeOnce<T> {
504 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505 f.debug_struct("RayonSafeOnce")
506 .field("slot", &self.slot.get())
507 .finish()
508 }
509}
510
511#[cfg(test)]
512mod byte_lru_tests {
513 use super::*;
514
515 #[derive(Clone, PartialEq, Debug)]
517 struct Payload(u64);
518 impl ResidentBytes for Payload {
519 fn resident_bytes(&self) -> usize {
520 8
521 }
522 }
523
524 #[test]
525 fn single_shard_round_trips_and_evicts_by_bytes() {
526 let cache: ByteLruCache<u64, Payload> = ByteLruCache::new(24);
528 for k in 0..3 {
529 cache.insert(k, Payload(k));
530 }
531 assert_eq!(cache.len(), 3);
532 assert_eq!(cache.resident_bytes(), 24);
533 assert_eq!(cache.get(&0), Some(Payload(0)));
535 cache.insert(3, Payload(3));
536 assert_eq!(cache.len(), 3);
538 assert_eq!(cache.get(&1), None);
539 assert_eq!(cache.get(&0), Some(Payload(0)));
540 assert_eq!(cache.get(&3), Some(Payload(3)));
541 }
542
543 #[test]
544 fn zero_entry_budget_disables_caching_in_every_shard() {
545 let single: ByteLruCache<u64, Payload> = ByteLruCache::with_max_entries(1 << 20, 0);
546 single.insert(7, Payload(7));
547 assert_eq!(single.get(&7), None);
548 let sharded: ByteLruCache<u64, Payload> =
549 ByteLruCache::with_max_entries_sharded(1 << 20, 0, 16);
550 sharded.insert(7, Payload(7));
551 assert_eq!(sharded.get(&7), None);
552 }
553
554 #[test]
555 fn sharded_cache_retrieves_all_keys_and_respects_aggregate_budget() {
556 let shard_count = 8usize;
560 let max_bytes = 8 * 64; let cache: ByteLruCache<u64, Payload> = ByteLruCache::new_sharded(max_bytes, shard_count);
562 for k in 0..64u64 {
563 cache.insert(k, Payload(k));
564 }
565 assert!(cache.resident_bytes() <= max_bytes.div_ceil(shard_count) * shard_count);
567 cache.insert(123, Payload(123));
569 assert_eq!(cache.get(&123), Some(Payload(123)));
570 assert!(!cache.is_empty());
571 cache.clear();
572 assert_eq!(cache.len(), 0);
573 assert_eq!(cache.resident_bytes(), 0);
574 }
575}