grafeo_core/execution/operators/
project.rs1use super::filter::{ExpressionPredicate, FilterExpression};
4use super::{Operator, OperatorError, OperatorResult};
5use crate::execution::DataChunk;
6use crate::graph::lpg::LpgStore;
7use grafeo_common::types::{LogicalType, Value};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub enum ProjectExpr {
13 Column(usize),
15 Constant(Value),
17 PropertyAccess {
19 column: usize,
21 property: String,
23 },
24 EdgeType {
26 column: usize,
28 },
29 Expression {
31 expr: FilterExpression,
33 variable_columns: HashMap<String, usize>,
35 },
36}
37
38pub struct ProjectOperator {
40 child: Box<dyn Operator>,
42 projections: Vec<ProjectExpr>,
44 output_types: Vec<LogicalType>,
46 store: Option<Arc<LpgStore>>,
48}
49
50impl ProjectOperator {
51 pub fn new(
53 child: Box<dyn Operator>,
54 projections: Vec<ProjectExpr>,
55 output_types: Vec<LogicalType>,
56 ) -> Self {
57 assert_eq!(projections.len(), output_types.len());
58 Self {
59 child,
60 projections,
61 output_types,
62 store: None,
63 }
64 }
65
66 pub fn with_store(
68 child: Box<dyn Operator>,
69 projections: Vec<ProjectExpr>,
70 output_types: Vec<LogicalType>,
71 store: Arc<LpgStore>,
72 ) -> Self {
73 assert_eq!(projections.len(), output_types.len());
74 Self {
75 child,
76 projections,
77 output_types,
78 store: Some(store),
79 }
80 }
81
82 pub fn select_columns(
84 child: Box<dyn Operator>,
85 columns: Vec<usize>,
86 types: Vec<LogicalType>,
87 ) -> Self {
88 let projections = columns.into_iter().map(ProjectExpr::Column).collect();
89 Self::new(child, projections, types)
90 }
91}
92
93impl Operator for ProjectOperator {
94 fn next(&mut self) -> OperatorResult {
95 let Some(input) = self.child.next()? else {
97 return Ok(None);
98 };
99
100 let mut output = DataChunk::with_capacity(&self.output_types, input.row_count());
102
103 for (i, proj) in self.projections.iter().enumerate() {
105 match proj {
106 ProjectExpr::Column(col_idx) => {
107 let input_col = input.column(*col_idx).ok_or_else(|| {
109 OperatorError::ColumnNotFound(format!("Column {col_idx}"))
110 })?;
111
112 let output_col = output.column_mut(i).unwrap();
113
114 for row in input.selected_indices() {
116 if let Some(value) = input_col.get_value(row) {
117 output_col.push_value(value);
118 }
119 }
120 }
121 ProjectExpr::Constant(value) => {
122 let output_col = output.column_mut(i).unwrap();
124 for _ in input.selected_indices() {
125 output_col.push_value(value.clone());
126 }
127 }
128 ProjectExpr::PropertyAccess { column, property } => {
129 let input_col = input
131 .column(*column)
132 .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
133
134 let output_col = output.column_mut(i).unwrap();
135
136 let store = self.store.as_ref().ok_or_else(|| {
137 OperatorError::Execution("Store required for property access".to_string())
138 })?;
139
140 for row in input.selected_indices() {
142 let value = if let Some(node_id) = input_col.get_node_id(row) {
144 store
145 .get_node(node_id)
146 .and_then(|node| node.get_property(property).cloned())
147 .unwrap_or(Value::Null)
148 } else if let Some(edge_id) = input_col.get_edge_id(row) {
149 store
150 .get_edge(edge_id)
151 .and_then(|edge| edge.get_property(property).cloned())
152 .unwrap_or(Value::Null)
153 } else {
154 Value::Null
155 };
156 output_col.push_value(value);
157 }
158 }
159 ProjectExpr::EdgeType { column } => {
160 let input_col = input
162 .column(*column)
163 .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
164
165 let output_col = output.column_mut(i).unwrap();
166
167 let store = self.store.as_ref().ok_or_else(|| {
168 OperatorError::Execution("Store required for edge type access".to_string())
169 })?;
170
171 for row in input.selected_indices() {
172 let value = if let Some(edge_id) = input_col.get_edge_id(row) {
173 store.edge_type(edge_id).map_or(Value::Null, Value::String)
174 } else {
175 Value::Null
176 };
177 output_col.push_value(value);
178 }
179 }
180 ProjectExpr::Expression {
181 expr,
182 variable_columns,
183 } => {
184 let output_col = output.column_mut(i).unwrap();
185
186 let store = self.store.as_ref().ok_or_else(|| {
187 OperatorError::Execution(
188 "Store required for expression evaluation".to_string(),
189 )
190 })?;
191
192 let evaluator = ExpressionPredicate::new(
194 expr.clone(),
195 variable_columns.clone(),
196 Arc::clone(store),
197 );
198
199 for row in input.selected_indices() {
200 let value = evaluator.eval_at(&input, row).unwrap_or(Value::Null);
201 output_col.push_value(value);
202 }
203 }
204 }
205 }
206
207 output.set_count(input.row_count());
208 Ok(Some(output))
209 }
210
211 fn reset(&mut self) {
212 self.child.reset();
213 }
214
215 fn name(&self) -> &'static str {
216 "Project"
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::execution::chunk::DataChunkBuilder;
224 use grafeo_common::types::Value;
225
226 struct MockScanOperator {
227 chunks: Vec<DataChunk>,
228 position: usize,
229 }
230
231 impl Operator for MockScanOperator {
232 fn next(&mut self) -> OperatorResult {
233 if self.position < self.chunks.len() {
234 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
235 self.position += 1;
236 Ok(Some(chunk))
237 } else {
238 Ok(None)
239 }
240 }
241
242 fn reset(&mut self) {
243 self.position = 0;
244 }
245
246 fn name(&self) -> &'static str {
247 "MockScan"
248 }
249 }
250
251 #[test]
252 fn test_project_select_columns() {
253 let mut builder =
255 DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String, LogicalType::Int64]);
256
257 builder.column_mut(0).unwrap().push_int64(1);
258 builder.column_mut(1).unwrap().push_string("hello");
259 builder.column_mut(2).unwrap().push_int64(100);
260 builder.advance_row();
261
262 builder.column_mut(0).unwrap().push_int64(2);
263 builder.column_mut(1).unwrap().push_string("world");
264 builder.column_mut(2).unwrap().push_int64(200);
265 builder.advance_row();
266
267 let chunk = builder.finish();
268
269 let mock_scan = MockScanOperator {
270 chunks: vec![chunk],
271 position: 0,
272 };
273
274 let mut project = ProjectOperator::select_columns(
276 Box::new(mock_scan),
277 vec![2, 0],
278 vec![LogicalType::Int64, LogicalType::Int64],
279 );
280
281 let result = project.next().unwrap().unwrap();
282
283 assert_eq!(result.column_count(), 2);
284 assert_eq!(result.row_count(), 2);
285
286 assert_eq!(result.column(0).unwrap().get_int64(0), Some(100));
288 assert_eq!(result.column(1).unwrap().get_int64(0), Some(1));
289 }
290
291 #[test]
292 fn test_project_constant() {
293 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
294 builder.column_mut(0).unwrap().push_int64(1);
295 builder.advance_row();
296 builder.column_mut(0).unwrap().push_int64(2);
297 builder.advance_row();
298
299 let chunk = builder.finish();
300
301 let mock_scan = MockScanOperator {
302 chunks: vec![chunk],
303 position: 0,
304 };
305
306 let mut project = ProjectOperator::new(
308 Box::new(mock_scan),
309 vec![
310 ProjectExpr::Column(0),
311 ProjectExpr::Constant(Value::String("constant".into())),
312 ],
313 vec![LogicalType::Int64, LogicalType::String],
314 );
315
316 let result = project.next().unwrap().unwrap();
317
318 assert_eq!(result.column_count(), 2);
319 assert_eq!(result.column(1).unwrap().get_string(0), Some("constant"));
320 assert_eq!(result.column(1).unwrap().get_string(1), Some("constant"));
321 }
322
323 #[test]
324 fn test_project_empty_input() {
325 let mock_scan = MockScanOperator {
326 chunks: vec![],
327 position: 0,
328 };
329
330 let mut project =
331 ProjectOperator::select_columns(Box::new(mock_scan), vec![0], vec![LogicalType::Int64]);
332
333 assert!(project.next().unwrap().is_none());
334 }
335
336 #[test]
337 fn test_project_column_not_found() {
338 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
339 builder.column_mut(0).unwrap().push_int64(1);
340 builder.advance_row();
341 let chunk = builder.finish();
342
343 let mock_scan = MockScanOperator {
344 chunks: vec![chunk],
345 position: 0,
346 };
347
348 let mut project = ProjectOperator::new(
350 Box::new(mock_scan),
351 vec![ProjectExpr::Column(5)],
352 vec![LogicalType::Int64],
353 );
354
355 let result = project.next();
356 assert!(result.is_err(), "Should fail with ColumnNotFound");
357 }
358
359 #[test]
360 fn test_project_multiple_constants() {
361 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
362 builder.column_mut(0).unwrap().push_int64(1);
363 builder.advance_row();
364 let chunk = builder.finish();
365
366 let mock_scan = MockScanOperator {
367 chunks: vec![chunk],
368 position: 0,
369 };
370
371 let mut project = ProjectOperator::new(
372 Box::new(mock_scan),
373 vec![
374 ProjectExpr::Constant(Value::Int64(42)),
375 ProjectExpr::Constant(Value::String("fixed".into())),
376 ProjectExpr::Constant(Value::Bool(true)),
377 ],
378 vec![LogicalType::Int64, LogicalType::String, LogicalType::Bool],
379 );
380
381 let result = project.next().unwrap().unwrap();
382 assert_eq!(result.column_count(), 3);
383 assert_eq!(result.column(0).unwrap().get_int64(0), Some(42));
384 assert_eq!(result.column(1).unwrap().get_string(0), Some("fixed"));
385 assert_eq!(
386 result.column(2).unwrap().get_value(0),
387 Some(Value::Bool(true))
388 );
389 }
390
391 #[test]
392 fn test_project_identity() {
393 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String]);
395 builder.column_mut(0).unwrap().push_int64(10);
396 builder.column_mut(1).unwrap().push_string("test");
397 builder.advance_row();
398 let chunk = builder.finish();
399
400 let mock_scan = MockScanOperator {
401 chunks: vec![chunk],
402 position: 0,
403 };
404
405 let mut project = ProjectOperator::select_columns(
406 Box::new(mock_scan),
407 vec![0, 1],
408 vec![LogicalType::Int64, LogicalType::String],
409 );
410
411 let result = project.next().unwrap().unwrap();
412 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
413 assert_eq!(result.column(1).unwrap().get_string(0), Some("test"));
414 }
415
416 #[test]
417 fn test_project_name() {
418 let mock_scan = MockScanOperator {
419 chunks: vec![],
420 position: 0,
421 };
422 let project =
423 ProjectOperator::select_columns(Box::new(mock_scan), vec![0], vec![LogicalType::Int64]);
424 assert_eq!(project.name(), "Project");
425 }
426}