use std::sync::Arc;
use arrow::array::RecordBatch;
use async_trait::async_trait;
use datafusion::prelude::SessionContext;
use laminar_sql::translator::AsofJoinTranslatorConfig;
use crate::asof_batch::{execute_asof_join_with_state, AsofBufferCheckpoint, AsofRightBuffer};
use crate::error::DbError;
use crate::operator_graph::{GraphOperator, OperatorCheckpoint};
use crate::sql_analysis::CompiledPostProjection;
pub(crate) struct AsofJoinOperator {
op_name: Arc<str>,
config: AsofJoinTranslatorConfig,
projection_sql: Option<Arc<str>>,
ctx: SessionContext,
compiled_post_proj: Option<CompiledPostProjection>,
post_proj_compile_failed: bool,
right_buffer: AsofRightBuffer,
last_evicted_watermark: i64,
}
impl AsofJoinOperator {
pub(crate) fn new(
name: &str,
config: AsofJoinTranslatorConfig,
projection_sql: Option<Arc<str>>,
ctx: SessionContext,
) -> Self {
Self {
op_name: Arc::from(name),
config,
projection_sql,
ctx,
compiled_post_proj: None,
post_proj_compile_failed: false,
right_buffer: AsofRightBuffer::default(),
last_evicted_watermark: i64::MIN,
}
}
async fn apply_projection(
&mut self,
batches: Vec<RecordBatch>,
) -> Result<Vec<RecordBatch>, DbError> {
super::apply_post_projection(
&self.ctx,
&self.op_name,
"__asof_tmp",
self.projection_sql.as_deref(),
&mut self.compiled_post_proj,
&mut self.post_proj_compile_failed,
batches,
)
.await
}
}
#[async_trait]
impl GraphOperator for AsofJoinOperator {
async fn process(
&mut self,
inputs: &[Vec<RecordBatch>],
watermarks: &[i64],
) -> Result<Vec<RecordBatch>, DbError> {
let left_batches = inputs.first().map_or(&[][..], Vec::as_slice);
let right_batches = inputs.get(1).map_or(&[][..], Vec::as_slice);
self.right_buffer.ingest(
right_batches,
&self.config.key_column,
&self.config.right_time_column,
)?;
let max_lookback_ms = self.config.tolerance.map_or(i64::MAX, |d| {
i64::try_from(d.as_millis()).unwrap_or(i64::MAX)
});
let left_wm = watermarks.first().copied().unwrap_or(i64::MIN);
let cutoff = left_wm.saturating_sub(max_lookback_ms);
if cutoff > self.last_evicted_watermark {
self.right_buffer.evict_before(cutoff)?;
self.last_evicted_watermark = cutoff;
}
if left_batches.is_empty() {
return Ok(Vec::new());
}
let joined = execute_asof_join_with_state(left_batches, &self.right_buffer, &self.config)?;
if joined.num_rows() == 0 {
return Ok(Vec::new());
}
self.apply_projection(vec![joined]).await
}
fn checkpoint(&mut self) -> Result<Option<OperatorCheckpoint>, DbError> {
let cp = self
.right_buffer
.snapshot_checkpoint(self.last_evicted_watermark)?;
let data = serde_json::to_vec(&cp).map_err(|e| {
DbError::Pipeline(format!(
"ASOF join [{}]: checkpoint serialization: {e}",
self.op_name
))
})?;
Ok(Some(OperatorCheckpoint { data }))
}
fn restore(&mut self, checkpoint: OperatorCheckpoint) -> Result<(), DbError> {
let cp: AsofBufferCheckpoint = serde_json::from_slice(&checkpoint.data).map_err(|e| {
DbError::Pipeline(format!(
"ASOF join [{}]: checkpoint deserialization: {e}",
self.op_name
))
})?;
let (buffer, last_wm) = AsofRightBuffer::from_checkpoint(&cp)?;
self.right_buffer = buffer;
self.last_evicted_watermark = last_wm;
Ok(())
}
fn estimated_state_bytes(&self) -> usize {
self.right_buffer.estimated_size_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use laminar_sql::parser::join_parser::AsofSqlDirection;
use laminar_sql::translator::AsofSqlJoinType;
fn trades_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("trade_ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
Arc::new(Int64Array::from(vec![100, 150])),
Arc::new(Float64Array::from(vec![150.0, 2800.0])),
],
)
.unwrap()
}
fn quotes_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("quote_ts", DataType::Int64, false),
Field::new("bid", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
Arc::new(Int64Array::from(vec![90, 140])),
Arc::new(Float64Array::from(vec![149.0, 2790.0])),
],
)
.unwrap()
}
fn test_config() -> AsofJoinTranslatorConfig {
AsofJoinTranslatorConfig {
left_table: "trades".to_string(),
right_table: "quotes".to_string(),
key_column: "symbol".to_string(),
left_time_column: "trade_ts".to_string(),
right_time_column: "quote_ts".to_string(),
direction: AsofSqlDirection::Backward,
tolerance: None,
join_type: AsofSqlJoinType::Left,
}
}
#[tokio::test]
async fn test_basic_asof_join() {
let ctx = laminar_sql::create_session_context();
let mut op = AsofJoinOperator::new("test_asof", test_config(), None, ctx);
let result = op
.process(&[vec![trades_batch()], vec![quotes_batch()]], &[0, 0])
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
}
#[tokio::test]
async fn test_cross_cycle_match() {
let ctx = laminar_sql::create_session_context();
let mut op = AsofJoinOperator::new("test_asof", test_config(), None, ctx);
let result = op
.process(&[vec![], vec![quotes_batch()]], &[0, 0])
.await
.unwrap();
assert!(result.is_empty());
let result = op
.process(&[vec![trades_batch()], vec![]], &[0, 0])
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
}
#[tokio::test]
async fn test_eviction_on_watermark_advance() {
let mut config = test_config();
config.tolerance = Some(std::time::Duration::from_millis(50));
let ctx = laminar_sql::create_session_context();
let mut op = AsofJoinOperator::new("test_asof", config, None, ctx);
op.process(&[vec![], vec![quotes_batch()]], &[0, 0])
.await
.unwrap();
op.process(&[vec![], vec![]], &[200, 200]).await.unwrap();
let result = op
.process(&[vec![trades_batch()], vec![]], &[200, 200])
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
let right_start = 3; for col_idx in right_start..result[0].num_columns() {
assert!(
result[0].column(col_idx).is_null(0),
"col {col_idx} row 0 should be null"
);
assert!(
result[0].column(col_idx).is_null(1),
"col {col_idx} row 1 should be null"
);
}
}
#[tokio::test]
async fn test_checkpoint_roundtrip() {
let ctx = laminar_sql::create_session_context();
let mut op = AsofJoinOperator::new("test_asof", test_config(), None, ctx.clone());
op.process(&[vec![], vec![quotes_batch()]], &[0, 0])
.await
.unwrap();
let cp = op.checkpoint().unwrap().expect("should have state");
assert!(!cp.data.is_empty());
let mut op2 = AsofJoinOperator::new("test_asof", test_config(), None, ctx);
op2.restore(cp).unwrap();
let result = op2
.process(&[vec![trades_batch()], vec![]], &[0, 0])
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].num_rows(), 2);
}
#[tokio::test]
async fn test_empty_left() {
let ctx = laminar_sql::create_session_context();
let mut op = AsofJoinOperator::new("test_asof", test_config(), None, ctx);
let result = op
.process(&[vec![], vec![quotes_batch()]], &[0, 0])
.await
.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_empty_inputs() {
let ctx = laminar_sql::create_session_context();
let mut op = AsofJoinOperator::new("test_asof", test_config(), None, ctx);
let result = op.process(&[], &[0]).await.unwrap();
assert!(result.is_empty());
}
#[test]
fn test_name() {
let ctx = laminar_sql::create_session_context();
let op = AsofJoinOperator::new("my_asof_query", test_config(), None, ctx);
assert_eq!(&*op.op_name, "my_asof_query");
}
#[test]
fn test_estimated_state_bytes_starts_zero() {
let ctx = laminar_sql::create_session_context();
let op = AsofJoinOperator::new("test_asof", test_config(), None, ctx);
assert_eq!(op.estimated_state_bytes(), 0);
}
}