nodedb_cluster/distributed_graph/
pattern_match.rs1use std::collections::HashMap;
18
19use serde::{Deserialize, Serialize};
20
21#[derive(
27 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
28)]
29pub struct PatternContinuation {
30 pub target_shard: u16,
32 pub source_shard: u16,
34 pub bindings: HashMap<String, String>,
36 pub next_triple_idx: usize,
38 pub start_node: String,
40 pub start_binding: String,
42}
43
44#[derive(
46 Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
47)]
48pub struct ShardMatchResult {
49 pub shard_id: u16,
51 pub completed_rows: Vec<HashMap<String, String>>,
53 pub continuations: Vec<PatternContinuation>,
55}
56
57#[derive(Debug)]
62pub struct DistributedMatchCoordinator {
63 pub completed: Vec<HashMap<String, String>>,
65 pub pending: HashMap<u16, Vec<PatternContinuation>>,
67 pub round: u32,
69 pub max_rounds: u32,
71}
72
73impl DistributedMatchCoordinator {
74 pub fn new(max_rounds: u32) -> Self {
75 Self {
76 completed: Vec::new(),
77 pending: HashMap::new(),
78 round: 0,
79 max_rounds,
80 }
81 }
82
83 pub fn add_shard_result(&mut self, result: ShardMatchResult) {
85 self.completed.extend(result.completed_rows);
86 for cont in result.continuations {
87 self.pending
88 .entry(cont.target_shard)
89 .or_default()
90 .push(cont);
91 }
92 }
93
94 pub fn has_pending(&self) -> bool {
96 !self.pending.is_empty()
97 }
98
99 pub fn take_pending(&mut self, shard_id: u16) -> Vec<PatternContinuation> {
101 self.pending.remove(&shard_id).unwrap_or_default()
102 }
103
104 pub fn take_all_pending(&mut self) -> HashMap<u16, Vec<PatternContinuation>> {
106 std::mem::take(&mut self.pending)
107 }
108
109 pub fn advance(&mut self) -> bool {
111 self.round += 1;
112 self.round < self.max_rounds
113 }
114
115 pub fn result_count(&self) -> usize {
117 self.completed.len()
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[test]
126 fn pattern_continuation_serde() {
127 let cont = PatternContinuation {
128 target_shard: 2,
129 source_shard: 0,
130 bindings: [("a".into(), "alice".into())].into_iter().collect(),
131 next_triple_idx: 1,
132 start_node: "bob".into(),
133 start_binding: "b".into(),
134 };
135 let bytes = zerompk::to_msgpack_vec(&cont).unwrap();
136 let decoded: PatternContinuation = zerompk::from_msgpack(&bytes).unwrap();
137 assert_eq!(decoded.target_shard, 2);
138 assert_eq!(decoded.start_node, "bob");
139 assert_eq!(decoded.bindings["a"], "alice");
140 }
141
142 #[test]
143 fn shard_match_result_serde() {
144 let result = ShardMatchResult {
145 shard_id: 1,
146 completed_rows: vec![
147 [("a".into(), "alice".into()), ("b".into(), "bob".into())]
148 .into_iter()
149 .collect(),
150 ],
151 continuations: vec![],
152 };
153 let bytes = zerompk::to_msgpack_vec(&result).unwrap();
154 let decoded: ShardMatchResult = zerompk::from_msgpack(&bytes).unwrap();
155 assert_eq!(decoded.completed_rows.len(), 1);
156 }
157
158 #[test]
159 fn coordinator_collects_results() {
160 let mut coord = DistributedMatchCoordinator::new(10);
161
162 coord.add_shard_result(ShardMatchResult {
163 shard_id: 0,
164 completed_rows: vec![[("a".into(), "alice".into())].into_iter().collect()],
165 continuations: vec![PatternContinuation {
166 target_shard: 1,
167 source_shard: 0,
168 bindings: [("a".into(), "alice".into())].into_iter().collect(),
169 next_triple_idx: 1,
170 start_node: "bob".into(),
171 start_binding: "b".into(),
172 }],
173 });
174
175 assert_eq!(coord.result_count(), 1);
176 assert!(coord.has_pending());
177 assert_eq!(coord.take_pending(1).len(), 1);
178 assert!(!coord.has_pending());
179 }
180
181 #[test]
182 fn coordinator_multi_round() {
183 let mut coord = DistributedMatchCoordinator::new(5);
184
185 coord.add_shard_result(ShardMatchResult {
187 shard_id: 0,
188 completed_rows: vec![
189 [("x".into(), "1".into())].into_iter().collect(),
190 [("x".into(), "2".into())].into_iter().collect(),
191 ],
192 continuations: vec![PatternContinuation {
193 target_shard: 1,
194 source_shard: 0,
195 bindings: HashMap::new(),
196 next_triple_idx: 0,
197 start_node: "n".into(),
198 start_binding: "a".into(),
199 }],
200 });
201
202 assert!(coord.advance()); let pending = coord.take_all_pending();
206 assert_eq!(pending.len(), 1);
207 assert_eq!(pending[&1].len(), 1);
208
209 coord.add_shard_result(ShardMatchResult {
210 shard_id: 1,
211 completed_rows: vec![[("x".into(), "3".into())].into_iter().collect()],
212 continuations: vec![],
213 });
214
215 assert!(!coord.has_pending());
216 assert_eq!(coord.result_count(), 3);
217 }
218
219 #[test]
220 fn coordinator_max_rounds() {
221 let mut coord = DistributedMatchCoordinator::new(2);
222 assert!(coord.advance()); assert!(!coord.advance()); }
225
226 #[test]
227 fn coordinator_no_pending_initially() {
228 let coord = DistributedMatchCoordinator::new(10);
229 assert!(!coord.has_pending());
230 assert_eq!(coord.result_count(), 0);
231 }
232}