#![deny(clippy::disallowed_types)]
use std::collections::hash_map::DefaultHasher;
use std::collections::BTreeMap;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use rustc_hash::{FxHashMap, FxHashSet};
use arrow::array::{
Array, ArrayRef, Float64Array, Int64Array, RecordBatch, StringArray, TimestampMillisecondArray,
};
use arrow::compute::concat_batches;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use laminar_sql::parser::join_parser::AsofSqlDirection;
use laminar_sql::translator::{AsofJoinTranslatorConfig, AsofSqlJoinType};
use crate::error::DbError;
enum KeyColumn<'a> {
Utf8(&'a StringArray),
Int64(&'a Int64Array),
}
impl KeyColumn<'_> {
fn is_null(&self, i: usize) -> bool {
match self {
KeyColumn::Utf8(a) => a.is_null(i),
KeyColumn::Int64(a) => a.is_null(i),
}
}
fn hash_at(&self, i: usize) -> Option<u64> {
if self.is_null(i) {
return None;
}
let mut hasher = DefaultHasher::new();
match self {
KeyColumn::Utf8(a) => a.value(i).hash(&mut hasher),
KeyColumn::Int64(a) => a.value(i).hash(&mut hasher),
}
Some(hasher.finish())
}
fn keys_equal(&self, i: usize, other: &KeyColumn<'_>, j: usize) -> bool {
if self.is_null(i) || other.is_null(j) {
return false;
}
match (self, other) {
(KeyColumn::Utf8(a), KeyColumn::Utf8(b)) => a.value(i) == b.value(j),
(KeyColumn::Int64(a), KeyColumn::Int64(b)) => a.value(i) == b.value(j),
_ => false,
}
}
}
fn extract_key_column<'a>(
batch: &'a RecordBatch,
col_name: &str,
) -> Result<KeyColumn<'a>, DbError> {
let col_idx = batch
.schema()
.index_of(col_name)
.map_err(|_| DbError::Pipeline(format!("Column '{col_name}' not found")))?;
let array = batch.column(col_idx);
match array.data_type() {
DataType::Utf8 => {
let string_array = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| DbError::Pipeline(format!("Column '{col_name}' is not Utf8")))?;
Ok(KeyColumn::Utf8(string_array))
}
DataType::Int64 => {
let int_array = array
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| DbError::Pipeline(format!("Column '{col_name}' is not Int64")))?;
Ok(KeyColumn::Int64(int_array))
}
other => Err(DbError::Pipeline(format!(
"Unsupported key column type: {other}"
))),
}
}
pub(crate) fn execute_asof_join_batch(
left_batches: &[RecordBatch],
right_batches: &[RecordBatch],
config: &AsofJoinTranslatorConfig,
) -> Result<RecordBatch, DbError> {
if left_batches.is_empty() {
let schema = if right_batches.is_empty() {
Arc::new(Schema::empty())
} else {
build_output_schema(
&Arc::new(Schema::empty()),
&right_batches[0].schema(),
config,
)
};
return Ok(RecordBatch::new_empty(schema));
}
let left_schema = left_batches[0].schema();
let left = concat_batches(&left_schema, left_batches)
.map_err(|e| DbError::query_pipeline_arrow("ASOF join (left)", &e))?;
let right_schema = if right_batches.is_empty() {
Arc::new(Schema::empty())
} else {
right_batches[0].schema()
};
let right = if right_batches.is_empty() {
RecordBatch::new_empty(right_schema.clone())
} else {
concat_batches(&right_schema, right_batches)
.map_err(|e| DbError::query_pipeline_arrow("ASOF join (right)", &e))?
};
let output_schema = build_output_schema(&left_schema, &right_schema, config);
let mut right_index: FxHashMap<u64, BTreeMap<i64, Vec<usize>>> =
FxHashMap::with_capacity_and_hasher(right.num_rows(), rustc_hash::FxBuildHasher);
let right_keys_col;
if right.num_rows() > 0 {
right_keys_col = Some(extract_key_column(&right, &config.key_column)?);
let right_timestamps = extract_column_as_timestamps(&right, &config.right_time_column)?;
let rk = right_keys_col.as_ref().unwrap();
for (i, &ts) in right_timestamps.iter().enumerate() {
if let Some(key_hash) = rk.hash_at(i) {
right_index
.entry(key_hash)
.or_default()
.entry(ts)
.or_default()
.push(i);
}
}
} else {
right_keys_col = None;
}
let left_keys_col = extract_key_column(&left, &config.key_column)?;
let left_timestamps = extract_column_as_timestamps(&left, &config.left_time_column)?;
let tolerance_ms = config
.tolerance
.map(|d| i64::try_from(d.as_millis()).unwrap_or(i64::MAX));
let mut left_indices: Vec<usize> = Vec::with_capacity(left.num_rows());
let mut right_indices: Vec<Option<usize>> = Vec::with_capacity(left.num_rows());
for (left_idx, &left_ts) in left_timestamps.iter().enumerate() {
let Some(left_hash) = left_keys_col.hash_at(left_idx) else {
if config.join_type == AsofSqlJoinType::Left {
left_indices.push(left_idx);
right_indices.push(None);
}
continue;
};
let matched_right = right_index.get(&left_hash).and_then(|btree| {
let candidates = find_match(btree, left_ts, config.direction, tolerance_ms)?;
if let Some(ref rk) = right_keys_col {
for &candidate in &candidates {
if left_keys_col.keys_equal(left_idx, rk, candidate) {
return Some(candidate);
}
}
}
None
});
match (&config.join_type, matched_right) {
(_, Some(right_idx)) => {
left_indices.push(left_idx);
right_indices.push(Some(right_idx));
}
(AsofSqlJoinType::Left, None) => {
left_indices.push(left_idx);
right_indices.push(None);
}
(AsofSqlJoinType::Inner, None) => {
}
}
}
build_output_batch(
&left,
&right,
&left_indices,
&right_indices,
&output_schema,
config,
)
}
fn find_match(
btree: &BTreeMap<i64, Vec<usize>>,
left_ts: i64,
direction: AsofSqlDirection,
tolerance_ms: Option<i64>,
) -> Option<Vec<usize>> {
let candidate = match direction {
AsofSqlDirection::Backward => {
btree
.range(..=left_ts)
.next_back()
.map(|(&ts, indices)| (ts, indices.clone()))
}
AsofSqlDirection::Forward => {
btree
.range(left_ts..)
.next()
.map(|(&ts, indices)| (ts, indices.clone()))
}
AsofSqlDirection::Nearest => {
let backward = btree
.range(..=left_ts)
.next_back()
.map(|(&ts, indices)| (ts, indices.clone()));
let forward = btree
.range(left_ts..)
.next()
.map(|(&ts, indices)| (ts, indices.clone()));
match (backward, forward) {
(Some((b_ts, b_indices)), Some((f_ts, f_indices))) => {
let b_diff = (left_ts - b_ts).abs();
let f_diff = (f_ts - left_ts).abs();
if b_diff <= f_diff {
Some((b_ts, b_indices))
} else {
Some((f_ts, f_indices))
}
}
(Some(b), None) => Some(b),
(None, Some(f)) => Some(f),
(None, None) => None,
}
}
};
candidate.and_then(|(right_ts, indices)| {
if let Some(tol) = tolerance_ms {
if (left_ts - right_ts).abs() <= tol {
Some(indices)
} else {
None
}
} else {
Some(indices)
}
})
}
fn extract_column_as_timestamps(batch: &RecordBatch, col_name: &str) -> Result<Vec<i64>, DbError> {
let col_idx = batch
.schema()
.index_of(col_name)
.map_err(|_| DbError::Pipeline(format!("Timestamp column '{col_name}' not found")))?;
let array = batch.column(col_idx);
match array.data_type() {
DataType::Int64 => {
let int_array = array
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| DbError::Pipeline(format!("Column '{col_name}' is not Int64")))?;
Ok(int_array.values().to_vec())
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
let ts_array = array
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.ok_or_else(|| {
DbError::Pipeline(format!("Column '{col_name}' is not TimestampMillisecond"))
})?;
Ok(ts_array.values().to_vec())
}
DataType::Float64 => {
let f_array = array
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| DbError::Pipeline(format!("Column '{col_name}' is not Float64")))?;
#[allow(clippy::cast_possible_truncation)]
Ok(f_array.values().iter().map(|v| *v as i64).collect())
}
other => Err(DbError::Pipeline(format!(
"Unsupported timestamp column type for '{col_name}': {other}"
))),
}
}
fn build_output_schema(
left_schema: &SchemaRef,
right_schema: &SchemaRef,
config: &AsofJoinTranslatorConfig,
) -> SchemaRef {
let mut fields: Vec<Field> = left_schema
.fields()
.iter()
.map(|f| f.as_ref().clone())
.collect();
let left_names: FxHashSet<&str> = left_schema
.fields()
.iter()
.map(|f| f.name().as_str())
.collect();
let make_nullable = config.join_type == AsofSqlJoinType::Left;
for field in right_schema.fields() {
if field.name() == &config.key_column {
continue;
}
let mut f = field.as_ref().clone();
if make_nullable {
f = f.with_nullable(true);
}
if left_names.contains(f.name().as_str()) {
let suffixed_name = format!("{}_{}", f.name(), config.right_table);
f = f.with_name(suffixed_name);
}
fields.push(f);
}
Arc::new(Schema::new(fields))
}
fn build_output_batch(
left: &RecordBatch,
right: &RecordBatch,
left_indices: &[usize],
right_indices: &[Option<usize>],
output_schema: &SchemaRef,
config: &AsofJoinTranslatorConfig,
) -> Result<RecordBatch, DbError> {
let num_rows = left_indices.len();
let mut columns: Vec<ArrayRef> = Vec::with_capacity(left.num_columns() + right.num_columns());
#[allow(clippy::cast_possible_truncation)]
let left_idx_array =
arrow::array::UInt32Array::from(left_indices.iter().map(|&i| i as u32).collect::<Vec<_>>());
for col_idx in 0..left.num_columns() {
let array = left.column(col_idx);
let taken = arrow::compute::take(array, &left_idx_array, None)
.map_err(|e| DbError::query_pipeline_arrow("ASOF join (left take)", &e))?;
columns.push(taken);
}
let right_schema = right.schema();
for col_idx in 0..right.num_columns() {
let field_name = right_schema.field(col_idx).name();
if field_name == &config.key_column {
continue;
}
let array = right.column(col_idx);
let taken = take_with_nulls(array, right_indices, num_rows)?;
columns.push(taken);
}
RecordBatch::try_new(output_schema.clone(), columns)
.map_err(|e| DbError::query_pipeline_arrow("ASOF join (result)", &e))
}
fn take_with_nulls(
array: &dyn Array,
indices: &[Option<usize>],
num_rows: usize,
) -> Result<ArrayRef, DbError> {
if array.is_empty() {
return Ok(arrow::array::new_null_array(array.data_type(), num_rows));
}
#[allow(clippy::cast_possible_truncation)]
let index_array = arrow::array::UInt32Array::from(
indices
.iter()
.map(|opt| opt.map(|i| i as u32))
.collect::<Vec<Option<u32>>>(),
);
arrow::compute::take(array, &index_array, None)
.map_err(|e| DbError::query_pipeline_arrow("ASOF join (right take)", &e))
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn trades_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("trade_ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec!["AAPL", "AAPL", "GOOG", "AAPL"])),
Arc::new(Int64Array::from(vec![100, 200, 150, 300])),
Arc::new(Float64Array::from(vec![150.0, 152.0, 2800.0, 155.0])),
],
)
.unwrap()
}
fn quotes_batch() -> RecordBatch {
let schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("quote_ts", DataType::Int64, false),
Field::new("bid", DataType::Float64, false),
Field::new("ask", DataType::Float64, false),
]));
RecordBatch::try_new(
schema,
vec![
Arc::new(StringArray::from(vec![
"AAPL", "AAPL", "GOOG", "AAPL", "GOOG",
])),
Arc::new(Int64Array::from(vec![90, 180, 140, 250, 160])),
Arc::new(Float64Array::from(vec![
149.0, 151.0, 2790.0, 153.0, 2795.0,
])),
Arc::new(Float64Array::from(vec![
150.0, 152.0, 2800.0, 154.0, 2805.0,
])),
],
)
.unwrap()
}
fn backward_config() -> AsofJoinTranslatorConfig {
AsofJoinTranslatorConfig {
left_table: "trades".to_string(),
right_table: "quotes".to_string(),
key_column: "symbol".to_string(),
left_time_column: "trade_ts".to_string(),
right_time_column: "quote_ts".to_string(),
direction: AsofSqlDirection::Backward,
tolerance: None,
join_type: AsofSqlJoinType::Left,
}
}
#[test]
fn test_backward_join_basic() {
let config = backward_config();
let result =
execute_asof_join_batch(&[trades_batch()], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 4);
assert_eq!(result.num_columns(), 6);
let quote_ts = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(quote_ts.value(0), 90); assert_eq!(quote_ts.value(1), 180); }
#[test]
fn test_forward_join_basic() {
let mut config = backward_config();
config.direction = AsofSqlDirection::Forward;
let result =
execute_asof_join_batch(&[trades_batch()], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 4);
let quote_ts = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(quote_ts.value(0), 180); assert_eq!(quote_ts.value(1), 250); }
#[test]
fn test_left_join_emits_unmatched_with_nulls() {
let trades_schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("trade_ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
let trades = RecordBatch::try_new(
trades_schema,
vec![
Arc::new(StringArray::from(vec!["MSFT"])),
Arc::new(Int64Array::from(vec![100])),
Arc::new(Float64Array::from(vec![300.0])),
],
)
.unwrap();
let config = backward_config();
let result = execute_asof_join_batch(&[trades], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 1);
assert!(result.column(3).is_null(0)); }
#[test]
fn test_inner_join_skips_unmatched() {
let trades_schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("trade_ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
let trades = RecordBatch::try_new(
trades_schema,
vec![
Arc::new(StringArray::from(vec!["MSFT", "AAPL"])),
Arc::new(Int64Array::from(vec![100, 200])),
Arc::new(Float64Array::from(vec![300.0, 152.0])),
],
)
.unwrap();
let mut config = backward_config();
config.join_type = AsofSqlJoinType::Inner;
let result = execute_asof_join_batch(&[trades], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 1);
}
#[test]
fn test_tolerance_filtering() {
let mut config = backward_config();
config.tolerance = Some(Duration::from_millis(15));
let result =
execute_asof_join_batch(&[trades_batch()], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 4); let quote_ts = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(quote_ts.value(0), 90); assert!(result.column(3).is_null(1)); assert_eq!(quote_ts.value(2), 140); assert!(result.column(3).is_null(3)); }
#[test]
fn test_empty_left_input() {
let config = backward_config();
let result = execute_asof_join_batch(&[], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 0);
}
#[test]
fn test_empty_right_input() {
let config = backward_config();
let result = execute_asof_join_batch(&[trades_batch()], &[], &config).unwrap();
assert_eq!(result.num_rows(), 4);
}
#[test]
fn test_multiple_keys() {
let config = backward_config();
let result =
execute_asof_join_batch(&[trades_batch()], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 4);
let symbols = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let quote_ts = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(symbols.value(2), "GOOG");
assert_eq!(quote_ts.value(2), 140); }
#[test]
fn test_multiple_right_matches_picks_closest() {
let config = backward_config();
let result =
execute_asof_join_batch(&[trades_batch()], &[quotes_batch()], &config).unwrap();
let quote_ts = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(quote_ts.value(1), 180);
}
#[test]
fn test_nearest_join() {
let mut config = backward_config();
config.direction = AsofSqlDirection::Nearest;
let result =
execute_asof_join_batch(&[trades_batch()], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 4);
let quote_ts = result
.column(3)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(quote_ts.value(0), 90); assert_eq!(quote_ts.value(1), 180); assert_eq!(quote_ts.value(2), 140); assert_eq!(quote_ts.value(3), 250); }
#[test]
fn test_hash_collision_different_keys() {
let trades_schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("trade_ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
let trades = RecordBatch::try_new(
trades_schema,
vec![
Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
Arc::new(Int64Array::from(vec![100, 100])), Arc::new(Float64Array::from(vec![150.0, 2800.0])),
],
)
.unwrap();
let quotes_schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, false),
Field::new("quote_ts", DataType::Int64, false),
Field::new("bid", DataType::Float64, false),
]));
let quotes = RecordBatch::try_new(
quotes_schema,
vec![
Arc::new(StringArray::from(vec!["AAPL", "GOOG"])),
Arc::new(Int64Array::from(vec![100, 100])), Arc::new(Float64Array::from(vec![149.0, 2790.0])),
],
)
.unwrap();
let config = backward_config();
let result = execute_asof_join_batch(&[trades], &[quotes], &config).unwrap();
assert_eq!(result.num_rows(), 2);
let symbols = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let bids = result
.column(4)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(symbols.value(0), "AAPL");
assert!((bids.value(0) - 149.0).abs() < f64::EPSILON);
assert_eq!(symbols.value(1), "GOOG");
assert!((bids.value(1) - 2790.0).abs() < f64::EPSILON);
}
#[test]
fn test_null_key_no_match() {
let trades_schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, true),
Field::new("trade_ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
let trades = RecordBatch::try_new(
trades_schema,
vec![
Arc::new(StringArray::from(vec![Some("AAPL"), None])),
Arc::new(Int64Array::from(vec![100, 100])),
Arc::new(Float64Array::from(vec![150.0, 200.0])),
],
)
.unwrap();
let mut config = backward_config();
config.join_type = AsofSqlJoinType::Inner;
let result = execute_asof_join_batch(&[trades], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 1);
let symbols = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(symbols.value(0), "AAPL");
}
#[test]
fn test_null_key_left_join_emits_nulls() {
let trades_schema = Arc::new(Schema::new(vec![
Field::new("symbol", DataType::Utf8, true),
Field::new("trade_ts", DataType::Int64, false),
Field::new("price", DataType::Float64, false),
]));
let trades = RecordBatch::try_new(
trades_schema,
vec![
Arc::new(StringArray::from(vec![Some("AAPL"), None])),
Arc::new(Int64Array::from(vec![100, 100])),
Arc::new(Float64Array::from(vec![150.0, 200.0])),
],
)
.unwrap();
let config = backward_config();
let result = execute_asof_join_batch(&[trades], &[quotes_batch()], &config).unwrap();
assert_eq!(result.num_rows(), 2);
let symbols = result
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(symbols.value(0), "AAPL");
assert!(result.column(0).is_null(1)); assert!(result.column(3).is_null(1)); }
}