grafeo_core/execution/operators/push/
project.rs1use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
6use crate::execution::vector::ValueVector;
7use grafeo_common::types::Value;
8
9pub trait ProjectExpression: Send + Sync {
11 fn evaluate(&self, chunk: &DataChunk, row: usize) -> Value;
13
14 fn evaluate_batch(&self, chunk: &DataChunk) -> ValueVector {
16 let mut result = ValueVector::new();
17 for i in chunk.selected_indices() {
18 result.push(self.evaluate(chunk, i));
19 }
20 result
21 }
22}
23
24pub struct ColumnExpr {
26 column: usize,
27}
28
29impl ColumnExpr {
30 pub fn new(column: usize) -> Self {
32 Self { column }
33 }
34}
35
36impl ProjectExpression for ColumnExpr {
37 fn evaluate(&self, chunk: &DataChunk, row: usize) -> Value {
38 chunk
39 .column(self.column)
40 .and_then(|c| c.get_value(row))
41 .unwrap_or(Value::Null)
42 }
43}
44
45pub struct ConstantExpr {
47 value: Value,
48}
49
50impl ConstantExpr {
51 pub fn new(value: Value) -> Self {
53 Self { value }
54 }
55}
56
57impl ProjectExpression for ConstantExpr {
58 fn evaluate(&self, _chunk: &DataChunk, _row: usize) -> Value {
59 self.value.clone()
60 }
61}
62
63#[derive(Debug, Clone, Copy)]
65pub enum ArithOp {
66 Add,
68 Sub,
70 Mul,
72 Div,
74 Mod,
76}
77
78pub struct BinaryExpr {
80 left: Box<dyn ProjectExpression>,
81 right: Box<dyn ProjectExpression>,
82 op: ArithOp,
83}
84
85impl BinaryExpr {
86 pub fn new(
88 left: Box<dyn ProjectExpression>,
89 right: Box<dyn ProjectExpression>,
90 op: ArithOp,
91 ) -> Self {
92 Self { left, right, op }
93 }
94}
95
96impl ProjectExpression for BinaryExpr {
97 fn evaluate(&self, chunk: &DataChunk, row: usize) -> Value {
98 let left_val = self.left.evaluate(chunk, row);
99 let right_val = self.right.evaluate(chunk, row);
100
101 match (&left_val, &right_val) {
102 (Value::Int64(l), Value::Int64(r)) => match self.op {
103 ArithOp::Add => Value::Int64(l.wrapping_add(*r)),
104 ArithOp::Sub => Value::Int64(l.wrapping_sub(*r)),
105 ArithOp::Mul => Value::Int64(l.wrapping_mul(*r)),
106 ArithOp::Div => {
107 if *r == 0 {
108 Value::Null
109 } else {
110 Value::Int64(l / r)
111 }
112 }
113 ArithOp::Mod => {
114 if *r == 0 {
115 Value::Null
116 } else {
117 Value::Int64(l % r)
118 }
119 }
120 },
121 (Value::Float64(l), Value::Float64(r)) => match self.op {
122 ArithOp::Add => Value::Float64(l + r),
123 ArithOp::Sub => Value::Float64(l - r),
124 ArithOp::Mul => Value::Float64(l * r),
125 ArithOp::Div => Value::Float64(l / r),
126 ArithOp::Mod => Value::Float64(l % r),
127 },
128 _ => Value::Null,
129 }
130 }
131}
132
133pub struct ProjectPushOperator {
137 expressions: Vec<Box<dyn ProjectExpression>>,
138}
139
140impl ProjectPushOperator {
141 pub fn new(expressions: Vec<Box<dyn ProjectExpression>>) -> Self {
143 Self { expressions }
144 }
145
146 pub fn select_columns(columns: &[usize]) -> Self {
148 let expressions: Vec<Box<dyn ProjectExpression>> = columns
149 .iter()
150 .map(|&c| Box::new(ColumnExpr::new(c)) as Box<dyn ProjectExpression>)
151 .collect();
152 Self { expressions }
153 }
154}
155
156impl PushOperator for ProjectPushOperator {
157 fn push(&mut self, chunk: DataChunk, sink: &mut dyn Sink) -> Result<bool, OperatorError> {
158 if chunk.is_empty() {
159 return Ok(true);
160 }
161
162 let columns: Vec<ValueVector> = self
164 .expressions
165 .iter()
166 .map(|expr| expr.evaluate_batch(&chunk))
167 .collect();
168
169 let projected = DataChunk::new(columns);
170
171 sink.consume(projected)
173 }
174
175 fn finalize(&mut self, _sink: &mut dyn Sink) -> Result<(), OperatorError> {
176 Ok(())
178 }
179
180 fn preferred_chunk_size(&self) -> ChunkSizeHint {
181 ChunkSizeHint::Default
182 }
183
184 fn name(&self) -> &'static str {
185 "ProjectPush"
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::execution::sink::CollectorSink;
193
194 fn create_test_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
195 let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
196 let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
197 let vec1 = ValueVector::from_values(&v1);
198 let vec2 = ValueVector::from_values(&v2);
199 DataChunk::new(vec![vec1, vec2])
200 }
201
202 #[test]
203 fn test_project_select_columns() {
204 let mut project = ProjectPushOperator::select_columns(&[1, 0]); let mut sink = CollectorSink::new();
206
207 let chunk = create_test_chunk(&[1, 2, 3], &[10, 20, 30]);
208 project.push(chunk, &mut sink).unwrap();
209 project.finalize(&mut sink).unwrap();
210
211 assert_eq!(sink.row_count(), 3);
212 let chunks = sink.into_chunks();
213 assert_eq!(chunks.len(), 1);
214
215 let col = chunks[0].column(0).unwrap();
217 assert_eq!(col.get_value(0), Some(Value::Int64(10)));
218 }
219
220 #[test]
221 fn test_project_constant() {
222 let expressions: Vec<Box<dyn ProjectExpression>> =
223 vec![Box::new(ConstantExpr::new(Value::Int64(42)))];
224 let mut project = ProjectPushOperator::new(expressions);
225 let mut sink = CollectorSink::new();
226
227 let chunk = create_test_chunk(&[1, 2, 3], &[10, 20, 30]);
228 project.push(chunk, &mut sink).unwrap();
229 project.finalize(&mut sink).unwrap();
230
231 assert_eq!(sink.row_count(), 3);
232 let chunks = sink.into_chunks();
233 let col = chunks[0].column(0).unwrap();
234 assert_eq!(col.get_value(0), Some(Value::Int64(42)));
235 assert_eq!(col.get_value(1), Some(Value::Int64(42)));
236 assert_eq!(col.get_value(2), Some(Value::Int64(42)));
237 }
238
239 #[test]
240 fn test_project_arithmetic() {
241 let expressions: Vec<Box<dyn ProjectExpression>> = vec![Box::new(BinaryExpr::new(
242 Box::new(ColumnExpr::new(0)),
243 Box::new(ColumnExpr::new(1)),
244 ArithOp::Add,
245 ))];
246 let mut project = ProjectPushOperator::new(expressions);
247 let mut sink = CollectorSink::new();
248
249 let chunk = create_test_chunk(&[1, 2, 3], &[10, 20, 30]);
250 project.push(chunk, &mut sink).unwrap();
251 project.finalize(&mut sink).unwrap();
252
253 let chunks = sink.into_chunks();
254 let col = chunks[0].column(0).unwrap();
255 assert_eq!(col.get_value(0), Some(Value::Int64(11))); assert_eq!(col.get_value(1), Some(Value::Int64(22))); assert_eq!(col.get_value(2), Some(Value::Int64(33))); }
259}