use std::cmp::Ordering;
use std::collections::HashMap;
use arcstr::ArcStr;
use grafeo_common::types::{LogicalType, Value};
use super::{Operator, OperatorError, OperatorResult};
use crate::execution::chunk::DataChunkBuilder;
use crate::execution::{DataChunk, ValueVector};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum JoinType {
Inner,
Left,
Right,
Full,
Cross,
Semi,
Anti,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum HashKey {
Null,
Bool(bool),
Int64(i64),
String(ArcStr),
Bytes(Vec<u8>),
Composite(Vec<HashKey>),
}
impl Ord for HashKey {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(HashKey::Null, HashKey::Null) => Ordering::Equal,
(HashKey::Null, _) => Ordering::Less,
(_, HashKey::Null) => Ordering::Greater,
(HashKey::Bool(a), HashKey::Bool(b)) => a.cmp(b),
(HashKey::Bool(_), _) => Ordering::Less,
(_, HashKey::Bool(_)) => Ordering::Greater,
(HashKey::Int64(a), HashKey::Int64(b)) => a.cmp(b),
(HashKey::Int64(_), _) => Ordering::Less,
(_, HashKey::Int64(_)) => Ordering::Greater,
(HashKey::String(a), HashKey::String(b)) => a.cmp(b),
(HashKey::String(_), _) => Ordering::Less,
(_, HashKey::String(_)) => Ordering::Greater,
(HashKey::Bytes(a), HashKey::Bytes(b)) => a.cmp(b),
(HashKey::Bytes(_), _) => Ordering::Less,
(_, HashKey::Bytes(_)) => Ordering::Greater,
(HashKey::Composite(a), HashKey::Composite(b)) => a.cmp(b),
}
}
}
impl PartialOrd for HashKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl HashKey {
pub fn from_value(value: &Value) -> Self {
match value {
Value::Null => HashKey::Null,
Value::Bool(b) => HashKey::Bool(*b),
Value::Int64(i) => HashKey::Int64(*i),
Value::Float64(f) => {
#[allow(clippy::cast_possible_wrap)]
HashKey::Int64(f.to_bits() as i64)
}
Value::String(s) => HashKey::String(s.clone()),
Value::Bytes(b) => HashKey::Bytes(b.to_vec()),
Value::Timestamp(t) => HashKey::Int64(t.as_micros()),
#[allow(clippy::cast_possible_wrap)]
Value::Date(d) => HashKey::Int64(d.as_days() as i64),
#[allow(clippy::cast_possible_wrap)]
Value::Time(t) => HashKey::Int64(t.as_nanos() as i64),
Value::Duration(d) => HashKey::Composite(vec![
HashKey::Int64(d.months()),
HashKey::Int64(d.days()),
HashKey::Int64(d.nanos()),
]),
Value::ZonedDatetime(zdt) => HashKey::Int64(zdt.as_timestamp().as_micros()),
Value::List(items) => {
HashKey::Composite(items.iter().map(HashKey::from_value).collect())
}
Value::Map(map) => {
let keys: Vec<_> = map
.iter()
.map(|(k, v)| {
HashKey::Composite(vec![
HashKey::String(ArcStr::from(k.as_str())),
HashKey::from_value(v),
])
})
.collect();
HashKey::Composite(keys)
}
Value::Vector(v) => {
HashKey::Composite(
v.iter()
.map(|f| HashKey::Int64(f.to_bits() as i64))
.collect(),
)
}
Value::Path { nodes, edges } => {
let mut parts: Vec<_> = nodes.iter().map(HashKey::from_value).collect();
parts.extend(edges.iter().map(HashKey::from_value));
HashKey::Composite(parts)
}
Value::GCounter(counts) => {
#[allow(clippy::cast_possible_wrap)]
HashKey::Int64(counts.values().copied().map(|v| v as i64).sum())
}
Value::OnCounter { pos, neg } => {
#[allow(clippy::cast_possible_wrap)]
let p: i64 = pos.values().copied().map(|v| v as i64).sum();
#[allow(clippy::cast_possible_wrap)]
let n: i64 = neg.values().copied().map(|v| v as i64).sum();
HashKey::Int64(p - n)
}
_ => HashKey::Null,
}
}
pub fn from_column(column: &ValueVector, row: usize) -> Option<Self> {
column.get_value(row).map(|v| Self::from_value(&v))
}
}
pub struct HashJoinOperator {
probe_side: Box<dyn Operator>,
build_side: Box<dyn Operator>,
probe_keys: Vec<usize>,
build_keys: Vec<usize>,
join_type: JoinType,
output_schema: Vec<LogicalType>,
hash_table: HashMap<HashKey, Vec<(usize, usize)>>,
build_chunks: Vec<DataChunk>,
build_complete: bool,
current_probe_chunk: Option<DataChunk>,
current_probe_row: usize,
current_match_position: usize,
current_matches: Vec<(usize, usize)>,
probe_matched: Vec<bool>,
build_matched: Vec<Vec<bool>>,
emitting_unmatched: bool,
unmatched_chunk_idx: usize,
unmatched_row_idx: usize,
}
impl HashJoinOperator {
pub fn new(
probe_side: Box<dyn Operator>,
build_side: Box<dyn Operator>,
probe_keys: Vec<usize>,
build_keys: Vec<usize>,
join_type: JoinType,
output_schema: Vec<LogicalType>,
) -> Self {
Self {
probe_side,
build_side,
probe_keys,
build_keys,
join_type,
output_schema,
hash_table: HashMap::new(),
build_chunks: Vec::new(),
build_complete: false,
current_probe_chunk: None,
current_probe_row: 0,
current_match_position: 0,
current_matches: Vec::new(),
probe_matched: Vec::new(),
build_matched: Vec::new(),
emitting_unmatched: false,
unmatched_chunk_idx: 0,
unmatched_row_idx: 0,
}
}
fn build_hash_table(&mut self) -> Result<(), OperatorError> {
while let Some(chunk) = self.build_side.next()? {
let chunk_idx = self.build_chunks.len();
if matches!(self.join_type, JoinType::Right | JoinType::Full) {
self.build_matched.push(vec![false; chunk.row_count()]);
}
for row in chunk.selected_indices() {
let key = self.extract_key(&chunk, row, &self.build_keys)?;
if matches!(key, HashKey::Null)
&& !matches!(
self.join_type,
JoinType::Left | JoinType::Right | JoinType::Full
)
{
continue;
}
self.hash_table
.entry(key)
.or_default()
.push((chunk_idx, row));
}
self.build_chunks.push(chunk);
}
self.build_complete = true;
Ok(())
}
fn extract_key(
&self,
chunk: &DataChunk,
row: usize,
key_columns: &[usize],
) -> Result<HashKey, OperatorError> {
if key_columns.len() == 1 {
let col = chunk.column(key_columns[0]).ok_or_else(|| {
OperatorError::ColumnNotFound(format!("column {}", key_columns[0]))
})?;
Ok(HashKey::from_column(col, row).unwrap_or(HashKey::Null))
} else {
let keys: Vec<HashKey> = key_columns
.iter()
.map(|&col_idx| {
chunk
.column(col_idx)
.and_then(|col| HashKey::from_column(col, row))
.unwrap_or(HashKey::Null)
})
.collect();
Ok(HashKey::Composite(keys))
}
}
fn produce_output_row(
&self,
builder: &mut DataChunkBuilder,
probe_chunk: &DataChunk,
probe_row: usize,
build_chunk: Option<&DataChunk>,
build_row: Option<usize>,
) -> Result<(), OperatorError> {
let probe_col_count = probe_chunk.column_count();
for col_idx in 0..probe_col_count {
let src_col = probe_chunk
.column(col_idx)
.ok_or_else(|| OperatorError::ColumnNotFound(format!("probe column {col_idx}")))?;
let dst_col = builder
.column_mut(col_idx)
.ok_or_else(|| OperatorError::ColumnNotFound(format!("output column {col_idx}")))?;
if let Some(value) = src_col.get_value(probe_row) {
dst_col.push_value(value);
} else {
dst_col.push_value(Value::Null);
}
}
match (build_chunk, build_row) {
(Some(chunk), Some(row)) => {
for col_idx in 0..chunk.column_count() {
let src_col = chunk.column(col_idx).ok_or_else(|| {
OperatorError::ColumnNotFound(format!("build column {col_idx}"))
})?;
let dst_col =
builder
.column_mut(probe_col_count + col_idx)
.ok_or_else(|| {
OperatorError::ColumnNotFound(format!(
"output column {}",
probe_col_count + col_idx
))
})?;
if let Some(value) = src_col.get_value(row) {
dst_col.push_value(value);
} else {
dst_col.push_value(Value::Null);
}
}
}
_ => {
if !self.build_chunks.is_empty() {
let build_col_count = self.build_chunks[0].column_count();
for col_idx in 0..build_col_count {
let dst_col =
builder
.column_mut(probe_col_count + col_idx)
.ok_or_else(|| {
OperatorError::ColumnNotFound(format!(
"output column {}",
probe_col_count + col_idx
))
})?;
dst_col.push_value(Value::Null);
}
}
}
}
builder.advance_row();
Ok(())
}
fn get_next_probe_chunk(&mut self) -> Result<bool, OperatorError> {
let chunk = self.probe_side.next()?;
if let Some(ref c) = chunk {
if matches!(self.join_type, JoinType::Left | JoinType::Full) {
self.probe_matched = vec![false; c.row_count()];
}
}
let has_chunk = chunk.is_some();
self.current_probe_chunk = chunk;
self.current_probe_row = 0;
Ok(has_chunk)
}
fn emit_unmatched_build(&mut self) -> OperatorResult {
if self.build_matched.is_empty() {
return Ok(None);
}
let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
let probe_col_count = if !self.build_chunks.is_empty() {
self.output_schema.len() - self.build_chunks[0].column_count()
} else {
0
};
while self.unmatched_chunk_idx < self.build_chunks.len() {
let chunk = &self.build_chunks[self.unmatched_chunk_idx];
let matched = &self.build_matched[self.unmatched_chunk_idx];
while self.unmatched_row_idx < matched.len() {
if !matched[self.unmatched_row_idx] {
for col_idx in 0..probe_col_count {
if let Some(dst_col) = builder.column_mut(col_idx) {
dst_col.push_value(Value::Null);
}
}
for col_idx in 0..chunk.column_count() {
if let (Some(src_col), Some(dst_col)) = (
chunk.column(col_idx),
builder.column_mut(probe_col_count + col_idx),
) {
if let Some(value) = src_col.get_value(self.unmatched_row_idx) {
dst_col.push_value(value);
} else {
dst_col.push_value(Value::Null);
}
}
}
builder.advance_row();
if builder.is_full() {
self.unmatched_row_idx += 1;
return Ok(Some(builder.finish()));
}
}
self.unmatched_row_idx += 1;
}
self.unmatched_chunk_idx += 1;
self.unmatched_row_idx = 0;
}
if builder.row_count() > 0 {
Ok(Some(builder.finish()))
} else {
Ok(None)
}
}
}
impl Operator for HashJoinOperator {
fn next(&mut self) -> OperatorResult {
if !self.build_complete {
self.build_hash_table()?;
}
if self.emitting_unmatched {
return self.emit_unmatched_build();
}
let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
loop {
if self.current_probe_chunk.is_none() && !self.get_next_probe_chunk()? {
if matches!(self.join_type, JoinType::Right | JoinType::Full) {
self.emitting_unmatched = true;
return self.emit_unmatched_build();
}
return if builder.row_count() > 0 {
Ok(Some(builder.finish()))
} else {
Ok(None)
};
}
let probe_chunk = self
.current_probe_chunk
.as_ref()
.expect("probe chunk is Some: guard at line 396 ensures this");
let probe_rows: Vec<usize> = probe_chunk.selected_indices().collect();
while self.current_probe_row < probe_rows.len() {
let probe_row = probe_rows[self.current_probe_row];
if self.current_matches.is_empty() && self.current_match_position == 0 {
let key = self.extract_key(probe_chunk, probe_row, &self.probe_keys)?;
match self.join_type {
JoinType::Semi => {
if self.hash_table.contains_key(&key) {
for col_idx in 0..probe_chunk.column_count() {
if let (Some(src_col), Some(dst_col)) =
(probe_chunk.column(col_idx), builder.column_mut(col_idx))
&& let Some(value) = src_col.get_value(probe_row)
{
dst_col.push_value(value);
}
}
builder.advance_row();
}
self.current_probe_row += 1;
continue;
}
JoinType::Anti => {
if !self.hash_table.contains_key(&key) {
for col_idx in 0..probe_chunk.column_count() {
if let (Some(src_col), Some(dst_col)) =
(probe_chunk.column(col_idx), builder.column_mut(col_idx))
&& let Some(value) = src_col.get_value(probe_row)
{
dst_col.push_value(value);
}
}
builder.advance_row();
}
self.current_probe_row += 1;
continue;
}
_ => {
self.current_matches =
self.hash_table.get(&key).cloned().unwrap_or_default();
}
}
}
if self.current_matches.is_empty() {
if matches!(self.join_type, JoinType::Left | JoinType::Full) {
self.produce_output_row(&mut builder, probe_chunk, probe_row, None, None)?;
}
self.current_probe_row += 1;
self.current_match_position = 0;
} else {
while self.current_match_position < self.current_matches.len() {
let (build_chunk_idx, build_row) =
self.current_matches[self.current_match_position];
let build_chunk = &self.build_chunks[build_chunk_idx];
if matches!(self.join_type, JoinType::Left | JoinType::Full)
&& probe_row < self.probe_matched.len()
{
self.probe_matched[probe_row] = true;
}
if matches!(self.join_type, JoinType::Right | JoinType::Full)
&& build_chunk_idx < self.build_matched.len()
&& build_row < self.build_matched[build_chunk_idx].len()
{
self.build_matched[build_chunk_idx][build_row] = true;
}
self.produce_output_row(
&mut builder,
probe_chunk,
probe_row,
Some(build_chunk),
Some(build_row),
)?;
self.current_match_position += 1;
if builder.is_full() {
return Ok(Some(builder.finish()));
}
}
self.current_probe_row += 1;
self.current_matches.clear();
self.current_match_position = 0;
}
if builder.is_full() {
return Ok(Some(builder.finish()));
}
}
self.current_probe_chunk = None;
self.current_probe_row = 0;
if builder.row_count() > 0 {
return Ok(Some(builder.finish()));
}
}
}
fn reset(&mut self) {
self.probe_side.reset();
self.build_side.reset();
self.hash_table.clear();
self.build_chunks.clear();
self.build_complete = false;
self.current_probe_chunk = None;
self.current_probe_row = 0;
self.current_match_position = 0;
self.current_matches.clear();
self.probe_matched.clear();
self.build_matched.clear();
self.emitting_unmatched = false;
self.unmatched_chunk_idx = 0;
self.unmatched_row_idx = 0;
}
fn name(&self) -> &'static str {
"HashJoin"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
pub struct NestedLoopJoinOperator {
left: Box<dyn Operator>,
right: Box<dyn Operator>,
condition: Option<Box<dyn JoinCondition>>,
join_type: JoinType,
output_schema: Vec<LogicalType>,
right_chunks: Vec<DataChunk>,
right_materialized: bool,
current_left_chunk: Option<DataChunk>,
current_left_row: usize,
current_right_chunk: usize,
current_left_matched: bool,
current_right_row: usize,
}
pub trait JoinCondition: Send + Sync {
fn evaluate(
&self,
left_chunk: &DataChunk,
left_row: usize,
right_chunk: &DataChunk,
right_row: usize,
) -> bool;
}
pub struct EqualityCondition {
left_column: usize,
right_column: usize,
}
impl EqualityCondition {
pub fn new(left_column: usize, right_column: usize) -> Self {
Self {
left_column,
right_column,
}
}
}
impl JoinCondition for EqualityCondition {
fn evaluate(
&self,
left_chunk: &DataChunk,
left_row: usize,
right_chunk: &DataChunk,
right_row: usize,
) -> bool {
let left_val = left_chunk
.column(self.left_column)
.and_then(|c| c.get_value(left_row));
let right_val = right_chunk
.column(self.right_column)
.and_then(|c| c.get_value(right_row));
match (left_val, right_val) {
(Some(l), Some(r)) => l == r,
_ => false,
}
}
}
impl NestedLoopJoinOperator {
pub fn new(
left: Box<dyn Operator>,
right: Box<dyn Operator>,
condition: Option<Box<dyn JoinCondition>>,
join_type: JoinType,
output_schema: Vec<LogicalType>,
) -> Self {
Self {
left,
right,
condition,
join_type,
output_schema,
right_chunks: Vec::new(),
right_materialized: false,
current_left_chunk: None,
current_left_row: 0,
current_right_chunk: 0,
current_right_row: 0,
current_left_matched: false,
}
}
fn materialize_right(&mut self) -> Result<(), OperatorError> {
while let Some(chunk) = self.right.next()? {
self.right_chunks.push(chunk);
}
self.right_materialized = true;
Ok(())
}
fn produce_row(
&self,
builder: &mut DataChunkBuilder,
left_chunk: &DataChunk,
left_row: usize,
right_chunk: &DataChunk,
right_row: usize,
) {
for col_idx in 0..left_chunk.column_count() {
if let (Some(src), Some(dst)) =
(left_chunk.column(col_idx), builder.column_mut(col_idx))
{
if let Some(val) = src.get_value(left_row) {
dst.push_value(val);
} else {
dst.push_value(Value::Null);
}
}
}
let left_col_count = left_chunk.column_count();
for col_idx in 0..right_chunk.column_count() {
if let (Some(src), Some(dst)) = (
right_chunk.column(col_idx),
builder.column_mut(left_col_count + col_idx),
) {
if let Some(val) = src.get_value(right_row) {
dst.push_value(val);
} else {
dst.push_value(Value::Null);
}
}
}
builder.advance_row();
}
fn produce_left_unmatched_row(
&self,
builder: &mut DataChunkBuilder,
left_chunk: &DataChunk,
left_row: usize,
right_col_count: usize,
) {
for col_idx in 0..left_chunk.column_count() {
if let (Some(src), Some(dst)) =
(left_chunk.column(col_idx), builder.column_mut(col_idx))
{
if let Some(val) = src.get_value(left_row) {
dst.push_value(val);
} else {
dst.push_value(Value::Null);
}
}
}
let left_col_count = left_chunk.column_count();
for col_idx in 0..right_col_count {
if let Some(dst) = builder.column_mut(left_col_count + col_idx) {
dst.push_value(Value::Null);
}
}
builder.advance_row();
}
}
impl Operator for NestedLoopJoinOperator {
fn next(&mut self) -> OperatorResult {
if !self.right_materialized {
self.materialize_right()?;
}
if self.right_chunks.is_empty() && !matches!(self.join_type, JoinType::Left) {
return Ok(None);
}
let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
loop {
if self.current_left_chunk.is_none() {
self.current_left_chunk = self.left.next()?;
self.current_left_row = 0;
self.current_right_chunk = 0;
self.current_right_row = 0;
if self.current_left_chunk.is_none() {
return if builder.row_count() > 0 {
Ok(Some(builder.finish()))
} else {
Ok(None)
};
}
}
let left_chunk = self
.current_left_chunk
.as_ref()
.expect("left chunk is Some: loaded in loop above");
let left_rows: Vec<usize> = left_chunk.selected_indices().collect();
let right_col_count = if !self.right_chunks.is_empty() {
self.right_chunks[0].column_count()
} else {
self.output_schema
.len()
.saturating_sub(left_chunk.column_count())
};
while self.current_left_row < left_rows.len() {
let left_row = left_rows[self.current_left_row];
if self.current_right_chunk == 0 && self.current_right_row == 0 {
self.current_left_matched = false;
}
while self.current_right_chunk < self.right_chunks.len() {
let right_chunk = &self.right_chunks[self.current_right_chunk];
let right_rows: Vec<usize> = right_chunk.selected_indices().collect();
while self.current_right_row < right_rows.len() {
let right_row = right_rows[self.current_right_row];
let matches = match &self.condition {
Some(cond) => {
cond.evaluate(left_chunk, left_row, right_chunk, right_row)
}
None => true, };
if matches {
self.current_left_matched = true;
self.produce_row(
&mut builder,
left_chunk,
left_row,
right_chunk,
right_row,
);
if builder.is_full() {
self.current_right_row += 1;
return Ok(Some(builder.finish()));
}
}
self.current_right_row += 1;
}
self.current_right_chunk += 1;
self.current_right_row = 0;
}
if matches!(self.join_type, JoinType::Left) && !self.current_left_matched {
self.produce_left_unmatched_row(
&mut builder,
left_chunk,
left_row,
right_col_count,
);
if builder.is_full() {
self.current_left_row += 1;
self.current_right_chunk = 0;
self.current_right_row = 0;
return Ok(Some(builder.finish()));
}
}
self.current_left_row += 1;
self.current_right_chunk = 0;
self.current_right_row = 0;
}
self.current_left_chunk = None;
if builder.row_count() > 0 {
return Ok(Some(builder.finish()));
}
}
}
fn reset(&mut self) {
self.left.reset();
self.right.reset();
self.right_chunks.clear();
self.right_materialized = false;
self.current_left_chunk = None;
self.current_left_row = 0;
self.current_right_chunk = 0;
self.current_right_row = 0;
self.current_left_matched = false;
}
fn name(&self) -> &'static str {
"NestedLoopJoin"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::chunk::DataChunkBuilder;
struct MockOperator {
chunks: Vec<DataChunk>,
position: usize,
}
impl MockOperator {
fn new(chunks: Vec<DataChunk>) -> Self {
Self {
chunks,
position: 0,
}
}
}
impl Operator for MockOperator {
fn next(&mut self) -> OperatorResult {
if self.position < self.chunks.len() {
let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
self.position += 1;
Ok(Some(chunk))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.position = 0;
}
fn name(&self) -> &'static str {
"Mock"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
fn create_int_chunk(values: &[i64]) -> DataChunk {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
for &v in values {
builder.column_mut(0).unwrap().push_int64(v);
builder.advance_row();
}
builder.finish()
}
#[test]
fn test_hash_join_inner() {
let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
let right = MockOperator::new(vec![create_int_chunk(&[2, 3, 4, 5])]);
let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
let mut join = HashJoinOperator::new(
Box::new(left),
Box::new(right),
vec![0],
vec![0],
JoinType::Inner,
output_schema,
);
let mut results = Vec::new();
while let Some(chunk) = join.next().unwrap() {
for row in chunk.selected_indices() {
let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
results.push((left_val, right_val));
}
}
results.sort_unstable();
assert_eq!(results, vec![(2, 2), (3, 3), (4, 4)]);
}
#[test]
fn test_hash_join_left_outer() {
let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3])]);
let right = MockOperator::new(vec![create_int_chunk(&[2, 3])]);
let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
let mut join = HashJoinOperator::new(
Box::new(left),
Box::new(right),
vec![0],
vec![0],
JoinType::Left,
output_schema,
);
let mut results = Vec::new();
while let Some(chunk) = join.next().unwrap() {
for row in chunk.selected_indices() {
let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
let right_val = chunk.column(1).unwrap().get_int64(row);
results.push((left_val, right_val));
}
}
results.sort_by_key(|(l, _)| *l);
assert_eq!(results.len(), 3);
assert_eq!(results[0], (1, None)); assert_eq!(results[1], (2, Some(2)));
assert_eq!(results[2], (3, Some(3)));
}
#[test]
fn test_nested_loop_cross_join() {
let left = MockOperator::new(vec![create_int_chunk(&[1, 2])]);
let right = MockOperator::new(vec![create_int_chunk(&[10, 20])]);
let output_schema = vec![LogicalType::Int64, LogicalType::Int64];
let mut join = NestedLoopJoinOperator::new(
Box::new(left),
Box::new(right),
None,
JoinType::Cross,
output_schema,
);
let mut results = Vec::new();
while let Some(chunk) = join.next().unwrap() {
for row in chunk.selected_indices() {
let left_val = chunk.column(0).unwrap().get_int64(row).unwrap();
let right_val = chunk.column(1).unwrap().get_int64(row).unwrap();
results.push((left_val, right_val));
}
}
results.sort_unstable();
assert_eq!(results, vec![(1, 10), (1, 20), (2, 10), (2, 20)]);
}
#[test]
fn test_hash_join_semi() {
let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
let output_schema = vec![LogicalType::Int64];
let mut join = HashJoinOperator::new(
Box::new(left),
Box::new(right),
vec![0],
vec![0],
JoinType::Semi,
output_schema,
);
let mut results = Vec::new();
while let Some(chunk) = join.next().unwrap() {
for row in chunk.selected_indices() {
let val = chunk.column(0).unwrap().get_int64(row).unwrap();
results.push(val);
}
}
results.sort_unstable();
assert_eq!(results, vec![2, 4]);
}
#[test]
fn test_hash_join_anti() {
let left = MockOperator::new(vec![create_int_chunk(&[1, 2, 3, 4])]);
let right = MockOperator::new(vec![create_int_chunk(&[2, 4])]);
let output_schema = vec![LogicalType::Int64];
let mut join = HashJoinOperator::new(
Box::new(left),
Box::new(right),
vec![0],
vec![0],
JoinType::Anti,
output_schema,
);
let mut results = Vec::new();
while let Some(chunk) = join.next().unwrap() {
for row in chunk.selected_indices() {
let val = chunk.column(0).unwrap().get_int64(row).unwrap();
results.push(val);
}
}
results.sort_unstable();
assert_eq!(results, vec![1, 3]);
}
#[test]
fn test_hash_key_from_map() {
use grafeo_common::types::{PropertyKey, Value};
use std::collections::BTreeMap;
use std::sync::Arc;
let mut map = BTreeMap::new();
map.insert(PropertyKey::new("key"), Value::Int64(42));
let v = Value::Map(Arc::new(map));
let key = HashKey::from_value(&v);
assert!(matches!(key, HashKey::Composite(_)));
let mut map2 = BTreeMap::new();
map2.insert(PropertyKey::new("key"), Value::Int64(42));
let v2 = Value::Map(Arc::new(map2));
assert_eq!(HashKey::from_value(&v), HashKey::from_value(&v2));
}
#[test]
fn test_hash_key_from_map_empty() {
use grafeo_common::types::Value;
use std::collections::BTreeMap;
use std::sync::Arc;
let v = Value::Map(Arc::new(BTreeMap::new()));
let key = HashKey::from_value(&v);
assert_eq!(key, HashKey::Composite(vec![]));
}
#[test]
fn test_hash_key_from_gcounter() {
use grafeo_common::types::Value;
use std::collections::HashMap;
use std::sync::Arc;
let mut counts = HashMap::new();
counts.insert("node-a".to_string(), 5u64);
counts.insert("node-b".to_string(), 3u64);
let v = Value::GCounter(Arc::new(counts));
assert_eq!(HashKey::from_value(&v), HashKey::Int64(8));
}
#[test]
fn test_hash_key_from_gcounter_empty() {
use grafeo_common::types::Value;
use std::collections::HashMap;
use std::sync::Arc;
let v = Value::GCounter(Arc::new(HashMap::new()));
assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
}
#[test]
fn test_hash_key_from_oncounter() {
use grafeo_common::types::Value;
use std::collections::HashMap;
use std::sync::Arc;
let mut pos = HashMap::new();
pos.insert("node-a".to_string(), 10u64);
let mut neg = HashMap::new();
neg.insert("node-a".to_string(), 3u64);
let v = Value::OnCounter {
pos: Arc::new(pos),
neg: Arc::new(neg),
};
assert_eq!(HashKey::from_value(&v), HashKey::Int64(7));
}
#[test]
fn test_hash_key_from_oncounter_balanced() {
use grafeo_common::types::Value;
use std::collections::HashMap;
use std::sync::Arc;
let mut pos = HashMap::new();
pos.insert("r".to_string(), 5u64);
let mut neg = HashMap::new();
neg.insert("r".to_string(), 5u64);
let v = Value::OnCounter {
pos: Arc::new(pos),
neg: Arc::new(neg),
};
assert_eq!(HashKey::from_value(&v), HashKey::Int64(0));
}
#[test]
fn test_hash_join_into_any() {
let left = MockOperator::new(vec![]);
let right = MockOperator::new(vec![]);
let op = HashJoinOperator::new(
Box::new(left),
Box::new(right),
vec![0],
vec![0],
JoinType::Inner,
vec![LogicalType::Int64, LogicalType::Int64],
);
let any = Box::new(op).into_any();
assert!(any.downcast::<HashJoinOperator>().is_ok());
}
#[test]
fn test_nested_loop_join_into_any() {
let left = MockOperator::new(vec![]);
let right = MockOperator::new(vec![]);
let op = NestedLoopJoinOperator::new(
Box::new(left),
Box::new(right),
None,
JoinType::Cross,
vec![LogicalType::Int64, LogicalType::Int64],
);
let any = Box::new(op).into_any();
assert!(any.downcast::<NestedLoopJoinOperator>().is_ok());
}
#[test]
fn test_hash_key_ord_same_variant() {
use std::cmp::Ordering;
assert_eq!(HashKey::Null.cmp(&HashKey::Null), Ordering::Equal);
assert_eq!(
HashKey::Bool(false).cmp(&HashKey::Bool(true)),
Ordering::Less
);
assert_eq!(
HashKey::Bool(true).cmp(&HashKey::Bool(false)),
Ordering::Greater
);
assert_eq!(HashKey::Int64(1).cmp(&HashKey::Int64(2)), Ordering::Less);
assert_eq!(HashKey::Int64(5).cmp(&HashKey::Int64(5)), Ordering::Equal);
assert_eq!(
HashKey::String(arcstr::literal!("a")).cmp(&HashKey::String(arcstr::literal!("b"))),
Ordering::Less,
);
assert_eq!(
HashKey::Bytes(vec![1, 2]).cmp(&HashKey::Bytes(vec![1, 3])),
Ordering::Less,
);
assert_eq!(
HashKey::Composite(vec![HashKey::Int64(1)])
.cmp(&HashKey::Composite(vec![HashKey::Int64(2)])),
Ordering::Less,
);
}
#[test]
fn test_hash_key_ord_cross_variant() {
use std::cmp::Ordering;
assert_eq!(HashKey::Null.cmp(&HashKey::Bool(false)), Ordering::Less);
assert_eq!(HashKey::Bool(true).cmp(&HashKey::Null), Ordering::Greater);
assert_eq!(HashKey::Bool(false).cmp(&HashKey::Int64(0)), Ordering::Less);
assert_eq!(
HashKey::Int64(0).cmp(&HashKey::Bool(false)),
Ordering::Greater
);
assert_eq!(
HashKey::Int64(0).cmp(&HashKey::String(arcstr::literal!("a"))),
Ordering::Less,
);
assert_eq!(
HashKey::String(arcstr::literal!("a")).cmp(&HashKey::Int64(0)),
Ordering::Greater,
);
assert_eq!(
HashKey::String(arcstr::literal!("a")).cmp(&HashKey::Bytes(vec![1])),
Ordering::Less,
);
assert_eq!(
HashKey::Bytes(vec![1]).cmp(&HashKey::String(arcstr::literal!("a"))),
Ordering::Greater,
);
assert_eq!(
HashKey::Bytes(vec![1]).cmp(&HashKey::Composite(vec![])),
Ordering::Less,
);
assert_eq!(
HashKey::Composite(vec![]).cmp(&HashKey::Bytes(vec![1])),
Ordering::Greater,
);
}
#[test]
fn test_hash_key_partial_ord() {
assert!(HashKey::Null.partial_cmp(&HashKey::Int64(1)).is_some());
assert!(HashKey::Int64(1).partial_cmp(&HashKey::Int64(2)).is_some());
}
}