llkv_join/
lib.rs

1#![forbid(unsafe_code)]
2
3mod hash_join;
4
5use arrow::array::RecordBatch;
6use llkv_result::{Error, Result as LlkvResult};
7use llkv_storage::pager::Pager;
8use llkv_table::table::Table;
9use llkv_table::types::FieldId;
10use simd_r_drive_entry_handle::EntryHandle;
11use std::fmt;
12
13pub use hash_join::hash_join_stream;
14
15/// Type of join to perform.
16#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub enum JoinType {
18    /// Emit only matching row pairs.
19    Inner,
20    /// Emit all left rows; unmatched left rows have NULL right columns.
21    Left,
22    /// Emit all right rows; unmatched right rows have NULL left columns.
23    Right,
24    /// Emit all rows from both sides; unmatched rows have NULLs.
25    Full,
26    /// Emit left rows that have at least one match (no right columns).
27    Semi,
28    /// Emit left rows that have no match (no right columns).
29    Anti,
30}
31
32impl fmt::Display for JoinType {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            JoinType::Inner => write!(f, "INNER"),
36            JoinType::Left => write!(f, "LEFT"),
37            JoinType::Right => write!(f, "RIGHT"),
38            JoinType::Full => write!(f, "FULL"),
39            JoinType::Semi => write!(f, "SEMI"),
40            JoinType::Anti => write!(f, "ANTI"),
41        }
42    }
43}
44
45/// Join key pair describing which columns to equate.
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct JoinKey {
48    /// Field ID from the left table.
49    pub left_field: FieldId,
50    /// Field ID from the right table.
51    pub right_field: FieldId,
52    /// If true, NULL == NULL for this key (SQL-style NULL-safe equality).
53    /// If false, NULL != NULL (Arrow default).
54    pub null_equals_null: bool,
55}
56
57impl JoinKey {
58    /// Create a join key with standard Arrow null semantics (NULL != NULL).
59    pub fn new(left_field: FieldId, right_field: FieldId) -> Self {
60        Self {
61            left_field,
62            right_field,
63            null_equals_null: false,
64        }
65    }
66
67    /// Create a join key with SQL-style NULL-safe equality (NULL == NULL).
68    pub fn null_safe(left_field: FieldId, right_field: FieldId) -> Self {
69        Self {
70            left_field,
71            right_field,
72            null_equals_null: true,
73        }
74    }
75}
76
77/// Algorithm to use for join execution.
78#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
79pub enum JoinAlgorithm {
80    /// Hash join: build hash table on one side, probe with other.
81    /// O(N+M) complexity - suitable for production workloads.
82    /// Default and recommended for all equality joins.
83    #[default]
84    Hash,
85    /// Sort-merge join: sort both sides, then merge.
86    /// Good for pre-sorted inputs or when memory is constrained.
87    /// Not yet implemented.
88    SortMerge,
89}
90
91impl fmt::Display for JoinAlgorithm {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        match self {
94            JoinAlgorithm::Hash => write!(f, "Hash"),
95            JoinAlgorithm::SortMerge => write!(f, "SortMerge"),
96        }
97    }
98}
99
100/// Options controlling join execution.
101#[derive(Clone, Debug)]
102pub struct JoinOptions {
103    /// Type of join to perform.
104    pub join_type: JoinType,
105    /// Algorithm to use. Planner may override this based on table sizes.
106    pub algorithm: JoinAlgorithm,
107    /// Target number of probe rows per output `RecordBatch`.
108    /// Larger batches reduce per-batch overhead (fewer Arrow gathers) at the
109    /// cost of increased peak memory; smaller batches improve latency.
110    pub batch_size: usize,
111    /// Memory limit in bytes for hash table (hash join only).
112    /// When exceeded, algorithm will partition and spill to disk.
113    pub memory_limit_bytes: Option<usize>,
114    /// Concurrency hint: number of threads for parallel partitions.
115    pub concurrency: usize,
116}
117
118impl Default for JoinOptions {
119    fn default() -> Self {
120        Self {
121            join_type: JoinType::Inner,
122            algorithm: JoinAlgorithm::Hash,
123            batch_size: 8192,
124            memory_limit_bytes: None,
125            concurrency: 1,
126        }
127    }
128}
129
130impl JoinOptions {
131    /// Create options for an inner join with default settings.
132    pub fn inner() -> Self {
133        Self {
134            join_type: JoinType::Inner,
135            ..Default::default()
136        }
137    }
138
139    /// Create options for a left outer join with default settings.
140    pub fn left() -> Self {
141        Self {
142            join_type: JoinType::Left,
143            ..Default::default()
144        }
145    }
146
147    /// Create options for a right outer join with default settings.
148    pub fn right() -> Self {
149        Self {
150            join_type: JoinType::Right,
151            ..Default::default()
152        }
153    }
154
155    /// Create options for a full outer join with default settings.
156    pub fn full() -> Self {
157        Self {
158            join_type: JoinType::Full,
159            ..Default::default()
160        }
161    }
162
163    /// Create options for a semi join with default settings.
164    pub fn semi() -> Self {
165        Self {
166            join_type: JoinType::Semi,
167            ..Default::default()
168        }
169    }
170
171    /// Create options for an anti join with default settings.
172    pub fn anti() -> Self {
173        Self {
174            join_type: JoinType::Anti,
175            ..Default::default()
176        }
177    }
178
179    /// Set the join algorithm.
180    pub fn with_algorithm(mut self, algorithm: JoinAlgorithm) -> Self {
181        self.algorithm = algorithm;
182        self
183    }
184
185    /// Set the output batch size.
186    pub fn with_batch_size(mut self, batch_size: usize) -> Self {
187        self.batch_size = batch_size;
188        self
189    }
190
191    /// Set the memory limit for hash joins.
192    pub fn with_memory_limit(mut self, limit_bytes: usize) -> Self {
193        self.memory_limit_bytes = Some(limit_bytes);
194        self
195    }
196
197    /// Set the concurrency level.
198    pub fn with_concurrency(mut self, concurrency: usize) -> Self {
199        self.concurrency = concurrency;
200        self
201    }
202}
203
204// TODO: Build out more fully or remove
205/// Validate join keys before execution.
206/// Note: Empty keys = cross product (Cartesian product)
207pub fn validate_join_keys(_keys: &[JoinKey]) -> LlkvResult<()> {
208    // Empty keys is valid for cross product
209    Ok(())
210}
211
212/// Validate join options before execution.
213pub fn validate_join_options(options: &JoinOptions) -> LlkvResult<()> {
214    if options.batch_size == 0 {
215        return Err(Error::InvalidArgumentError(
216            "join batch_size must be > 0".to_string(),
217        ));
218    }
219    if options.concurrency == 0 {
220        return Err(Error::InvalidArgumentError(
221            "join concurrency must be > 0".to_string(),
222        ));
223    }
224    Ok(())
225}
226
227/// Extension trait adding join operations to `Table`.
228pub trait TableJoinExt<P>
229where
230    P: Pager<Blob = EntryHandle> + Send + Sync,
231{
232    /// Join this table with another table based on equality predicates.
233    fn join_stream<F>(
234        &self,
235        right: &Table<P>,
236        keys: &[JoinKey],
237        options: &JoinOptions,
238        on_batch: F,
239    ) -> LlkvResult<()>
240    where
241        F: FnMut(RecordBatch);
242}
243
244impl<P> TableJoinExt<P> for Table<P>
245where
246    P: Pager<Blob = EntryHandle> + Send + Sync,
247{
248    fn join_stream<F>(
249        &self,
250        right: &Table<P>,
251        keys: &[JoinKey],
252        options: &JoinOptions,
253        on_batch: F,
254    ) -> LlkvResult<()>
255    where
256        F: FnMut(RecordBatch),
257    {
258        validate_join_keys(keys)?;
259        validate_join_options(options)?;
260
261        match options.algorithm {
262            JoinAlgorithm::Hash => {
263                hash_join::hash_join_stream(self, right, keys, options, on_batch)
264            }
265            JoinAlgorithm::SortMerge => Err(Error::Internal(
266                "Sort-merge join not yet implemented; use JoinAlgorithm::Hash".to_string(),
267            )),
268        }
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275
276    #[test]
277    fn test_join_key_constructors() {
278        let key = JoinKey::new(10, 20);
279        assert_eq!(key.left_field, 10);
280        assert_eq!(key.right_field, 20);
281        assert!(!key.null_equals_null);
282
283        let key_null_safe = JoinKey::null_safe(10, 20);
284        assert!(key_null_safe.null_equals_null);
285    }
286
287    #[test]
288    fn test_join_options_builders() {
289        let inner = JoinOptions::inner();
290        assert_eq!(inner.join_type, JoinType::Inner);
291
292        let left = JoinOptions::left()
293            .with_algorithm(JoinAlgorithm::Hash)
294            .with_batch_size(1024)
295            .with_memory_limit(1_000_000)
296            .with_concurrency(4);
297        assert_eq!(left.join_type, JoinType::Left);
298        assert_eq!(left.algorithm, JoinAlgorithm::Hash);
299        assert_eq!(left.batch_size, 1024);
300        assert_eq!(left.memory_limit_bytes, Some(1_000_000));
301        assert_eq!(left.concurrency, 4);
302    }
303
304    #[test]
305    fn test_validate_join_keys() {
306        // Empty keys are valid (cross product)
307        let empty: Vec<JoinKey> = vec![];
308        assert!(validate_join_keys(&empty).is_ok());
309
310        let keys = vec![JoinKey::new(1, 2)];
311        assert!(validate_join_keys(&keys).is_ok());
312    }
313
314    #[test]
315    fn test_validate_join_options() {
316        let bad_batch = JoinOptions {
317            batch_size: 0,
318            ..Default::default()
319        };
320        assert!(validate_join_options(&bad_batch).is_err());
321
322        let bad_concurrency = JoinOptions {
323            concurrency: 0,
324            ..Default::default()
325        };
326        assert!(validate_join_options(&bad_concurrency).is_err());
327
328        let good = JoinOptions::default();
329        assert!(validate_join_options(&good).is_ok());
330    }
331
332    #[test]
333    fn test_join_type_display() {
334        assert_eq!(JoinType::Inner.to_string(), "INNER");
335        assert_eq!(JoinType::Left.to_string(), "LEFT");
336        assert_eq!(JoinType::Right.to_string(), "RIGHT");
337        assert_eq!(JoinType::Full.to_string(), "FULL");
338        assert_eq!(JoinType::Semi.to_string(), "SEMI");
339        assert_eq!(JoinType::Anti.to_string(), "ANTI");
340    }
341
342    #[test]
343    fn test_join_algorithm_display() {
344        assert_eq!(JoinAlgorithm::Hash.to_string(), "Hash");
345        assert_eq!(JoinAlgorithm::SortMerge.to_string(), "SortMerge");
346    }
347}