use std::collections::{HashMap, HashSet};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use sparrowdb_common::{Error, Result};
use sparrowdb_storage::csr::CsrForward;
use tempfile::NamedTempFile;
pub const SPILL_THRESHOLD: usize = 500_000;
const NUM_PARTITIONS: usize = 16;
pub struct SpillingHashJoin<'a> {
csr: &'a CsrForward,
spill_threshold: usize,
num_partitions: usize,
}
impl<'a> SpillingHashJoin<'a> {
pub fn new(csr: &'a CsrForward) -> Self {
SpillingHashJoin {
csr,
spill_threshold: SPILL_THRESHOLD,
num_partitions: NUM_PARTITIONS,
}
}
pub fn with_thresholds(
csr: &'a CsrForward,
spill_threshold: usize,
num_partitions: usize,
) -> Self {
SpillingHashJoin {
csr,
spill_threshold,
num_partitions: num_partitions.max(1), }
}
pub fn two_hop(&self, src_slot: u64) -> Result<Vec<u64>> {
let direct = self.csr.neighbors(src_slot);
if direct.is_empty() {
return Ok(vec![]);
}
let total_fof_estimate: usize = direct
.iter()
.map(|&mid| self.csr.neighbors(mid).len())
.sum();
if total_fof_estimate <= self.spill_threshold {
return self.two_hop_in_memory(direct);
}
self.two_hop_spilling(direct)
}
fn two_hop_in_memory(&self, direct: &[u64]) -> Result<Vec<u64>> {
let mut hash: HashMap<u64, Vec<u64>> = HashMap::new();
for &mid in direct {
let fof_list = self.csr.neighbors(mid);
if !fof_list.is_empty() {
hash.entry(mid).or_default().extend_from_slice(fof_list);
}
}
let mut fof_set: HashSet<u64> = HashSet::new();
for &mid in direct {
if let Some(fof_list) = hash.get(&mid) {
fof_set.extend(fof_list.iter().copied());
}
}
let mut result: Vec<u64> = fof_set.into_iter().collect();
result.sort_unstable();
Ok(result)
}
fn two_hop_spilling(&self, direct: &[u64]) -> Result<Vec<u64>> {
let np = self.num_partitions;
let mut part_files: Vec<NamedTempFile> = (0..np)
.map(|_| NamedTempFile::new().map_err(Error::Io))
.collect::<Result<_>>()?;
{
let mut writers: Vec<BufWriter<&mut std::fs::File>> = part_files
.iter_mut()
.map(|f| BufWriter::new(f.as_file_mut()))
.collect();
for &mid in direct {
let fof_list = self.csr.neighbors(mid);
if fof_list.is_empty() {
continue;
}
let p = (mid as usize) % np;
for &fof in fof_list {
write_u64_pair(&mut writers[p], mid, fof)?;
}
}
for w in &mut writers {
w.flush().map_err(Error::Io)?;
}
}
let mut fof_set: HashSet<u64> = HashSet::new();
for file in &mut part_files {
file.as_file_mut()
.seek(SeekFrom::Start(0))
.map_err(Error::Io)?;
let mut reader = BufReader::new(file.as_file_mut());
let mut hash: HashMap<u64, Vec<u64>> = HashMap::new();
while let Some((mid, fof)) = read_u64_pair(&mut reader)? {
hash.entry(mid).or_default().push(fof);
}
for fof_list in hash.values() {
fof_set.extend(fof_list.iter().copied());
}
}
let mut result: Vec<u64> = fof_set.into_iter().collect();
result.sort_unstable();
Ok(result)
}
}
fn write_u64_pair<W: Write>(w: &mut W, a: u64, b: u64) -> Result<()> {
w.write_all(&a.to_le_bytes()).map_err(Error::Io)?;
w.write_all(&b.to_le_bytes()).map_err(Error::Io)?;
Ok(())
}
fn read_u64_pair<R: Read>(r: &mut R) -> Result<Option<(u64, u64)>> {
let mut buf = [0u8; 8];
match r.read_exact(&mut buf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(Error::Io(e)),
}
let a = u64::from_le_bytes(buf);
r.read_exact(&mut buf).map_err(Error::Io)?;
let b = u64::from_le_bytes(buf);
Ok(Some((a, b)))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::join::AspJoin;
fn social_graph() -> CsrForward {
let edges = vec![(0u64, 1u64), (0, 2), (1, 3), (2, 3), (2, 4)];
CsrForward::build(5u64, &edges)
}
#[test]
fn join_spill_small_graph() {
let csr = social_graph();
let baseline = AspJoin::new(&csr);
let spilling = SpillingHashJoin::new(&csr);
let expected = baseline.two_hop(0).unwrap();
let got = spilling.two_hop(0).unwrap();
assert_eq!(got, expected, "Alice fof mismatch");
let expected_bob = baseline.two_hop(1).unwrap();
let got_bob = spilling.two_hop(1).unwrap();
assert_eq!(got_bob, expected_bob, "Bob fof mismatch");
}
#[test]
fn join_spill_large_graph() {
const N: u64 = 10_000;
let edges: Vec<(u64, u64)> = (0..N).map(|i| (i, (i + 1) % N)).collect();
let csr = CsrForward::build(N, &edges);
let baseline = AspJoin::new(&csr);
let spilling = SpillingHashJoin::with_thresholds(&csr, 1, 4);
for src in 0..N {
let expected = baseline.two_hop(src).unwrap();
let got = spilling.two_hop(src).unwrap();
assert_eq!(got, expected, "ring fof mismatch for src={src}");
}
}
#[test]
fn join_spill_no_edges() {
let csr = CsrForward::build(3u64, &[(1u64, 2u64)]);
let spilling = SpillingHashJoin::new(&csr);
let got = spilling.two_hop(0).unwrap();
assert!(got.is_empty());
}
#[test]
fn join_spill_zero_partitions_does_not_panic() {
let csr = CsrForward::build(3u64, &[(0u64, 1u64), (1u64, 2u64)]);
let join = SpillingHashJoin::with_thresholds(&csr, 0, 0);
let result = join.two_hop(0).unwrap();
assert_eq!(result, vec![2]);
}
}