#![deny(clippy::disallowed_types)]
use std::collections::BTreeMap;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, RecordBatch};
use arrow::compute::concat_batches;
use arrow::datatypes::{Field, Schema, SchemaRef};
use rustc_hash::{FxHashMap, FxHashSet};
use laminar_sql::translator::{StreamJoinConfig, StreamJoinType};
use crate::aggregate_state::JoinStateCheckpoint;
use crate::error::DbError;
use crate::key_column::{extract_column_as_timestamps, extract_key_column, KeyColumn};
const COMPACTION_THRESHOLD: usize = 32;
type SideIndex = FxHashMap<u64, BTreeMap<i64, Vec<(usize, usize)>>>;
pub(crate) struct SideState {
batches: Vec<RecordBatch>,
index: SideIndex,
row_count: usize,
}
impl SideState {
fn new() -> Self {
Self {
batches: Vec::new(),
index: FxHashMap::default(),
row_count: 0,
}
}
fn add_batch(
&mut self,
batch: &RecordBatch,
key_col_name: &str,
time_col_name: &str,
) -> Result<(), DbError> {
if batch.num_rows() == 0 {
return Ok(());
}
let batch_idx = self.batches.len();
let keys = extract_key_column(batch, key_col_name)?;
let timestamps = extract_column_as_timestamps(batch, time_col_name)?;
let mut indexed_rows = 0usize;
for (row_idx, &ts) in timestamps.iter().enumerate() {
if let Some(key_hash) = keys.hash_at(row_idx) {
self.index
.entry(key_hash)
.or_default()
.entry(ts)
.or_default()
.push((batch_idx, row_idx));
indexed_rows += 1;
}
}
self.row_count += indexed_rows;
self.batches.push(batch.clone());
Ok(())
}
fn remove_by_key_ts(
&mut self,
key_hash: u64,
ts: i64,
delete_key: &KeyColumn<'_>,
delete_row: usize,
key_col_name: &str,
) {
let Some(btree) = self.index.get_mut(&key_hash) else {
return;
};
let Some(entries) = btree.get_mut(&ts) else {
return;
};
let before = entries.len();
entries.retain(|&(batch_idx, row_idx)| {
extract_key_column(&self.batches[batch_idx], key_col_name).map_or(true, |stored_key| {
!delete_key.keys_equal(delete_row, &stored_key, row_idx)
})
});
let removed = before - entries.len();
self.row_count = self.row_count.saturating_sub(removed);
if entries.is_empty() {
btree.remove(&ts);
}
if btree.is_empty() {
self.index.remove(&key_hash);
}
}
fn evict_before(&mut self, cutoff: i64, key_col: &str, time_col: &str) -> Result<(), DbError> {
for btree in self.index.values_mut() {
let keep = btree.split_off(&cutoff);
for entries in btree.values() {
self.row_count = self.row_count.saturating_sub(entries.len());
}
*btree = keep;
}
self.index.retain(|_, btree| !btree.is_empty());
if self.batches.len() > COMPACTION_THRESHOLD {
self.compact(key_col, time_col)?;
}
Ok(())
}
fn compact(&mut self, key_col: &str, time_col: &str) -> Result<(), DbError> {
let mut live_rows: Vec<(usize, usize)> = Vec::with_capacity(self.row_count);
for btree in self.index.values() {
for entries in btree.values() {
live_rows.extend_from_slice(entries);
}
}
if live_rows.is_empty() {
self.batches.clear();
return Ok(());
}
live_rows.sort_unstable();
let mut slices: Vec<RecordBatch> = Vec::with_capacity(live_rows.len());
for &(batch_idx, row_idx) in &live_rows {
slices.push(self.batches[batch_idx].slice(row_idx, 1));
}
let schema = self.batches[0].schema();
let compacted = concat_batches(&schema, &slices)
.map_err(|e| DbError::query_pipeline_arrow("interval join (compact)", &e))?;
self.batches = vec![compacted];
self.index.clear();
let keys = extract_key_column(&self.batches[0], key_col)?;
let timestamps = extract_column_as_timestamps(&self.batches[0], time_col)?;
for (row_idx, &ts) in timestamps.iter().enumerate() {
if let Some(key_hash) = keys.hash_at(row_idx) {
self.index
.entry(key_hash)
.or_default()
.entry(ts)
.or_default()
.push((0, row_idx));
}
}
self.row_count = self.batches[0].num_rows();
Ok(())
}
}
pub(crate) struct IntervalJoinState {
left: SideState,
right: SideState,
left_evicted_cutoff: i64,
right_evicted_cutoff: i64,
output_schema: Option<SchemaRef>,
}
impl IntervalJoinState {
pub(crate) fn new() -> Self {
Self {
left: SideState::new(),
right: SideState::new(),
left_evicted_cutoff: i64::MIN,
right_evicted_cutoff: i64::MIN,
output_schema: None,
}
}
pub(crate) fn estimated_size_bytes(&self) -> usize {
let mut size = 0usize;
for b in &self.left.batches {
size += b.get_array_memory_size();
}
for b in &self.right.batches {
size += b.get_array_memory_size();
}
let index_entries: usize = self.left.index.values().map(BTreeMap::len).sum::<usize>()
+ self.right.index.values().map(BTreeMap::len).sum::<usize>();
size += index_entries * 64;
size
}
pub(crate) fn snapshot_checkpoint(
&mut self,
left_key: &str,
left_time: &str,
right_key: &str,
right_time: &str,
) -> Result<JoinStateCheckpoint, DbError> {
if !self.left.batches.is_empty() {
self.left.compact(left_key, left_time)?;
}
if !self.right.batches.is_empty() {
self.right.compact(right_key, right_time)?;
}
let mut left_batches_ipc = Vec::with_capacity(self.left.batches.len());
for batch in &self.left.batches {
if batch.num_rows() == 0 {
continue;
}
let ipc = laminar_core::serialization::serialize_batch_stream(batch).map_err(|e| {
DbError::Pipeline(format!("interval join left batch serialization: {e}"))
})?;
left_batches_ipc.push(ipc);
}
let mut right_batches_ipc = Vec::with_capacity(self.right.batches.len());
for batch in &self.right.batches {
if batch.num_rows() == 0 {
continue;
}
let ipc = laminar_core::serialization::serialize_batch_stream(batch).map_err(|e| {
DbError::Pipeline(format!("interval join right batch serialization: {e}"))
})?;
right_batches_ipc.push(ipc);
}
Ok(JoinStateCheckpoint {
left_buffer_rows: self.left.row_count as u64,
right_buffer_rows: self.right.row_count as u64,
left_batches: left_batches_ipc,
right_batches: right_batches_ipc,
last_evicted_watermark: self.left_evicted_cutoff,
last_evicted_watermark_right: self.right_evicted_cutoff,
})
}
pub(crate) fn from_checkpoint(
cp: &JoinStateCheckpoint,
left_key_col: &str,
left_time_col: &str,
right_key_col: &str,
right_time_col: &str,
) -> Result<Self, DbError> {
let mut state = Self::new();
state.left_evicted_cutoff = cp.last_evicted_watermark;
state.right_evicted_cutoff = cp.last_evicted_watermark_right;
for ipc_bytes in &cp.left_batches {
let batch =
laminar_core::serialization::deserialize_batch_stream(ipc_bytes).map_err(|e| {
DbError::Pipeline(format!("interval join left batch deserialization: {e}"))
})?;
state.left.add_batch(&batch, left_key_col, left_time_col)?;
}
for ipc_bytes in &cp.right_batches {
let batch =
laminar_core::serialization::deserialize_batch_stream(ipc_bytes).map_err(|e| {
DbError::Pipeline(format!("interval join right batch deserialization: {e}"))
})?;
state
.right
.add_batch(&batch, right_key_col, right_time_col)?;
}
Ok(state)
}
}
fn build_output_schema(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
config: &StreamJoinConfig,
) -> SchemaRef {
let left_nullable = matches!(
config.join_type,
StreamJoinType::Right | StreamJoinType::Full
);
let right_nullable = matches!(
config.join_type,
StreamJoinType::Left | StreamJoinType::Full
);
let mut fields: Vec<Field> = left_schema
.fields()
.iter()
.map(|f| {
let mut field = f.as_ref().clone();
if left_nullable {
field = field.with_nullable(true);
}
field
})
.collect();
if matches!(
config.join_type,
StreamJoinType::LeftSemi | StreamJoinType::LeftAnti
) {
return Arc::new(Schema::new(fields));
}
for field in right_schema.fields() {
let mut f = field.as_ref().clone();
if right_nullable {
f = f.with_nullable(true);
}
let suffixed = format!("{}_{}", f.name(), config.right_table);
fields.push(f.with_name(suffixed));
}
Arc::new(Schema::new(fields))
}
fn probe_index(
index: &SideIndex,
key_hash: u64,
probe_ts: i64,
bound_ms: i64,
) -> Vec<(usize, usize)> {
let Some(btree) = index.get(&key_hash) else {
return Vec::new();
};
let low = probe_ts.saturating_sub(bound_ms);
let high = probe_ts.saturating_add(bound_ms);
let mut results = Vec::new();
for (_, entries) in btree.range(low..=high) {
results.extend_from_slice(entries);
}
results
}
#[allow(clippy::too_many_lines)]
pub(crate) fn execute_interval_join_cycle(
state: &mut IntervalJoinState,
left_batches: &[RecordBatch],
right_batches: &[RecordBatch],
config: &StreamJoinConfig,
left_watermark: i64,
right_watermark: i64,
) -> Result<Vec<RecordBatch>, DbError> {
let bound_ms = i64::try_from(config.time_bound.as_millis()).unwrap_or(i64::MAX);
let left_pos: Vec<RecordBatch> = left_batches
.iter()
.map(crate::changelog_filter::filter_positive_events)
.collect::<Result<Vec<_>, _>>()?;
let right_pos: Vec<RecordBatch> = right_batches
.iter()
.map(crate::changelog_filter::filter_positive_events)
.collect::<Result<Vec<_>, _>>()?;
for raw_batch in left_batches {
if let Some(neg) = crate::changelog_filter::extract_negative_events(raw_batch)? {
let keys = extract_key_column(&neg, &config.left_key)?;
let timestamps = extract_column_as_timestamps(&neg, &config.left_time_column)?;
for (i, &ts) in timestamps.iter().enumerate() {
if let Some(kh) = keys.hash_at(i) {
state
.left
.remove_by_key_ts(kh, ts, &keys, i, &config.left_key);
}
}
}
}
for raw_batch in right_batches {
if let Some(neg) = crate::changelog_filter::extract_negative_events(raw_batch)? {
let keys = extract_key_column(&neg, &config.right_key)?;
let timestamps = extract_column_as_timestamps(&neg, &config.right_time_column)?;
for (i, &ts) in timestamps.iter().enumerate() {
if let Some(kh) = keys.hash_at(i) {
state
.right
.remove_by_key_ts(kh, ts, &keys, i, &config.right_key);
}
}
}
}
let left_batches = &left_pos[..];
let right_batches = &right_pos[..];
let new_left = if left_batches.is_empty() {
None
} else {
let schema = left_batches[0].schema();
Some(
concat_batches(&schema, left_batches)
.map_err(|e| DbError::query_pipeline_arrow("interval join (left concat)", &e))?,
)
};
let new_right = if right_batches.is_empty() {
None
} else {
let schema = right_batches[0].schema();
Some(
concat_batches(&schema, right_batches)
.map_err(|e| DbError::query_pipeline_arrow("interval join (right concat)", &e))?,
)
};
let left_old_count = state.left.batches.len();
let right_old_count = state.right.batches.len();
if let Some(ref rb) = new_right {
state
.right
.add_batch(rb, &config.right_key, &config.right_time_column)?;
}
let mut match_pairs: Vec<(usize, usize, usize, usize)> = Vec::new();
let is_semi = config.join_type == StreamJoinType::LeftSemi;
let mut semi_matched: FxHashSet<usize> = FxHashSet::default();
if let Some(ref lb) = new_left {
let left_keys = extract_key_column(lb, &config.left_key)?;
let left_timestamps = extract_column_as_timestamps(lb, &config.left_time_column)?;
let new_left_batch_idx = left_old_count;
for (row_idx, &left_ts) in left_timestamps.iter().enumerate() {
if is_semi && semi_matched.contains(&row_idx) {
continue;
}
let Some(key_hash) = left_keys.hash_at(row_idx) else {
continue; };
let candidates = probe_index(&state.right.index, key_hash, left_ts, bound_ms);
for (r_batch, r_row) in candidates {
if is_semi && semi_matched.contains(&row_idx) {
break;
}
let r_key_col =
extract_key_column(&state.right.batches[r_batch], &config.right_key)?;
if left_keys.keys_equal(row_idx, &r_key_col, r_row) {
match_pairs.push((new_left_batch_idx, row_idx, r_batch, r_row));
if is_semi {
semi_matched.insert(row_idx);
}
}
}
}
}
if let Some(ref rb) = new_right {
let right_keys = extract_key_column(rb, &config.right_key)?;
let right_timestamps = extract_column_as_timestamps(rb, &config.right_time_column)?;
let new_right_batch_idx = right_old_count;
for (row_idx, &right_ts) in right_timestamps.iter().enumerate() {
let Some(key_hash) = right_keys.hash_at(row_idx) else {
continue; };
let candidates = probe_index(&state.left.index, key_hash, right_ts, bound_ms);
for (l_batch, l_row) in candidates {
if l_batch < left_old_count {
let l_key_col =
extract_key_column(&state.left.batches[l_batch], &config.left_key)?;
if right_keys.keys_equal(row_idx, &l_key_col, l_row) {
match_pairs.push((l_batch, l_row, new_right_batch_idx, row_idx));
}
}
}
}
}
if let Some(ref lb) = new_left {
state
.left
.add_batch(lb, &config.left_key, &config.left_time_column)?;
}
{
let left_schema = state.left.batches.first().map(RecordBatch::schema);
let right_schema = state.right.batches.first().map(RecordBatch::schema);
match (left_schema, right_schema) {
(Some(ls), Some(rs)) => {
state.output_schema = Some(build_output_schema(&ls, &rs, config));
}
(Some(ls), None)
if state.output_schema.is_none()
&& matches!(
config.join_type,
StreamJoinType::LeftSemi | StreamJoinType::LeftAnti
) =>
{
state.output_schema =
Some(build_output_schema(&ls, &Arc::new(Schema::empty()), config));
}
(None, Some(rs))
if state.output_schema.is_none() && config.join_type == StreamJoinType::Right =>
{
state.output_schema =
Some(build_output_schema(&Arc::new(Schema::empty()), &rs, config));
}
_ => {}
}
}
let left_only = matches!(
config.join_type,
StreamJoinType::LeftSemi | StreamJoinType::LeftAnti
);
let mut result = if match_pairs.is_empty() || config.join_type == StreamJoinType::LeftAnti {
Vec::new()
} else {
let output_schema = state.output_schema.as_ref().ok_or_else(|| {
DbError::Pipeline("interval join: output schema not available".to_string())
})?;
let left_schema = state
.left
.batches
.first()
.map_or_else(|| Arc::new(Schema::empty()), RecordBatch::schema);
let num_rows = match_pairs.len();
let mut columns: Vec<ArrayRef> = Vec::with_capacity(output_schema.fields().len());
for col_idx in 0..left_schema.fields().len() {
let mut builder = Vec::with_capacity(num_rows);
for &(l_batch, l_row, _, _) in &match_pairs {
let array = state.left.batches[l_batch].column(col_idx);
let sliced = array.slice(l_row, 1);
builder.push(sliced);
}
let refs: Vec<&dyn Array> = builder.iter().map(AsRef::as_ref).collect();
let concatenated = arrow::compute::concat(&refs)
.map_err(|e| DbError::query_pipeline_arrow("interval join (left concat)", &e))?;
columns.push(concatenated);
}
if !left_only {
let right_schema = state
.right
.batches
.first()
.map_or_else(|| Arc::new(Schema::empty()), RecordBatch::schema);
for col_idx in 0..right_schema.fields().len() {
let mut builder = Vec::with_capacity(num_rows);
for &(_, _, r_batch, r_row) in &match_pairs {
let array = state.right.batches[r_batch].column(col_idx);
let sliced = array.slice(r_row, 1);
builder.push(sliced);
}
let refs: Vec<&dyn Array> = builder.iter().map(AsRef::as_ref).collect();
let concatenated = arrow::compute::concat(&refs).map_err(|e| {
DbError::query_pipeline_arrow("interval join (right concat)", &e)
})?;
columns.push(concatenated);
}
}
let batch = RecordBatch::try_new(output_schema.clone(), columns)
.map_err(|e| DbError::query_pipeline_arrow("interval join (result)", &e))?;
if batch.num_rows() > 0 {
vec![batch]
} else {
Vec::new()
}
};
if matches!(
config.join_type,
StreamJoinType::Left | StreamJoinType::Full | StreamJoinType::LeftAnti
) {
let left_cutoff = right_watermark.saturating_sub(bound_ms);
if left_cutoff > state.left_evicted_cutoff && !state.left.batches.is_empty() {
let unmatched = emit_unmatched_left_rows(state, config, left_cutoff, bound_ms)?;
if let Some(batch) = unmatched {
result.push(batch);
}
}
}
if matches!(
config.join_type,
StreamJoinType::Right | StreamJoinType::Full
) {
let right_cutoff = left_watermark.saturating_sub(bound_ms);
if right_cutoff > state.right_evicted_cutoff && !state.right.batches.is_empty() {
let unmatched = emit_unmatched_right_rows(state, config, right_cutoff, bound_ms)?;
if let Some(batch) = unmatched {
result.push(batch);
}
}
}
let left_cutoff = right_watermark.saturating_sub(bound_ms);
if left_cutoff > state.left_evicted_cutoff {
state
.left
.evict_before(left_cutoff, &config.left_key, &config.left_time_column)?;
state.left_evicted_cutoff = left_cutoff;
}
let right_cutoff = left_watermark.saturating_sub(bound_ms);
if right_cutoff > state.right_evicted_cutoff {
state
.right
.evict_before(right_cutoff, &config.right_key, &config.right_time_column)?;
state.right_evicted_cutoff = right_cutoff;
}
Ok(result)
}
fn emit_unmatched_left_rows(
state: &IntervalJoinState,
config: &StreamJoinConfig,
left_cutoff: i64,
bound_ms: i64,
) -> Result<Option<RecordBatch>, DbError> {
let Some(output_schema) = state.output_schema.as_ref() else {
return Ok(None);
};
let left_only = matches!(
config.join_type,
StreamJoinType::LeftSemi | StreamJoinType::LeftAnti
);
let mut unmatched_left: Vec<(usize, usize)> = Vec::new();
for (&key_hash, btree) in &state.left.index {
for (&ts, entries) in btree.range(..left_cutoff) {
for &(batch_idx, row_idx) in entries {
let candidates = probe_index(&state.right.index, key_hash, ts, bound_ms);
let left_key =
extract_key_column(&state.left.batches[batch_idx], &config.left_key)?;
let has_match = candidates.iter().any(|&(rb, rr)| {
extract_key_column(&state.right.batches[rb], &config.right_key)
.is_ok_and(|rk| left_key.keys_equal(row_idx, &rk, rr))
});
if !has_match {
unmatched_left.push((batch_idx, row_idx));
}
}
}
}
if unmatched_left.is_empty() {
return Ok(None);
}
let left_schema = state
.left
.batches
.first()
.map_or_else(|| Arc::new(Schema::empty()), RecordBatch::schema);
let num_rows = unmatched_left.len();
let mut columns: Vec<ArrayRef> = Vec::with_capacity(output_schema.fields().len());
for col_idx in 0..left_schema.fields().len() {
let mut builder = Vec::with_capacity(num_rows);
for &(batch_idx, row_idx) in &unmatched_left {
let array = state.left.batches[batch_idx].column(col_idx);
builder.push(array.slice(row_idx, 1));
}
let refs: Vec<&dyn Array> = builder.iter().map(AsRef::as_ref).collect();
let concatenated = arrow::compute::concat(&refs)
.map_err(|e| DbError::query_pipeline_arrow("interval join (unmatched left)", &e))?;
columns.push(concatenated);
}
if !left_only {
let right_schema = state
.right
.batches
.first()
.map_or_else(|| Arc::new(Schema::empty()), RecordBatch::schema);
for col_idx in 0..right_schema.fields().len() {
let dt = output_schema
.field(left_schema.fields().len() + col_idx)
.data_type();
columns.push(arrow::array::new_null_array(dt, num_rows));
}
}
RecordBatch::try_new(output_schema.clone(), columns)
.map(|b| if b.num_rows() > 0 { Some(b) } else { None })
.map_err(|e| DbError::query_pipeline_arrow("interval join (unmatched result)", &e))
}
fn emit_unmatched_right_rows(
state: &IntervalJoinState,
config: &StreamJoinConfig,
right_cutoff: i64,
bound_ms: i64,
) -> Result<Option<RecordBatch>, DbError> {
let Some(output_schema) = state.output_schema.as_ref() else {
return Ok(None);
};
let mut unmatched_right: Vec<(usize, usize)> = Vec::new();
for (&key_hash, btree) in &state.right.index {
for (&ts, entries) in btree.range(..right_cutoff) {
for &(batch_idx, row_idx) in entries {
let candidates = probe_index(&state.left.index, key_hash, ts, bound_ms);
let right_key =
extract_key_column(&state.right.batches[batch_idx], &config.right_key)?;
let has_match = candidates.iter().any(|&(lb, lr)| {
extract_key_column(&state.left.batches[lb], &config.left_key)
.is_ok_and(|lk| right_key.keys_equal(row_idx, &lk, lr))
});
if !has_match {
unmatched_right.push((batch_idx, row_idx));
}
}
}
}
if unmatched_right.is_empty() {
return Ok(None);
}
let left_schema = state
.left
.batches
.first()
.map_or_else(|| Arc::new(Schema::empty()), RecordBatch::schema);
let right_schema = state
.right
.batches
.first()
.map_or_else(|| Arc::new(Schema::empty()), RecordBatch::schema);
let num_rows = unmatched_right.len();
let mut columns: Vec<ArrayRef> = Vec::with_capacity(output_schema.fields().len());
for col_idx in 0..left_schema.fields().len() {
let dt = output_schema.field(col_idx).data_type();
columns.push(arrow::array::new_null_array(dt, num_rows));
}
for col_idx in 0..right_schema.fields().len() {
let mut builder = Vec::with_capacity(num_rows);
for &(batch_idx, row_idx) in &unmatched_right {
let array = state.right.batches[batch_idx].column(col_idx);
builder.push(array.slice(row_idx, 1));
}
let refs: Vec<&dyn Array> = builder.iter().map(AsRef::as_ref).collect();
let concatenated = arrow::compute::concat(&refs)
.map_err(|e| DbError::query_pipeline_arrow("interval join (unmatched right)", &e))?;
columns.push(concatenated);
}
RecordBatch::try_new(output_schema.clone(), columns)
.map(|b| if b.num_rows() > 0 { Some(b) } else { None })
.map_err(|e| DbError::query_pipeline_arrow("interval join (unmatched right result)", &e))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array, StringArray};
use arrow::datatypes::DataType;
use laminar_sql::translator::StreamJoinType;
use std::time::Duration;
fn make_config() -> StreamJoinConfig {
StreamJoinConfig {
left_key: "id".to_string(),
right_key: "id".to_string(),
left_time_column: "ts".to_string(),
right_time_column: "ts".to_string(),
left_table: "left_stream".to_string(),
right_table: "right_stream".to_string(),
time_bound: Duration::from_millis(100),
join_type: StreamJoinType::Inner,
}
}
fn left_batch(ids: &[&str], timestamps: &[i64], values: &[f64]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids.to_vec())),
Arc::new(Int64Array::from(timestamps.to_vec())),
Arc::new(Float64Array::from(values.to_vec())),
],
)
.unwrap()
}
fn right_batch(ids: &[&str], timestamps: &[i64], amounts: &[f64]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("ts", DataType::Int64, false),
Field::new("amount", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids.to_vec())),
Arc::new(Int64Array::from(timestamps.to_vec())),
Arc::new(Float64Array::from(amounts.to_vec())),
],
)
.unwrap()
}
#[test]
fn test_basic_inner_join_same_cycle() {
let config = make_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A", "B"], &[100, 200], &[10.0, 20.0]);
let right = right_batch(&["A", "B"], &[110, 250], &[1.0, 2.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
assert_eq!(result[0].num_columns(), 6); }
#[test]
fn test_cross_cycle_matching() {
let config = make_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let result = execute_interval_join_cycle(&mut state, &[left], &[], &config, 0, 0).unwrap();
assert!(result.is_empty());
let right = right_batch(&["A"], &[150], &[1.0]);
let result = execute_interval_join_cycle(&mut state, &[], &[right], &config, 0, 0).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1); }
#[test]
fn test_time_bound_enforcement() {
let config = make_config(); let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["A"], &[300], &[1.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert!(result.is_empty()); }
#[test]
fn test_eviction_on_watermark_advance() {
let config = make_config(); let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let _ = execute_interval_join_cycle(&mut state, &[left], &[], &config, 0, 0).unwrap();
assert_eq!(state.left.row_count, 1);
let _ = execute_interval_join_cycle(&mut state, &[], &[], &config, 300, 300).unwrap();
assert_eq!(state.left.row_count, 0);
}
#[test]
fn test_multiple_keys() {
let config = make_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A", "B"], &[100, 100], &[10.0, 20.0]);
let right = right_batch(&["B", "A"], &[110, 110], &[1.0, 2.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
}
#[test]
fn test_no_double_emit() {
let config = make_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["A"], &[110], &[1.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1); }
#[test]
fn test_empty_inputs() {
let config = make_config();
let mut state = IntervalJoinState::new();
let result = execute_interval_join_cycle(&mut state, &[], &[], &config, 0, 0).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_checkpoint_roundtrip() {
let config = make_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["A"], &[110], &[1.0]);
let _ =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 50, 50).unwrap();
let cp = state
.snapshot_checkpoint(
&config.left_key,
&config.left_time_column,
&config.right_key,
&config.right_time_column,
)
.unwrap();
assert!(cp.left_buffer_rows > 0);
assert!(cp.right_buffer_rows > 0);
let mut restored = IntervalJoinState::from_checkpoint(
&cp,
&config.left_key,
&config.left_time_column,
&config.right_key,
&config.right_time_column,
)
.unwrap();
let right2 = right_batch(&["A"], &[120], &[2.0]);
let result =
execute_interval_join_cycle(&mut restored, &[], &[right2], &config, 50, 50).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1); }
fn left_batch_nullable(
ids: &[Option<&str>],
timestamps: &[i64],
values: &[f64],
) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, true),
Field::new("ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids.to_vec())),
Arc::new(Int64Array::from(timestamps.to_vec())),
Arc::new(Float64Array::from(values.to_vec())),
],
)
.unwrap()
}
fn right_batch_nullable(
ids: &[Option<&str>],
timestamps: &[i64],
amounts: &[f64],
) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, true),
Field::new("ts", DataType::Int64, false),
Field::new("amount", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids.to_vec())),
Arc::new(Int64Array::from(timestamps.to_vec())),
Arc::new(Float64Array::from(amounts.to_vec())),
],
)
.unwrap()
}
#[test]
fn test_null_key_no_match() {
let config = make_config();
let mut state = IntervalJoinState::new();
let left = left_batch_nullable(&[Some("A"), None], &[100, 100], &[10.0, 20.0]);
let right = right_batch_nullable(&[Some("A"), None], &[110, 110], &[1.0, 2.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1);
}
#[test]
fn test_compaction_frees_batches() {
let config = make_config(); let mut state = IntervalJoinState::new();
for i in 0i64..40 {
let ts = i * 10 + 1000;
#[allow(clippy::cast_precision_loss)]
let left = left_batch(&["A"], &[ts], &[i as f64]);
let _ = execute_interval_join_cycle(&mut state, &[left], &[], &config, 0, 0).unwrap();
}
assert!(state.left.batches.len() >= 40);
let _ = execute_interval_join_cycle(&mut state, &[], &[], &config, 1300, 1300).unwrap();
assert_eq!(state.left.batches.len(), 1);
assert!(state.left.row_count > 0);
let right = right_batch(&["A"], &[1350], &[99.0]);
let result =
execute_interval_join_cycle(&mut state, &[], &[right], &config, 1300, 1300).unwrap();
assert!(!result.is_empty());
}
fn make_left_config() -> StreamJoinConfig {
StreamJoinConfig {
left_key: "id".to_string(),
right_key: "id".to_string(),
left_time_column: "ts".to_string(),
right_time_column: "ts".to_string(),
left_table: "left_stream".to_string(),
right_table: "right_stream".to_string(),
time_bound: Duration::from_millis(100),
join_type: StreamJoinType::Left,
}
}
#[test]
fn test_left_join_unmatched_emitted_with_nulls() {
let config = make_left_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["B"], &[100], &[1.0]); let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert!(result.is_empty());
let result = execute_interval_join_cycle(&mut state, &[], &[], &config, 0, 300).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1);
assert_eq!(result[0].num_columns(), 6); assert!(result[0].column(3).is_null(0)); }
#[test]
fn test_left_join_matched_not_re_emitted() {
let config = make_left_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["A"], &[110], &[1.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1);
let result = execute_interval_join_cycle(&mut state, &[], &[], &config, 0, 300).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_right_join_unmatched_emitted() {
let mut config = make_config();
config.join_type = StreamJoinType::Right;
let mut state = IntervalJoinState::new();
let left = left_batch(&["B"], &[100], &[99.0]);
let right = right_batch(&["A"], &[100], &[1.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert!(result.is_empty());
let result = execute_interval_join_cycle(&mut state, &[], &[], &config, 300, 0).unwrap();
let total_rows: usize = result.iter().map(RecordBatch::num_rows).sum();
assert!(total_rows >= 1);
assert!(result[0].column(0).is_null(0)); }
#[test]
fn test_full_join_unmatched_both_sides() {
let mut config = make_config();
config.join_type = StreamJoinType::Full;
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["A"], &[500], &[1.0]);
let _ = execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
let result = execute_interval_join_cycle(&mut state, &[], &[], &config, 700, 700).unwrap();
let total_rows: usize = result.iter().map(RecordBatch::num_rows).sum();
assert_eq!(total_rows, 2);
}
#[test]
fn test_semi_join_dedup() {
let mut config = make_config();
config.join_type = StreamJoinType::LeftSemi;
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["A", "A"], &[110, 120], &[1.0, 2.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1); assert_eq!(result[0].num_columns(), 3); }
#[test]
fn test_anti_join_unmatched_only() {
let mut config = make_config();
config.join_type = StreamJoinType::LeftAnti;
let mut state = IntervalJoinState::new();
let left = left_batch(&["A", "B"], &[100, 200], &[10.0, 20.0]);
let right = right_batch(&["A"], &[110], &[1.0]);
let result =
execute_interval_join_cycle(&mut state, &[left], &[right], &config, 0, 0).unwrap();
assert!(result.is_empty());
let result = execute_interval_join_cycle(&mut state, &[], &[], &config, 0, 400).unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1);
assert_eq!(result[0].num_columns(), 3); }
#[test]
fn test_cdc_delete_removes_from_state() {
let config = make_config();
let mut state = IntervalJoinState::new();
let left = left_batch(&["A"], &[100], &[10.0]);
let _ = execute_interval_join_cycle(&mut state, &[left], &[], &config, 0, 0).unwrap();
assert_eq!(state.left.row_count, 1);
let del_schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
Field::new("_op", DataType::Utf8, false),
]));
let del_batch = RecordBatch::try_new(
del_schema,
vec![
Arc::new(StringArray::from(vec!["A"])),
Arc::new(Int64Array::from(vec![100])),
Arc::new(Float64Array::from(vec![10.0])),
Arc::new(StringArray::from(vec!["D"])),
],
)
.unwrap();
let _ = execute_interval_join_cycle(&mut state, &[del_batch], &[], &config, 0, 0).unwrap();
assert_eq!(state.left.row_count, 0); }
}