use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::hash::Hash;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
pub trait Mergeable {
fn merge(&mut self, other: Self);
}
#[derive(Debug)]
pub enum SpillError {
Io(std::io::Error),
BatchTooLarge { size: usize, limit: usize },
Codec(String),
}
impl From<std::io::Error> for SpillError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
impl std::fmt::Display for SpillError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "spill i/o: {e}"),
Self::BatchTooLarge { size, limit } => {
write!(f, "spill batch {size} bytes exceeds limit {limit}")
}
Self::Codec(msg) => write!(f, "spill codec: {msg}"),
}
}
}
impl std::error::Error for SpillError {}
pub trait SpillCodec: Sized {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError>;
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError>;
}
impl SpillCodec for u64 {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
w.write_all(&self.to_le_bytes())?;
Ok(8)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
let mut buf = [0u8; 8];
match r.read_exact(&mut buf) {
Ok(()) => Ok(Some(u64::from_le_bytes(buf))),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
Err(e) => Err(SpillError::Io(e)),
}
}
}
impl SpillCodec for i64 {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
w.write_all(&self.to_le_bytes())?;
Ok(8)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
let mut buf = [0u8; 8];
match r.read_exact(&mut buf) {
Ok(()) => Ok(Some(i64::from_le_bytes(buf))),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
Err(e) => Err(SpillError::Io(e)),
}
}
}
impl SpillCodec for f64 {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
w.write_all(&self.to_le_bytes())?;
Ok(8)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
let mut buf = [0u8; 8];
match r.read_exact(&mut buf) {
Ok(()) => Ok(Some(f64::from_le_bytes(buf))),
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => Ok(None),
Err(e) => Err(SpillError::Io(e)),
}
}
}
impl SpillCodec for String {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
let bytes = self.as_bytes();
let len = bytes.len() as u32;
w.write_all(&len.to_le_bytes())?;
w.write_all(bytes)?;
Ok(4 + bytes.len())
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
let mut lenbuf = [0u8; 4];
match r.read_exact(&mut lenbuf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(SpillError::Io(e)),
}
let len = u32::from_le_bytes(lenbuf) as usize;
let mut buf = vec![0u8; len];
r.read_exact(&mut buf)?;
String::from_utf8(buf)
.map(Some)
.map_err(|e| SpillError::Codec(format!("invalid utf-8: {e}")))
}
}
pub struct SpilledHashAgg<K, S>
where
K: Hash + Eq + Clone + SpillCodec,
S: Clone + Mergeable + SpillCodec,
{
table: HashMap<K, S>,
avg_entry_bytes: usize,
mem_limit_bytes: usize,
spill_dir: PathBuf,
spilled_batches: Vec<PathBuf>,
next_seq: u64,
pub total_spilled_bytes: u64,
pub spill_count: u64,
}
impl<K, S> SpilledHashAgg<K, S>
where
K: Hash + Eq + Clone + SpillCodec,
S: Clone + Mergeable + SpillCodec,
{
pub fn new(
spill_dir: impl AsRef<Path>,
mem_limit_bytes: usize,
avg_entry_bytes: usize,
) -> Self {
Self {
table: HashMap::new(),
avg_entry_bytes,
mem_limit_bytes,
spill_dir: spill_dir.as_ref().to_path_buf(),
spilled_batches: Vec::new(),
next_seq: 0,
total_spilled_bytes: 0,
spill_count: 0,
}
}
pub fn accumulate(&mut self, key: K, increment: S) -> Result<(), SpillError> {
match self.table.get_mut(&key) {
Some(existing) => existing.merge(increment),
None => {
self.table.insert(key, increment);
if self.should_spill() {
self.spill_partition()?;
}
}
}
Ok(())
}
fn should_spill(&self) -> bool {
if self.mem_limit_bytes == 0 {
return false;
}
let estimated = self.table.len().saturating_mul(self.avg_entry_bytes);
estimated > self.mem_limit_bytes
}
pub fn spill_partition(&mut self) -> Result<(), SpillError> {
if self.table.is_empty() {
return Ok(());
}
let path = self.spill_dir.join(format!("spill_{}.bin", self.next_seq));
self.next_seq += 1;
let file = OpenOptions::new()
.write(true)
.create_new(true)
.open(&path)?;
let mut writer = BufWriter::new(file);
let mut bytes_written = 0usize;
for (k, s) in self.table.drain() {
bytes_written += k.encode(&mut writer)?;
bytes_written += s.encode(&mut writer)?;
}
writer.flush()?;
self.total_spilled_bytes += bytes_written as u64;
self.spill_count += 1;
self.spilled_batches.push(path);
Ok(())
}
pub fn drain(mut self) -> Result<HashMap<K, S>, SpillError> {
let mut merged = std::mem::take(&mut self.table);
for path in self.spilled_batches.drain(..) {
let file = File::open(&path)?;
let metadata = file.metadata()?;
if self.mem_limit_bytes > 0 && (metadata.len() as usize) > self.mem_limit_bytes {
return Err(SpillError::BatchTooLarge {
size: metadata.len() as usize,
limit: self.mem_limit_bytes,
});
}
let mut reader = BufReader::new(file);
loop {
let key = match K::decode(&mut reader)? {
Some(k) => k,
None => break,
};
let state = match S::decode(&mut reader)? {
Some(s) => s,
None => {
return Err(SpillError::Codec(
"spill batch ended mid-entry: state missing".to_string(),
))
}
};
match merged.get_mut(&key) {
Some(existing) => existing.merge(state),
None => {
merged.insert(key, state);
}
}
}
let _ = std::fs::remove_file(&path);
}
Ok(merged)
}
pub fn spilled_batch_count(&self) -> usize {
self.spilled_batches.len()
}
pub fn in_memory_groups(&self) -> usize {
self.table.len()
}
}
impl<K, S> Drop for SpilledHashAgg<K, S>
where
K: Hash + Eq + Clone + SpillCodec,
S: Clone + Mergeable + SpillCodec,
{
fn drop(&mut self) {
for path in self.spilled_batches.drain(..) {
let _ = std::fs::remove_file(&path);
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SumState<T>(pub T);
impl Mergeable for SumState<i64> {
fn merge(&mut self, other: Self) {
self.0 = self.0.saturating_add(other.0);
}
}
impl SpillCodec for SumState<i64> {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
self.0.encode(w)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
Ok(i64::decode(r)?.map(SumState))
}
}
impl Mergeable for SumState<f64> {
fn merge(&mut self, other: Self) {
self.0 += other.0;
}
}
impl SpillCodec for SumState<f64> {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
self.0.encode(w)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
Ok(f64::decode(r)?.map(SumState))
}
}
#[derive(Debug, Clone, Copy)]
pub struct CountState(pub u64);
impl Mergeable for CountState {
fn merge(&mut self, other: Self) {
self.0 = self.0.saturating_add(other.0);
}
}
impl SpillCodec for CountState {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
self.0.encode(w)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
Ok(u64::decode(r)?.map(CountState))
}
}
#[derive(Debug, Clone, Copy)]
pub struct MinState<T>(pub T);
#[derive(Debug, Clone, Copy)]
pub struct MaxState<T>(pub T);
impl Mergeable for MinState<i64> {
fn merge(&mut self, other: Self) {
if other.0 < self.0 {
self.0 = other.0;
}
}
}
impl SpillCodec for MinState<i64> {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
self.0.encode(w)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
Ok(i64::decode(r)?.map(MinState))
}
}
impl Mergeable for MaxState<i64> {
fn merge(&mut self, other: Self) {
if other.0 > self.0 {
self.0 = other.0;
}
}
}
impl SpillCodec for MaxState<i64> {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
self.0.encode(w)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
Ok(i64::decode(r)?.map(MaxState))
}
}
#[derive(Debug, Clone, Copy)]
pub struct AvgState {
pub sum: f64,
pub count: u64,
}
impl Mergeable for AvgState {
fn merge(&mut self, other: Self) {
self.sum += other.sum;
self.count += other.count;
}
}
impl SpillCodec for AvgState {
fn encode<W: Write>(&self, w: &mut W) -> Result<usize, SpillError> {
let a = self.sum.encode(w)?;
let b = self.count.encode(w)?;
Ok(a + b)
}
fn decode<R: Read>(r: &mut R) -> Result<Option<Self>, SpillError> {
let sum = match f64::decode(r)? {
Some(v) => v,
None => return Ok(None),
};
let count = match u64::decode(r)? {
Some(v) => v,
None => {
return Err(SpillError::Codec(
"AvgState ended after sum: count missing".to_string(),
))
}
};
Ok(Some(AvgState { sum, count }))
}
}
impl AvgState {
pub fn finalize(self) -> Option<f64> {
if self.count == 0 {
None
} else {
Some(self.sum / self.count as f64)
}
}
}
pub fn spilled_hash_agg_default<K, S>() -> std::io::Result<SpilledHashAgg<K, S>>
where
K: std::hash::Hash + Eq + Clone + SpillCodec,
S: Clone + Mergeable + SpillCodec,
{
use std::sync::atomic::{AtomicU64, Ordering};
static SEQ: AtomicU64 = AtomicU64::new(0);
let seq = SEQ.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let dir = std::env::temp_dir().join(format!("reddb-spill-{pid}-{seq}"));
std::fs::create_dir_all(&dir)?;
Ok(SpilledHashAgg::new(
dir,
64 * 1024 * 1024, 128, ))
}