grafeo_core/execution/operators/
project.rs1use super::filter::{ExpressionPredicate, FilterExpression};
4use super::{Operator, OperatorError, OperatorResult};
5use crate::execution::DataChunk;
6use crate::graph::lpg::LpgStore;
7use grafeo_common::types::{LogicalType, Value};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub enum ProjectExpr {
13 Column(usize),
15 Constant(Value),
17 PropertyAccess {
19 column: usize,
21 property: String,
23 },
24 EdgeType {
26 column: usize,
28 },
29 Expression {
31 expr: FilterExpression,
33 variable_columns: HashMap<String, usize>,
35 },
36}
37
38pub struct ProjectOperator {
40 child: Box<dyn Operator>,
42 projections: Vec<ProjectExpr>,
44 output_types: Vec<LogicalType>,
46 store: Option<Arc<LpgStore>>,
48}
49
50impl ProjectOperator {
51 pub fn new(
53 child: Box<dyn Operator>,
54 projections: Vec<ProjectExpr>,
55 output_types: Vec<LogicalType>,
56 ) -> Self {
57 assert_eq!(projections.len(), output_types.len());
58 Self {
59 child,
60 projections,
61 output_types,
62 store: None,
63 }
64 }
65
66 pub fn with_store(
68 child: Box<dyn Operator>,
69 projections: Vec<ProjectExpr>,
70 output_types: Vec<LogicalType>,
71 store: Arc<LpgStore>,
72 ) -> Self {
73 assert_eq!(projections.len(), output_types.len());
74 Self {
75 child,
76 projections,
77 output_types,
78 store: Some(store),
79 }
80 }
81
82 pub fn select_columns(
84 child: Box<dyn Operator>,
85 columns: Vec<usize>,
86 types: Vec<LogicalType>,
87 ) -> Self {
88 let projections = columns.into_iter().map(ProjectExpr::Column).collect();
89 Self::new(child, projections, types)
90 }
91}
92
93impl Operator for ProjectOperator {
94 fn next(&mut self) -> OperatorResult {
95 let Some(input) = self.child.next()? else {
97 return Ok(None);
98 };
99
100 let mut output = DataChunk::with_capacity(&self.output_types, input.row_count());
102
103 for (i, proj) in self.projections.iter().enumerate() {
105 match proj {
106 ProjectExpr::Column(col_idx) => {
107 let input_col = input.column(*col_idx).ok_or_else(|| {
109 OperatorError::ColumnNotFound(format!("Column {col_idx}"))
110 })?;
111
112 let output_col = output.column_mut(i).unwrap();
113
114 for row in input.selected_indices() {
116 if let Some(value) = input_col.get_value(row) {
117 output_col.push_value(value);
118 }
119 }
120 }
121 ProjectExpr::Constant(value) => {
122 let output_col = output.column_mut(i).unwrap();
124 for _ in input.selected_indices() {
125 output_col.push_value(value.clone());
126 }
127 }
128 ProjectExpr::PropertyAccess { column, property } => {
129 let input_col = input
131 .column(*column)
132 .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
133
134 let output_col = output.column_mut(i).unwrap();
135
136 let store = self.store.as_ref().ok_or_else(|| {
137 OperatorError::Execution("Store required for property access".to_string())
138 })?;
139
140 for row in input.selected_indices() {
142 let value = if let Some(node_id) = input_col.get_node_id(row) {
144 store
145 .get_node(node_id)
146 .and_then(|node| node.get_property(property).cloned())
147 .unwrap_or(Value::Null)
148 } else if let Some(edge_id) = input_col.get_edge_id(row) {
149 store
150 .get_edge(edge_id)
151 .and_then(|edge| edge.get_property(property).cloned())
152 .unwrap_or(Value::Null)
153 } else {
154 Value::Null
155 };
156 output_col.push_value(value);
157 }
158 }
159 ProjectExpr::EdgeType { column } => {
160 let input_col = input
162 .column(*column)
163 .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
164
165 let output_col = output.column_mut(i).unwrap();
166
167 let store = self.store.as_ref().ok_or_else(|| {
168 OperatorError::Execution("Store required for edge type access".to_string())
169 })?;
170
171 for row in input.selected_indices() {
172 let value = if let Some(edge_id) = input_col.get_edge_id(row) {
173 store.edge_type(edge_id).map_or(Value::Null, Value::String)
174 } else {
175 Value::Null
176 };
177 output_col.push_value(value);
178 }
179 }
180 ProjectExpr::Expression {
181 expr,
182 variable_columns,
183 } => {
184 let output_col = output.column_mut(i).unwrap();
185
186 let store = self.store.as_ref().ok_or_else(|| {
187 OperatorError::Execution(
188 "Store required for expression evaluation".to_string(),
189 )
190 })?;
191
192 let evaluator = ExpressionPredicate::new(
194 expr.clone(),
195 variable_columns.clone(),
196 Arc::clone(store),
197 );
198
199 for row in input.selected_indices() {
200 let value = evaluator.eval_at(&input, row).unwrap_or(Value::Null);
201 output_col.push_value(value);
202 }
203 }
204 }
205 }
206
207 output.set_count(input.row_count());
208 Ok(Some(output))
209 }
210
211 fn reset(&mut self) {
212 self.child.reset();
213 }
214
215 fn name(&self) -> &'static str {
216 "Project"
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::execution::chunk::DataChunkBuilder;
224 use grafeo_common::types::Value;
225
226 struct MockScanOperator {
227 chunks: Vec<DataChunk>,
228 position: usize,
229 }
230
231 impl Operator for MockScanOperator {
232 fn next(&mut self) -> OperatorResult {
233 if self.position < self.chunks.len() {
234 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
235 self.position += 1;
236 Ok(Some(chunk))
237 } else {
238 Ok(None)
239 }
240 }
241
242 fn reset(&mut self) {
243 self.position = 0;
244 }
245
246 fn name(&self) -> &'static str {
247 "MockScan"
248 }
249 }
250
251 #[test]
252 fn test_project_select_columns() {
253 let mut builder =
255 DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String, LogicalType::Int64]);
256
257 builder.column_mut(0).unwrap().push_int64(1);
258 builder.column_mut(1).unwrap().push_string("hello");
259 builder.column_mut(2).unwrap().push_int64(100);
260 builder.advance_row();
261
262 builder.column_mut(0).unwrap().push_int64(2);
263 builder.column_mut(1).unwrap().push_string("world");
264 builder.column_mut(2).unwrap().push_int64(200);
265 builder.advance_row();
266
267 let chunk = builder.finish();
268
269 let mock_scan = MockScanOperator {
270 chunks: vec![chunk],
271 position: 0,
272 };
273
274 let mut project = ProjectOperator::select_columns(
276 Box::new(mock_scan),
277 vec![2, 0],
278 vec![LogicalType::Int64, LogicalType::Int64],
279 );
280
281 let result = project.next().unwrap().unwrap();
282
283 assert_eq!(result.column_count(), 2);
284 assert_eq!(result.row_count(), 2);
285
286 assert_eq!(result.column(0).unwrap().get_int64(0), Some(100));
288 assert_eq!(result.column(1).unwrap().get_int64(0), Some(1));
289 }
290
291 #[test]
292 fn test_project_constant() {
293 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
294 builder.column_mut(0).unwrap().push_int64(1);
295 builder.advance_row();
296 builder.column_mut(0).unwrap().push_int64(2);
297 builder.advance_row();
298
299 let chunk = builder.finish();
300
301 let mock_scan = MockScanOperator {
302 chunks: vec![chunk],
303 position: 0,
304 };
305
306 let mut project = ProjectOperator::new(
308 Box::new(mock_scan),
309 vec![
310 ProjectExpr::Column(0),
311 ProjectExpr::Constant(Value::String("constant".into())),
312 ],
313 vec![LogicalType::Int64, LogicalType::String],
314 );
315
316 let result = project.next().unwrap().unwrap();
317
318 assert_eq!(result.column_count(), 2);
319 assert_eq!(result.column(1).unwrap().get_string(0), Some("constant"));
320 assert_eq!(result.column(1).unwrap().get_string(1), Some("constant"));
321 }
322}