use std::collections::HashMap;
use roaring::RoaringBitmap;
use sparrowdb_common::Result;
use sparrowdb_storage::csr::CsrForward;
pub struct AspJoin<'a> {
csr: &'a CsrForward,
}
impl<'a> AspJoin<'a> {
pub fn new(csr: &'a CsrForward) -> Self {
AspJoin { csr }
}
pub fn two_hop(&self, src_slot: u64) -> Result<Vec<u64>> {
let direct = self.csr.neighbors(src_slot);
if direct.is_empty() {
return Ok(vec![]);
}
let mut filter = RoaringBitmap::new();
for &mid in direct {
let mid32 = u32::try_from(mid).map_err(|_| {
sparrowdb_common::Error::InvalidArgument(format!(
"node slot {mid} exceeds u32::MAX; cannot use RoaringBitmap semijoin filter"
))
})?;
filter.insert(mid32);
}
let mut hash: HashMap<u32, Vec<u64>> = HashMap::new();
for &mid in direct {
if !filter.contains(mid as u32) {
continue;
}
let fof_list = self.csr.neighbors(mid);
hash.entry(mid as u32)
.or_default()
.extend_from_slice(fof_list);
}
let mut fof_set: std::collections::HashSet<u64> = std::collections::HashSet::new();
for &mid in direct {
if let Some(fof_list) = hash.get(&(mid as u32)) {
for &fof in fof_list {
fof_set.insert(fof);
}
}
}
let mut result: Vec<u64> = fof_set.into_iter().collect();
result.sort_unstable();
Ok(result)
}
pub fn two_hop_factorized(&self, src_slot: u64) -> Result<TwoHopChunk> {
let direct = self.csr.neighbors(src_slot);
if direct.is_empty() {
return Ok(TwoHopChunk {
groups: vec![],
total_count: 0,
});
}
let mut filter = RoaringBitmap::new();
for &mid in direct {
if mid <= u32::MAX as u64 {
filter.insert(mid as u32);
}
}
let mut hash: HashMap<u32, Vec<u64>> = HashMap::new();
for &mid in direct {
if !filter.contains(mid as u32) {
continue;
}
let fof_list = self.csr.neighbors(mid);
if !fof_list.is_empty() {
hash.entry(mid as u32)
.or_default()
.extend_from_slice(fof_list);
}
}
let mut groups = Vec::new();
let mut total_count = 0u64;
for (&mid, fof_list) in &hash {
let count = fof_list.len() as u64;
total_count += count;
groups.push(TwoHopGroup {
mid_slot: mid as u64,
fof_slots: fof_list.clone(),
multiplicity: 1,
});
}
Ok(TwoHopChunk {
groups,
total_count,
})
}
}
pub struct TwoHopChunk {
pub groups: Vec<TwoHopGroup>,
pub total_count: u64,
}
impl TwoHopChunk {
pub fn logical_row_count(&self) -> u64 {
self.total_count
}
}
pub struct TwoHopGroup {
pub mid_slot: u64,
pub fof_slots: Vec<u64>,
pub multiplicity: u64,
}
#[cfg(test)]
mod tests {
use super::*;
fn social_graph() -> CsrForward {
let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3), (2, 4)];
CsrForward::build(5, &edges)
}
#[test]
fn two_hop_alice_fof() {
let csr = social_graph();
let join = AspJoin::new(&csr);
let fof = join.two_hop(0).unwrap();
assert_eq!(fof, vec![3, 4]); }
#[test]
fn two_hop_no_friends() {
let csr = CsrForward::build(3, &[(1, 2)]);
let join = AspJoin::new(&csr);
let fof = join.two_hop(0).unwrap();
assert!(fof.is_empty());
}
#[test]
fn two_hop_factorized_count() {
let csr = social_graph();
let join = AspJoin::new(&csr);
let chunk = join.two_hop_factorized(0).unwrap();
assert_eq!(chunk.logical_row_count(), 3);
}
}