alopex_sql/executor/
memory.rs1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use crate::executor::{ExecutorError, Result};
5use crate::storage::SqlValue;
6
7#[derive(Clone, Debug)]
8pub enum SpillPolicy {
9 FailFast,
10 SpillToDisk { directory: PathBuf },
11}
12
13pub trait SpillMetricsSink: Send + Sync {
14 fn record_spill(&self, bytes: u64, files: u64);
15}
16
17#[derive(Clone)]
18pub struct MemoryPolicy {
19 limit_bytes: Option<u64>,
20 spill_policy: SpillPolicy,
21 metrics: Option<Arc<dyn SpillMetricsSink>>,
22}
23
24impl MemoryPolicy {
25 pub fn new(limit_bytes: Option<u64>, spill_policy: SpillPolicy) -> Self {
26 Self {
27 limit_bytes,
28 spill_policy,
29 metrics: None,
30 }
31 }
32
33 pub fn limit_bytes(&self) -> Option<u64> {
34 self.limit_bytes
35 }
36
37 pub fn spill_policy(&self) -> &SpillPolicy {
38 &self.spill_policy
39 }
40
41 pub fn with_metrics(mut self, metrics: Arc<dyn SpillMetricsSink>) -> Self {
42 self.metrics = Some(metrics);
43 self
44 }
45
46 pub fn spill_directory(&self) -> Option<&Path> {
47 match &self.spill_policy {
48 SpillPolicy::SpillToDisk { directory } => Some(directory.as_path()),
49 SpillPolicy::FailFast => None,
50 }
51 }
52
53 pub fn record_spill(&self, bytes: u64, files: u64) {
54 if let Some(metrics) = &self.metrics {
55 metrics.record_spill(bytes, files);
56 }
57 }
58
59 pub fn over_limit(&self, used_bytes: u64) -> bool {
60 self.limit_bytes
61 .map(|limit| used_bytes > limit)
62 .unwrap_or(false)
63 }
64
65 pub fn enforce(&self, used_bytes: u64) -> Result<()> {
66 let Some(limit) = self.limit_bytes else {
67 return Ok(());
68 };
69 if used_bytes <= limit {
70 return Ok(());
71 }
72 match &self.spill_policy {
73 SpillPolicy::FailFast => Err(ExecutorError::ResourceExhausted {
74 message: format!("query memory limit exceeded: {used_bytes} bytes (limit {limit})"),
75 }),
76 SpillPolicy::SpillToDisk { .. } => Ok(()),
77 }
78 }
79}
80
81impl std::fmt::Debug for MemoryPolicy {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("MemoryPolicy")
84 .field("limit_bytes", &self.limit_bytes)
85 .field("spill_policy", &self.spill_policy)
86 .finish()
87 }
88}
89
90#[derive(Clone, Debug)]
91pub struct MemoryTracker {
92 policy: MemoryPolicy,
93 used_bytes: u64,
94}
95
96impl MemoryTracker {
97 pub fn new(policy: MemoryPolicy) -> Self {
98 Self {
99 policy,
100 used_bytes: 0,
101 }
102 }
103
104 pub fn used_bytes(&self) -> u64 {
105 self.used_bytes
106 }
107
108 pub fn policy(&self) -> &MemoryPolicy {
109 &self.policy
110 }
111
112 pub fn over_limit(&self) -> bool {
113 self.policy.over_limit(self.used_bytes)
114 }
115
116 pub fn reset(&mut self) {
117 self.used_bytes = 0;
118 }
119
120 pub fn add_bytes(&mut self, bytes: u64) -> Result<()> {
121 self.used_bytes = self.used_bytes.saturating_add(bytes);
122 self.policy.enforce(self.used_bytes)
123 }
124
125 pub fn add_row(&mut self, row: &[SqlValue]) -> Result<()> {
126 self.add_values(row)
127 }
128
129 pub fn add_values(&mut self, values: &[SqlValue]) -> Result<()> {
130 let bytes: u64 = values.iter().map(estimate_value_bytes).sum();
131 self.add_bytes(bytes)
132 }
133
134 pub fn add_value(&mut self, value: &SqlValue) -> Result<()> {
135 self.add_bytes(estimate_value_bytes(value))
136 }
137}
138
139fn estimate_value_bytes(value: &SqlValue) -> u64 {
140 match value {
141 SqlValue::Null => 0,
142 SqlValue::Integer(_) => 4,
143 SqlValue::BigInt(_) => 8,
144 SqlValue::Float(_) => 4,
145 SqlValue::Double(_) => 8,
146 SqlValue::Text(text) => text.len() as u64,
147 SqlValue::Blob(blob) => blob.len() as u64,
148 SqlValue::Boolean(_) => 1,
149 SqlValue::Timestamp(_) => 8,
150 SqlValue::Vector(values) => values.len() as u64 * 4,
151 }
152}