1#![forbid(unsafe_code)]
2
3mod hash_join;
4
5use arrow::array::RecordBatch;
6use llkv_result::{Error, Result as LlkvResult};
7use llkv_storage::pager::Pager;
8use llkv_table::table::Table;
9use llkv_table::types::FieldId;
10use simd_r_drive_entry_handle::EntryHandle;
11use std::fmt;
12
13pub use hash_join::hash_join_stream;
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub enum JoinType {
18 Inner,
20 Left,
22 Right,
24 Full,
26 Semi,
28 Anti,
30}
31
32impl fmt::Display for JoinType {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 JoinType::Inner => write!(f, "INNER"),
36 JoinType::Left => write!(f, "LEFT"),
37 JoinType::Right => write!(f, "RIGHT"),
38 JoinType::Full => write!(f, "FULL"),
39 JoinType::Semi => write!(f, "SEMI"),
40 JoinType::Anti => write!(f, "ANTI"),
41 }
42 }
43}
44
45#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct JoinKey {
48 pub left_field: FieldId,
50 pub right_field: FieldId,
52 pub null_equals_null: bool,
55}
56
57impl JoinKey {
58 pub fn new(left_field: FieldId, right_field: FieldId) -> Self {
60 Self {
61 left_field,
62 right_field,
63 null_equals_null: false,
64 }
65 }
66
67 pub fn null_safe(left_field: FieldId, right_field: FieldId) -> Self {
69 Self {
70 left_field,
71 right_field,
72 null_equals_null: true,
73 }
74 }
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
79pub enum JoinAlgorithm {
80 #[default]
84 Hash,
85 SortMerge,
89}
90
91impl fmt::Display for JoinAlgorithm {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self {
94 JoinAlgorithm::Hash => write!(f, "Hash"),
95 JoinAlgorithm::SortMerge => write!(f, "SortMerge"),
96 }
97 }
98}
99
100#[derive(Clone, Debug)]
102pub struct JoinOptions {
103 pub join_type: JoinType,
105 pub algorithm: JoinAlgorithm,
107 pub batch_size: usize,
111 pub memory_limit_bytes: Option<usize>,
114 pub concurrency: usize,
116}
117
118impl Default for JoinOptions {
119 fn default() -> Self {
120 Self {
121 join_type: JoinType::Inner,
122 algorithm: JoinAlgorithm::Hash,
123 batch_size: 8192,
124 memory_limit_bytes: None,
125 concurrency: 1,
126 }
127 }
128}
129
130impl JoinOptions {
131 pub fn inner() -> Self {
133 Self {
134 join_type: JoinType::Inner,
135 ..Default::default()
136 }
137 }
138
139 pub fn left() -> Self {
141 Self {
142 join_type: JoinType::Left,
143 ..Default::default()
144 }
145 }
146
147 pub fn right() -> Self {
149 Self {
150 join_type: JoinType::Right,
151 ..Default::default()
152 }
153 }
154
155 pub fn full() -> Self {
157 Self {
158 join_type: JoinType::Full,
159 ..Default::default()
160 }
161 }
162
163 pub fn semi() -> Self {
165 Self {
166 join_type: JoinType::Semi,
167 ..Default::default()
168 }
169 }
170
171 pub fn anti() -> Self {
173 Self {
174 join_type: JoinType::Anti,
175 ..Default::default()
176 }
177 }
178
179 pub fn with_algorithm(mut self, algorithm: JoinAlgorithm) -> Self {
181 self.algorithm = algorithm;
182 self
183 }
184
185 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
187 self.batch_size = batch_size;
188 self
189 }
190
191 pub fn with_memory_limit(mut self, limit_bytes: usize) -> Self {
193 self.memory_limit_bytes = Some(limit_bytes);
194 self
195 }
196
197 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
199 self.concurrency = concurrency;
200 self
201 }
202}
203
204pub fn validate_join_keys(_keys: &[JoinKey]) -> LlkvResult<()> {
208 Ok(())
210}
211
212pub fn validate_join_options(options: &JoinOptions) -> LlkvResult<()> {
214 if options.batch_size == 0 {
215 return Err(Error::InvalidArgumentError(
216 "join batch_size must be > 0".to_string(),
217 ));
218 }
219 if options.concurrency == 0 {
220 return Err(Error::InvalidArgumentError(
221 "join concurrency must be > 0".to_string(),
222 ));
223 }
224 Ok(())
225}
226
227pub trait TableJoinExt<P>
229where
230 P: Pager<Blob = EntryHandle> + Send + Sync,
231{
232 fn join_stream<F>(
234 &self,
235 right: &Table<P>,
236 keys: &[JoinKey],
237 options: &JoinOptions,
238 on_batch: F,
239 ) -> LlkvResult<()>
240 where
241 F: FnMut(RecordBatch);
242}
243
244impl<P> TableJoinExt<P> for Table<P>
245where
246 P: Pager<Blob = EntryHandle> + Send + Sync,
247{
248 fn join_stream<F>(
249 &self,
250 right: &Table<P>,
251 keys: &[JoinKey],
252 options: &JoinOptions,
253 on_batch: F,
254 ) -> LlkvResult<()>
255 where
256 F: FnMut(RecordBatch),
257 {
258 validate_join_keys(keys)?;
259 validate_join_options(options)?;
260
261 match options.algorithm {
262 JoinAlgorithm::Hash => {
263 hash_join::hash_join_stream(self, right, keys, options, on_batch)
264 }
265 JoinAlgorithm::SortMerge => Err(Error::Internal(
266 "Sort-merge join not yet implemented; use JoinAlgorithm::Hash".to_string(),
267 )),
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_join_key_constructors() {
278 let key = JoinKey::new(10, 20);
279 assert_eq!(key.left_field, 10);
280 assert_eq!(key.right_field, 20);
281 assert!(!key.null_equals_null);
282
283 let key_null_safe = JoinKey::null_safe(10, 20);
284 assert!(key_null_safe.null_equals_null);
285 }
286
287 #[test]
288 fn test_join_options_builders() {
289 let inner = JoinOptions::inner();
290 assert_eq!(inner.join_type, JoinType::Inner);
291
292 let left = JoinOptions::left()
293 .with_algorithm(JoinAlgorithm::Hash)
294 .with_batch_size(1024)
295 .with_memory_limit(1_000_000)
296 .with_concurrency(4);
297 assert_eq!(left.join_type, JoinType::Left);
298 assert_eq!(left.algorithm, JoinAlgorithm::Hash);
299 assert_eq!(left.batch_size, 1024);
300 assert_eq!(left.memory_limit_bytes, Some(1_000_000));
301 assert_eq!(left.concurrency, 4);
302 }
303
304 #[test]
305 fn test_validate_join_keys() {
306 let empty: Vec<JoinKey> = vec![];
308 assert!(validate_join_keys(&empty).is_ok());
309
310 let keys = vec![JoinKey::new(1, 2)];
311 assert!(validate_join_keys(&keys).is_ok());
312 }
313
314 #[test]
315 fn test_validate_join_options() {
316 let bad_batch = JoinOptions {
317 batch_size: 0,
318 ..Default::default()
319 };
320 assert!(validate_join_options(&bad_batch).is_err());
321
322 let bad_concurrency = JoinOptions {
323 concurrency: 0,
324 ..Default::default()
325 };
326 assert!(validate_join_options(&bad_concurrency).is_err());
327
328 let good = JoinOptions::default();
329 assert!(validate_join_options(&good).is_ok());
330 }
331
332 #[test]
333 fn test_join_type_display() {
334 assert_eq!(JoinType::Inner.to_string(), "INNER");
335 assert_eq!(JoinType::Left.to_string(), "LEFT");
336 assert_eq!(JoinType::Right.to_string(), "RIGHT");
337 assert_eq!(JoinType::Full.to_string(), "FULL");
338 assert_eq!(JoinType::Semi.to_string(), "SEMI");
339 assert_eq!(JoinType::Anti.to_string(), "ANTI");
340 }
341
342 #[test]
343 fn test_join_algorithm_display() {
344 assert_eq!(JoinAlgorithm::Hash.to_string(), "Hash");
345 assert_eq!(JoinAlgorithm::SortMerge.to_string(), "SortMerge");
346 }
347}