grafeo_core/execution/operators/push/
project.rs1use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
6use crate::execution::vector::ValueVector;
7use grafeo_common::types::Value;
8
9pub trait ProjectExpression: Send + Sync {
11 fn evaluate(&self, chunk: &DataChunk, row: usize) -> Value;
13
14 fn evaluate_batch(&self, chunk: &DataChunk) -> ValueVector {
16 let mut result = ValueVector::new();
17 for i in chunk.selected_indices() {
18 result.push(self.evaluate(chunk, i));
19 }
20 result
21 }
22}
23
24pub struct ColumnExpr {
26 column: usize,
27}
28
29impl ColumnExpr {
30 pub fn new(column: usize) -> Self {
32 Self { column }
33 }
34}
35
36impl ProjectExpression for ColumnExpr {
37 fn evaluate(&self, chunk: &DataChunk, row: usize) -> Value {
38 chunk
39 .column(self.column)
40 .and_then(|c| c.get_value(row))
41 .unwrap_or(Value::Null)
42 }
43}
44
45pub struct ConstantExpr {
47 value: Value,
48}
49
50impl ConstantExpr {
51 pub fn new(value: Value) -> Self {
53 Self { value }
54 }
55}
56
57impl ProjectExpression for ConstantExpr {
58 fn evaluate(&self, _chunk: &DataChunk, _row: usize) -> Value {
59 self.value.clone()
60 }
61}
62
63#[derive(Debug, Clone, Copy)]
65#[non_exhaustive]
66pub enum ArithOp {
67 Add,
69 Sub,
71 Mul,
73 Div,
75 Mod,
77}
78
79pub struct BinaryExpr {
81 left: Box<dyn ProjectExpression>,
82 right: Box<dyn ProjectExpression>,
83 op: ArithOp,
84}
85
86impl BinaryExpr {
87 pub fn new(
89 left: Box<dyn ProjectExpression>,
90 right: Box<dyn ProjectExpression>,
91 op: ArithOp,
92 ) -> Self {
93 Self { left, right, op }
94 }
95}
96
97impl ProjectExpression for BinaryExpr {
98 fn evaluate(&self, chunk: &DataChunk, row: usize) -> Value {
99 let left_val = self.left.evaluate(chunk, row);
100 let right_val = self.right.evaluate(chunk, row);
101
102 match (&left_val, &right_val) {
103 (Value::Int64(l), Value::Int64(r)) => match self.op {
104 ArithOp::Add => Value::Int64(l.wrapping_add(*r)),
105 ArithOp::Sub => Value::Int64(l.wrapping_sub(*r)),
106 ArithOp::Mul => Value::Int64(l.wrapping_mul(*r)),
107 ArithOp::Div => {
108 if *r == 0 {
109 Value::Null
110 } else {
111 Value::Int64(l / r)
112 }
113 }
114 ArithOp::Mod => {
115 if *r == 0 {
116 Value::Null
117 } else {
118 Value::Int64(l % r)
119 }
120 }
121 },
122 (Value::Float64(l), Value::Float64(r)) => match self.op {
123 ArithOp::Add => Value::Float64(l + r),
124 ArithOp::Sub => Value::Float64(l - r),
125 ArithOp::Mul => Value::Float64(l * r),
126 ArithOp::Div => Value::Float64(l / r),
127 ArithOp::Mod => Value::Float64(l % r),
128 },
129 _ => Value::Null,
130 }
131 }
132}
133
134pub struct ProjectPushOperator {
138 expressions: Vec<Box<dyn ProjectExpression>>,
139}
140
141impl ProjectPushOperator {
142 pub fn new(expressions: Vec<Box<dyn ProjectExpression>>) -> Self {
144 Self { expressions }
145 }
146
147 pub fn select_columns(columns: &[usize]) -> Self {
149 let expressions: Vec<Box<dyn ProjectExpression>> = columns
150 .iter()
151 .map(|&c| Box::new(ColumnExpr::new(c)) as Box<dyn ProjectExpression>)
152 .collect();
153 Self { expressions }
154 }
155}
156
157impl PushOperator for ProjectPushOperator {
158 fn push(&mut self, chunk: DataChunk, sink: &mut dyn Sink) -> Result<bool, OperatorError> {
159 if chunk.is_empty() {
160 return Ok(true);
161 }
162
163 let columns: Vec<ValueVector> = self
165 .expressions
166 .iter()
167 .map(|expr| expr.evaluate_batch(&chunk))
168 .collect();
169
170 let projected = DataChunk::new(columns);
171
172 sink.consume(projected)
174 }
175
176 fn finalize(&mut self, _sink: &mut dyn Sink) -> Result<(), OperatorError> {
177 Ok(())
179 }
180
181 fn preferred_chunk_size(&self) -> ChunkSizeHint {
182 ChunkSizeHint::Default
183 }
184
185 fn name(&self) -> &'static str {
186 "ProjectPush"
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::execution::sink::CollectorSink;
194
195 fn create_test_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
196 let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
197 let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
198 let vec1 = ValueVector::from_values(&v1);
199 let vec2 = ValueVector::from_values(&v2);
200 DataChunk::new(vec![vec1, vec2])
201 }
202
203 #[test]
204 fn test_project_select_columns() {
205 let mut project = ProjectPushOperator::select_columns(&[1, 0]); let mut sink = CollectorSink::new();
207
208 let chunk = create_test_chunk(&[1, 2, 3], &[10, 20, 30]);
209 project.push(chunk, &mut sink).unwrap();
210 project.finalize(&mut sink).unwrap();
211
212 assert_eq!(sink.row_count(), 3);
213 let chunks = sink.into_chunks();
214 assert_eq!(chunks.len(), 1);
215
216 let col = chunks[0].column(0).unwrap();
218 assert_eq!(col.get_value(0), Some(Value::Int64(10)));
219 }
220
221 #[test]
222 fn test_project_constant() {
223 let expressions: Vec<Box<dyn ProjectExpression>> =
224 vec![Box::new(ConstantExpr::new(Value::Int64(42)))];
225 let mut project = ProjectPushOperator::new(expressions);
226 let mut sink = CollectorSink::new();
227
228 let chunk = create_test_chunk(&[1, 2, 3], &[10, 20, 30]);
229 project.push(chunk, &mut sink).unwrap();
230 project.finalize(&mut sink).unwrap();
231
232 assert_eq!(sink.row_count(), 3);
233 let chunks = sink.into_chunks();
234 let col = chunks[0].column(0).unwrap();
235 assert_eq!(col.get_value(0), Some(Value::Int64(42)));
236 assert_eq!(col.get_value(1), Some(Value::Int64(42)));
237 assert_eq!(col.get_value(2), Some(Value::Int64(42)));
238 }
239
240 #[test]
241 fn test_project_arithmetic() {
242 let expressions: Vec<Box<dyn ProjectExpression>> = vec![Box::new(BinaryExpr::new(
243 Box::new(ColumnExpr::new(0)),
244 Box::new(ColumnExpr::new(1)),
245 ArithOp::Add,
246 ))];
247 let mut project = ProjectPushOperator::new(expressions);
248 let mut sink = CollectorSink::new();
249
250 let chunk = create_test_chunk(&[1, 2, 3], &[10, 20, 30]);
251 project.push(chunk, &mut sink).unwrap();
252 project.finalize(&mut sink).unwrap();
253
254 let chunks = sink.into_chunks();
255 let col = chunks[0].column(0).unwrap();
256 assert_eq!(col.get_value(0), Some(Value::Int64(11))); assert_eq!(col.get_value(1), Some(Value::Int64(22))); assert_eq!(col.get_value(2), Some(Value::Int64(33))); }
260}