Skip to main content

alopex_sql/executor/
memory.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use crate::executor::{ExecutorError, Result};
5use crate::storage::SqlValue;
6
7#[derive(Clone, Debug)]
8pub enum SpillPolicy {
9    FailFast,
10    SpillToDisk { directory: PathBuf },
11}
12
13pub trait SpillMetricsSink: Send + Sync {
14    fn record_spill(&self, bytes: u64, files: u64);
15}
16
17#[derive(Clone)]
18pub struct MemoryPolicy {
19    limit_bytes: Option<u64>,
20    spill_policy: SpillPolicy,
21    metrics: Option<Arc<dyn SpillMetricsSink>>,
22}
23
24impl MemoryPolicy {
25    pub fn new(limit_bytes: Option<u64>, spill_policy: SpillPolicy) -> Self {
26        Self {
27            limit_bytes,
28            spill_policy,
29            metrics: None,
30        }
31    }
32
33    pub fn limit_bytes(&self) -> Option<u64> {
34        self.limit_bytes
35    }
36
37    pub fn spill_policy(&self) -> &SpillPolicy {
38        &self.spill_policy
39    }
40
41    pub fn with_metrics(mut self, metrics: Arc<dyn SpillMetricsSink>) -> Self {
42        self.metrics = Some(metrics);
43        self
44    }
45
46    pub fn spill_directory(&self) -> Option<&Path> {
47        match &self.spill_policy {
48            SpillPolicy::SpillToDisk { directory } => Some(directory.as_path()),
49            SpillPolicy::FailFast => None,
50        }
51    }
52
53    pub fn record_spill(&self, bytes: u64, files: u64) {
54        if let Some(metrics) = &self.metrics {
55            metrics.record_spill(bytes, files);
56        }
57    }
58
59    pub fn over_limit(&self, used_bytes: u64) -> bool {
60        self.limit_bytes
61            .map(|limit| used_bytes > limit)
62            .unwrap_or(false)
63    }
64
65    pub fn enforce(&self, used_bytes: u64) -> Result<()> {
66        let Some(limit) = self.limit_bytes else {
67            return Ok(());
68        };
69        if used_bytes <= limit {
70            return Ok(());
71        }
72        match &self.spill_policy {
73            SpillPolicy::FailFast => Err(ExecutorError::ResourceExhausted {
74                message: format!("query memory limit exceeded: {used_bytes} bytes (limit {limit})"),
75            }),
76            SpillPolicy::SpillToDisk { .. } => Ok(()),
77        }
78    }
79}
80
81impl std::fmt::Debug for MemoryPolicy {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("MemoryPolicy")
84            .field("limit_bytes", &self.limit_bytes)
85            .field("spill_policy", &self.spill_policy)
86            .finish()
87    }
88}
89
90#[derive(Clone, Debug)]
91pub struct MemoryTracker {
92    policy: MemoryPolicy,
93    used_bytes: u64,
94}
95
96impl MemoryTracker {
97    pub fn new(policy: MemoryPolicy) -> Self {
98        Self {
99            policy,
100            used_bytes: 0,
101        }
102    }
103
104    pub fn used_bytes(&self) -> u64 {
105        self.used_bytes
106    }
107
108    pub fn policy(&self) -> &MemoryPolicy {
109        &self.policy
110    }
111
112    pub fn over_limit(&self) -> bool {
113        self.policy.over_limit(self.used_bytes)
114    }
115
116    pub fn reset(&mut self) {
117        self.used_bytes = 0;
118    }
119
120    pub fn add_bytes(&mut self, bytes: u64) -> Result<()> {
121        self.used_bytes = self.used_bytes.saturating_add(bytes);
122        self.policy.enforce(self.used_bytes)
123    }
124
125    pub fn add_row(&mut self, row: &[SqlValue]) -> Result<()> {
126        self.add_values(row)
127    }
128
129    pub fn add_values(&mut self, values: &[SqlValue]) -> Result<()> {
130        let bytes: u64 = values.iter().map(estimate_value_bytes).sum();
131        self.add_bytes(bytes)
132    }
133
134    pub fn add_value(&mut self, value: &SqlValue) -> Result<()> {
135        self.add_bytes(estimate_value_bytes(value))
136    }
137}
138
139fn estimate_value_bytes(value: &SqlValue) -> u64 {
140    match value {
141        SqlValue::Null => 0,
142        SqlValue::Integer(_) => 4,
143        SqlValue::BigInt(_) => 8,
144        SqlValue::Float(_) => 4,
145        SqlValue::Double(_) => 8,
146        SqlValue::Text(text) => text.len() as u64,
147        SqlValue::Blob(blob) => blob.len() as u64,
148        SqlValue::Boolean(_) => 1,
149        SqlValue::Timestamp(_) => 8,
150        SqlValue::Vector(values) => values.len() as u64 * 4,
151    }
152}