Skip to main content

sparrowdb_execution/
join.rs

1//! Binary ASP-Join for 2-hop traversals.
2//!
3//! Implements the factorized join contract from spec Section 13.3.
4//!
5//! Algorithm:
6//! 1. Collect probe-side (src node's direct neighbors).
7//! 2. Build a Roaring semijoin filter from the probe keys.
8//! 3. For build-side, scan mid-node neighbors — only those admitted by filter.
9//! 4. Build hash state: {mid_node_slot → Vec<dst_slot>}.
10//! 5. Re-probe: for each src neighbor, look up in hash to get fof set.
11//! 6. Propagate multiplicity without materializing the full Cartesian product.
12
13use std::collections::HashMap;
14
15use roaring::RoaringBitmap;
16use sparrowdb_common::Result;
17use sparrowdb_storage::csr::CsrForward;
18
19/// Binary ASP-Join: 2-hop traversal over a CSR graph.
20pub struct AspJoin<'a> {
21    csr: &'a CsrForward,
22}
23
24impl<'a> AspJoin<'a> {
25    pub fn new(csr: &'a CsrForward) -> Self {
26        AspJoin { csr }
27    }
28
29    /// Compute 2-hop friends-of-friends for `src_slot`.
30    ///
31    /// Returns the deduplicated set of fof node slots.
32    /// Does NOT exclude direct friends of `src_slot` — that filtering is
33    /// handled separately in the WHERE NOT clause at the planner level.
34    pub fn two_hop(&self, src_slot: u64) -> Result<Vec<u64>> {
35        // Step 1: probe side — direct neighbors of src.
36        let direct = self.csr.neighbors(src_slot);
37
38        if direct.is_empty() {
39            return Ok(vec![]);
40        }
41
42        // Step 2: build semijoin filter from direct neighbors.
43        // RoaringBitmap only supports u32 keys — return an error rather than
44        // silently dropping nodes whose slot exceeds u32::MAX.
45        let mut filter = RoaringBitmap::new();
46        for &mid in direct {
47            let mid32 = u32::try_from(mid).map_err(|_| {
48                sparrowdb_common::Error::InvalidArgument(format!(
49                    "node slot {mid} exceeds u32::MAX; cannot use RoaringBitmap semijoin filter"
50                ))
51            })?;
52            filter.insert(mid32);
53        }
54
55        // Step 3 & 4: for each mid node admitted by the filter, collect fof.
56        let mut hash: HashMap<u32, Vec<u64>> = HashMap::new();
57        for &mid in direct {
58            // Safety: all mids were validated as u32 in the filter step above.
59            if !filter.contains(mid as u32) {
60                continue;
61            }
62            let fof_list = self.csr.neighbors(mid);
63            hash.entry(mid as u32)
64                .or_default()
65                .extend_from_slice(fof_list);
66        }
67
68        // Step 5: re-probe — collect all fof nodes, deduplicate.
69        let mut fof_set: std::collections::HashSet<u64> = std::collections::HashSet::new();
70        for &mid in direct {
71            if let Some(fof_list) = hash.get(&(mid as u32)) {
72                for &fof in fof_list {
73                    fof_set.insert(fof);
74                }
75            }
76        }
77
78        let mut result: Vec<u64> = fof_set.into_iter().collect();
79        result.sort_unstable();
80        Ok(result)
81    }
82
83    /// Compute 2-hop in factorized form: preserves multiplicity without
84    /// materializing a flat list.
85    ///
86    /// Returns a FactorizedChunk where:
87    /// - Each VectorGroup represents one mid-node with its fof set.
88    /// - Multiplicity is preserved per group.
89    /// - `logical_row_count()` returns the total count of (mid, fof) pairs.
90    pub fn two_hop_factorized(&self, src_slot: u64) -> Result<TwoHopChunk> {
91        let direct = self.csr.neighbors(src_slot);
92        if direct.is_empty() {
93            return Ok(TwoHopChunk {
94                groups: vec![],
95                total_count: 0,
96            });
97        }
98
99        // Build semijoin filter.
100        let mut filter = RoaringBitmap::new();
101        for &mid in direct {
102            if mid <= u32::MAX as u64 {
103                filter.insert(mid as u32);
104            }
105        }
106
107        // Build hash state: mid → [fof slots].
108        let mut hash: HashMap<u32, Vec<u64>> = HashMap::new();
109        for &mid in direct {
110            if !filter.contains(mid as u32) {
111                continue;
112            }
113            let fof_list = self.csr.neighbors(mid);
114            if !fof_list.is_empty() {
115                hash.entry(mid as u32)
116                    .or_default()
117                    .extend_from_slice(fof_list);
118            }
119        }
120
121        // Re-probe in factorized form: one group per (src, mid) pair.
122        let mut groups = Vec::new();
123        let mut total_count = 0u64;
124
125        for (&mid, fof_list) in &hash {
126            // Each VectorGroup represents one mid-node with all its fof neighbors.
127            // Multiplicity = 1 because each mid produces exactly these fof nodes.
128            let count = fof_list.len() as u64;
129            total_count += count;
130            groups.push(TwoHopGroup {
131                mid_slot: mid as u64,
132                fof_slots: fof_list.clone(),
133                multiplicity: 1,
134            });
135        }
136
137        Ok(TwoHopChunk {
138            groups,
139            total_count,
140        })
141    }
142}
143
144/// A factorized 2-hop chunk: each group is one (mid, [fof...]) set.
145pub struct TwoHopChunk {
146    pub groups: Vec<TwoHopGroup>,
147    pub total_count: u64,
148}
149
150impl TwoHopChunk {
151    /// Total logical row count (sum of fof set sizes).
152    pub fn logical_row_count(&self) -> u64 {
153        self.total_count
154    }
155}
156
157/// One group in a factorized 2-hop chunk.
158pub struct TwoHopGroup {
159    pub mid_slot: u64,
160    pub fof_slots: Vec<u64>,
161    pub multiplicity: u64,
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    fn social_graph() -> CsrForward {
169        // Alice=0, Bob=1, Carol=2, Dave=3, Eve=4
170        // Alice->Bob, Alice->Carol, Bob->Dave, Carol->Dave, Carol->Eve
171        let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3), (2, 4)];
172        CsrForward::build(5, &edges)
173    }
174
175    #[test]
176    fn two_hop_alice_fof() {
177        let csr = social_graph();
178        let join = AspJoin::new(&csr);
179        let fof = join.two_hop(0).unwrap();
180        assert_eq!(fof, vec![3, 4]); // Dave and Eve
181    }
182
183    #[test]
184    fn two_hop_no_friends() {
185        let csr = CsrForward::build(3, &[(1, 2)]);
186        let join = AspJoin::new(&csr);
187        // Node 0 has no friends.
188        let fof = join.two_hop(0).unwrap();
189        assert!(fof.is_empty());
190    }
191
192    #[test]
193    fn two_hop_factorized_count() {
194        let csr = social_graph();
195        let join = AspJoin::new(&csr);
196        let chunk = join.two_hop_factorized(0).unwrap();
197        // Bob->Dave (1), Carol->Dave (1), Carol->Eve (1) = 3 logical rows
198        assert_eq!(chunk.logical_row_count(), 3);
199    }
200}