1#![forbid(unsafe_code)]
8
9mod cartesian;
10mod hash_join;
11
12use arrow::array::RecordBatch;
13use llkv_result::{Error, Result as LlkvResult};
14use llkv_storage::pager::Pager;
15use llkv_table::table::Table;
16use llkv_table::types::FieldId;
17use simd_r_drive_entry_handle::EntryHandle;
18use std::fmt;
19
20pub use cartesian::cross_join_pair;
21pub use hash_join::hash_join_stream;
22
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
25pub enum JoinType {
26 Inner,
28 Left,
30 Right,
32 Full,
34 Semi,
36 Anti,
38}
39
40impl fmt::Display for JoinType {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 JoinType::Inner => write!(f, "INNER"),
44 JoinType::Left => write!(f, "LEFT"),
45 JoinType::Right => write!(f, "RIGHT"),
46 JoinType::Full => write!(f, "FULL"),
47 JoinType::Semi => write!(f, "SEMI"),
48 JoinType::Anti => write!(f, "ANTI"),
49 }
50 }
51}
52
53#[derive(Clone, Debug, PartialEq, Eq)]
55pub struct JoinKey {
56 pub left_field: FieldId,
58 pub right_field: FieldId,
60 pub null_equals_null: bool,
63}
64
65impl JoinKey {
66 pub fn new(left_field: FieldId, right_field: FieldId) -> Self {
68 Self {
69 left_field,
70 right_field,
71 null_equals_null: false,
72 }
73 }
74
75 pub fn null_safe(left_field: FieldId, right_field: FieldId) -> Self {
77 Self {
78 left_field,
79 right_field,
80 null_equals_null: true,
81 }
82 }
83}
84
85#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
87pub enum JoinAlgorithm {
88 #[default]
92 Hash,
93 SortMerge,
97}
98
99impl fmt::Display for JoinAlgorithm {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 match self {
102 JoinAlgorithm::Hash => write!(f, "Hash"),
103 JoinAlgorithm::SortMerge => write!(f, "SortMerge"),
104 }
105 }
106}
107
108#[derive(Clone, Debug)]
110pub struct JoinOptions {
111 pub join_type: JoinType,
113 pub algorithm: JoinAlgorithm,
115 pub batch_size: usize,
119 pub memory_limit_bytes: Option<usize>,
122 pub concurrency: usize,
124}
125
126impl Default for JoinOptions {
127 fn default() -> Self {
128 Self {
129 join_type: JoinType::Inner,
130 algorithm: JoinAlgorithm::Hash,
131 batch_size: 8192,
132 memory_limit_bytes: None,
133 concurrency: 1,
134 }
135 }
136}
137
138impl JoinOptions {
139 pub fn inner() -> Self {
141 Self {
142 join_type: JoinType::Inner,
143 ..Default::default()
144 }
145 }
146
147 pub fn left() -> Self {
149 Self {
150 join_type: JoinType::Left,
151 ..Default::default()
152 }
153 }
154
155 pub fn right() -> Self {
157 Self {
158 join_type: JoinType::Right,
159 ..Default::default()
160 }
161 }
162
163 pub fn full() -> Self {
165 Self {
166 join_type: JoinType::Full,
167 ..Default::default()
168 }
169 }
170
171 pub fn semi() -> Self {
173 Self {
174 join_type: JoinType::Semi,
175 ..Default::default()
176 }
177 }
178
179 pub fn anti() -> Self {
181 Self {
182 join_type: JoinType::Anti,
183 ..Default::default()
184 }
185 }
186
187 pub fn with_algorithm(mut self, algorithm: JoinAlgorithm) -> Self {
189 self.algorithm = algorithm;
190 self
191 }
192
193 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
195 self.batch_size = batch_size;
196 self
197 }
198
199 pub fn with_memory_limit(mut self, limit_bytes: usize) -> Self {
201 self.memory_limit_bytes = Some(limit_bytes);
202 self
203 }
204
205 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
207 self.concurrency = concurrency;
208 self
209 }
210}
211
212pub fn validate_join_keys(_keys: &[JoinKey]) -> LlkvResult<()> {
220 Ok(())
222}
223
224pub fn validate_join_options(options: &JoinOptions) -> LlkvResult<()> {
226 if options.batch_size == 0 {
227 return Err(Error::InvalidArgumentError(
228 "join batch_size must be > 0".to_string(),
229 ));
230 }
231 if options.concurrency == 0 {
232 return Err(Error::InvalidArgumentError(
233 "join concurrency must be > 0".to_string(),
234 ));
235 }
236 Ok(())
237}
238
239pub trait TableJoinExt<P>
241where
242 P: Pager<Blob = EntryHandle> + Send + Sync,
243{
244 fn join_stream<F>(
246 &self,
247 right: &Table<P>,
248 keys: &[JoinKey],
249 options: &JoinOptions,
250 on_batch: F,
251 ) -> LlkvResult<()>
252 where
253 F: FnMut(RecordBatch);
254}
255
256impl<P> TableJoinExt<P> for Table<P>
257where
258 P: Pager<Blob = EntryHandle> + Send + Sync,
259{
260 fn join_stream<F>(
261 &self,
262 right: &Table<P>,
263 keys: &[JoinKey],
264 options: &JoinOptions,
265 on_batch: F,
266 ) -> LlkvResult<()>
267 where
268 F: FnMut(RecordBatch),
269 {
270 validate_join_keys(keys)?;
271 validate_join_options(options)?;
272
273 match options.algorithm {
274 JoinAlgorithm::Hash => {
275 hash_join::hash_join_stream(self, right, keys, options, on_batch)
276 }
277 JoinAlgorithm::SortMerge => Err(Error::Internal(
278 "Sort-merge join not yet implemented; use JoinAlgorithm::Hash".to_string(),
279 )),
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn test_join_key_constructors() {
290 let key = JoinKey::new(10, 20);
291 assert_eq!(key.left_field, 10);
292 assert_eq!(key.right_field, 20);
293 assert!(!key.null_equals_null);
294
295 let key_null_safe = JoinKey::null_safe(10, 20);
296 assert!(key_null_safe.null_equals_null);
297 }
298
299 #[test]
300 fn test_join_options_builders() {
301 let inner = JoinOptions::inner();
302 assert_eq!(inner.join_type, JoinType::Inner);
303
304 let left = JoinOptions::left()
305 .with_algorithm(JoinAlgorithm::Hash)
306 .with_batch_size(1024)
307 .with_memory_limit(1_000_000)
308 .with_concurrency(4);
309 assert_eq!(left.join_type, JoinType::Left);
310 assert_eq!(left.algorithm, JoinAlgorithm::Hash);
311 assert_eq!(left.batch_size, 1024);
312 assert_eq!(left.memory_limit_bytes, Some(1_000_000));
313 assert_eq!(left.concurrency, 4);
314 }
315
316 #[test]
317 fn test_validate_join_keys() {
318 let empty: Vec<JoinKey> = vec![];
320 assert!(validate_join_keys(&empty).is_ok());
321
322 let keys = vec![JoinKey::new(1, 2)];
323 assert!(validate_join_keys(&keys).is_ok());
324 }
325
326 #[test]
327 fn test_validate_join_options() {
328 let bad_batch = JoinOptions {
329 batch_size: 0,
330 ..Default::default()
331 };
332 assert!(validate_join_options(&bad_batch).is_err());
333
334 let bad_concurrency = JoinOptions {
335 concurrency: 0,
336 ..Default::default()
337 };
338 assert!(validate_join_options(&bad_concurrency).is_err());
339
340 let good = JoinOptions::default();
341 assert!(validate_join_options(&good).is_ok());
342 }
343
344 #[test]
345 fn test_join_type_display() {
346 assert_eq!(JoinType::Inner.to_string(), "INNER");
347 assert_eq!(JoinType::Left.to_string(), "LEFT");
348 assert_eq!(JoinType::Right.to_string(), "RIGHT");
349 assert_eq!(JoinType::Full.to_string(), "FULL");
350 assert_eq!(JoinType::Semi.to_string(), "SEMI");
351 assert_eq!(JoinType::Anti.to_string(), "ANTI");
352 }
353
354 #[test]
355 fn test_join_algorithm_display() {
356 assert_eq!(JoinAlgorithm::Hash.to_string(), "Hash");
357 assert_eq!(JoinAlgorithm::SortMerge.to_string(), "SortMerge");
358 }
359}