1#![forbid(unsafe_code)]
2
3mod hash_join;
4
5use arrow::array::RecordBatch;
6use llkv_result::{Error, Result as LlkvResult};
7use llkv_storage::pager::Pager;
8use llkv_table::table::Table;
9use llkv_table::types::FieldId;
10use simd_r_drive_entry_handle::EntryHandle;
11use std::fmt;
12
13pub use hash_join::hash_join_stream;
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub enum JoinType {
18 Inner,
20 Left,
22 Right,
24 Full,
26 Semi,
28 Anti,
30}
31
32impl fmt::Display for JoinType {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 JoinType::Inner => write!(f, "INNER"),
36 JoinType::Left => write!(f, "LEFT"),
37 JoinType::Right => write!(f, "RIGHT"),
38 JoinType::Full => write!(f, "FULL"),
39 JoinType::Semi => write!(f, "SEMI"),
40 JoinType::Anti => write!(f, "ANTI"),
41 }
42 }
43}
44
45#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct JoinKey {
48 pub left_field: FieldId,
50 pub right_field: FieldId,
52 pub null_equals_null: bool,
55}
56
57impl JoinKey {
58 pub fn new(left_field: FieldId, right_field: FieldId) -> Self {
60 Self {
61 left_field,
62 right_field,
63 null_equals_null: false,
64 }
65 }
66
67 pub fn null_safe(left_field: FieldId, right_field: FieldId) -> Self {
69 Self {
70 left_field,
71 right_field,
72 null_equals_null: true,
73 }
74 }
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
79pub enum JoinAlgorithm {
80 #[default]
84 Hash,
85 SortMerge,
89}
90
91impl fmt::Display for JoinAlgorithm {
92 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93 match self {
94 JoinAlgorithm::Hash => write!(f, "Hash"),
95 JoinAlgorithm::SortMerge => write!(f, "SortMerge"),
96 }
97 }
98}
99
100#[derive(Clone, Debug)]
102pub struct JoinOptions {
103 pub join_type: JoinType,
105 pub algorithm: JoinAlgorithm,
107 pub batch_size: usize,
109 pub memory_limit_bytes: Option<usize>,
112 pub concurrency: usize,
114}
115
116impl Default for JoinOptions {
117 fn default() -> Self {
118 Self {
119 join_type: JoinType::Inner,
120 algorithm: JoinAlgorithm::Hash,
121 batch_size: 8192,
122 memory_limit_bytes: None,
123 concurrency: 1,
124 }
125 }
126}
127
128impl JoinOptions {
129 pub fn inner() -> Self {
131 Self {
132 join_type: JoinType::Inner,
133 ..Default::default()
134 }
135 }
136
137 pub fn left() -> Self {
139 Self {
140 join_type: JoinType::Left,
141 ..Default::default()
142 }
143 }
144
145 pub fn right() -> Self {
147 Self {
148 join_type: JoinType::Right,
149 ..Default::default()
150 }
151 }
152
153 pub fn full() -> Self {
155 Self {
156 join_type: JoinType::Full,
157 ..Default::default()
158 }
159 }
160
161 pub fn semi() -> Self {
163 Self {
164 join_type: JoinType::Semi,
165 ..Default::default()
166 }
167 }
168
169 pub fn anti() -> Self {
171 Self {
172 join_type: JoinType::Anti,
173 ..Default::default()
174 }
175 }
176
177 pub fn with_algorithm(mut self, algorithm: JoinAlgorithm) -> Self {
179 self.algorithm = algorithm;
180 self
181 }
182
183 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
185 self.batch_size = batch_size;
186 self
187 }
188
189 pub fn with_memory_limit(mut self, limit_bytes: usize) -> Self {
191 self.memory_limit_bytes = Some(limit_bytes);
192 self
193 }
194
195 pub fn with_concurrency(mut self, concurrency: usize) -> Self {
197 self.concurrency = concurrency;
198 self
199 }
200}
201
202pub fn validate_join_keys(keys: &[JoinKey]) -> LlkvResult<()> {
204 if keys.is_empty() {
205 return Err(Error::InvalidArgumentError(
206 "join requires at least one key pair".to_string(),
207 ));
208 }
209 Ok(())
210}
211
212pub fn validate_join_options(options: &JoinOptions) -> LlkvResult<()> {
214 if options.batch_size == 0 {
215 return Err(Error::InvalidArgumentError(
216 "join batch_size must be > 0".to_string(),
217 ));
218 }
219 if options.concurrency == 0 {
220 return Err(Error::InvalidArgumentError(
221 "join concurrency must be > 0".to_string(),
222 ));
223 }
224 Ok(())
225}
226
227pub trait TableJoinExt<P>
229where
230 P: Pager<Blob = EntryHandle> + Send + Sync,
231{
232 fn join_stream<F>(
234 &self,
235 right: &Table<P>,
236 keys: &[JoinKey],
237 options: &JoinOptions,
238 on_batch: F,
239 ) -> LlkvResult<()>
240 where
241 F: FnMut(RecordBatch);
242}
243
244impl<P> TableJoinExt<P> for Table<P>
245where
246 P: Pager<Blob = EntryHandle> + Send + Sync,
247{
248 fn join_stream<F>(
249 &self,
250 right: &Table<P>,
251 keys: &[JoinKey],
252 options: &JoinOptions,
253 on_batch: F,
254 ) -> LlkvResult<()>
255 where
256 F: FnMut(RecordBatch),
257 {
258 validate_join_keys(keys)?;
259 validate_join_options(options)?;
260
261 match options.algorithm {
262 JoinAlgorithm::Hash => {
263 hash_join::hash_join_stream(self, right, keys, options, on_batch)
264 }
265 JoinAlgorithm::SortMerge => Err(Error::Internal(
266 "Sort-merge join not yet implemented; use JoinAlgorithm::Hash".to_string(),
267 )),
268 }
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn test_join_key_constructors() {
278 let key = JoinKey::new(10, 20);
279 assert_eq!(key.left_field, 10);
280 assert_eq!(key.right_field, 20);
281 assert!(!key.null_equals_null);
282
283 let key_null_safe = JoinKey::null_safe(10, 20);
284 assert!(key_null_safe.null_equals_null);
285 }
286
287 #[test]
288 fn test_join_options_builders() {
289 let inner = JoinOptions::inner();
290 assert_eq!(inner.join_type, JoinType::Inner);
291
292 let left = JoinOptions::left()
293 .with_algorithm(JoinAlgorithm::Hash)
294 .with_batch_size(1024)
295 .with_memory_limit(1_000_000)
296 .with_concurrency(4);
297 assert_eq!(left.join_type, JoinType::Left);
298 assert_eq!(left.algorithm, JoinAlgorithm::Hash);
299 assert_eq!(left.batch_size, 1024);
300 assert_eq!(left.memory_limit_bytes, Some(1_000_000));
301 assert_eq!(left.concurrency, 4);
302 }
303
304 #[test]
305 fn test_validate_join_keys() {
306 let empty: Vec<JoinKey> = vec![];
307 assert!(validate_join_keys(&empty).is_err());
308
309 let keys = vec![JoinKey::new(1, 2)];
310 assert!(validate_join_keys(&keys).is_ok());
311 }
312
313 #[test]
314 fn test_validate_join_options() {
315 let bad_batch = JoinOptions {
316 batch_size: 0,
317 ..Default::default()
318 };
319 assert!(validate_join_options(&bad_batch).is_err());
320
321 let bad_concurrency = JoinOptions {
322 concurrency: 0,
323 ..Default::default()
324 };
325 assert!(validate_join_options(&bad_concurrency).is_err());
326
327 let good = JoinOptions::default();
328 assert!(validate_join_options(&good).is_ok());
329 }
330
331 #[test]
332 fn test_join_type_display() {
333 assert_eq!(JoinType::Inner.to_string(), "INNER");
334 assert_eq!(JoinType::Left.to_string(), "LEFT");
335 assert_eq!(JoinType::Right.to_string(), "RIGHT");
336 assert_eq!(JoinType::Full.to_string(), "FULL");
337 assert_eq!(JoinType::Semi.to_string(), "SEMI");
338 assert_eq!(JoinType::Anti.to_string(), "ANTI");
339 }
340
341 #[test]
342 fn test_join_algorithm_display() {
343 assert_eq!(JoinAlgorithm::Hash.to_string(), "Hash");
344 assert_eq!(JoinAlgorithm::SortMerge.to_string(), "SortMerge");
345 }
346}