Skip to main content

grafeo_core/execution/operators/
union.rs

1//! Union operator for combining multiple result sets.
2//!
3//! The union operator concatenates results from multiple input operators,
4//! producing all rows from each input in sequence.
5
6use grafeo_common::types::LogicalType;
7
8use super::{Operator, OperatorResult};
9
10/// Union operator that combines results from multiple inputs.
11///
12/// This produces all rows from all inputs, in order. It does not
13/// remove duplicates (use DISTINCT after UNION for UNION DISTINCT).
14pub struct UnionOperator {
15    /// Input operators.
16    inputs: Vec<Box<dyn Operator>>,
17    /// Current input index.
18    current_input: usize,
19    /// Output schema.
20    output_schema: Vec<LogicalType>,
21}
22
23impl UnionOperator {
24    /// Creates a new union operator.
25    ///
26    /// # Arguments
27    /// * `inputs` - The input operators to union.
28    /// * `output_schema` - The schema of the output (should match all inputs).
29    pub fn new(inputs: Vec<Box<dyn Operator>>, output_schema: Vec<LogicalType>) -> Self {
30        Self {
31            inputs,
32            current_input: 0,
33            output_schema,
34        }
35    }
36
37    /// Returns the output schema.
38    #[must_use]
39    pub fn output_schema(&self) -> &[LogicalType] {
40        &self.output_schema
41    }
42}
43
44impl Operator for UnionOperator {
45    fn next(&mut self) -> OperatorResult {
46        // Process inputs in order
47        while self.current_input < self.inputs.len() {
48            if let Some(chunk) = self.inputs[self.current_input].next()? {
49                return Ok(Some(chunk));
50            }
51            // Move to next input when current is exhausted
52            self.current_input += 1;
53        }
54
55        Ok(None)
56    }
57
58    fn reset(&mut self) {
59        for input in &mut self.inputs {
60            input.reset();
61        }
62        self.current_input = 0;
63    }
64
65    fn name(&self) -> &'static str {
66        "Union"
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use crate::execution::DataChunk;
74    use crate::execution::chunk::DataChunkBuilder;
75
76    /// Mock operator for testing.
77    struct MockOperator {
78        chunks: Vec<DataChunk>,
79        position: usize,
80    }
81
82    impl MockOperator {
83        fn new(chunks: Vec<DataChunk>) -> Self {
84            Self {
85                chunks,
86                position: 0,
87            }
88        }
89    }
90
91    impl Operator for MockOperator {
92        fn next(&mut self) -> OperatorResult {
93            if self.position < self.chunks.len() {
94                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
95                self.position += 1;
96                Ok(Some(chunk))
97            } else {
98                Ok(None)
99            }
100        }
101
102        fn reset(&mut self) {
103            self.position = 0;
104        }
105
106        fn name(&self) -> &'static str {
107            "Mock"
108        }
109    }
110
111    fn create_int_chunk(values: &[i64]) -> DataChunk {
112        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
113        for &v in values {
114            builder.column_mut(0).unwrap().push_int64(v);
115            builder.advance_row();
116        }
117        builder.finish()
118    }
119
120    #[test]
121    fn test_union_two_inputs() {
122        let input1 = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
123        let input2 = MockOperator::new(vec![create_int_chunk(&[3, 4])]);
124
125        let mut union = UnionOperator::new(
126            vec![Box::new(input1), Box::new(input2)],
127            vec![LogicalType::Int64],
128        );
129
130        let mut results = Vec::new();
131        while let Some(chunk) = union.next().unwrap() {
132            for row in chunk.selected_indices() {
133                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
134                results.push(val);
135            }
136        }
137
138        assert_eq!(results, vec![1, 2, 3, 4]);
139    }
140
141    #[test]
142    fn test_union_three_inputs() {
143        let input1 = MockOperator::new(vec![create_int_chunk(&[1])]);
144        let input2 = MockOperator::new(vec![create_int_chunk(&[2])]);
145        let input3 = MockOperator::new(vec![create_int_chunk(&[3])]);
146
147        let mut union = UnionOperator::new(
148            vec![Box::new(input1), Box::new(input2), Box::new(input3)],
149            vec![LogicalType::Int64],
150        );
151
152        let mut results = Vec::new();
153        while let Some(chunk) = union.next().unwrap() {
154            for row in chunk.selected_indices() {
155                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
156                results.push(val);
157            }
158        }
159
160        assert_eq!(results, vec![1, 2, 3]);
161    }
162
163    #[test]
164    fn test_union_empty_input() {
165        let input1 = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
166        let input2 = MockOperator::new(vec![]); // Empty
167        let input3 = MockOperator::new(vec![create_int_chunk(&[3])]);
168
169        let mut union = UnionOperator::new(
170            vec![Box::new(input1), Box::new(input2), Box::new(input3)],
171            vec![LogicalType::Int64],
172        );
173
174        let mut results = Vec::new();
175        while let Some(chunk) = union.next().unwrap() {
176            for row in chunk.selected_indices() {
177                let val = chunk.column(0).unwrap().get_int64(row).unwrap();
178                results.push(val);
179            }
180        }
181
182        assert_eq!(results, vec![1, 2, 3]);
183    }
184
185    #[test]
186    fn test_union_reset() {
187        let input1 = MockOperator::new(vec![create_int_chunk(&[1])]);
188        let input2 = MockOperator::new(vec![create_int_chunk(&[2])]);
189
190        let mut union = UnionOperator::new(
191            vec![Box::new(input1), Box::new(input2)],
192            vec![LogicalType::Int64],
193        );
194
195        // First pass
196        let mut count = 0;
197        while union.next().unwrap().is_some() {
198            count += 1;
199        }
200        assert_eq!(count, 2);
201
202        // Reset and second pass
203        union.reset();
204        count = 0;
205        while union.next().unwrap().is_some() {
206            count += 1;
207        }
208        assert_eq!(count, 2);
209    }
210}