Skip to main content

reddb_server/storage/query/executors/
agg_spill.rs

1//! Hash-aggregation spill helper — Fase 4 P4 building block.
2//!
3//! Provides a `SpilledHashAgg` data structure that holds an
4//! in-memory hash table plus zero or more on-disk batch files.
5//! When the in-memory table exceeds `mem_limit_bytes`, callers
6//! invoke `spill_partition` to write a batch to a temporary file
7//! and free the corresponding entries from the hash map. The
8//! `drain` step reads all spilled batches back, merges them with
9//! whatever is still in memory, and produces the final aggregated
10//! output.
11//!
12//! Mirrors PostgreSQL's `nodeAgg.c::hashagg_spill_*` family
13//! modulo features we don't have:
14//!
15//! - No tape-based recursion: PG does N-way repartitioning when a
16//!   spilled batch itself doesn't fit. Week 4 here just rewinds and
17//!   reads each batch back in full. If a single batch exceeds
18//!   memory we return an error so the caller can switch to
19//!   sort-based aggregation.
20//! - No parallel spill: single producer, single consumer.
21//! - No on-disk hash format: each spill batch is a plain
22//!   serialised `Vec<(GroupKey, AggState)>` — small overhead but
23//!   simple to read back.
24//!
25//! The module is **not yet wired** into `executors/aggregation.rs`.
26//! Wiring happens in a follow-up commit when the aggregation
27//! executor learns to track its current memory footprint and
28//! call `spill_partition` from inside its insert loop.
29//!
30//! ## Type parameters
31//!
32//! - `K` — the group key. Must be `Hash + Eq + Clone + Serialize`
33//!   so it can both index the hash map and round-trip through
34//!   the spill file.
35//! - `S` — the aggregation state per group. Must be `Clone +
36//!   Serialize + Mergeable` so spilled batches can be combined
37//!   with the in-memory state during drain.
38
39use std::collections::HashMap;
40use std::fs::{File, OpenOptions};
41use std::hash::Hash;
42use std::io::{BufReader, BufWriter, Read, Write};
43use std::path::{Path, PathBuf};
44
45/// Trait implemented by any aggregation state that can absorb
46/// another value of its own type. Used by the drain step to merge
47/// spilled batches back into the in-memory table.
48///
49/// Implementors:
50/// - SUM:    `lhs += rhs`
51/// - COUNT:  `lhs += rhs`
52/// - MIN:    `lhs = min(lhs, rhs)`
53/// - MAX:    `lhs = max(lhs, rhs)`
54/// - AVG:    pair `(sum, count)` → element-wise add
55/// - STDDEV: triple `(n, mean, M2)` → Welford's parallel formula
56pub trait Mergeable {
57    /// Combine `other` into `self`, leaving `self` as the merged
58    /// result and consuming `other`.
59    fn merge(&mut self, other: Self);
60}
61
62/// Errors raised by the spill helper.
63#[derive(Debug)]
64pub enum SpillError {
65    /// I/O failure writing or reading a spill batch.
66    Io(std::io::Error),
67    /// A single spill batch exceeds the configured memory limit
68    /// even after offloading. Caller should fall back to
69    /// sort-based aggregation.
70    BatchTooLarge { size: usize, limit: usize },
71    /// Encoding / decoding of a key / state failed during
72    /// round-trip.
73    Codec(String),
74}
75
76impl From<std::io::Error> for SpillError {
77    fn from(e: std::io::Error) -> Self {
78        Self::Io(e)
79    }
80}
81
82impl std::fmt::Display for SpillError {
83    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84        match self {
85            Self::Io(e) => write!(f, "spill i/o: {e}"),
86            Self::BatchTooLarge { size, limit } => {
87                write!(f, "spill batch {size} bytes exceeds limit {limit}")
88            }
89            Self::Codec(msg) => write!(f, "spill codec: {msg}"),
90        }
91    }
92}
93
94impl std::error::Error for SpillError {}
95
96/// Trait for serialising a key or state into a flat byte
97/// representation. Implementations should be deterministic so the
98/// spill file is byte-equal across runs (helpful for debugging).
99///
100/// Default impls below use `bincode`-style length-prefixed
101/// encoding; the helper doesn't require a specific serde crate
102/// because reddb deliberately avoids large transitive deps.
103pub trait SpillCodec: Sized {
104    /// Encode `self` into the writer. Returns the number of bytes
105    /// written so the caller can track per-batch size.
106    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError>;
107    /// Decode a fresh value from the reader. Reads exactly one
108    /// element; returns `Ok(None)` on a clean end-of-file so
109    /// drain loops can terminate naturally.
110    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError>;
111}
112
113/// Implementation strategy for fixed-size primitive types. The
114/// code uses raw little-endian writes so we don't depend on
115/// `bincode` / `serde` from this module — keeps the dep graph
116/// flat.
117impl SpillCodec for u64 {
118    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
119        w.write_all(&self.to_le_bytes())?;
120        Ok(8)
121    }
122    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
123        let mut buf = [0u8; 8];
124        match r.read_exact(&mut buf) {
125            Ok(()) => Ok(Some(u64::from_le_bytes(buf))),
126            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
127            Err(e) => Err(SpillError::Io(e)),
128        }
129    }
130}
131
132impl SpillCodec for i64 {
133    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
134        w.write_all(&self.to_le_bytes())?;
135        Ok(8)
136    }
137    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
138        let mut buf = [0u8; 8];
139        match r.read_exact(&mut buf) {
140            Ok(()) => Ok(Some(i64::from_le_bytes(buf))),
141            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
142            Err(e) => Err(SpillError::Io(e)),
143        }
144    }
145}
146
147impl SpillCodec for f64 {
148    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
149        w.write_all(&self.to_le_bytes())?;
150        Ok(8)
151    }
152    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
153        let mut buf = [0u8; 8];
154        match r.read_exact(&mut buf) {
155            Ok(()) => Ok(Some(f64::from_le_bytes(buf))),
156            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
157            Err(e) => Err(SpillError::Io(e)),
158        }
159    }
160}
161
162impl SpillCodec for String {
163    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
164        let bytes = self.as_bytes();
165        let len = bytes.len() as u32;
166        w.write_all(&len.to_le_bytes())?;
167        w.write_all(bytes)?;
168        Ok(4 + bytes.len())
169    }
170    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
171        let mut lenbuf = [0u8; 4];
172        match r.read_exact(&mut lenbuf) {
173            Ok(()) => {}
174            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
175            Err(e) => return Err(SpillError::Io(e)),
176        }
177        let len = u32::from_le_bytes(lenbuf) as usize;
178        let mut buf = vec![0u8; len];
179        r.read_exact(&mut buf)?;
180        String::from_utf8(buf)
181            .map(Some)
182            .map_err(|e| SpillError::Codec(format!("invalid utf-8: {e}")))
183    }
184}
185
186/// Hash aggregation table with optional spill-to-disk overflow.
187///
188/// Owns the in-memory `HashMap<K, S>` plus a list of `PathBuf`s
189/// pointing at spilled batch files. The caller drives the
190/// lifecycle by calling `accumulate` for each input row and
191/// `drain` once the input is exhausted.
192pub struct SpilledHashAgg<K, S>
193where
194    K: Hash + Eq + Clone + SpillCodec,
195    S: Clone + Mergeable + SpillCodec,
196{
197    /// In-memory hash table. Cleared after each spill.
198    table: HashMap<K, S>,
199    /// Estimated bytes-per-(key, state) pair. Used to compute
200    /// when to spill — a rough proxy for actual heap usage.
201    /// Callers can tune by passing a more accurate value.
202    avg_entry_bytes: usize,
203    /// Soft limit on `table.len() * avg_entry_bytes`. Crossing
204    /// this triggers a spill.
205    mem_limit_bytes: usize,
206    /// Directory where spill batches land. Each batch is a single
207    /// file named `spill_{seq}.bin`.
208    spill_dir: PathBuf,
209    /// List of spilled batch paths in order of creation.
210    spilled_batches: Vec<PathBuf>,
211    /// Monotonic batch counter for unique filenames.
212    next_seq: u64,
213    /// Total bytes spilled across all batches — diagnostic.
214    pub total_spilled_bytes: u64,
215    /// Number of times `spill_partition` was called — diagnostic.
216    pub spill_count: u64,
217}
218
219impl<K, S> SpilledHashAgg<K, S>
220where
221    K: Hash + Eq + Clone + SpillCodec,
222    S: Clone + Mergeable + SpillCodec,
223{
224    /// Create a new spillable hash aggregator. `spill_dir` must
225    /// exist and be writable; the helper does NOT create it.
226    /// `mem_limit_bytes == 0` disables spilling entirely (useful
227    /// for tests that want to exercise the in-memory path).
228    pub fn new(
229        spill_dir: impl AsRef<Path>,
230        mem_limit_bytes: usize,
231        avg_entry_bytes: usize,
232    ) -> Self {
233        Self {
234            table: HashMap::new(),
235            avg_entry_bytes,
236            mem_limit_bytes,
237            spill_dir: spill_dir.as_ref().to_path_buf(),
238            spilled_batches: Vec::new(),
239            next_seq: 0,
240            total_spilled_bytes: 0,
241            spill_count: 0,
242        }
243    }
244
245    /// Insert or update an aggregation state for the given key.
246    /// `accumulate` triggers a spill when the in-memory table's
247    /// estimated footprint exceeds the configured limit. Returns
248    /// the key/state pair after the merge so callers can chain.
249    pub fn accumulate(&mut self, key: K, increment: S) -> Result<(), SpillError> {
250        match self.table.get_mut(&key) {
251            Some(existing) => existing.merge(increment),
252            None => {
253                self.table.insert(key, increment);
254                if self.should_spill() {
255                    self.spill_partition()?;
256                }
257            }
258        }
259        Ok(())
260    }
261
262    /// Returns true when the current in-memory footprint exceeds
263    /// `mem_limit_bytes`. Cheap O(1) check using the estimated
264    /// per-entry size; callers should keep `avg_entry_bytes`
265    /// in sync with reality if precision matters.
266    fn should_spill(&self) -> bool {
267        if self.mem_limit_bytes == 0 {
268            return false;
269        }
270        let estimated = self.table.len().saturating_mul(self.avg_entry_bytes);
271        estimated > self.mem_limit_bytes
272    }
273
274    /// Write the entire in-memory table to a new spill batch file
275    /// and clear the table. Updates the spill diagnostics. Caller
276    /// is free to keep accumulating after this returns — the
277    /// batch will be merged back during `drain`.
278    pub fn spill_partition(&mut self) -> Result<(), SpillError> {
279        if self.table.is_empty() {
280            return Ok(());
281        }
282        let path = self.spill_dir.join(format!("spill_{}.bin", self.next_seq));
283        self.next_seq += 1;
284        let file = OpenOptions::new()
285            .write(true)
286            .create_new(true)
287            .open(&path)?;
288        let mut writer = BufWriter::new(file);
289        let mut bytes_written = 0usize;
290        // Drain so we don't hold both copies in memory while
291        // writing — the file is the canonical store after this.
292        for (k, s) in self.table.drain() {
293            bytes_written += k.encode(&mut writer)?;
294            bytes_written += s.encode(&mut writer)?;
295        }
296        writer.flush()?;
297        self.total_spilled_bytes += bytes_written as u64;
298        self.spill_count += 1;
299        self.spilled_batches.push(path);
300        Ok(())
301    }
302
303    /// Consume the aggregator and return the final merged state
304    /// for every group. Reads every spilled batch back into a
305    /// new in-memory hash table, merges with whatever the
306    /// accumulator left in place, and yields the unified set.
307    ///
308    /// Memory profile: at peak, this holds ONE spill batch plus
309    /// the running merge table in memory simultaneously. If a
310    /// single spill batch is larger than `mem_limit_bytes`, we
311    /// return `BatchTooLarge` so the caller can switch strategies.
312    pub fn drain(mut self) -> Result<HashMap<K, S>, SpillError> {
313        // The current `table` is the most recent in-memory chunk
314        // that hasn't been spilled — start the merge from it.
315        let mut merged = std::mem::take(&mut self.table);
316        for path in self.spilled_batches.drain(..) {
317            let file = File::open(&path)?;
318            let metadata = file.metadata()?;
319            if self.mem_limit_bytes > 0 && (metadata.len() as usize) > self.mem_limit_bytes {
320                return Err(SpillError::BatchTooLarge {
321                    size: metadata.len() as usize,
322                    limit: self.mem_limit_bytes,
323                });
324            }
325            let mut reader = BufReader::new(file);
326            loop {
327                let key = match K::decode(&mut reader)? {
328                    Some(k) => k,
329                    None => break,
330                };
331                let state = match S::decode(&mut reader)? {
332                    Some(s) => s,
333                    None => {
334                        return Err(SpillError::Codec(
335                            "spill batch ended mid-entry: state missing".to_string(),
336                        ))
337                    }
338                };
339                match merged.get_mut(&key) {
340                    Some(existing) => existing.merge(state),
341                    None => {
342                        merged.insert(key, state);
343                    }
344                }
345            }
346            // Best-effort cleanup — ignore errors so a missing
347            // file doesn't hide a successful merge.
348            let _ = std::fs::remove_file(&path);
349        }
350        Ok(merged)
351    }
352
353    /// Number of spill batches currently on disk. Diagnostic
354    /// hook for tests / metrics.
355    pub fn spilled_batch_count(&self) -> usize {
356        self.spilled_batches.len()
357    }
358
359    /// Number of groups currently held in memory.
360    pub fn in_memory_groups(&self) -> usize {
361        self.table.len()
362    }
363}
364
365impl<K, S> Drop for SpilledHashAgg<K, S>
366where
367    K: Hash + Eq + Clone + SpillCodec,
368    S: Clone + Mergeable + SpillCodec,
369{
370    fn drop(&mut self) {
371        // Clean up any spill files left behind if the caller
372        // never called drain. Best-effort — failures are silent
373        // so Drop doesn't panic.
374        for path in self.spilled_batches.drain(..) {
375            let _ = std::fs::remove_file(&path);
376        }
377    }
378}
379
380// ────────────────────────────────────────────────────────────────
381// Convenience Mergeable implementations for the common scalar
382// aggregates. The aggregation executor wires its concrete state
383// types through these so callers don't need to spell out the
384// merge logic at every call site.
385// ────────────────────────────────────────────────────────────────
386
387/// SUM state — running total. Generic over any numeric type that
388/// supports `+=`.
389#[derive(Debug, Clone, Copy)]
390pub struct SumState<T>(pub T);
391
392impl Mergeable for SumState<i64> {
393    fn merge(&mut self, other: Self) {
394        self.0 = self.0.saturating_add(other.0);
395    }
396}
397impl SpillCodec for SumState<i64> {
398    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
399        self.0.encode(w)
400    }
401    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
402        Ok(i64::decode(r)?.map(SumState))
403    }
404}
405
406impl Mergeable for SumState<f64> {
407    fn merge(&mut self, other: Self) {
408        self.0 += other.0;
409    }
410}
411impl SpillCodec for SumState<f64> {
412    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
413        self.0.encode(w)
414    }
415    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
416        Ok(f64::decode(r)?.map(SumState))
417    }
418}
419
420/// COUNT state — monotonic non-negative counter.
421#[derive(Debug, Clone, Copy)]
422pub struct CountState(pub u64);
423
424impl Mergeable for CountState {
425    fn merge(&mut self, other: Self) {
426        self.0 = self.0.saturating_add(other.0);
427    }
428}
429impl SpillCodec for CountState {
430    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
431        self.0.encode(w)
432    }
433    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
434        Ok(u64::decode(r)?.map(CountState))
435    }
436}
437
438/// MIN/MAX state — wraps a numeric and merges via comparison.
439/// Two distinct types so the type system enforces direction.
440#[derive(Debug, Clone, Copy)]
441pub struct MinState<T>(pub T);
442#[derive(Debug, Clone, Copy)]
443pub struct MaxState<T>(pub T);
444
445impl Mergeable for MinState<i64> {
446    fn merge(&mut self, other: Self) {
447        if other.0 < self.0 {
448            self.0 = other.0;
449        }
450    }
451}
452impl SpillCodec for MinState<i64> {
453    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
454        self.0.encode(w)
455    }
456    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
457        Ok(i64::decode(r)?.map(MinState))
458    }
459}
460
461impl Mergeable for MaxState<i64> {
462    fn merge(&mut self, other: Self) {
463        if other.0 > self.0 {
464            self.0 = other.0;
465        }
466    }
467}
468impl SpillCodec for MaxState<i64> {
469    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
470        self.0.encode(w)
471    }
472    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
473        Ok(i64::decode(r)?.map(MaxState))
474    }
475}
476
477/// AVG state — pair (sum, count). Final value is sum / count.
478#[derive(Debug, Clone, Copy)]
479pub struct AvgState {
480    pub sum: f64,
481    pub count: u64,
482}
483
484impl Mergeable for AvgState {
485    fn merge(&mut self, other: Self) {
486        self.sum += other.sum;
487        self.count += other.count;
488    }
489}
490impl SpillCodec for AvgState {
491    fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
492        let a = self.sum.encode(w)?;
493        let b = self.count.encode(w)?;
494        Ok(a + b)
495    }
496    fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
497        let sum = match f64::decode(r)? {
498            Some(v) => v,
499            None => return Ok(None),
500        };
501        let count = match u64::decode(r)? {
502            Some(v) => v,
503            None => {
504                return Err(SpillError::Codec(
505                    "AvgState ended after sum: count missing".to_string(),
506                ))
507            }
508        };
509        Ok(Some(AvgState { sum, count }))
510    }
511}
512
513impl AvgState {
514    /// Final average value. Returns `None` for an empty state to
515    /// distinguish "no rows" from `0.0`.
516    pub fn finalize(self) -> Option<f64> {
517        if self.count == 0 {
518            None
519        } else {
520            Some(self.sum / self.count as f64)
521        }
522    }
523}
524
525/// Phase 3.3 wiring entry point. Builds a `SpilledHashAgg` with
526/// production-grade defaults (`mem_limit_bytes = 64 MiB`,
527/// `avg_entry_bytes = 128`) targeting reddb's tmpfs at
528/// `/tmp/reddb-spill`. Used by `executors/aggregation.rs::execute_group_by`
529/// when the input row count exceeds the in-memory threshold.
530///
531/// The caller is expected to feed every row via `accumulate` and
532/// then call `drain` to materialise the merged result. The helper
533/// returns the constructed aggregator so the caller can wire it
534/// into its existing per-row loop without re-implementing the
535/// spill bookkeeping.
536///
537/// Spill files land in a process-unique subdirectory so concurrent
538/// queries don't collide; the directory is auto-cleaned on Drop.
539pub fn spilled_hash_agg_default<K, S>() -> std::io::Result<SpilledHashAgg<K, S>>
540where
541    K: std::hash::Hash + Eq + Clone + SpillCodec,
542    S: Clone + Mergeable + SpillCodec,
543{
544    use std::sync::atomic::{AtomicU64, Ordering};
545    static SEQ: AtomicU64 = AtomicU64::new(0);
546    let seq = SEQ.fetch_add(1, Ordering::Relaxed);
547    let pid = std::process::id();
548    let dir = std::env::temp_dir().join(format!("reddb-spill-{pid}-{seq}"));
549    std::fs::create_dir_all(&dir)?;
550    Ok(SpilledHashAgg::new(
551        dir,
552        64 * 1024 * 1024, // 64 MiB soft limit
553        128,              // avg bytes per (key, state) — tuned for SUM/COUNT
554    ))
555}