nodedb_cluster/distributed_graph/
pattern_match.rs1use std::collections::HashMap;
20
21use serde::{Deserialize, Serialize};
22
23#[derive(
29 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
30)]
31pub struct PatternContinuation {
32 pub target_shard: u32,
34 pub source_shard: u32,
36 pub bindings: HashMap<String, String>,
38 pub next_triple_idx: usize,
40 pub start_node: String,
42 pub start_binding: String,
44}
45
46#[derive(
48 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
49)]
50pub struct ShardMatchResult {
51 pub shard_id: u32,
53 pub completed_rows: Vec<HashMap<String, String>>,
55 pub continuations: Vec<PatternContinuation>,
57}
58
59#[derive(Debug)]
64pub struct DistributedMatchCoordinator {
65 pub completed: Vec<HashMap<String, String>>,
67 pub pending: HashMap<u32, Vec<PatternContinuation>>,
69 pub round: u32,
71 pub max_rounds: u32,
73}
74
75impl DistributedMatchCoordinator {
76 pub fn new(max_rounds: u32) -> Self {
77 Self {
78 completed: Vec::new(),
79 pending: HashMap::new(),
80 round: 0,
81 max_rounds,
82 }
83 }
84
85 pub fn add_shard_result(&mut self, result: ShardMatchResult) {
87 self.completed.extend(result.completed_rows);
88 for cont in result.continuations {
89 self.pending
90 .entry(cont.target_shard)
91 .or_default()
92 .push(cont);
93 }
94 }
95
96 pub fn has_pending(&self) -> bool {
98 !self.pending.is_empty()
99 }
100
101 pub fn take_pending(&mut self, shard_id: u32) -> Vec<PatternContinuation> {
103 self.pending.remove(&shard_id).unwrap_or_default()
104 }
105
106 pub fn take_all_pending(&mut self) -> HashMap<u32, Vec<PatternContinuation>> {
108 std::mem::take(&mut self.pending)
109 }
110
111 pub fn advance(&mut self) -> bool {
113 self.round += 1;
114 self.round < self.max_rounds
115 }
116
117 pub fn result_count(&self) -> usize {
119 self.completed.len()
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn pattern_continuation_serde() {
129 let cont = PatternContinuation {
130 target_shard: 2,
131 source_shard: 0,
132 bindings: [("a".into(), "alice".into())].into_iter().collect(),
133 next_triple_idx: 1,
134 start_node: "bob".into(),
135 start_binding: "b".into(),
136 };
137 let bytes = zerompk::to_msgpack_vec(&cont).unwrap();
138 let decoded: PatternContinuation = zerompk::from_msgpack(&bytes).unwrap();
139 assert_eq!(decoded.target_shard, 2);
140 assert_eq!(decoded.start_node, "bob");
141 assert_eq!(decoded.bindings["a"], "alice");
142 }
143
144 #[test]
145 fn shard_match_result_serde() {
146 let result = ShardMatchResult {
147 shard_id: 1,
148 completed_rows: vec![
149 [("a".into(), "alice".into()), ("b".into(), "bob".into())]
150 .into_iter()
151 .collect(),
152 ],
153 continuations: vec![],
154 };
155 let bytes = zerompk::to_msgpack_vec(&result).unwrap();
156 let decoded: ShardMatchResult = zerompk::from_msgpack(&bytes).unwrap();
157 assert_eq!(decoded.completed_rows.len(), 1);
158 }
159
160 #[test]
161 fn coordinator_collects_results() {
162 let mut coord = DistributedMatchCoordinator::new(10);
163
164 coord.add_shard_result(ShardMatchResult {
165 shard_id: 0,
166 completed_rows: vec![[("a".into(), "alice".into())].into_iter().collect()],
167 continuations: vec![PatternContinuation {
168 target_shard: 1,
169 source_shard: 0,
170 bindings: [("a".into(), "alice".into())].into_iter().collect(),
171 next_triple_idx: 1,
172 start_node: "bob".into(),
173 start_binding: "b".into(),
174 }],
175 });
176
177 assert_eq!(coord.result_count(), 1);
178 assert!(coord.has_pending());
179 assert_eq!(coord.take_pending(1).len(), 1);
180 assert!(!coord.has_pending());
181 }
182
183 #[test]
184 fn coordinator_multi_round() {
185 let mut coord = DistributedMatchCoordinator::new(5);
186
187 coord.add_shard_result(ShardMatchResult {
189 shard_id: 0,
190 completed_rows: vec![
191 [("x".into(), "1".into())].into_iter().collect(),
192 [("x".into(), "2".into())].into_iter().collect(),
193 ],
194 continuations: vec![PatternContinuation {
195 target_shard: 1,
196 source_shard: 0,
197 bindings: HashMap::new(),
198 next_triple_idx: 0,
199 start_node: "n".into(),
200 start_binding: "a".into(),
201 }],
202 });
203
204 assert!(coord.advance()); let pending = coord.take_all_pending();
208 assert_eq!(pending.len(), 1);
209 assert_eq!(pending[&1].len(), 1);
210
211 coord.add_shard_result(ShardMatchResult {
212 shard_id: 1,
213 completed_rows: vec![[("x".into(), "3".into())].into_iter().collect()],
214 continuations: vec![],
215 });
216
217 assert!(!coord.has_pending());
218 assert_eq!(coord.result_count(), 3);
219 }
220
221 #[test]
222 fn coordinator_max_rounds() {
223 let mut coord = DistributedMatchCoordinator::new(2);
224 assert!(coord.advance()); assert!(!coord.advance()); }
227
228 #[test]
229 fn coordinator_no_pending_initially() {
230 let coord = DistributedMatchCoordinator::new(10);
231 assert!(!coord.has_pending());
232 assert_eq!(coord.result_count(), 0);
233 }
234}