sparrowdb_execution/
join.rs1use std::collections::HashMap;
14
15use roaring::RoaringBitmap;
16use sparrowdb_common::Result;
17use sparrowdb_storage::csr::CsrForward;
18
19pub struct AspJoin<'a> {
21 csr: &'a CsrForward,
22}
23
24impl<'a> AspJoin<'a> {
25 pub fn new(csr: &'a CsrForward) -> Self {
26 AspJoin { csr }
27 }
28
29 pub fn two_hop(&self, src_slot: u64) -> Result<Vec<u64>> {
35 let direct = self.csr.neighbors(src_slot);
37
38 if direct.is_empty() {
39 return Ok(vec![]);
40 }
41
42 let mut filter = RoaringBitmap::new();
46 for &mid in direct {
47 let mid32 = u32::try_from(mid).map_err(|_| {
48 sparrowdb_common::Error::InvalidArgument(format!(
49 "node slot {mid} exceeds u32::MAX; cannot use RoaringBitmap semijoin filter"
50 ))
51 })?;
52 filter.insert(mid32);
53 }
54
55 let mut hash: HashMap<u32, Vec<u64>> = HashMap::new();
57 for &mid in direct {
58 if !filter.contains(mid as u32) {
60 continue;
61 }
62 let fof_list = self.csr.neighbors(mid);
63 hash.entry(mid as u32)
64 .or_default()
65 .extend_from_slice(fof_list);
66 }
67
68 let mut fof_set: std::collections::HashSet<u64> = std::collections::HashSet::new();
70 for &mid in direct {
71 if let Some(fof_list) = hash.get(&(mid as u32)) {
72 for &fof in fof_list {
73 fof_set.insert(fof);
74 }
75 }
76 }
77
78 let mut result: Vec<u64> = fof_set.into_iter().collect();
79 result.sort_unstable();
80 Ok(result)
81 }
82
83 pub fn two_hop_factorized(&self, src_slot: u64) -> Result<TwoHopChunk> {
91 let direct = self.csr.neighbors(src_slot);
92 if direct.is_empty() {
93 return Ok(TwoHopChunk {
94 groups: vec![],
95 total_count: 0,
96 });
97 }
98
99 let mut filter = RoaringBitmap::new();
101 for &mid in direct {
102 if mid <= u32::MAX as u64 {
103 filter.insert(mid as u32);
104 }
105 }
106
107 let mut hash: HashMap<u32, Vec<u64>> = HashMap::new();
109 for &mid in direct {
110 if !filter.contains(mid as u32) {
111 continue;
112 }
113 let fof_list = self.csr.neighbors(mid);
114 if !fof_list.is_empty() {
115 hash.entry(mid as u32)
116 .or_default()
117 .extend_from_slice(fof_list);
118 }
119 }
120
121 let mut groups = Vec::new();
123 let mut total_count = 0u64;
124
125 for (&mid, fof_list) in &hash {
126 let count = fof_list.len() as u64;
129 total_count += count;
130 groups.push(TwoHopGroup {
131 mid_slot: mid as u64,
132 fof_slots: fof_list.clone(),
133 multiplicity: 1,
134 });
135 }
136
137 Ok(TwoHopChunk {
138 groups,
139 total_count,
140 })
141 }
142}
143
144pub struct TwoHopChunk {
146 pub groups: Vec<TwoHopGroup>,
147 pub total_count: u64,
148}
149
150impl TwoHopChunk {
151 pub fn logical_row_count(&self) -> u64 {
153 self.total_count
154 }
155}
156
157pub struct TwoHopGroup {
159 pub mid_slot: u64,
160 pub fof_slots: Vec<u64>,
161 pub multiplicity: u64,
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 fn social_graph() -> CsrForward {
169 let edges = vec![(0, 1), (0, 2), (1, 3), (2, 3), (2, 4)];
172 CsrForward::build(5, &edges)
173 }
174
175 #[test]
176 fn two_hop_alice_fof() {
177 let csr = social_graph();
178 let join = AspJoin::new(&csr);
179 let fof = join.two_hop(0).unwrap();
180 assert_eq!(fof, vec![3, 4]); }
182
183 #[test]
184 fn two_hop_no_friends() {
185 let csr = CsrForward::build(3, &[(1, 2)]);
186 let join = AspJoin::new(&csr);
187 let fof = join.two_hop(0).unwrap();
189 assert!(fof.is_empty());
190 }
191
192 #[test]
193 fn two_hop_factorized_count() {
194 let csr = social_graph();
195 let join = AspJoin::new(&csr);
196 let chunk = join.two_hop_factorized(0).unwrap();
197 assert_eq!(chunk.logical_row_count(), 3);
199 }
200}