Skip to main content

grafeo_core/execution/operators/
project.rs

1//! Project operator for selecting and transforming columns.
2
3use super::filter::{ExpressionPredicate, FilterExpression};
4use super::{Operator, OperatorError, OperatorResult};
5use crate::execution::DataChunk;
6use crate::graph::lpg::LpgStore;
7use grafeo_common::types::{LogicalType, Value};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// A projection expression.
12pub enum ProjectExpr {
13    /// Reference to an input column.
14    Column(usize),
15    /// A constant value.
16    Constant(Value),
17    /// Property access on a node/edge column.
18    PropertyAccess {
19        /// The column containing the node or edge ID.
20        column: usize,
21        /// The property name to access.
22        property: String,
23    },
24    /// Edge type accessor (for type(r) function).
25    EdgeType {
26        /// The column containing the edge ID.
27        column: usize,
28    },
29    /// Full expression evaluation (for CASE WHEN, etc.).
30    Expression {
31        /// The filter expression to evaluate.
32        expr: FilterExpression,
33        /// Variable name to column index mapping.
34        variable_columns: HashMap<String, usize>,
35    },
36}
37
38/// A project operator that selects and transforms columns.
39pub struct ProjectOperator {
40    /// Child operator to read from.
41    child: Box<dyn Operator>,
42    /// Projection expressions.
43    projections: Vec<ProjectExpr>,
44    /// Output column types.
45    output_types: Vec<LogicalType>,
46    /// Optional store for property access.
47    store: Option<Arc<LpgStore>>,
48}
49
50impl ProjectOperator {
51    /// Creates a new project operator.
52    pub fn new(
53        child: Box<dyn Operator>,
54        projections: Vec<ProjectExpr>,
55        output_types: Vec<LogicalType>,
56    ) -> Self {
57        assert_eq!(projections.len(), output_types.len());
58        Self {
59            child,
60            projections,
61            output_types,
62            store: None,
63        }
64    }
65
66    /// Creates a new project operator with store access for property lookups.
67    pub fn with_store(
68        child: Box<dyn Operator>,
69        projections: Vec<ProjectExpr>,
70        output_types: Vec<LogicalType>,
71        store: Arc<LpgStore>,
72    ) -> Self {
73        assert_eq!(projections.len(), output_types.len());
74        Self {
75            child,
76            projections,
77            output_types,
78            store: Some(store),
79        }
80    }
81
82    /// Creates a project operator that selects specific columns.
83    pub fn select_columns(
84        child: Box<dyn Operator>,
85        columns: Vec<usize>,
86        types: Vec<LogicalType>,
87    ) -> Self {
88        let projections = columns.into_iter().map(ProjectExpr::Column).collect();
89        Self::new(child, projections, types)
90    }
91}
92
93impl Operator for ProjectOperator {
94    fn next(&mut self) -> OperatorResult {
95        // Get next chunk from child
96        let Some(input) = self.child.next()? else {
97            return Ok(None);
98        };
99
100        // Create output chunk
101        let mut output = DataChunk::with_capacity(&self.output_types, input.row_count());
102
103        // Evaluate each projection
104        for (i, proj) in self.projections.iter().enumerate() {
105            match proj {
106                ProjectExpr::Column(col_idx) => {
107                    // Copy column from input to output
108                    let input_col = input.column(*col_idx).ok_or_else(|| {
109                        OperatorError::ColumnNotFound(format!("Column {col_idx}"))
110                    })?;
111
112                    let output_col = output.column_mut(i).unwrap();
113
114                    // Copy selected rows
115                    for row in input.selected_indices() {
116                        if let Some(value) = input_col.get_value(row) {
117                            output_col.push_value(value);
118                        }
119                    }
120                }
121                ProjectExpr::Constant(value) => {
122                    // Push constant for each row
123                    let output_col = output.column_mut(i).unwrap();
124                    for _ in input.selected_indices() {
125                        output_col.push_value(value.clone());
126                    }
127                }
128                ProjectExpr::PropertyAccess { column, property } => {
129                    // Access property from node/edge in the specified column
130                    let input_col = input
131                        .column(*column)
132                        .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
133
134                    let output_col = output.column_mut(i).unwrap();
135
136                    let store = self.store.as_ref().ok_or_else(|| {
137                        OperatorError::Execution("Store required for property access".to_string())
138                    })?;
139
140                    // Extract property for each row
141                    for row in input.selected_indices() {
142                        // Try to get node ID first, then edge ID
143                        let value = if let Some(node_id) = input_col.get_node_id(row) {
144                            store
145                                .get_node(node_id)
146                                .and_then(|node| node.get_property(property).cloned())
147                                .unwrap_or(Value::Null)
148                        } else if let Some(edge_id) = input_col.get_edge_id(row) {
149                            store
150                                .get_edge(edge_id)
151                                .and_then(|edge| edge.get_property(property).cloned())
152                                .unwrap_or(Value::Null)
153                        } else {
154                            Value::Null
155                        };
156                        output_col.push_value(value);
157                    }
158                }
159                ProjectExpr::EdgeType { column } => {
160                    // Get edge type string from an edge column
161                    let input_col = input
162                        .column(*column)
163                        .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
164
165                    let output_col = output.column_mut(i).unwrap();
166
167                    let store = self.store.as_ref().ok_or_else(|| {
168                        OperatorError::Execution("Store required for edge type access".to_string())
169                    })?;
170
171                    for row in input.selected_indices() {
172                        let value = if let Some(edge_id) = input_col.get_edge_id(row) {
173                            store.edge_type(edge_id).map_or(Value::Null, Value::String)
174                        } else {
175                            Value::Null
176                        };
177                        output_col.push_value(value);
178                    }
179                }
180                ProjectExpr::Expression {
181                    expr,
182                    variable_columns,
183                } => {
184                    let output_col = output.column_mut(i).unwrap();
185
186                    let store = self.store.as_ref().ok_or_else(|| {
187                        OperatorError::Execution(
188                            "Store required for expression evaluation".to_string(),
189                        )
190                    })?;
191
192                    // Use the ExpressionPredicate for expression evaluation
193                    let evaluator = ExpressionPredicate::new(
194                        expr.clone(),
195                        variable_columns.clone(),
196                        Arc::clone(store),
197                    );
198
199                    for row in input.selected_indices() {
200                        let value = evaluator.eval_at(&input, row).unwrap_or(Value::Null);
201                        output_col.push_value(value);
202                    }
203                }
204            }
205        }
206
207        output.set_count(input.row_count());
208        Ok(Some(output))
209    }
210
211    fn reset(&mut self) {
212        self.child.reset();
213    }
214
215    fn name(&self) -> &'static str {
216        "Project"
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::execution::chunk::DataChunkBuilder;
224    use grafeo_common::types::Value;
225
226    struct MockScanOperator {
227        chunks: Vec<DataChunk>,
228        position: usize,
229    }
230
231    impl Operator for MockScanOperator {
232        fn next(&mut self) -> OperatorResult {
233            if self.position < self.chunks.len() {
234                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
235                self.position += 1;
236                Ok(Some(chunk))
237            } else {
238                Ok(None)
239            }
240        }
241
242        fn reset(&mut self) {
243            self.position = 0;
244        }
245
246        fn name(&self) -> &'static str {
247            "MockScan"
248        }
249    }
250
251    #[test]
252    fn test_project_select_columns() {
253        // Create input with 3 columns: [int, string, int]
254        let mut builder =
255            DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String, LogicalType::Int64]);
256
257        builder.column_mut(0).unwrap().push_int64(1);
258        builder.column_mut(1).unwrap().push_string("hello");
259        builder.column_mut(2).unwrap().push_int64(100);
260        builder.advance_row();
261
262        builder.column_mut(0).unwrap().push_int64(2);
263        builder.column_mut(1).unwrap().push_string("world");
264        builder.column_mut(2).unwrap().push_int64(200);
265        builder.advance_row();
266
267        let chunk = builder.finish();
268
269        let mock_scan = MockScanOperator {
270            chunks: vec![chunk],
271            position: 0,
272        };
273
274        // Project to select columns 2 and 0 (reordering)
275        let mut project = ProjectOperator::select_columns(
276            Box::new(mock_scan),
277            vec![2, 0],
278            vec![LogicalType::Int64, LogicalType::Int64],
279        );
280
281        let result = project.next().unwrap().unwrap();
282
283        assert_eq!(result.column_count(), 2);
284        assert_eq!(result.row_count(), 2);
285
286        // Check values are reordered
287        assert_eq!(result.column(0).unwrap().get_int64(0), Some(100));
288        assert_eq!(result.column(1).unwrap().get_int64(0), Some(1));
289    }
290
291    #[test]
292    fn test_project_constant() {
293        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
294        builder.column_mut(0).unwrap().push_int64(1);
295        builder.advance_row();
296        builder.column_mut(0).unwrap().push_int64(2);
297        builder.advance_row();
298
299        let chunk = builder.finish();
300
301        let mock_scan = MockScanOperator {
302            chunks: vec![chunk],
303            position: 0,
304        };
305
306        // Project with a constant
307        let mut project = ProjectOperator::new(
308            Box::new(mock_scan),
309            vec![
310                ProjectExpr::Column(0),
311                ProjectExpr::Constant(Value::String("constant".into())),
312            ],
313            vec![LogicalType::Int64, LogicalType::String],
314        );
315
316        let result = project.next().unwrap().unwrap();
317
318        assert_eq!(result.column_count(), 2);
319        assert_eq!(result.column(1).unwrap().get_string(0), Some("constant"));
320        assert_eq!(result.column(1).unwrap().get_string(1), Some("constant"));
321    }
322}