Skip to main content

nodedb_cluster/distributed_graph/
pattern_match.rs

1// SPDX-License-Identifier: BUSL-1.1
2
3//! Distributed pattern matching — cross-shard scatter-gather for MATCH.
4//!
5//! When a MATCH pattern encounters a ghost edge (destination on another shard),
6//! the partial binding row is sent to the target shard for continuation.
7//! The target shard resumes pattern expansion from the ghost destination node.
8//!
9//! Protocol:
10//! 1. Coordinator broadcasts MATCH query to all shards.
11//! 2. Each shard executes the pattern locally on its CSR.
12//! 3. When a triple crosses a shard boundary (ghost edge):
13//!    - The shard packages the partial binding row + remaining pattern triples.
14//!    - Sends a `PatternContinuation` to the target shard.
15//! 4. Target shard resumes execution with the partial bindings.
16//! 5. Results from all shards are merged by the coordinator.
17//! 6. Iterate until no new continuations are pending.
18
19use std::collections::HashMap;
20
21use serde::{Deserialize, Serialize};
22
23/// A partial MATCH result that needs continuation on another shard.
24///
25/// Contains the current variable bindings and the index of the next
26/// triple to execute in the pattern chain. The target shard resumes
27/// from `next_triple_idx` using the provided bindings.
28#[derive(
29    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
30)]
31pub struct PatternContinuation {
32    /// Target shard that should continue execution.
33    pub target_shard: u32,
34    /// Source shard that generated this continuation.
35    pub source_shard: u32,
36    /// Current variable bindings (node_name → value).
37    pub bindings: HashMap<String, String>,
38    /// Index of the next triple to execute in the chain.
39    pub next_triple_idx: usize,
40    /// The ghost node name that the target shard should start from.
41    pub start_node: String,
42    /// The binding variable name for the start node.
43    pub start_binding: String,
44}
45
46/// Coordinator response from a shard for distributed MATCH.
47#[derive(
48    Debug, Clone, Serialize, Deserialize, zerompk::ToMessagePack, zerompk::FromMessagePack,
49)]
50pub struct ShardMatchResult {
51    /// Shard that produced these results.
52    pub shard_id: u32,
53    /// Completed binding rows (fully matched patterns).
54    pub completed_rows: Vec<HashMap<String, String>>,
55    /// Partial rows that need continuation on other shards.
56    pub continuations: Vec<PatternContinuation>,
57}
58
59/// Coordinator state for distributed MATCH execution.
60///
61/// Tracks pending continuations and completed results across rounds
62/// of scatter-gather until all continuations are resolved.
63#[derive(Debug)]
64pub struct DistributedMatchCoordinator {
65    /// Completed result rows from all shards.
66    pub completed: Vec<HashMap<String, String>>,
67    /// Pending continuations grouped by target shard.
68    pub pending: HashMap<u32, Vec<PatternContinuation>>,
69    /// Round counter (for debugging / max-round termination).
70    pub round: u32,
71    /// Maximum rounds before forced termination (prevent infinite loops).
72    pub max_rounds: u32,
73}
74
75impl DistributedMatchCoordinator {
76    pub fn new(max_rounds: u32) -> Self {
77        Self {
78            completed: Vec::new(),
79            pending: HashMap::new(),
80            round: 0,
81            max_rounds,
82        }
83    }
84
85    /// Ingest results from a shard.
86    pub fn add_shard_result(&mut self, result: ShardMatchResult) {
87        self.completed.extend(result.completed_rows);
88        for cont in result.continuations {
89            self.pending
90                .entry(cont.target_shard)
91                .or_default()
92                .push(cont);
93        }
94    }
95
96    /// Check if there are pending continuations to dispatch.
97    pub fn has_pending(&self) -> bool {
98        !self.pending.is_empty()
99    }
100
101    /// Take all pending continuations for a target shard.
102    pub fn take_pending(&mut self, shard_id: u32) -> Vec<PatternContinuation> {
103        self.pending.remove(&shard_id).unwrap_or_default()
104    }
105
106    /// Take all pending continuations, grouped by target shard.
107    pub fn take_all_pending(&mut self) -> HashMap<u32, Vec<PatternContinuation>> {
108        std::mem::take(&mut self.pending)
109    }
110
111    /// Advance to next round. Returns `false` if max rounds reached.
112    pub fn advance(&mut self) -> bool {
113        self.round += 1;
114        self.round < self.max_rounds
115    }
116
117    /// Total completed rows.
118    pub fn result_count(&self) -> usize {
119        self.completed.len()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn pattern_continuation_serde() {
129        let cont = PatternContinuation {
130            target_shard: 2,
131            source_shard: 0,
132            bindings: [("a".into(), "alice".into())].into_iter().collect(),
133            next_triple_idx: 1,
134            start_node: "bob".into(),
135            start_binding: "b".into(),
136        };
137        let bytes = zerompk::to_msgpack_vec(&cont).unwrap();
138        let decoded: PatternContinuation = zerompk::from_msgpack(&bytes).unwrap();
139        assert_eq!(decoded.target_shard, 2);
140        assert_eq!(decoded.start_node, "bob");
141        assert_eq!(decoded.bindings["a"], "alice");
142    }
143
144    #[test]
145    fn shard_match_result_serde() {
146        let result = ShardMatchResult {
147            shard_id: 1,
148            completed_rows: vec![
149                [("a".into(), "alice".into()), ("b".into(), "bob".into())]
150                    .into_iter()
151                    .collect(),
152            ],
153            continuations: vec![],
154        };
155        let bytes = zerompk::to_msgpack_vec(&result).unwrap();
156        let decoded: ShardMatchResult = zerompk::from_msgpack(&bytes).unwrap();
157        assert_eq!(decoded.completed_rows.len(), 1);
158    }
159
160    #[test]
161    fn coordinator_collects_results() {
162        let mut coord = DistributedMatchCoordinator::new(10);
163
164        coord.add_shard_result(ShardMatchResult {
165            shard_id: 0,
166            completed_rows: vec![[("a".into(), "alice".into())].into_iter().collect()],
167            continuations: vec![PatternContinuation {
168                target_shard: 1,
169                source_shard: 0,
170                bindings: [("a".into(), "alice".into())].into_iter().collect(),
171                next_triple_idx: 1,
172                start_node: "bob".into(),
173                start_binding: "b".into(),
174            }],
175        });
176
177        assert_eq!(coord.result_count(), 1);
178        assert!(coord.has_pending());
179        assert_eq!(coord.take_pending(1).len(), 1);
180        assert!(!coord.has_pending());
181    }
182
183    #[test]
184    fn coordinator_multi_round() {
185        let mut coord = DistributedMatchCoordinator::new(5);
186
187        // Round 1: shard 0 produces 2 completed + 1 continuation.
188        coord.add_shard_result(ShardMatchResult {
189            shard_id: 0,
190            completed_rows: vec![
191                [("x".into(), "1".into())].into_iter().collect(),
192                [("x".into(), "2".into())].into_iter().collect(),
193            ],
194            continuations: vec![PatternContinuation {
195                target_shard: 1,
196                source_shard: 0,
197                bindings: HashMap::new(),
198                next_triple_idx: 0,
199                start_node: "n".into(),
200                start_binding: "a".into(),
201            }],
202        });
203
204        assert!(coord.advance()); // Round 1 → 2.
205
206        // Round 2: shard 1 completes the continuation.
207        let pending = coord.take_all_pending();
208        assert_eq!(pending.len(), 1);
209        assert_eq!(pending[&1].len(), 1);
210
211        coord.add_shard_result(ShardMatchResult {
212            shard_id: 1,
213            completed_rows: vec![[("x".into(), "3".into())].into_iter().collect()],
214            continuations: vec![],
215        });
216
217        assert!(!coord.has_pending());
218        assert_eq!(coord.result_count(), 3);
219    }
220
221    #[test]
222    fn coordinator_max_rounds() {
223        let mut coord = DistributedMatchCoordinator::new(2);
224        assert!(coord.advance()); // round 1
225        assert!(!coord.advance()); // round 2 = max
226    }
227
228    #[test]
229    fn coordinator_no_pending_initially() {
230        let coord = DistributedMatchCoordinator::new(10);
231        assert!(!coord.has_pending());
232        assert_eq!(coord.result_count(), 0);
233    }
234}