Skip to main content

grafeo_core/execution/operators/
project.rs

1//! Project operator for selecting and transforming columns.
2
3use super::filter::{ExpressionPredicate, FilterExpression};
4use super::{Operator, OperatorError, OperatorResult};
5use crate::execution::DataChunk;
6use crate::graph::lpg::LpgStore;
7use grafeo_common::types::{LogicalType, Value};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// A projection expression.
12pub enum ProjectExpr {
13    /// Reference to an input column.
14    Column(usize),
15    /// A constant value.
16    Constant(Value),
17    /// Property access on a node/edge column.
18    PropertyAccess {
19        /// The column containing the node or edge ID.
20        column: usize,
21        /// The property name to access.
22        property: String,
23    },
24    /// Edge type accessor (for type(r) function).
25    EdgeType {
26        /// The column containing the edge ID.
27        column: usize,
28    },
29    /// Full expression evaluation (for CASE WHEN, etc.).
30    Expression {
31        /// The filter expression to evaluate.
32        expr: FilterExpression,
33        /// Variable name to column index mapping.
34        variable_columns: HashMap<String, usize>,
35    },
36}
37
38/// A project operator that selects and transforms columns.
39pub struct ProjectOperator {
40    /// Child operator to read from.
41    child: Box<dyn Operator>,
42    /// Projection expressions.
43    projections: Vec<ProjectExpr>,
44    /// Output column types.
45    output_types: Vec<LogicalType>,
46    /// Optional store for property access.
47    store: Option<Arc<LpgStore>>,
48}
49
50impl ProjectOperator {
51    /// Creates a new project operator.
52    pub fn new(
53        child: Box<dyn Operator>,
54        projections: Vec<ProjectExpr>,
55        output_types: Vec<LogicalType>,
56    ) -> Self {
57        assert_eq!(projections.len(), output_types.len());
58        Self {
59            child,
60            projections,
61            output_types,
62            store: None,
63        }
64    }
65
66    /// Creates a new project operator with store access for property lookups.
67    pub fn with_store(
68        child: Box<dyn Operator>,
69        projections: Vec<ProjectExpr>,
70        output_types: Vec<LogicalType>,
71        store: Arc<LpgStore>,
72    ) -> Self {
73        assert_eq!(projections.len(), output_types.len());
74        Self {
75            child,
76            projections,
77            output_types,
78            store: Some(store),
79        }
80    }
81
82    /// Creates a project operator that selects specific columns.
83    pub fn select_columns(
84        child: Box<dyn Operator>,
85        columns: Vec<usize>,
86        types: Vec<LogicalType>,
87    ) -> Self {
88        let projections = columns.into_iter().map(ProjectExpr::Column).collect();
89        Self::new(child, projections, types)
90    }
91}
92
93impl Operator for ProjectOperator {
94    fn next(&mut self) -> OperatorResult {
95        // Get next chunk from child
96        let input = match self.child.next()? {
97            Some(c) => c,
98            None => return Ok(None),
99        };
100
101        // Create output chunk
102        let mut output = DataChunk::with_capacity(&self.output_types, input.row_count());
103
104        // Evaluate each projection
105        for (i, proj) in self.projections.iter().enumerate() {
106            match proj {
107                ProjectExpr::Column(col_idx) => {
108                    // Copy column from input to output
109                    let input_col = input.column(*col_idx).ok_or_else(|| {
110                        OperatorError::ColumnNotFound(format!("Column {col_idx}"))
111                    })?;
112
113                    let output_col = output.column_mut(i).unwrap();
114
115                    // Copy selected rows
116                    for row in input.selected_indices() {
117                        if let Some(value) = input_col.get_value(row) {
118                            output_col.push_value(value);
119                        }
120                    }
121                }
122                ProjectExpr::Constant(value) => {
123                    // Push constant for each row
124                    let output_col = output.column_mut(i).unwrap();
125                    for _ in input.selected_indices() {
126                        output_col.push_value(value.clone());
127                    }
128                }
129                ProjectExpr::PropertyAccess { column, property } => {
130                    // Access property from node/edge in the specified column
131                    let input_col = input
132                        .column(*column)
133                        .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
134
135                    let output_col = output.column_mut(i).unwrap();
136
137                    let store = self.store.as_ref().ok_or_else(|| {
138                        OperatorError::Execution("Store required for property access".to_string())
139                    })?;
140
141                    // Extract property for each row
142                    for row in input.selected_indices() {
143                        // Try to get node ID first, then edge ID
144                        let value = if let Some(node_id) = input_col.get_node_id(row) {
145                            store
146                                .get_node(node_id)
147                                .and_then(|node| node.get_property(property).cloned())
148                                .unwrap_or(Value::Null)
149                        } else if let Some(edge_id) = input_col.get_edge_id(row) {
150                            store
151                                .get_edge(edge_id)
152                                .and_then(|edge| edge.get_property(property).cloned())
153                                .unwrap_or(Value::Null)
154                        } else {
155                            Value::Null
156                        };
157                        output_col.push_value(value);
158                    }
159                }
160                ProjectExpr::EdgeType { column } => {
161                    // Get edge type string from an edge column
162                    let input_col = input
163                        .column(*column)
164                        .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
165
166                    let output_col = output.column_mut(i).unwrap();
167
168                    let store = self.store.as_ref().ok_or_else(|| {
169                        OperatorError::Execution("Store required for edge type access".to_string())
170                    })?;
171
172                    for row in input.selected_indices() {
173                        let value = if let Some(edge_id) = input_col.get_edge_id(row) {
174                            store
175                                .edge_type(edge_id)
176                                .map(Value::String)
177                                .unwrap_or(Value::Null)
178                        } else {
179                            Value::Null
180                        };
181                        output_col.push_value(value);
182                    }
183                }
184                ProjectExpr::Expression {
185                    expr,
186                    variable_columns,
187                } => {
188                    let output_col = output.column_mut(i).unwrap();
189
190                    let store = self.store.as_ref().ok_or_else(|| {
191                        OperatorError::Execution(
192                            "Store required for expression evaluation".to_string(),
193                        )
194                    })?;
195
196                    // Use the ExpressionPredicate for expression evaluation
197                    let evaluator = ExpressionPredicate::new(
198                        expr.clone(),
199                        variable_columns.clone(),
200                        Arc::clone(store),
201                    );
202
203                    for row in input.selected_indices() {
204                        let value = evaluator.eval_at(&input, row).unwrap_or(Value::Null);
205                        output_col.push_value(value);
206                    }
207                }
208            }
209        }
210
211        output.set_count(input.row_count());
212        Ok(Some(output))
213    }
214
215    fn reset(&mut self) {
216        self.child.reset();
217    }
218
219    fn name(&self) -> &'static str {
220        "Project"
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::execution::chunk::DataChunkBuilder;
228    use grafeo_common::types::Value;
229
230    struct MockScanOperator {
231        chunks: Vec<DataChunk>,
232        position: usize,
233    }
234
235    impl Operator for MockScanOperator {
236        fn next(&mut self) -> OperatorResult {
237            if self.position < self.chunks.len() {
238                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
239                self.position += 1;
240                Ok(Some(chunk))
241            } else {
242                Ok(None)
243            }
244        }
245
246        fn reset(&mut self) {
247            self.position = 0;
248        }
249
250        fn name(&self) -> &'static str {
251            "MockScan"
252        }
253    }
254
255    #[test]
256    fn test_project_select_columns() {
257        // Create input with 3 columns: [int, string, int]
258        let mut builder =
259            DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String, LogicalType::Int64]);
260
261        builder.column_mut(0).unwrap().push_int64(1);
262        builder.column_mut(1).unwrap().push_string("hello");
263        builder.column_mut(2).unwrap().push_int64(100);
264        builder.advance_row();
265
266        builder.column_mut(0).unwrap().push_int64(2);
267        builder.column_mut(1).unwrap().push_string("world");
268        builder.column_mut(2).unwrap().push_int64(200);
269        builder.advance_row();
270
271        let chunk = builder.finish();
272
273        let mock_scan = MockScanOperator {
274            chunks: vec![chunk],
275            position: 0,
276        };
277
278        // Project to select columns 2 and 0 (reordering)
279        let mut project = ProjectOperator::select_columns(
280            Box::new(mock_scan),
281            vec![2, 0],
282            vec![LogicalType::Int64, LogicalType::Int64],
283        );
284
285        let result = project.next().unwrap().unwrap();
286
287        assert_eq!(result.column_count(), 2);
288        assert_eq!(result.row_count(), 2);
289
290        // Check values are reordered
291        assert_eq!(result.column(0).unwrap().get_int64(0), Some(100));
292        assert_eq!(result.column(1).unwrap().get_int64(0), Some(1));
293    }
294
295    #[test]
296    fn test_project_constant() {
297        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
298        builder.column_mut(0).unwrap().push_int64(1);
299        builder.advance_row();
300        builder.column_mut(0).unwrap().push_int64(2);
301        builder.advance_row();
302
303        let chunk = builder.finish();
304
305        let mock_scan = MockScanOperator {
306            chunks: vec![chunk],
307            position: 0,
308        };
309
310        // Project with a constant
311        let mut project = ProjectOperator::new(
312            Box::new(mock_scan),
313            vec![
314                ProjectExpr::Column(0),
315                ProjectExpr::Constant(Value::String("constant".into())),
316            ],
317            vec![LogicalType::Int64, LogicalType::String],
318        );
319
320        let result = project.next().unwrap().unwrap();
321
322        assert_eq!(result.column_count(), 2);
323        assert_eq!(result.column(1).unwrap().get_string(0), Some("constant"));
324        assert_eq!(result.column(1).unwrap().get_string(1), Some("constant"));
325    }
326}