Skip to main content

nodedb_cluster/distributed_graph/
pattern_match.rs

1//! Distributed pattern matching — cross-shard scatter-gather for MATCH.
2//!
3//! When a MATCH pattern encounters a ghost edge (destination on another shard),
4//! the partial binding row is sent to the target shard for continuation.
5//! The target shard resumes pattern expansion from the ghost destination node.
6//!
7//! Protocol:
8//! 1. Coordinator broadcasts MATCH query to all shards.
9//! 2. Each shard executes the pattern locally on its CSR.
10//! 3. When a triple crosses a shard boundary (ghost edge):
11//!    - The shard packages the partial binding row + remaining pattern triples.
12//!    - Sends a `PatternContinuation` to the target shard.
13//! 4. Target shard resumes execution with the partial bindings.
14//! 5. Results from all shards are merged by the coordinator.
15//! 6. Iterate until no new continuations are pending.
16
17use std::collections::HashMap;
18
19use serde::{Deserialize, Serialize};
20
21/// A partial MATCH result that needs continuation on another shard.
22///
23/// Contains the current variable bindings and the index of the next
24/// triple to execute in the pattern chain. The target shard resumes
25/// from `next_triple_idx` using the provided bindings.
26#[derive(
27    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
28)]
29pub struct PatternContinuation {
30    /// Target shard that should continue execution.
31    pub target_shard: u16,
32    /// Source shard that generated this continuation.
33    pub source_shard: u16,
34    /// Current variable bindings (node_name → value).
35    pub bindings: HashMap<String, String>,
36    /// Index of the next triple to execute in the chain.
37    pub next_triple_idx: usize,
38    /// The ghost node name that the target shard should start from.
39    pub start_node: String,
40    /// The binding variable name for the start node.
41    pub start_binding: String,
42}
43
44/// Coordinator response from a shard for distributed MATCH.
45#[derive(
46    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
47)]
48pub struct ShardMatchResult {
49    /// Shard that produced these results.
50    pub shard_id: u16,
51    /// Completed binding rows (fully matched patterns).
52    pub completed_rows: Vec<HashMap<String, String>>,
53    /// Partial rows that need continuation on other shards.
54    pub continuations: Vec<PatternContinuation>,
55}
56
57/// Coordinator state for distributed MATCH execution.
58///
59/// Tracks pending continuations and completed results across rounds
60/// of scatter-gather until all continuations are resolved.
61#[derive(Debug)]
62pub struct DistributedMatchCoordinator {
63    /// Completed result rows from all shards.
64    pub completed: Vec<HashMap<String, String>>,
65    /// Pending continuations grouped by target shard.
66    pub pending: HashMap<u16, Vec<PatternContinuation>>,
67    /// Round counter (for debugging / max-round termination).
68    pub round: u32,
69    /// Maximum rounds before forced termination (prevent infinite loops).
70    pub max_rounds: u32,
71}
72
73impl DistributedMatchCoordinator {
74    pub fn new(max_rounds: u32) -> Self {
75        Self {
76            completed: Vec::new(),
77            pending: HashMap::new(),
78            round: 0,
79            max_rounds,
80        }
81    }
82
83    /// Ingest results from a shard.
84    pub fn add_shard_result(&mut self, result: ShardMatchResult) {
85        self.completed.extend(result.completed_rows);
86        for cont in result.continuations {
87            self.pending
88                .entry(cont.target_shard)
89                .or_default()
90                .push(cont);
91        }
92    }
93
94    /// Check if there are pending continuations to dispatch.
95    pub fn has_pending(&self) -> bool {
96        !self.pending.is_empty()
97    }
98
99    /// Take all pending continuations for a target shard.
100    pub fn take_pending(&mut self, shard_id: u16) -> Vec<PatternContinuation> {
101        self.pending.remove(&shard_id).unwrap_or_default()
102    }
103
104    /// Take all pending continuations, grouped by target shard.
105    pub fn take_all_pending(&mut self) -> HashMap<u16, Vec<PatternContinuation>> {
106        std::mem::take(&mut self.pending)
107    }
108
109    /// Advance to next round. Returns `false` if max rounds reached.
110    pub fn advance(&mut self) -> bool {
111        self.round += 1;
112        self.round < self.max_rounds
113    }
114
115    /// Total completed rows.
116    pub fn result_count(&self) -> usize {
117        self.completed.len()
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[test]
126    fn pattern_continuation_serde() {
127        let cont = PatternContinuation {
128            target_shard: 2,
129            source_shard: 0,
130            bindings: [("a".into(), "alice".into())].into_iter().collect(),
131            next_triple_idx: 1,
132            start_node: "bob".into(),
133            start_binding: "b".into(),
134        };
135        let bytes = zerompk::to_msgpack_vec(&cont).unwrap();
136        let decoded: PatternContinuation = zerompk::from_msgpack(&bytes).unwrap();
137        assert_eq!(decoded.target_shard, 2);
138        assert_eq!(decoded.start_node, "bob");
139        assert_eq!(decoded.bindings["a"], "alice");
140    }
141
142    #[test]
143    fn shard_match_result_serde() {
144        let result = ShardMatchResult {
145            shard_id: 1,
146            completed_rows: vec![
147                [("a".into(), "alice".into()), ("b".into(), "bob".into())]
148                    .into_iter()
149                    .collect(),
150            ],
151            continuations: vec![],
152        };
153        let bytes = zerompk::to_msgpack_vec(&result).unwrap();
154        let decoded: ShardMatchResult = zerompk::from_msgpack(&bytes).unwrap();
155        assert_eq!(decoded.completed_rows.len(), 1);
156    }
157
158    #[test]
159    fn coordinator_collects_results() {
160        let mut coord = DistributedMatchCoordinator::new(10);
161
162        coord.add_shard_result(ShardMatchResult {
163            shard_id: 0,
164            completed_rows: vec![[("a".into(), "alice".into())].into_iter().collect()],
165            continuations: vec![PatternContinuation {
166                target_shard: 1,
167                source_shard: 0,
168                bindings: [("a".into(), "alice".into())].into_iter().collect(),
169                next_triple_idx: 1,
170                start_node: "bob".into(),
171                start_binding: "b".into(),
172            }],
173        });
174
175        assert_eq!(coord.result_count(), 1);
176        assert!(coord.has_pending());
177        assert_eq!(coord.take_pending(1).len(), 1);
178        assert!(!coord.has_pending());
179    }
180
181    #[test]
182    fn coordinator_multi_round() {
183        let mut coord = DistributedMatchCoordinator::new(5);
184
185        // Round 1: shard 0 produces 2 completed + 1 continuation.
186        coord.add_shard_result(ShardMatchResult {
187            shard_id: 0,
188            completed_rows: vec![
189                [("x".into(), "1".into())].into_iter().collect(),
190                [("x".into(), "2".into())].into_iter().collect(),
191            ],
192            continuations: vec![PatternContinuation {
193                target_shard: 1,
194                source_shard: 0,
195                bindings: HashMap::new(),
196                next_triple_idx: 0,
197                start_node: "n".into(),
198                start_binding: "a".into(),
199            }],
200        });
201
202        assert!(coord.advance()); // Round 1 → 2.
203
204        // Round 2: shard 1 completes the continuation.
205        let pending = coord.take_all_pending();
206        assert_eq!(pending.len(), 1);
207        assert_eq!(pending[&1].len(), 1);
208
209        coord.add_shard_result(ShardMatchResult {
210            shard_id: 1,
211            completed_rows: vec![[("x".into(), "3".into())].into_iter().collect()],
212            continuations: vec![],
213        });
214
215        assert!(!coord.has_pending());
216        assert_eq!(coord.result_count(), 3);
217    }
218
219    #[test]
220    fn coordinator_max_rounds() {
221        let mut coord = DistributedMatchCoordinator::new(2);
222        assert!(coord.advance()); // round 1
223        assert!(!coord.advance()); // round 2 = max
224    }
225
226    #[test]
227    fn coordinator_no_pending_initially() {
228        let coord = DistributedMatchCoordinator::new(10);
229        assert!(!coord.has_pending());
230        assert_eq!(coord.result_count(), 0);
231    }
232}