#![forbid(unsafe_code)]
mod cartesian;
mod hash_join;
use arrow::array::RecordBatch;
use llkv_result::{Error, Result as LlkvResult};
use llkv_storage::pager::Pager;
use llkv_table::table::Table;
use llkv_table::types::FieldId;
use simd_r_drive_entry_handle::EntryHandle;
use std::fmt;
pub use cartesian::cross_join_pair;
pub use hash_join::hash_join_stream;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum JoinType {
Inner,
Left,
Right,
Full,
Semi,
Anti,
}
impl fmt::Display for JoinType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
JoinType::Inner => write!(f, "INNER"),
JoinType::Left => write!(f, "LEFT"),
JoinType::Right => write!(f, "RIGHT"),
JoinType::Full => write!(f, "FULL"),
JoinType::Semi => write!(f, "SEMI"),
JoinType::Anti => write!(f, "ANTI"),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct JoinKey {
pub left_field: FieldId,
pub right_field: FieldId,
pub null_equals_null: bool,
}
impl JoinKey {
pub fn new(left_field: FieldId, right_field: FieldId) -> Self {
Self {
left_field,
right_field,
null_equals_null: false,
}
}
pub fn null_safe(left_field: FieldId, right_field: FieldId) -> Self {
Self {
left_field,
right_field,
null_equals_null: true,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
pub enum JoinAlgorithm {
#[default]
Hash,
SortMerge,
}
impl fmt::Display for JoinAlgorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
JoinAlgorithm::Hash => write!(f, "Hash"),
JoinAlgorithm::SortMerge => write!(f, "SortMerge"),
}
}
}
#[derive(Clone, Debug)]
pub struct JoinOptions {
pub join_type: JoinType,
pub algorithm: JoinAlgorithm,
pub batch_size: usize,
pub memory_limit_bytes: Option<usize>,
pub concurrency: usize,
}
impl Default for JoinOptions {
fn default() -> Self {
Self {
join_type: JoinType::Inner,
algorithm: JoinAlgorithm::Hash,
batch_size: 8192,
memory_limit_bytes: None,
concurrency: 1,
}
}
}
impl JoinOptions {
pub fn inner() -> Self {
Self {
join_type: JoinType::Inner,
..Default::default()
}
}
pub fn left() -> Self {
Self {
join_type: JoinType::Left,
..Default::default()
}
}
pub fn right() -> Self {
Self {
join_type: JoinType::Right,
..Default::default()
}
}
pub fn full() -> Self {
Self {
join_type: JoinType::Full,
..Default::default()
}
}
pub fn semi() -> Self {
Self {
join_type: JoinType::Semi,
..Default::default()
}
}
pub fn anti() -> Self {
Self {
join_type: JoinType::Anti,
..Default::default()
}
}
pub fn with_algorithm(mut self, algorithm: JoinAlgorithm) -> Self {
self.algorithm = algorithm;
self
}
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
self
}
pub fn with_memory_limit(mut self, limit_bytes: usize) -> Self {
self.memory_limit_bytes = Some(limit_bytes);
self
}
pub fn with_concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = concurrency;
self
}
}
pub fn validate_join_keys(_keys: &[JoinKey]) -> LlkvResult<()> {
Ok(())
}
pub fn validate_join_options(options: &JoinOptions) -> LlkvResult<()> {
if options.batch_size == 0 {
return Err(Error::InvalidArgumentError(
"join batch_size must be > 0".to_string(),
));
}
if options.concurrency == 0 {
return Err(Error::InvalidArgumentError(
"join concurrency must be > 0".to_string(),
));
}
Ok(())
}
pub trait TableJoinExt<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
fn join_stream<F>(
&self,
right: &Table<P>,
keys: &[JoinKey],
options: &JoinOptions,
on_batch: F,
) -> LlkvResult<()>
where
F: FnMut(RecordBatch);
}
impl<P> TableJoinExt<P> for Table<P>
where
P: Pager<Blob = EntryHandle> + Send + Sync,
{
fn join_stream<F>(
&self,
right: &Table<P>,
keys: &[JoinKey],
options: &JoinOptions,
on_batch: F,
) -> LlkvResult<()>
where
F: FnMut(RecordBatch),
{
validate_join_keys(keys)?;
validate_join_options(options)?;
match options.algorithm {
JoinAlgorithm::Hash => {
hash_join::hash_join_stream(self, right, keys, options, on_batch)
}
JoinAlgorithm::SortMerge => Err(Error::Internal(
"Sort-merge join not yet implemented; use JoinAlgorithm::Hash".to_string(),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_join_key_constructors() {
let key = JoinKey::new(10, 20);
assert_eq!(key.left_field, 10);
assert_eq!(key.right_field, 20);
assert!(!key.null_equals_null);
let key_null_safe = JoinKey::null_safe(10, 20);
assert!(key_null_safe.null_equals_null);
}
#[test]
fn test_join_options_builders() {
let inner = JoinOptions::inner();
assert_eq!(inner.join_type, JoinType::Inner);
let left = JoinOptions::left()
.with_algorithm(JoinAlgorithm::Hash)
.with_batch_size(1024)
.with_memory_limit(1_000_000)
.with_concurrency(4);
assert_eq!(left.join_type, JoinType::Left);
assert_eq!(left.algorithm, JoinAlgorithm::Hash);
assert_eq!(left.batch_size, 1024);
assert_eq!(left.memory_limit_bytes, Some(1_000_000));
assert_eq!(left.concurrency, 4);
}
#[test]
fn test_validate_join_keys() {
let empty: Vec<JoinKey> = vec![];
assert!(validate_join_keys(&empty).is_ok());
let keys = vec![JoinKey::new(1, 2)];
assert!(validate_join_keys(&keys).is_ok());
}
#[test]
fn test_validate_join_options() {
let bad_batch = JoinOptions {
batch_size: 0,
..Default::default()
};
assert!(validate_join_options(&bad_batch).is_err());
let bad_concurrency = JoinOptions {
concurrency: 0,
..Default::default()
};
assert!(validate_join_options(&bad_concurrency).is_err());
let good = JoinOptions::default();
assert!(validate_join_options(&good).is_ok());
}
#[test]
fn test_join_type_display() {
assert_eq!(JoinType::Inner.to_string(), "INNER");
assert_eq!(JoinType::Left.to_string(), "LEFT");
assert_eq!(JoinType::Right.to_string(), "RIGHT");
assert_eq!(JoinType::Full.to_string(), "FULL");
assert_eq!(JoinType::Semi.to_string(), "SEMI");
assert_eq!(JoinType::Anti.to_string(), "ANTI");
}
#[test]
fn test_join_algorithm_display() {
assert_eq!(JoinAlgorithm::Hash.to_string(), "Hash");
assert_eq!(JoinAlgorithm::SortMerge.to_string(), "SortMerge");
}
}