1#![forbid(unsafe_code)]
8
9mod hash_join;
10
11use arrow::array::RecordBatch;
12use llkv_result::{Error, Result as LlkvResult};
13use llkv_storage::pager::Pager;
14use llkv_table::table::Table;
15use llkv_table::types::FieldId;
16use simd_r_drive_entry_handle::EntryHandle;
17use std::fmt;
18
19pub use hash_join::hash_join_stream;
20
21#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
23pub enum JoinType {
24 Inner,
26 Left,
28 Right,
30 Full,
32 Semi,
34 Anti,
36}
37
38impl fmt::Display for JoinType {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 JoinType::Inner => write!(f, "INNER"),
42 JoinType::Left => write!(f, "LEFT"),
43 JoinType::Right => write!(f, "RIGHT"),
44 JoinType::Full => write!(f, "FULL"),
45 JoinType::Semi => write!(f, "SEMI"),
46 JoinType::Anti => write!(f, "ANTI"),
47 }
48 }
49}
50
51#[derive(Clone, Debug, PartialEq, Eq)]
53pub struct JoinKey {
54 pub left_field: FieldId,
56 pub right_field: FieldId,
58 pub null_equals_null: bool,
61}
62
63impl JoinKey {
64 pub fn new(left_field: FieldId, right_field: FieldId) -> Self {
66 Self {
67 left_field,
68 right_field,
69 null_equals_null: false,
70 }
71 }
72
73 pub fn null_safe(left_field: FieldId, right_field: FieldId) -> Self {
75 Self {
76 left_field,
77 right_field,
78 null_equals_null: true,
79 }
80 }
81}
82
83#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
85pub enum JoinAlgorithm {
86 #[default]
90 Hash,
91 SortMerge,
95}
96
97impl fmt::Display for JoinAlgorithm {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 match self {
100 JoinAlgorithm::Hash => write!(f, "Hash"),
101 JoinAlgorithm::SortMerge => write!(f, "SortMerge"),
102 }
103 }
104}
105
106#[derive(Clone, Debug)]
108pub struct JoinOptions {
109 pub join_type: JoinType,
111 pub algorithm: JoinAlgorithm,
113 pub batch_size: usize,
117 pub memory_limit_bytes: Option<usize>,
120 pub concurrency: usize,
122}
123
124impl Default for JoinOptions {
125 fn default() -> Self {
126 Self {
127 join_type: JoinType::Inner,
128 algorithm: JoinAlgorithm::Hash,
129 batch_size: 8192,
130 memory_limit_bytes: None,
131 concurrency: 1,
132 }
133 }
134}
135
136impl JoinOptions {
137 pub fn inner() -> Self {
139 Self {
140 join_type: JoinType::Inner,
141 ..Default::default()
142 }
143 }
144
145 pub fn left() -> Self {
147 Self {
148 join_type: JoinType::Left,
149 ..Default::default()
150 }
151 }
152
153 pub fn right() -> Self {
155 Self {
156 join_type: JoinType::Right,
157 ..Default::default()
158 }
159 }
160
161 pub fn full() -> Self {
163 Self {
164 join_type: JoinType::Full,
165 ..Default::default()
166 }
167 }
168
169 pub fn semi() -> Self {
171 Self {
172 join_type: JoinType::Semi,
173 ..Default::default()
174 }
175 }
176
177 pub fn anti() -> Self {
179 Self {
180 join_type: JoinType::Anti,
181 ..Default::default()
182 }
183 }
184
185 pub fn with_algorithm(mut self, algorithm: JoinAlgorithm) -> Self {
187 self.algorithm = algorithm;
188 self
189 }
190
191 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
193 self.batch_size = batch_size;
194 self
195 }
196
197 pub fn with_memory_limit(mut self, limit_bytes: usize) -> Self {
199 self.memory_limit_bytes = Some(limit_bytes);
200 self
201 }
202
203 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
205 self.concurrency = concurrency;
206 self
207 }
208}
209
210pub fn validate_join_keys(_keys: &[JoinKey]) -> LlkvResult<()> {
218 Ok(())
220}
221
222pub fn validate_join_options(options: &JoinOptions) -> LlkvResult<()> {
224 if options.batch_size == 0 {
225 return Err(Error::InvalidArgumentError(
226 "join batch_size must be > 0".to_string(),
227 ));
228 }
229 if options.concurrency == 0 {
230 return Err(Error::InvalidArgumentError(
231 "join concurrency must be > 0".to_string(),
232 ));
233 }
234 Ok(())
235}
236
237pub trait TableJoinExt<P>
239where
240 P: Pager<Blob = EntryHandle> + Send + Sync,
241{
242 fn join_stream<F>(
244 &self,
245 right: &Table<P>,
246 keys: &[JoinKey],
247 options: &JoinOptions,
248 on_batch: F,
249 ) -> LlkvResult<()>
250 where
251 F: FnMut(RecordBatch);
252}
253
254impl<P> TableJoinExt<P> for Table<P>
255where
256 P: Pager<Blob = EntryHandle> + Send + Sync,
257{
258 fn join_stream<F>(
259 &self,
260 right: &Table<P>,
261 keys: &[JoinKey],
262 options: &JoinOptions,
263 on_batch: F,
264 ) -> LlkvResult<()>
265 where
266 F: FnMut(RecordBatch),
267 {
268 validate_join_keys(keys)?;
269 validate_join_options(options)?;
270
271 match options.algorithm {
272 JoinAlgorithm::Hash => {
273 hash_join::hash_join_stream(self, right, keys, options, on_batch)
274 }
275 JoinAlgorithm::SortMerge => Err(Error::Internal(
276 "Sort-merge join not yet implemented; use JoinAlgorithm::Hash".to_string(),
277 )),
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_join_key_constructors() {
288 let key = JoinKey::new(10, 20);
289 assert_eq!(key.left_field, 10);
290 assert_eq!(key.right_field, 20);
291 assert!(!key.null_equals_null);
292
293 let key_null_safe = JoinKey::null_safe(10, 20);
294 assert!(key_null_safe.null_equals_null);
295 }
296
297 #[test]
298 fn test_join_options_builders() {
299 let inner = JoinOptions::inner();
300 assert_eq!(inner.join_type, JoinType::Inner);
301
302 let left = JoinOptions::left()
303 .with_algorithm(JoinAlgorithm::Hash)
304 .with_batch_size(1024)
305 .with_memory_limit(1_000_000)
306 .with_concurrency(4);
307 assert_eq!(left.join_type, JoinType::Left);
308 assert_eq!(left.algorithm, JoinAlgorithm::Hash);
309 assert_eq!(left.batch_size, 1024);
310 assert_eq!(left.memory_limit_bytes, Some(1_000_000));
311 assert_eq!(left.concurrency, 4);
312 }
313
314 #[test]
315 fn test_validate_join_keys() {
316 let empty: Vec<JoinKey> = vec![];
318 assert!(validate_join_keys(&empty).is_ok());
319
320 let keys = vec![JoinKey::new(1, 2)];
321 assert!(validate_join_keys(&keys).is_ok());
322 }
323
324 #[test]
325 fn test_validate_join_options() {
326 let bad_batch = JoinOptions {
327 batch_size: 0,
328 ..Default::default()
329 };
330 assert!(validate_join_options(&bad_batch).is_err());
331
332 let bad_concurrency = JoinOptions {
333 concurrency: 0,
334 ..Default::default()
335 };
336 assert!(validate_join_options(&bad_concurrency).is_err());
337
338 let good = JoinOptions::default();
339 assert!(validate_join_options(&good).is_ok());
340 }
341
342 #[test]
343 fn test_join_type_display() {
344 assert_eq!(JoinType::Inner.to_string(), "INNER");
345 assert_eq!(JoinType::Left.to_string(), "LEFT");
346 assert_eq!(JoinType::Right.to_string(), "RIGHT");
347 assert_eq!(JoinType::Full.to_string(), "FULL");
348 assert_eq!(JoinType::Semi.to_string(), "SEMI");
349 assert_eq!(JoinType::Anti.to_string(), "ANTI");
350 }
351
352 #[test]
353 fn test_join_algorithm_display() {
354 assert_eq!(JoinAlgorithm::Hash.to_string(), "Hash");
355 assert_eq!(JoinAlgorithm::SortMerge.to_string(), "SortMerge");
356 }
357}