Skip to main content

grafeo_core/execution/operators/
expand.rs

1//! Expand operator for relationship traversal.
2
3use super::{Operator, OperatorError, OperatorResult};
4use crate::execution::DataChunk;
5use crate::graph::Direction;
6use crate::graph::lpg::LpgStore;
7use grafeo_common::types::{EdgeId, EpochId, LogicalType, NodeId, TxId};
8use std::sync::Arc;
9
10/// An expand operator that traverses edges from source nodes.
11///
12/// For each input row containing a source node, this operator produces
13/// output rows for each neighbor connected via matching edges.
14pub struct ExpandOperator {
15    /// The store to traverse.
16    store: Arc<LpgStore>,
17    /// Input operator providing source nodes.
18    input: Box<dyn Operator>,
19    /// Index of the source node column in input.
20    source_column: usize,
21    /// Direction of edge traversal.
22    direction: Direction,
23    /// Optional edge type filter.
24    edge_type: Option<String>,
25    /// Chunk capacity.
26    chunk_capacity: usize,
27    /// Current input chunk being processed.
28    current_input: Option<DataChunk>,
29    /// Current row index in the input chunk.
30    current_row: usize,
31    /// Current edge iterator for the current row.
32    current_edges: Vec<(NodeId, EdgeId)>,
33    /// Current edge index.
34    current_edge_idx: usize,
35    /// Whether the operator is exhausted.
36    exhausted: bool,
37    /// Transaction ID for MVCC visibility (None = use current epoch).
38    tx_id: Option<TxId>,
39    /// Epoch for version visibility.
40    viewing_epoch: Option<EpochId>,
41}
42
43impl ExpandOperator {
44    /// Creates a new expand operator.
45    pub fn new(
46        store: Arc<LpgStore>,
47        input: Box<dyn Operator>,
48        source_column: usize,
49        direction: Direction,
50        edge_type: Option<String>,
51    ) -> Self {
52        Self {
53            store,
54            input,
55            source_column,
56            direction,
57            edge_type,
58            chunk_capacity: 2048,
59            current_input: None,
60            current_row: 0,
61            current_edges: Vec::new(),
62            current_edge_idx: 0,
63            exhausted: false,
64            tx_id: None,
65            viewing_epoch: None,
66        }
67    }
68
69    /// Sets the chunk capacity.
70    pub fn with_chunk_capacity(mut self, capacity: usize) -> Self {
71        self.chunk_capacity = capacity;
72        self
73    }
74
75    /// Sets the transaction context for MVCC visibility.
76    ///
77    /// When set, the expand will only traverse visible edges and nodes.
78    pub fn with_tx_context(mut self, epoch: EpochId, tx_id: Option<TxId>) -> Self {
79        self.viewing_epoch = Some(epoch);
80        self.tx_id = tx_id;
81        self
82    }
83
84    /// Loads the next input chunk.
85    fn load_next_input(&mut self) -> Result<bool, OperatorError> {
86        match self.input.next() {
87            Ok(Some(mut chunk)) => {
88                // Flatten the chunk if it has a selection vector so we can use direct indexing
89                chunk.flatten();
90                self.current_input = Some(chunk);
91                self.current_row = 0;
92                self.current_edges.clear();
93                self.current_edge_idx = 0;
94                Ok(true)
95            }
96            Ok(None) => {
97                self.exhausted = true;
98                Ok(false)
99            }
100            Err(e) => Err(e),
101        }
102    }
103
104    /// Loads edges for the current row.
105    fn load_edges_for_current_row(&mut self) -> Result<bool, OperatorError> {
106        let chunk = match &self.current_input {
107            Some(c) => c,
108            None => return Ok(false),
109        };
110
111        if self.current_row >= chunk.row_count() {
112            return Ok(false);
113        }
114
115        let col = chunk.column(self.source_column).ok_or_else(|| {
116            OperatorError::ColumnNotFound(format!("Column {} not found", self.source_column))
117        })?;
118
119        let source_id = col
120            .get_node_id(self.current_row)
121            .ok_or_else(|| OperatorError::Execution("Expected node ID in source column".into()))?;
122
123        // Get visibility context
124        let epoch = self.viewing_epoch;
125        let tx = self.tx_id.unwrap_or(TxId::SYSTEM);
126
127        // Get edges from this node
128        let edges: Vec<(NodeId, EdgeId)> = self
129            .store
130            .edges_from(source_id, self.direction)
131            .filter(|(target_id, edge_id)| {
132                // Filter by edge type if specified
133                let type_matches = if let Some(ref filter_type) = self.edge_type {
134                    if let Some(edge_type) = self.store.edge_type(*edge_id) {
135                        edge_type.as_ref() == filter_type.as_str()
136                    } else {
137                        false
138                    }
139                } else {
140                    true
141                };
142
143                if !type_matches {
144                    return false;
145                }
146
147                // Filter by visibility if we have tx context
148                if let Some(epoch) = epoch {
149                    // Check if edge and target node are visible
150                    let edge_visible = self.store.get_edge_versioned(*edge_id, epoch, tx).is_some();
151                    let target_visible = self
152                        .store
153                        .get_node_versioned(*target_id, epoch, tx)
154                        .is_some();
155                    edge_visible && target_visible
156                } else {
157                    true
158                }
159            })
160            .collect();
161
162        self.current_edges = edges;
163        self.current_edge_idx = 0;
164        Ok(true)
165    }
166}
167
168impl Operator for ExpandOperator {
169    fn next(&mut self) -> OperatorResult {
170        if self.exhausted {
171            return Ok(None);
172        }
173
174        // Build output schema: preserve all input columns + edge + target
175        // We need to build this dynamically based on input schema
176        if self.current_input.is_none() {
177            if !self.load_next_input()? {
178                return Ok(None);
179            }
180            self.load_edges_for_current_row()?;
181        }
182        let input_chunk = self.current_input.as_ref().expect("input loaded above");
183
184        // Build schema: [input_columns..., edge, target]
185        let input_col_count = input_chunk.column_count();
186        let mut schema: Vec<LogicalType> = (0..input_col_count)
187            .map(|i| {
188                input_chunk
189                    .column(i)
190                    .map_or(LogicalType::Any, |c| c.data_type().clone())
191            })
192            .collect();
193        schema.push(LogicalType::Edge);
194        schema.push(LogicalType::Node);
195
196        let mut chunk = DataChunk::with_capacity(&schema, self.chunk_capacity);
197        let mut count = 0;
198
199        while count < self.chunk_capacity {
200            // If we need a new input chunk
201            if self.current_input.is_none() {
202                if !self.load_next_input()? {
203                    break;
204                }
205                self.load_edges_for_current_row()?;
206            }
207
208            // If we've exhausted edges for current row, move to next row
209            while self.current_edge_idx >= self.current_edges.len() {
210                self.current_row += 1;
211
212                // If we've exhausted the current input chunk, get next one
213                if self.current_row >= self.current_input.as_ref().map_or(0, |c| c.row_count()) {
214                    self.current_input = None;
215                    if !self.load_next_input()? {
216                        // No more input chunks
217                        if count > 0 {
218                            chunk.set_count(count);
219                            return Ok(Some(chunk));
220                        }
221                        return Ok(None);
222                    }
223                }
224
225                self.load_edges_for_current_row()?;
226            }
227
228            // Get the current edge
229            let (target_id, edge_id) = self.current_edges[self.current_edge_idx];
230
231            // Copy all input columns to output
232            let input = self.current_input.as_ref().unwrap();
233            for col_idx in 0..input_col_count {
234                if let Some(input_col) = input.column(col_idx) {
235                    if let Some(output_col) = chunk.column_mut(col_idx) {
236                        // Use copy_row_to which preserves NodeId/EdgeId types
237                        input_col.copy_row_to(self.current_row, output_col);
238                    }
239                }
240            }
241
242            // Add edge column
243            if let Some(col) = chunk.column_mut(input_col_count) {
244                col.push_edge_id(edge_id);
245            }
246
247            // Add target node column
248            if let Some(col) = chunk.column_mut(input_col_count + 1) {
249                col.push_node_id(target_id);
250            }
251
252            count += 1;
253            self.current_edge_idx += 1;
254        }
255
256        if count > 0 {
257            chunk.set_count(count);
258            Ok(Some(chunk))
259        } else {
260            Ok(None)
261        }
262    }
263
264    fn reset(&mut self) {
265        self.input.reset();
266        self.current_input = None;
267        self.current_row = 0;
268        self.current_edges.clear();
269        self.current_edge_idx = 0;
270        self.exhausted = false;
271    }
272
273    fn name(&self) -> &'static str {
274        "Expand"
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use crate::execution::operators::ScanOperator;
282
283    #[test]
284    fn test_expand_outgoing() {
285        let store = Arc::new(LpgStore::new());
286
287        // Create nodes
288        let alice = store.create_node(&["Person"]);
289        let bob = store.create_node(&["Person"]);
290        let charlie = store.create_node(&["Person"]);
291
292        // Create edges: Alice -> Bob, Alice -> Charlie
293        store.create_edge(alice, bob, "KNOWS");
294        store.create_edge(alice, charlie, "KNOWS");
295
296        // Scan Alice only
297        let scan = Box::new(ScanOperator::with_label(Arc::clone(&store), "Person"));
298
299        let mut expand = ExpandOperator::new(
300            Arc::clone(&store),
301            scan,
302            0, // source column
303            Direction::Outgoing,
304            None,
305        );
306
307        // Collect all results
308        let mut results = Vec::new();
309        while let Ok(Some(chunk)) = expand.next() {
310            for i in 0..chunk.row_count() {
311                let src = chunk.column(0).unwrap().get_node_id(i).unwrap();
312                let edge = chunk.column(1).unwrap().get_edge_id(i).unwrap();
313                let dst = chunk.column(2).unwrap().get_node_id(i).unwrap();
314                results.push((src, edge, dst));
315            }
316        }
317
318        // Alice -> Bob, Alice -> Charlie
319        assert_eq!(results.len(), 2);
320
321        // All source nodes should be Alice
322        for (src, _, _) in &results {
323            assert_eq!(*src, alice);
324        }
325
326        // Target nodes should be Bob and Charlie
327        let targets: Vec<NodeId> = results.iter().map(|(_, _, dst)| *dst).collect();
328        assert!(targets.contains(&bob));
329        assert!(targets.contains(&charlie));
330    }
331
332    #[test]
333    fn test_expand_with_edge_type_filter() {
334        let store = Arc::new(LpgStore::new());
335
336        let alice = store.create_node(&["Person"]);
337        let bob = store.create_node(&["Person"]);
338        let company = store.create_node(&["Company"]);
339
340        store.create_edge(alice, bob, "KNOWS");
341        store.create_edge(alice, company, "WORKS_AT");
342
343        let scan = Box::new(ScanOperator::with_label(Arc::clone(&store), "Person"));
344
345        let mut expand = ExpandOperator::new(
346            Arc::clone(&store),
347            scan,
348            0,
349            Direction::Outgoing,
350            Some("KNOWS".to_string()),
351        );
352
353        let mut results = Vec::new();
354        while let Ok(Some(chunk)) = expand.next() {
355            for i in 0..chunk.row_count() {
356                let dst = chunk.column(2).unwrap().get_node_id(i).unwrap();
357                results.push(dst);
358            }
359        }
360
361        // Only KNOWS edges should be followed
362        assert_eq!(results.len(), 1);
363        assert_eq!(results[0], bob);
364    }
365
366    #[test]
367    fn test_expand_incoming() {
368        let store = Arc::new(LpgStore::new());
369
370        let alice = store.create_node(&["Person"]);
371        let bob = store.create_node(&["Person"]);
372
373        store.create_edge(alice, bob, "KNOWS");
374
375        // Scan Bob
376        let scan = Box::new(ScanOperator::with_label(Arc::clone(&store), "Person"));
377
378        let mut expand =
379            ExpandOperator::new(Arc::clone(&store), scan, 0, Direction::Incoming, None);
380
381        let mut results = Vec::new();
382        while let Ok(Some(chunk)) = expand.next() {
383            for i in 0..chunk.row_count() {
384                let src = chunk.column(0).unwrap().get_node_id(i).unwrap();
385                let dst = chunk.column(2).unwrap().get_node_id(i).unwrap();
386                results.push((src, dst));
387            }
388        }
389
390        // Bob <- Alice (Bob's incoming edge from Alice)
391        assert_eq!(results.len(), 1);
392        assert_eq!(results[0].0, bob); // source in the expand is Bob
393        assert_eq!(results[0].1, alice); // target is Alice (who points to Bob)
394    }
395
396    #[test]
397    fn test_expand_no_edges() {
398        let store = Arc::new(LpgStore::new());
399
400        store.create_node(&["Person"]);
401
402        let scan = Box::new(ScanOperator::with_label(Arc::clone(&store), "Person"));
403
404        let mut expand =
405            ExpandOperator::new(Arc::clone(&store), scan, 0, Direction::Outgoing, None);
406
407        let result = expand.next().unwrap();
408        assert!(result.is_none());
409    }
410}