use std::sync::Arc;
use arrow::array::RecordBatch;
use async_trait::async_trait;
use datafusion::prelude::SessionContext;
use laminar_sql::translator::StreamJoinConfig;
use crate::aggregate_state::JoinStateCheckpoint;
use crate::error::DbError;
use crate::interval_join::{execute_interval_join_cycle, IntervalJoinState};
use crate::operator_graph::{GraphOperator, OperatorCheckpoint};
use crate::stream_executor::CompiledPostProjection;
pub(crate) struct IntervalJoinOperator {
op_name: Arc<str>,
config: StreamJoinConfig,
state: IntervalJoinState,
projection_sql: Option<Arc<str>>,
ctx: SessionContext,
compiled_post_proj: Option<CompiledPostProjection>,
post_proj_compile_failed: bool,
}
impl IntervalJoinOperator {
pub(crate) fn new(
name: &str,
config: StreamJoinConfig,
projection_sql: Option<Arc<str>>,
ctx: SessionContext,
) -> Self {
Self {
op_name: Arc::from(name),
config,
state: IntervalJoinState::new(),
projection_sql,
ctx,
compiled_post_proj: None,
post_proj_compile_failed: false,
}
}
async fn apply_projection(
&mut self,
join_result: Vec<RecordBatch>,
) -> Result<Vec<RecordBatch>, DbError> {
super::apply_post_projection(
&self.ctx,
&self.op_name,
"__interval_tmp",
self.projection_sql.as_deref(),
&mut self.compiled_post_proj,
&mut self.post_proj_compile_failed,
join_result,
)
.await
}
}
#[async_trait]
impl GraphOperator for IntervalJoinOperator {
async fn process(
&mut self,
inputs: &[Vec<RecordBatch>],
watermark: i64,
) -> Result<Vec<RecordBatch>, DbError> {
let left_batches = inputs.first().map_or(&[][..], Vec::as_slice);
let right_batches = inputs.get(1).map_or(&[][..], Vec::as_slice);
let join_result = execute_interval_join_cycle(
&mut self.state,
left_batches,
right_batches,
&self.config,
watermark,
)?;
self.apply_projection(join_result).await
}
fn checkpoint(&mut self) -> Result<Option<OperatorCheckpoint>, DbError> {
let cp = self.state.snapshot_checkpoint(
&self.config.left_key,
&self.config.left_time_column,
&self.config.right_key,
&self.config.right_time_column,
)?;
let data = serde_json::to_vec(&cp).map_err(|e| {
DbError::Pipeline(format!(
"interval join [{}]: checkpoint serialization: {e}",
self.op_name
))
})?;
Ok(Some(OperatorCheckpoint { data }))
}
fn restore(&mut self, checkpoint: OperatorCheckpoint) -> Result<(), DbError> {
let cp: JoinStateCheckpoint = serde_json::from_slice(&checkpoint.data).map_err(|e| {
DbError::Pipeline(format!(
"interval join [{}]: checkpoint deserialization: {e}",
self.op_name
))
})?;
self.state = IntervalJoinState::from_checkpoint(
&cp,
&self.config.left_key,
&self.config.left_time_column,
&self.config.right_key,
&self.config.right_time_column,
)?;
Ok(())
}
fn estimated_state_bytes(&self) -> usize {
self.state.estimated_size_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use laminar_sql::translator::StreamJoinType;
use std::time::Duration;
fn test_config() -> StreamJoinConfig {
StreamJoinConfig {
left_key: "id".to_string(),
right_key: "id".to_string(),
left_time_column: "ts".to_string(),
right_time_column: "ts".to_string(),
left_table: "left_stream".to_string(),
right_table: "right_stream".to_string(),
time_bound: Duration::from_millis(100),
join_type: StreamJoinType::Inner,
}
}
fn left_batch(ids: &[&str], timestamps: &[i64], values: &[f64]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids.to_vec())),
Arc::new(Int64Array::from(timestamps.to_vec())),
Arc::new(Float64Array::from(values.to_vec())),
],
)
.unwrap()
}
fn right_batch(ids: &[&str], timestamps: &[i64], amounts: &[f64]) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Utf8, false),
Field::new("ts", DataType::Int64, false),
Field::new("amount", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(ids.to_vec())),
Arc::new(Int64Array::from(timestamps.to_vec())),
Arc::new(Float64Array::from(amounts.to_vec())),
],
)
.unwrap()
}
#[tokio::test]
async fn test_basic_interval_join() {
let ctx = laminar_sql::create_session_context();
let mut op = IntervalJoinOperator::new("test_interval", test_config(), None, ctx);
let left = left_batch(&["A", "B"], &[100, 200], &[10.0, 20.0]);
let right = right_batch(&["A", "B"], &[110, 250], &[1.0, 2.0]);
let result = op.process(&[vec![left], vec![right]], 0).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
}
#[tokio::test]
async fn test_cross_cycle_matching() {
let ctx = laminar_sql::create_session_context();
let mut op = IntervalJoinOperator::new("test_interval", test_config(), None, ctx);
let left = left_batch(&["A"], &[100], &[10.0]);
let result = op.process(&[vec![left], vec![]], 0).await.unwrap();
assert!(result.is_empty());
let right = right_batch(&["A"], &[150], &[1.0]);
let result = op.process(&[vec![], vec![right]], 0).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1);
}
#[tokio::test]
async fn test_empty_inputs() {
let ctx = laminar_sql::create_session_context();
let mut op = IntervalJoinOperator::new("test_interval", test_config(), None, ctx);
let result = op.process(&[], 0).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_checkpoint_roundtrip() {
let ctx = laminar_sql::create_session_context();
let mut op = IntervalJoinOperator::new("test_interval", test_config(), None, ctx.clone());
let left = left_batch(&["A"], &[100], &[10.0]);
let right = right_batch(&["A"], &[110], &[1.0]);
let _ = op.process(&[vec![left], vec![right]], 50).await.unwrap();
let cp = op.checkpoint().unwrap().expect("should have state");
assert!(!cp.data.is_empty());
let mut op2 = IntervalJoinOperator::new("test_interval", test_config(), None, ctx);
op2.restore(cp).unwrap();
let right2 = right_batch(&["A"], &[120], &[2.0]);
let result = op2.process(&[vec![], vec![right2]], 50).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 1);
}
#[test]
fn test_estimated_state_bytes() {
let ctx = laminar_sql::create_session_context();
let op = IntervalJoinOperator::new("test_interval", test_config(), None, ctx);
assert_eq!(op.estimated_state_bytes(), 0);
}
#[test]
fn test_name() {
let ctx = laminar_sql::create_session_context();
let op = IntervalJoinOperator::new("my_interval_join", test_config(), None, ctx);
assert_eq!(&*op.op_name, "my_interval_join");
}
}