grafeo_core/execution/operators/
project.rs1use super::filter::{ExpressionPredicate, FilterExpression};
4use super::{Operator, OperatorError, OperatorResult};
5use crate::execution::DataChunk;
6use crate::graph::lpg::LpgStore;
7use grafeo_common::types::{LogicalType, Value};
8use std::collections::HashMap;
9use std::sync::Arc;
10
11pub enum ProjectExpr {
13 Column(usize),
15 Constant(Value),
17 PropertyAccess {
19 column: usize,
21 property: String,
23 },
24 EdgeType {
26 column: usize,
28 },
29 Expression {
31 expr: FilterExpression,
33 variable_columns: HashMap<String, usize>,
35 },
36}
37
38pub struct ProjectOperator {
40 child: Box<dyn Operator>,
42 projections: Vec<ProjectExpr>,
44 output_types: Vec<LogicalType>,
46 store: Option<Arc<LpgStore>>,
48}
49
50impl ProjectOperator {
51 pub fn new(
53 child: Box<dyn Operator>,
54 projections: Vec<ProjectExpr>,
55 output_types: Vec<LogicalType>,
56 ) -> Self {
57 assert_eq!(projections.len(), output_types.len());
58 Self {
59 child,
60 projections,
61 output_types,
62 store: None,
63 }
64 }
65
66 pub fn with_store(
68 child: Box<dyn Operator>,
69 projections: Vec<ProjectExpr>,
70 output_types: Vec<LogicalType>,
71 store: Arc<LpgStore>,
72 ) -> Self {
73 assert_eq!(projections.len(), output_types.len());
74 Self {
75 child,
76 projections,
77 output_types,
78 store: Some(store),
79 }
80 }
81
82 pub fn select_columns(
84 child: Box<dyn Operator>,
85 columns: Vec<usize>,
86 types: Vec<LogicalType>,
87 ) -> Self {
88 let projections = columns.into_iter().map(ProjectExpr::Column).collect();
89 Self::new(child, projections, types)
90 }
91}
92
93impl Operator for ProjectOperator {
94 fn next(&mut self) -> OperatorResult {
95 let input = match self.child.next()? {
97 Some(c) => c,
98 None => return Ok(None),
99 };
100
101 let mut output = DataChunk::with_capacity(&self.output_types, input.row_count());
103
104 for (i, proj) in self.projections.iter().enumerate() {
106 match proj {
107 ProjectExpr::Column(col_idx) => {
108 let input_col = input.column(*col_idx).ok_or_else(|| {
110 OperatorError::ColumnNotFound(format!("Column {col_idx}"))
111 })?;
112
113 let output_col = output.column_mut(i).unwrap();
114
115 for row in input.selected_indices() {
117 if let Some(value) = input_col.get_value(row) {
118 output_col.push_value(value);
119 }
120 }
121 }
122 ProjectExpr::Constant(value) => {
123 let output_col = output.column_mut(i).unwrap();
125 for _ in input.selected_indices() {
126 output_col.push_value(value.clone());
127 }
128 }
129 ProjectExpr::PropertyAccess { column, property } => {
130 let input_col = input
132 .column(*column)
133 .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
134
135 let output_col = output.column_mut(i).unwrap();
136
137 let store = self.store.as_ref().ok_or_else(|| {
138 OperatorError::Execution("Store required for property access".to_string())
139 })?;
140
141 for row in input.selected_indices() {
143 let value = if let Some(node_id) = input_col.get_node_id(row) {
145 store
146 .get_node(node_id)
147 .and_then(|node| node.get_property(property).cloned())
148 .unwrap_or(Value::Null)
149 } else if let Some(edge_id) = input_col.get_edge_id(row) {
150 store
151 .get_edge(edge_id)
152 .and_then(|edge| edge.get_property(property).cloned())
153 .unwrap_or(Value::Null)
154 } else {
155 Value::Null
156 };
157 output_col.push_value(value);
158 }
159 }
160 ProjectExpr::EdgeType { column } => {
161 let input_col = input
163 .column(*column)
164 .ok_or_else(|| OperatorError::ColumnNotFound(format!("Column {column}")))?;
165
166 let output_col = output.column_mut(i).unwrap();
167
168 let store = self.store.as_ref().ok_or_else(|| {
169 OperatorError::Execution("Store required for edge type access".to_string())
170 })?;
171
172 for row in input.selected_indices() {
173 let value = if let Some(edge_id) = input_col.get_edge_id(row) {
174 store
175 .edge_type(edge_id)
176 .map(Value::String)
177 .unwrap_or(Value::Null)
178 } else {
179 Value::Null
180 };
181 output_col.push_value(value);
182 }
183 }
184 ProjectExpr::Expression {
185 expr,
186 variable_columns,
187 } => {
188 let output_col = output.column_mut(i).unwrap();
189
190 let store = self.store.as_ref().ok_or_else(|| {
191 OperatorError::Execution(
192 "Store required for expression evaluation".to_string(),
193 )
194 })?;
195
196 let evaluator = ExpressionPredicate::new(
198 expr.clone(),
199 variable_columns.clone(),
200 Arc::clone(store),
201 );
202
203 for row in input.selected_indices() {
204 let value = evaluator.eval_at(&input, row).unwrap_or(Value::Null);
205 output_col.push_value(value);
206 }
207 }
208 }
209 }
210
211 output.set_count(input.row_count());
212 Ok(Some(output))
213 }
214
215 fn reset(&mut self) {
216 self.child.reset();
217 }
218
219 fn name(&self) -> &'static str {
220 "Project"
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::execution::chunk::DataChunkBuilder;
228 use grafeo_common::types::Value;
229
230 struct MockScanOperator {
231 chunks: Vec<DataChunk>,
232 position: usize,
233 }
234
235 impl Operator for MockScanOperator {
236 fn next(&mut self) -> OperatorResult {
237 if self.position < self.chunks.len() {
238 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
239 self.position += 1;
240 Ok(Some(chunk))
241 } else {
242 Ok(None)
243 }
244 }
245
246 fn reset(&mut self) {
247 self.position = 0;
248 }
249
250 fn name(&self) -> &'static str {
251 "MockScan"
252 }
253 }
254
255 #[test]
256 fn test_project_select_columns() {
257 let mut builder =
259 DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String, LogicalType::Int64]);
260
261 builder.column_mut(0).unwrap().push_int64(1);
262 builder.column_mut(1).unwrap().push_string("hello");
263 builder.column_mut(2).unwrap().push_int64(100);
264 builder.advance_row();
265
266 builder.column_mut(0).unwrap().push_int64(2);
267 builder.column_mut(1).unwrap().push_string("world");
268 builder.column_mut(2).unwrap().push_int64(200);
269 builder.advance_row();
270
271 let chunk = builder.finish();
272
273 let mock_scan = MockScanOperator {
274 chunks: vec![chunk],
275 position: 0,
276 };
277
278 let mut project = ProjectOperator::select_columns(
280 Box::new(mock_scan),
281 vec![2, 0],
282 vec![LogicalType::Int64, LogicalType::Int64],
283 );
284
285 let result = project.next().unwrap().unwrap();
286
287 assert_eq!(result.column_count(), 2);
288 assert_eq!(result.row_count(), 2);
289
290 assert_eq!(result.column(0).unwrap().get_int64(0), Some(100));
292 assert_eq!(result.column(1).unwrap().get_int64(0), Some(1));
293 }
294
295 #[test]
296 fn test_project_constant() {
297 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
298 builder.column_mut(0).unwrap().push_int64(1);
299 builder.advance_row();
300 builder.column_mut(0).unwrap().push_int64(2);
301 builder.advance_row();
302
303 let chunk = builder.finish();
304
305 let mock_scan = MockScanOperator {
306 chunks: vec![chunk],
307 position: 0,
308 };
309
310 let mut project = ProjectOperator::new(
312 Box::new(mock_scan),
313 vec![
314 ProjectExpr::Column(0),
315 ProjectExpr::Constant(Value::String("constant".into())),
316 ],
317 vec![LogicalType::Int64, LogicalType::String],
318 );
319
320 let result = project.next().unwrap().unwrap();
321
322 assert_eq!(result.column_count(), 2);
323 assert_eq!(result.column(1).unwrap().get_string(0), Some("constant"));
324 assert_eq!(result.column(1).unwrap().get_string(1), Some("constant"));
325 }
326}