Skip to main content

grafeo_core/execution/operators/
project.rs

1//! Project operator for selecting and transforming columns.
2
3use super::filter::{ExpressionPredicate, FilterExpression};
4use super::{Operator, OperatorError, OperatorResult};
5use crate::execution::DataChunk;
6use crate::graph::lpg::LpgStore;
7use grafeo_common::types::{LogicalType, Value};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// A projection expression.
12pub enum ProjectExpr {
13    /// Reference to an input column.
14    Column(usize),
15    /// A constant value.
16    Constant(Value),
17    /// Property access on a node/edge column.
18    PropertyAccess {
19        /// The column containing the node or edge ID.
20        column: usize,
21        /// The property name to access.
22        property: String,
23    },
24    /// Edge type accessor (for type(r) function).
25    EdgeType {
26        /// The column containing the edge ID.
27        column: usize,
28    },
29    /// Full expression evaluation (for CASE WHEN, etc.).
30    Expression {
31        /// The filter expression to evaluate.
32        expr: FilterExpression,
33        /// Variable name to column index mapping.
34        variable_columns: HashMap<String, usize>,
35    },
36}
37
38/// A project operator that selects and transforms columns.
39pub struct ProjectOperator {
40    /// Child operator to read from.
41    child: Box<dyn Operator>,
42    /// Projection expressions.
43    projections: Vec<ProjectExpr>,
44    /// Output column types.
45    output_types: Vec<LogicalType>,
46    /// Optional store for property access.
47    store: Option<Arc<LpgStore>>,
48}
49
50impl ProjectOperator {
51    /// Creates a new project operator.
52    pub fn new(
53        child: Box<dyn Operator>,
54        projections: Vec<ProjectExpr>,
55        output_types: Vec<LogicalType>,
56    ) -> Self {
57        assert_eq!(projections.len(), output_types.len());
58        Self {
59            child,
60            projections,
61            output_types,
62            store: None,
63        }
64    }
65
66    /// Creates a new project operator with store access for property lookups.
67    pub fn with_store(
68        child: Box<dyn Operator>,
69        projections: Vec<ProjectExpr>,
70        output_types: Vec<LogicalType>,
71        store: Arc<LpgStore>,
72    ) -> Self {
73        assert_eq!(projections.len(), output_types.len());
74        Self {
75            child,
76            projections,
77            output_types,
78            store: Some(store),
79        }
80    }
81
82    /// Creates a project operator that selects specific columns.
83    pub fn select_columns(
84        child: Box<dyn Operator>,
85        columns: Vec<usize>,
86        types: Vec<LogicalType>,
87    ) -> Self {
88        let projections = columns.into_iter().map(ProjectExpr::Column).collect();
89        Self::new(child, projections, types)
90    }
91}
92
93impl Operator for ProjectOperator {
94    fn next(&mut self) -> OperatorResult {
95        // Get next chunk from child
96        let Some(input) = self.child.next()? else {
97            return Ok(None);
98        };
99
100        // Create output chunk
101        let mut output = DataChunk::with_capacity(&self.output_types, input.row_count());
102
103        // Evaluate each projection
104        for (i, proj) in self.projections.iter().enumerate() {
105            match proj {
106                ProjectExpr::Column(col_idx) => {
107                    // Copy column from input to output
108                    let input_col = input.column(*col_idx).ok_or_else(|| {
109                        OperatorError::ColumnNotFound(format!("Column {col_idx}"))
110                    })?;
111
112                    let output_col = output.column_mut(i).unwrap();
113
114                    // Copy selected rows
115                    for row in input.selected_indices() {
116                        if let Some(value) = input_col.get_value(row) {
117                            output_col.push_value(value);
118                        }
119                    }
120                }
121                ProjectExpr::Constant(value) => {
122                    // Push constant for each row
123                    let output_col = output.column_mut(i).unwrap();
124                    for _ in input.selected_indices() {
125                        output_col.push_value(value.clone());
126                    }
127                }
128                ProjectExpr::PropertyAccess { column, property } => {
129                    // Access property from node/edge in the specified column
130                    let input_col = input
131                        .column(*column)
132                        .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
133
134                    let output_col = output.column_mut(i).unwrap();
135
136                    let store = self.store.as_ref().ok_or_else(|| {
137                        OperatorError::Execution("Store required for property access".to_string())
138                    })?;
139
140                    // Extract property for each row
141                    for row in input.selected_indices() {
142                        // Try to get node ID first, then edge ID
143                        let value = if let Some(node_id) = input_col.get_node_id(row) {
144                            store
145                                .get_node(node_id)
146                                .and_then(|node| node.get_property(property).cloned())
147                                .unwrap_or(Value::Null)
148                        } else if let Some(edge_id) = input_col.get_edge_id(row) {
149                            store
150                                .get_edge(edge_id)
151                                .and_then(|edge| edge.get_property(property).cloned())
152                                .unwrap_or(Value::Null)
153                        } else {
154                            Value::Null
155                        };
156                        output_col.push_value(value);
157                    }
158                }
159                ProjectExpr::EdgeType { column } => {
160                    // Get edge type string from an edge column
161                    let input_col = input
162                        .column(*column)
163                        .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
164
165                    let output_col = output.column_mut(i).unwrap();
166
167                    let store = self.store.as_ref().ok_or_else(|| {
168                        OperatorError::Execution("Store required for edge type access".to_string())
169                    })?;
170
171                    for row in input.selected_indices() {
172                        let value = if let Some(edge_id) = input_col.get_edge_id(row) {
173                            store
174                                .edge_type(edge_id)
175                                .map(Value::String)
176                                .unwrap_or(Value::Null)
177                        } else {
178                            Value::Null
179                        };
180                        output_col.push_value(value);
181                    }
182                }
183                ProjectExpr::Expression {
184                    expr,
185                    variable_columns,
186                } => {
187                    let output_col = output.column_mut(i).unwrap();
188
189                    let store = self.store.as_ref().ok_or_else(|| {
190                        OperatorError::Execution(
191                            "Store required for expression evaluation".to_string(),
192                        )
193                    })?;
194
195                    // Use the ExpressionPredicate for expression evaluation
196                    let evaluator = ExpressionPredicate::new(
197                        expr.clone(),
198                        variable_columns.clone(),
199                        Arc::clone(store),
200                    );
201
202                    for row in input.selected_indices() {
203                        let value = evaluator.eval_at(&input, row).unwrap_or(Value::Null);
204                        output_col.push_value(value);
205                    }
206                }
207            }
208        }
209
210        output.set_count(input.row_count());
211        Ok(Some(output))
212    }
213
214    fn reset(&mut self) {
215        self.child.reset();
216    }
217
218    fn name(&self) -> &'static str {
219        "Project"
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::execution::chunk::DataChunkBuilder;
227    use grafeo_common::types::Value;
228
229    struct MockScanOperator {
230        chunks: Vec<DataChunk>,
231        position: usize,
232    }
233
234    impl Operator for MockScanOperator {
235        fn next(&mut self) -> OperatorResult {
236            if self.position < self.chunks.len() {
237                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
238                self.position += 1;
239                Ok(Some(chunk))
240            } else {
241                Ok(None)
242            }
243        }
244
245        fn reset(&mut self) {
246            self.position = 0;
247        }
248
249        fn name(&self) -> &'static str {
250            "MockScan"
251        }
252    }
253
254    #[test]
255    fn test_project_select_columns() {
256        // Create input with 3 columns: [int, string, int]
257        let mut builder =
258            DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String, LogicalType::Int64]);
259
260        builder.column_mut(0).unwrap().push_int64(1);
261        builder.column_mut(1).unwrap().push_string("hello");
262        builder.column_mut(2).unwrap().push_int64(100);
263        builder.advance_row();
264
265        builder.column_mut(0).unwrap().push_int64(2);
266        builder.column_mut(1).unwrap().push_string("world");
267        builder.column_mut(2).unwrap().push_int64(200);
268        builder.advance_row();
269
270        let chunk = builder.finish();
271
272        let mock_scan = MockScanOperator {
273            chunks: vec![chunk],
274            position: 0,
275        };
276
277        // Project to select columns 2 and 0 (reordering)
278        let mut project = ProjectOperator::select_columns(
279            Box::new(mock_scan),
280            vec![2, 0],
281            vec![LogicalType::Int64, LogicalType::Int64],
282        );
283
284        let result = project.next().unwrap().unwrap();
285
286        assert_eq!(result.column_count(), 2);
287        assert_eq!(result.row_count(), 2);
288
289        // Check values are reordered
290        assert_eq!(result.column(0).unwrap().get_int64(0), Some(100));
291        assert_eq!(result.column(1).unwrap().get_int64(0), Some(1));
292    }
293
294    #[test]
295    fn test_project_constant() {
296        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
297        builder.column_mut(0).unwrap().push_int64(1);
298        builder.advance_row();
299        builder.column_mut(0).unwrap().push_int64(2);
300        builder.advance_row();
301
302        let chunk = builder.finish();
303
304        let mock_scan = MockScanOperator {
305            chunks: vec![chunk],
306            position: 0,
307        };
308
309        // Project with a constant
310        let mut project = ProjectOperator::new(
311            Box::new(mock_scan),
312            vec![
313                ProjectExpr::Column(0),
314                ProjectExpr::Constant(Value::String("constant".into())),
315            ],
316            vec![LogicalType::Int64, LogicalType::String],
317        );
318
319        let result = project.next().unwrap().unwrap();
320
321        assert_eq!(result.column_count(), 2);
322        assert_eq!(result.column(1).unwrap().get_string(0), Some("constant"));
323        assert_eq!(result.column(1).unwrap().get_string(1), Some("constant"));
324    }
325}