llkv_join/
lib.rs

1#![forbid(unsafe_code)]
2
3mod hash_join;
4
5use arrow::array::RecordBatch;
6use llkv_result::{Error, Result as LlkvResult};
7use llkv_storage::pager::Pager;
8use llkv_table::table::Table;
9use llkv_table::types::FieldId;
10use simd_r_drive_entry_handle::EntryHandle;
11use std::fmt;
12
13pub use hash_join::hash_join_stream;
14
15/// Type of join to perform.
16#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub enum JoinType {
18    /// Emit only matching row pairs.
19    Inner,
20    /// Emit all left rows; unmatched left rows have NULL right columns.
21    Left,
22    /// Emit all right rows; unmatched right rows have NULL left columns.
23    Right,
24    /// Emit all rows from both sides; unmatched rows have NULLs.
25    Full,
26    /// Emit left rows that have at least one match (no right columns).
27    Semi,
28    /// Emit left rows that have no match (no right columns).
29    Anti,
30}
31
32impl fmt::Display for JoinType {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            JoinType::Inner => write!(f, "INNER"),
36            JoinType::Left => write!(f, "LEFT"),
37            JoinType::Right => write!(f, "RIGHT"),
38            JoinType::Full => write!(f, "FULL"),
39            JoinType::Semi => write!(f, "SEMI"),
40            JoinType::Anti => write!(f, "ANTI"),
41        }
42    }
43}
44
45/// Join key pair describing which columns to equate.
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct JoinKey {
48    /// Field ID from the left table.
49    pub left_field: FieldId,
50    /// Field ID from the right table.
51    pub right_field: FieldId,
52    /// If true, NULL == NULL for this key (SQL-style NULL-safe equality).
53    /// If false, NULL != NULL (Arrow default).
54    pub null_equals_null: bool,
55}
56
57impl JoinKey {
58    /// Create a join key with standard Arrow null semantics (NULL != NULL).
59    pub fn new(left_field: FieldId, right_field: FieldId) -> Self {
60        Self {
61            left_field,
62            right_field,
63            null_equals_null: false,
64        }
65    }
66
67    /// Create a join key with SQL-style NULL-safe equality (NULL == NULL).
68    pub fn null_safe(left_field: FieldId, right_field: FieldId) -> Self {
69        Self {
70            left_field,
71            right_field,
72            null_equals_null: true,
73        }
74    }
75}
76
77/// Algorithm to use for join execution.
78#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
79pub enum JoinAlgorithm {
80    /// Hash join: build hash table on one side, probe with other.
81    /// O(N+M) complexity - suitable for production workloads.
82    /// Default and recommended for all equality joins.
83    #[default]
84    Hash,
85    /// Sort-merge join: sort both sides, then merge.
86    /// Good for pre-sorted inputs or when memory is constrained.
87    /// Not yet implemented.
88    SortMerge,
89}
90
91impl fmt::Display for JoinAlgorithm {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        match self {
94            JoinAlgorithm::Hash => write!(f, "Hash"),
95            JoinAlgorithm::SortMerge => write!(f, "SortMerge"),
96        }
97    }
98}
99
100/// Options controlling join execution.
101#[derive(Clone, Debug)]
102pub struct JoinOptions {
103    /// Type of join to perform.
104    pub join_type: JoinType,
105    /// Algorithm to use. Planner may override this based on table sizes.
106    pub algorithm: JoinAlgorithm,
107    /// Batch size for output RecordBatches.
108    pub batch_size: usize,
109    /// Memory limit in bytes for hash table (hash join only).
110    /// When exceeded, algorithm will partition and spill to disk.
111    pub memory_limit_bytes: Option<usize>,
112    /// Concurrency hint: number of threads for parallel partitions.
113    pub concurrency: usize,
114}
115
116impl Default for JoinOptions {
117    fn default() -> Self {
118        Self {
119            join_type: JoinType::Inner,
120            algorithm: JoinAlgorithm::Hash,
121            batch_size: 8192,
122            memory_limit_bytes: None,
123            concurrency: 1,
124        }
125    }
126}
127
128impl JoinOptions {
129    /// Create options for an inner join with default settings.
130    pub fn inner() -> Self {
131        Self {
132            join_type: JoinType::Inner,
133            ..Default::default()
134        }
135    }
136
137    /// Create options for a left outer join with default settings.
138    pub fn left() -> Self {
139        Self {
140            join_type: JoinType::Left,
141            ..Default::default()
142        }
143    }
144
145    /// Create options for a right outer join with default settings.
146    pub fn right() -> Self {
147        Self {
148            join_type: JoinType::Right,
149            ..Default::default()
150        }
151    }
152
153    /// Create options for a full outer join with default settings.
154    pub fn full() -> Self {
155        Self {
156            join_type: JoinType::Full,
157            ..Default::default()
158        }
159    }
160
161    /// Create options for a semi join with default settings.
162    pub fn semi() -> Self {
163        Self {
164            join_type: JoinType::Semi,
165            ..Default::default()
166        }
167    }
168
169    /// Create options for an anti join with default settings.
170    pub fn anti() -> Self {
171        Self {
172            join_type: JoinType::Anti,
173            ..Default::default()
174        }
175    }
176
177    /// Set the join algorithm.
178    pub fn with_algorithm(mut self, algorithm: JoinAlgorithm) -> Self {
179        self.algorithm = algorithm;
180        self
181    }
182
183    /// Set the output batch size.
184    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
185        self.batch_size = batch_size;
186        self
187    }
188
189    /// Set the memory limit for hash joins.
190    pub fn with_memory_limit(mut self, limit_bytes: usize) -> Self {
191        self.memory_limit_bytes = Some(limit_bytes);
192        self
193    }
194
195    /// Set the concurrency level.
196    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
197        self.concurrency = concurrency;
198        self
199    }
200}
201
202/// Validate join keys before execution.
203pub fn validate_join_keys(keys: &[JoinKey]) -> LlkvResult<()> {
204    if keys.is_empty() {
205        return Err(Error::InvalidArgumentError(
206            "join requires at least one key pair".to_string(),
207        ));
208    }
209    Ok(())
210}
211
212/// Validate join options before execution.
213pub fn validate_join_options(options: &JoinOptions) -> LlkvResult<()> {
214    if options.batch_size == 0 {
215        return Err(Error::InvalidArgumentError(
216            "join batch_size must be > 0".to_string(),
217        ));
218    }
219    if options.concurrency == 0 {
220        return Err(Error::InvalidArgumentError(
221            "join concurrency must be > 0".to_string(),
222        ));
223    }
224    Ok(())
225}
226
227/// Extension trait adding join operations to `Table`.
228pub trait TableJoinExt<P>
229where
230    P: Pager<Blob = EntryHandle> + Send + Sync,
231{
232    /// Join this table with another table based on equality predicates.
233    fn join_stream<F>(
234        &self,
235        right: &Table<P>,
236        keys: &[JoinKey],
237        options: &JoinOptions,
238        on_batch: F,
239    ) -> LlkvResult<()>
240    where
241        F: FnMut(RecordBatch);
242}
243
244impl<P> TableJoinExt<P> for Table<P>
245where
246    P: Pager<Blob = EntryHandle> + Send + Sync,
247{
248    fn join_stream<F>(
249        &self,
250        right: &Table<P>,
251        keys: &[JoinKey],
252        options: &JoinOptions,
253        on_batch: F,
254    ) -> LlkvResult<()>
255    where
256        F: FnMut(RecordBatch),
257    {
258        validate_join_keys(keys)?;
259        validate_join_options(options)?;
260
261        match options.algorithm {
262            JoinAlgorithm::Hash => {
263                hash_join::hash_join_stream(self, right, keys, options, on_batch)
264            }
265            JoinAlgorithm::SortMerge => Err(Error::Internal(
266                "Sort-merge join not yet implemented; use JoinAlgorithm::Hash".to_string(),
267            )),
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_join_key_constructors() {
278        let key = JoinKey::new(10, 20);
279        assert_eq!(key.left_field, 10);
280        assert_eq!(key.right_field, 20);
281        assert!(!key.null_equals_null);
282
283        let key_null_safe = JoinKey::null_safe(10, 20);
284        assert!(key_null_safe.null_equals_null);
285    }
286
287    #[test]
288    fn test_join_options_builders() {
289        let inner = JoinOptions::inner();
290        assert_eq!(inner.join_type, JoinType::Inner);
291
292        let left = JoinOptions::left()
293            .with_algorithm(JoinAlgorithm::Hash)
294            .with_batch_size(1024)
295            .with_memory_limit(1_000_000)
296            .with_concurrency(4);
297        assert_eq!(left.join_type, JoinType::Left);
298        assert_eq!(left.algorithm, JoinAlgorithm::Hash);
299        assert_eq!(left.batch_size, 1024);
300        assert_eq!(left.memory_limit_bytes, Some(1_000_000));
301        assert_eq!(left.concurrency, 4);
302    }
303
304    #[test]
305    fn test_validate_join_keys() {
306        let empty: Vec<JoinKey> = vec![];
307        assert!(validate_join_keys(&empty).is_err());
308
309        let keys = vec![JoinKey::new(1, 2)];
310        assert!(validate_join_keys(&keys).is_ok());
311    }
312
313    #[test]
314    fn test_validate_join_options() {
315        let bad_batch = JoinOptions {
316            batch_size: 0,
317            ..Default::default()
318        };
319        assert!(validate_join_options(&bad_batch).is_err());
320
321        let bad_concurrency = JoinOptions {
322            concurrency: 0,
323            ..Default::default()
324        };
325        assert!(validate_join_options(&bad_concurrency).is_err());
326
327        let good = JoinOptions::default();
328        assert!(validate_join_options(&good).is_ok());
329    }
330
331    #[test]
332    fn test_join_type_display() {
333        assert_eq!(JoinType::Inner.to_string(), "INNER");
334        assert_eq!(JoinType::Left.to_string(), "LEFT");
335        assert_eq!(JoinType::Right.to_string(), "RIGHT");
336        assert_eq!(JoinType::Full.to_string(), "FULL");
337        assert_eq!(JoinType::Semi.to_string(), "SEMI");
338        assert_eq!(JoinType::Anti.to_string(), "ANTI");
339    }
340
341    #[test]
342    fn test_join_algorithm_display() {
343        assert_eq!(JoinAlgorithm::Hash.to_string(), "Hash");
344        assert_eq!(JoinAlgorithm::SortMerge.to_string(), "SortMerge");
345    }
346}