use grafeo_common::types::LogicalType;
use super::{Operator, OperatorResult};
pub struct UnionOperator {
inputs: Vec<Box<dyn Operator>>,
current_input: usize,
output_schema: Vec<LogicalType>,
}
impl UnionOperator {
pub fn new(inputs: Vec<Box<dyn Operator>>, output_schema: Vec<LogicalType>) -> Self {
Self {
inputs,
current_input: 0,
output_schema,
}
}
#[must_use]
pub fn output_schema(&self) -> &[LogicalType] {
&self.output_schema
}
}
impl Operator for UnionOperator {
fn next(&mut self) -> OperatorResult {
while self.current_input < self.inputs.len() {
if let Some(chunk) = self.inputs[self.current_input].next()? {
return Ok(Some(chunk));
}
self.current_input += 1;
}
Ok(None)
}
fn reset(&mut self) {
for input in &mut self.inputs {
input.reset();
}
self.current_input = 0;
}
fn name(&self) -> &'static str {
"Union"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::DataChunk;
use crate::execution::chunk::DataChunkBuilder;
struct MockOperator {
chunks: Vec<DataChunk>,
position: usize,
}
impl MockOperator {
fn new(chunks: Vec<DataChunk>) -> Self {
Self {
chunks,
position: 0,
}
}
}
impl Operator for MockOperator {
fn next(&mut self) -> OperatorResult {
if self.position < self.chunks.len() {
let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
self.position += 1;
Ok(Some(chunk))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.position = 0;
}
fn name(&self) -> &'static str {
"Mock"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
fn create_int_chunk(values: &[i64]) -> DataChunk {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
for &v in values {
builder.column_mut(0).unwrap().push_int64(v);
builder.advance_row();
}
builder.finish()
}
#[test]
fn test_union_two_inputs() {
let input1 = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
let input2 = MockOperator::new(vec![create_int_chunk(&[3, 4])]);
let mut union = UnionOperator::new(
vec![Box::new(input1), Box::new(input2)],
vec![LogicalType::Int64],
);
let mut results = Vec::new();
while let Some(chunk) = union.next().unwrap() {
for row in chunk.selected_indices() {
let val = chunk.column(0).unwrap().get_int64(row).unwrap();
results.push(val);
}
}
assert_eq!(results, vec![1, 2, 3, 4]);
}
#[test]
fn test_union_three_inputs() {
let input1 = MockOperator::new(vec![create_int_chunk(&[1])]);
let input2 = MockOperator::new(vec![create_int_chunk(&[2])]);
let input3 = MockOperator::new(vec![create_int_chunk(&[3])]);
let mut union = UnionOperator::new(
vec![Box::new(input1), Box::new(input2), Box::new(input3)],
vec![LogicalType::Int64],
);
let mut results = Vec::new();
while let Some(chunk) = union.next().unwrap() {
for row in chunk.selected_indices() {
let val = chunk.column(0).unwrap().get_int64(row).unwrap();
results.push(val);
}
}
assert_eq!(results, vec![1, 2, 3]);
}
#[test]
fn test_union_empty_input() {
let input1 = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
let input2 = MockOperator::new(vec![]); let input3 = MockOperator::new(vec![create_int_chunk(&[3])]);
let mut union = UnionOperator::new(
vec![Box::new(input1), Box::new(input2), Box::new(input3)],
vec![LogicalType::Int64],
);
let mut results = Vec::new();
while let Some(chunk) = union.next().unwrap() {
for row in chunk.selected_indices() {
let val = chunk.column(0).unwrap().get_int64(row).unwrap();
results.push(val);
}
}
assert_eq!(results, vec![1, 2, 3]);
}
#[test]
fn test_union_reset() {
let input1 = MockOperator::new(vec![create_int_chunk(&[1])]);
let input2 = MockOperator::new(vec![create_int_chunk(&[2])]);
let mut union = UnionOperator::new(
vec![Box::new(input1), Box::new(input2)],
vec![LogicalType::Int64],
);
let mut count = 0;
while union.next().unwrap().is_some() {
count += 1;
}
assert_eq!(count, 2);
union.reset();
count = 0;
while union.next().unwrap().is_some() {
count += 1;
}
assert_eq!(count, 2);
}
#[test]
fn test_union_into_any() {
let left = MockOperator::new(vec![]);
let right = MockOperator::new(vec![]);
let op = UnionOperator::new(
vec![Box::new(left), Box::new(right)],
vec![LogicalType::Int64],
);
let any = Box::new(op).into_any();
assert!(any.downcast::<UnionOperator>().is_ok());
}
}