Skip to main content

grafeo_core/execution/operators/
expand.rs

1//! Expand operator for relationship traversal.
2
3use super::{Operator, OperatorError, OperatorResult};
4use crate::execution::DataChunk;
5use crate::graph::Direction;
6use crate::graph::GraphStore;
7use grafeo_common::types::{EdgeId, EpochId, LogicalType, NodeId, TxId};
8use std::sync::Arc;
9
10/// An expand operator that traverses edges from source nodes.
11///
12/// For each input row containing a source node, this operator produces
13/// output rows for each neighbor connected via matching edges.
14pub struct ExpandOperator {
15    /// The store to traverse.
16    store: Arc<dyn GraphStore>,
17    /// Input operator providing source nodes.
18    input: Box<dyn Operator>,
19    /// Index of the source node column in input.
20    source_column: usize,
21    /// Direction of edge traversal.
22    direction: Direction,
23    /// Edge type filter (empty = match all types, multiple = match any).
24    edge_types: Vec<String>,
25    /// Chunk capacity.
26    chunk_capacity: usize,
27    /// Current input chunk being processed.
28    current_input: Option<DataChunk>,
29    /// Current row index in the input chunk.
30    current_row: usize,
31    /// Current edge iterator for the current row.
32    current_edges: Vec<(NodeId, EdgeId)>,
33    /// Current edge index.
34    current_edge_idx: usize,
35    /// Whether the operator is exhausted.
36    exhausted: bool,
37    /// Transaction ID for MVCC visibility (None = use current epoch).
38    tx_id: Option<TxId>,
39    /// Epoch for version visibility.
40    viewing_epoch: Option<EpochId>,
41}
42
43impl ExpandOperator {
44    /// Creates a new expand operator.
45    pub fn new(
46        store: Arc<dyn GraphStore>,
47        input: Box<dyn Operator>,
48        source_column: usize,
49        direction: Direction,
50        edge_types: Vec<String>,
51    ) -> Self {
52        Self {
53            store,
54            input,
55            source_column,
56            direction,
57            edge_types,
58            chunk_capacity: 2048,
59            current_input: None,
60            current_row: 0,
61            current_edges: Vec::with_capacity(16), // typical node degree
62            current_edge_idx: 0,
63            exhausted: false,
64            tx_id: None,
65            viewing_epoch: None,
66        }
67    }
68
69    /// Sets the chunk capacity.
70    pub fn with_chunk_capacity(mut self, capacity: usize) -> Self {
71        self.chunk_capacity = capacity;
72        self
73    }
74
75    /// Sets the transaction context for MVCC visibility.
76    ///
77    /// When set, the expand will only traverse visible edges and nodes.
78    pub fn with_tx_context(mut self, epoch: EpochId, tx_id: Option<TxId>) -> Self {
79        self.viewing_epoch = Some(epoch);
80        self.tx_id = tx_id;
81        self
82    }
83
84    /// Loads the next input chunk.
85    fn load_next_input(&mut self) -> Result<bool, OperatorError> {
86        match self.input.next() {
87            Ok(Some(mut chunk)) => {
88                // Flatten the chunk if it has a selection vector so we can use direct indexing
89                chunk.flatten();
90                self.current_input = Some(chunk);
91                self.current_row = 0;
92                self.current_edges.clear();
93                self.current_edge_idx = 0;
94                Ok(true)
95            }
96            Ok(None) => {
97                self.exhausted = true;
98                Ok(false)
99            }
100            Err(e) => Err(e),
101        }
102    }
103
104    /// Loads edges for the current row.
105    fn load_edges_for_current_row(&mut self) -> Result<bool, OperatorError> {
106        let Some(chunk) = &self.current_input else {
107            return Ok(false);
108        };
109
110        if self.current_row >= chunk.row_count() {
111            return Ok(false);
112        }
113
114        let col = chunk.column(self.source_column).ok_or_else(|| {
115            OperatorError::ColumnNotFound(format!("Column {} not found", self.source_column))
116        })?;
117
118        let source_id = col
119            .get_node_id(self.current_row)
120            .ok_or_else(|| OperatorError::Execution("Expected node ID in source column".into()))?;
121
122        // Get visibility context
123        let epoch = self.viewing_epoch;
124        let tx_id = self.tx_id;
125
126        // Get edges from this node
127        let edges: Vec<(NodeId, EdgeId)> = self
128            .store
129            .edges_from(source_id, self.direction)
130            .into_iter()
131            .filter(|(target_id, edge_id)| {
132                // Filter by edge type if specified
133                let type_matches = if self.edge_types.is_empty() {
134                    true
135                } else if let Some(actual_type) = self.store.edge_type(*edge_id) {
136                    self.edge_types
137                        .iter()
138                        .any(|t| actual_type.as_str().eq_ignore_ascii_case(t.as_str()))
139                } else {
140                    false
141                };
142
143                if !type_matches {
144                    return false;
145                }
146
147                // Filter by visibility if we have epoch context
148                if let Some(epoch) = epoch {
149                    if let Some(tx) = tx_id {
150                        // Transaction-aware visibility
151                        let edge_visible =
152                            self.store.get_edge_versioned(*edge_id, epoch, tx).is_some();
153                        let target_visible = self
154                            .store
155                            .get_node_versioned(*target_id, epoch, tx)
156                            .is_some();
157                        edge_visible && target_visible
158                    } else {
159                        // Pure epoch-based visibility (time-travel)
160                        let edge_visible = self.store.get_edge_at_epoch(*edge_id, epoch).is_some();
161                        let target_visible =
162                            self.store.get_node_at_epoch(*target_id, epoch).is_some();
163                        edge_visible && target_visible
164                    }
165                } else {
166                    true
167                }
168            })
169            .collect();
170
171        self.current_edges = edges;
172        self.current_edge_idx = 0;
173        Ok(true)
174    }
175}
176
177impl Operator for ExpandOperator {
178    fn next(&mut self) -> OperatorResult {
179        if self.exhausted {
180            return Ok(None);
181        }
182
183        // Build output schema: preserve all input columns + edge + target
184        // We need to build this dynamically based on input schema
185        if self.current_input.is_none() {
186            if !self.load_next_input()? {
187                return Ok(None);
188            }
189            self.load_edges_for_current_row()?;
190        }
191        let input_chunk = self.current_input.as_ref().expect("input loaded above");
192
193        // Build schema: [input_columns..., edge, target]
194        let input_col_count = input_chunk.column_count();
195        let mut schema: Vec<LogicalType> = (0..input_col_count)
196            .map(|i| {
197                input_chunk
198                    .column(i)
199                    .map_or(LogicalType::Any, |c| c.data_type().clone())
200            })
201            .collect();
202        schema.push(LogicalType::Edge);
203        schema.push(LogicalType::Node);
204
205        let mut chunk = DataChunk::with_capacity(&schema, self.chunk_capacity);
206        let mut count = 0;
207
208        while count < self.chunk_capacity {
209            // If we need a new input chunk
210            if self.current_input.is_none() {
211                if !self.load_next_input()? {
212                    break;
213                }
214                self.load_edges_for_current_row()?;
215            }
216
217            // If we've exhausted edges for current row, move to next row
218            while self.current_edge_idx >= self.current_edges.len() {
219                self.current_row += 1;
220
221                // If we've exhausted the current input chunk, get next one
222                if self.current_row >= self.current_input.as_ref().map_or(0, |c| c.row_count()) {
223                    self.current_input = None;
224                    if !self.load_next_input()? {
225                        // No more input chunks
226                        if count > 0 {
227                            chunk.set_count(count);
228                            return Ok(Some(chunk));
229                        }
230                        return Ok(None);
231                    }
232                }
233
234                self.load_edges_for_current_row()?;
235            }
236
237            // Get the current edge
238            let (target_id, edge_id) = self.current_edges[self.current_edge_idx];
239
240            // Copy all input columns to output
241            let input = self.current_input.as_ref().expect("input loaded above");
242            for col_idx in 0..input_col_count {
243                if let Some(input_col) = input.column(col_idx)
244                    && let Some(output_col) = chunk.column_mut(col_idx)
245                {
246                    // Use copy_row_to which preserves NodeId/EdgeId types
247                    input_col.copy_row_to(self.current_row, output_col);
248                }
249            }
250
251            // Add edge column
252            if let Some(col) = chunk.column_mut(input_col_count) {
253                col.push_edge_id(edge_id);
254            }
255
256            // Add target node column
257            if let Some(col) = chunk.column_mut(input_col_count + 1) {
258                col.push_node_id(target_id);
259            }
260
261            count += 1;
262            self.current_edge_idx += 1;
263        }
264
265        if count > 0 {
266            chunk.set_count(count);
267            Ok(Some(chunk))
268        } else {
269            Ok(None)
270        }
271    }
272
273    fn reset(&mut self) {
274        self.input.reset();
275        self.current_input = None;
276        self.current_row = 0;
277        self.current_edges.clear();
278        self.current_edge_idx = 0;
279        self.exhausted = false;
280    }
281
282    fn name(&self) -> &'static str {
283        "Expand"
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::execution::operators::ScanOperator;
291    use crate::graph::lpg::LpgStore;
292
293    /// Creates a new `LpgStore` wrapped in an `Arc` and returns both the
294    /// concrete handle (for mutation) and a trait-object handle (for operators).
295    fn test_store() -> (Arc<LpgStore>, Arc<dyn GraphStore>) {
296        let store = Arc::new(LpgStore::new().unwrap());
297        let dyn_store: Arc<dyn GraphStore> = Arc::clone(&store) as Arc<dyn GraphStore>;
298        (store, dyn_store)
299    }
300
301    #[test]
302    fn test_expand_outgoing() {
303        let (store, dyn_store) = test_store();
304
305        // Create nodes
306        let alix = store.create_node(&["Person"]);
307        let gus = store.create_node(&["Person"]);
308        let vincent = store.create_node(&["Person"]);
309
310        // Create edges: Alix -> Gus, Alix -> Vincent
311        store.create_edge(alix, gus, "KNOWS");
312        store.create_edge(alix, vincent, "KNOWS");
313
314        // Scan Alix only
315        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
316
317        let mut expand = ExpandOperator::new(
318            Arc::clone(&dyn_store),
319            scan,
320            0, // source column
321            Direction::Outgoing,
322            vec![],
323        );
324
325        // Collect all results
326        let mut results = Vec::new();
327        while let Ok(Some(chunk)) = expand.next() {
328            for i in 0..chunk.row_count() {
329                let src = chunk.column(0).unwrap().get_node_id(i).unwrap();
330                let edge = chunk.column(1).unwrap().get_edge_id(i).unwrap();
331                let dst = chunk.column(2).unwrap().get_node_id(i).unwrap();
332                results.push((src, edge, dst));
333            }
334        }
335
336        // Alix -> Gus, Alix -> Vincent
337        assert_eq!(results.len(), 2);
338
339        // All source nodes should be Alix
340        for (src, _, _) in &results {
341            assert_eq!(*src, alix);
342        }
343
344        // Target nodes should be Gus and Vincent
345        let targets: Vec<NodeId> = results.iter().map(|(_, _, dst)| *dst).collect();
346        assert!(targets.contains(&gus));
347        assert!(targets.contains(&vincent));
348    }
349
350    #[test]
351    fn test_expand_with_edge_type_filter() {
352        let (store, dyn_store) = test_store();
353
354        let alix = store.create_node(&["Person"]);
355        let gus = store.create_node(&["Person"]);
356        let company = store.create_node(&["Company"]);
357
358        store.create_edge(alix, gus, "KNOWS");
359        store.create_edge(alix, company, "WORKS_AT");
360
361        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
362
363        let mut expand = ExpandOperator::new(
364            Arc::clone(&dyn_store),
365            scan,
366            0,
367            Direction::Outgoing,
368            vec!["KNOWS".to_string()],
369        );
370
371        let mut results = Vec::new();
372        while let Ok(Some(chunk)) = expand.next() {
373            for i in 0..chunk.row_count() {
374                let dst = chunk.column(2).unwrap().get_node_id(i).unwrap();
375                results.push(dst);
376            }
377        }
378
379        // Only KNOWS edges should be followed
380        assert_eq!(results.len(), 1);
381        assert_eq!(results[0], gus);
382    }
383
384    #[test]
385    fn test_expand_incoming() {
386        let (store, dyn_store) = test_store();
387
388        let alix = store.create_node(&["Person"]);
389        let gus = store.create_node(&["Person"]);
390
391        store.create_edge(alix, gus, "KNOWS");
392
393        // Scan Gus
394        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
395
396        let mut expand =
397            ExpandOperator::new(Arc::clone(&dyn_store), scan, 0, Direction::Incoming, vec![]);
398
399        let mut results = Vec::new();
400        while let Ok(Some(chunk)) = expand.next() {
401            for i in 0..chunk.row_count() {
402                let src = chunk.column(0).unwrap().get_node_id(i).unwrap();
403                let dst = chunk.column(2).unwrap().get_node_id(i).unwrap();
404                results.push((src, dst));
405            }
406        }
407
408        // Gus <- Alix (Gus's incoming edge from Alix)
409        assert_eq!(results.len(), 1);
410        assert_eq!(results[0].0, gus); // source in the expand is Gus
411        assert_eq!(results[0].1, alix); // target is Alix (who points to Gus)
412    }
413
414    #[test]
415    fn test_expand_no_edges() {
416        let (store, dyn_store) = test_store();
417
418        store.create_node(&["Person"]);
419
420        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
421
422        let mut expand =
423            ExpandOperator::new(Arc::clone(&dyn_store), scan, 0, Direction::Outgoing, vec![]);
424
425        let result = expand.next().unwrap();
426        assert!(result.is_none());
427    }
428
429    #[test]
430    fn test_expand_reset() {
431        let (store, dyn_store) = test_store();
432
433        let a = store.create_node(&["Person"]);
434        let b = store.create_node(&["Person"]);
435        store.create_edge(a, b, "KNOWS");
436
437        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
438        let mut expand =
439            ExpandOperator::new(Arc::clone(&dyn_store), scan, 0, Direction::Outgoing, vec![]);
440
441        // First pass
442        let mut count1 = 0;
443        while let Ok(Some(chunk)) = expand.next() {
444            count1 += chunk.row_count();
445        }
446
447        // Reset and run again
448        expand.reset();
449        let mut count2 = 0;
450        while let Ok(Some(chunk)) = expand.next() {
451            count2 += chunk.row_count();
452        }
453
454        assert_eq!(count1, count2);
455        assert_eq!(count1, 1);
456    }
457
458    #[test]
459    fn test_expand_name() {
460        let (_store, dyn_store) = test_store();
461        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
462        let expand =
463            ExpandOperator::new(Arc::clone(&dyn_store), scan, 0, Direction::Outgoing, vec![]);
464        assert_eq!(expand.name(), "Expand");
465    }
466
467    #[test]
468    fn test_expand_with_chunk_capacity() {
469        let (store, dyn_store) = test_store();
470
471        let a = store.create_node(&["Person"]);
472        for _ in 0..5 {
473            let b = store.create_node(&["Person"]);
474            store.create_edge(a, b, "KNOWS");
475        }
476
477        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
478        let mut expand =
479            ExpandOperator::new(Arc::clone(&dyn_store), scan, 0, Direction::Outgoing, vec![])
480                .with_chunk_capacity(2);
481
482        // With capacity 2 and 5 edges from node a, we should get multiple chunks
483        let mut total = 0;
484        let mut chunk_count = 0;
485        while let Ok(Some(chunk)) = expand.next() {
486            chunk_count += 1;
487            total += chunk.row_count();
488        }
489
490        assert_eq!(total, 5);
491        assert!(
492            chunk_count >= 2,
493            "Expected multiple chunks with small capacity"
494        );
495    }
496
497    #[test]
498    fn test_expand_edge_type_case_insensitive() {
499        let (store, dyn_store) = test_store();
500
501        let a = store.create_node(&["Person"]);
502        let b = store.create_node(&["Person"]);
503        store.create_edge(a, b, "KNOWS");
504
505        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
506        let mut expand = ExpandOperator::new(
507            Arc::clone(&dyn_store),
508            scan,
509            0,
510            Direction::Outgoing,
511            vec!["knows".to_string()], // lowercase
512        );
513
514        let mut count = 0;
515        while let Ok(Some(chunk)) = expand.next() {
516            count += chunk.row_count();
517        }
518
519        // Should match case-insensitively
520        assert_eq!(count, 1);
521    }
522
523    #[test]
524    fn test_expand_multiple_source_nodes() {
525        let (store, dyn_store) = test_store();
526
527        let a = store.create_node(&["Person"]);
528        let b = store.create_node(&["Person"]);
529        let c = store.create_node(&["Person"]);
530
531        store.create_edge(a, c, "KNOWS");
532        store.create_edge(b, c, "KNOWS");
533
534        let scan = Box::new(ScanOperator::with_label(Arc::clone(&dyn_store), "Person"));
535        let mut expand =
536            ExpandOperator::new(Arc::clone(&dyn_store), scan, 0, Direction::Outgoing, vec![]);
537
538        let mut results = Vec::new();
539        while let Ok(Some(chunk)) = expand.next() {
540            for i in 0..chunk.row_count() {
541                let src = chunk.column(0).unwrap().get_node_id(i).unwrap();
542                let dst = chunk.column(2).unwrap().get_node_id(i).unwrap();
543                results.push((src, dst));
544            }
545        }
546
547        // Both a->c and b->c
548        assert_eq!(results.len(), 2);
549    }
550
551    #[test]
552    fn test_expand_empty_input() {
553        let (_store, dyn_store) = test_store();
554
555        // No nodes with this label
556        let scan = Box::new(ScanOperator::with_label(
557            Arc::clone(&dyn_store),
558            "Nonexistent",
559        ));
560        let mut expand =
561            ExpandOperator::new(Arc::clone(&dyn_store), scan, 0, Direction::Outgoing, vec![]);
562
563        let result = expand.next().unwrap();
564        assert!(result.is_none());
565    }
566}