1use std::collections::HashMap;
29use std::hash::{Hash, Hasher};
30
31use super::super::engine::binding::{Binding, Value, Var};
32use super::value_compare::total_compare_values;
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum JoinType {
41 Inner,
43 Left,
45 Right,
47 Cross,
49 FullOuter,
51}
52
53#[derive(Debug, Clone)]
55pub enum JoinCondition {
56 Eq(Var, Var),
58 And(Vec<JoinCondition>),
60 None,
62}
63
64impl JoinCondition {
65 pub fn left_vars(&self) -> Vec<Var> {
67 match self {
68 JoinCondition::Eq(left, _) => vec![left.clone()],
69 JoinCondition::And(conditions) => {
70 conditions.iter().flat_map(|c| c.left_vars()).collect()
71 }
72 JoinCondition::None => Vec::new(),
73 }
74 }
75
76 pub fn right_vars(&self) -> Vec<Var> {
78 match self {
79 JoinCondition::Eq(_, right) => vec![right.clone()],
80 JoinCondition::And(conditions) => {
81 conditions.iter().flat_map(|c| c.right_vars()).collect()
82 }
83 JoinCondition::None => Vec::new(),
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum JoinStrategy {
95 Hash,
97 NestedLoop,
99 Merge,
101}
102
103#[derive(Debug, Clone)]
105pub struct JoinStats {
106 pub left_cardinality: usize,
107 pub right_cardinality: usize,
108 pub left_sorted: bool,
109 pub right_sorted: bool,
110 pub condition_selectivity: f64,
111}
112
113impl Default for JoinStats {
114 fn default() -> Self {
115 Self {
116 left_cardinality: 0,
117 right_cardinality: 0,
118 left_sorted: false,
119 right_sorted: false,
120 condition_selectivity: 1.0,
121 }
122 }
123}
124
125pub fn choose_strategy(stats: &JoinStats, condition: &JoinCondition) -> JoinStrategy {
127 if matches!(condition, JoinCondition::None) {
129 return JoinStrategy::NestedLoop;
130 }
131
132 if stats.left_sorted && stats.right_sorted {
134 return JoinStrategy::Merge;
135 }
136
137 let total = stats.left_cardinality * stats.right_cardinality;
139 if total < 1000 {
140 return JoinStrategy::NestedLoop;
141 }
142
143 if matches!(condition, JoinCondition::Eq(_, _) | JoinCondition::And(_)) {
145 return JoinStrategy::Hash;
146 }
147
148 JoinStrategy::NestedLoop
149}
150
151#[derive(Clone, PartialEq, Eq)]
157struct HashKey(Vec<Option<Value>>);
158
159impl Hash for HashKey {
160 fn hash<H: Hasher>(&self, state: &mut H) {
161 for value in &self.0 {
162 match value {
163 Some(Value::String(s)) => {
164 1u8.hash(state);
165 s.hash(state);
166 }
167 Some(Value::Integer(i)) => {
168 2u8.hash(state);
169 i.hash(state);
170 }
171 Some(Value::Float(f)) => {
172 3u8.hash(state);
173 f.to_bits().hash(state);
174 }
175 Some(Value::Boolean(b)) => {
176 4u8.hash(state);
177 b.hash(state);
178 }
179 Some(Value::Uri(u)) => {
180 5u8.hash(state);
181 u.hash(state);
182 }
183 Some(Value::Node(n)) => {
184 6u8.hash(state);
185 n.hash(state);
186 }
187 Some(Value::Edge(e)) => {
188 7u8.hash(state);
189 e.hash(state);
190 }
191 Some(Value::Null) | None => {
192 0u8.hash(state);
193 }
194 }
195 }
196 }
197}
198
199pub fn hash_join(
201 left: Vec<Binding>,
202 right: Vec<Binding>,
203 condition: &JoinCondition,
204 join_type: JoinType,
205) -> Vec<Binding> {
206 let left_keys = condition.left_vars();
207 let right_keys = condition.right_vars();
208
209 if left_keys.is_empty() {
210 return nested_loop_join(left, right, condition, join_type);
212 }
213
214 let (build_side, probe_side, build_keys, probe_keys, is_left_build) =
216 if left.len() <= right.len() {
217 (&left, &right, &left_keys, &right_keys, true)
218 } else {
219 (&right, &left, &right_keys, &left_keys, false)
220 };
221
222 let mut hash_table: HashMap<HashKey, Vec<&Binding>> = HashMap::new();
224 for binding in build_side {
225 let key = extract_key(binding, build_keys);
226 hash_table.entry(key).or_default().push(binding);
227 }
228
229 let mut results = Vec::new();
231 let mut matched_build: std::collections::HashSet<usize> = std::collections::HashSet::new();
232
233 for (probe_idx, probe_binding) in probe_side.iter().enumerate() {
234 let key = extract_key(probe_binding, probe_keys);
235 let matches = hash_table.get(&key);
236
237 let mut had_match = false;
238 if let Some(build_bindings) = matches {
239 for (build_idx, &build_binding) in build_bindings.iter().enumerate() {
240 had_match = true;
241
242 if matches!(join_type, JoinType::FullOuter) {
244 let original_idx = build_side
246 .iter()
247 .position(|b| std::ptr::eq(b, build_binding));
248 if let Some(idx) = original_idx {
249 matched_build.insert(idx);
250 }
251 }
252
253 let merged = if is_left_build {
255 merge_bindings(build_binding, probe_binding)
256 } else {
257 merge_bindings(probe_binding, build_binding)
258 };
259 results.push(merged);
260 }
261 }
262
263 if !had_match {
265 match join_type {
266 JoinType::Left if !is_left_build => {
267 results.push(probe_binding.clone());
269 }
270 JoinType::Right if is_left_build => {
271 results.push(probe_binding.clone());
273 }
274 JoinType::FullOuter => {
275 results.push(probe_binding.clone());
276 }
277 _ => {}
278 }
279 }
280 }
281
282 if matches!(join_type, JoinType::FullOuter) {
284 for (idx, binding) in build_side.iter().enumerate() {
285 if !matched_build.contains(&idx) {
286 results.push((*binding).clone());
287 }
288 }
289 }
290
291 match (join_type, is_left_build) {
293 (JoinType::Left, true) => {
294 let mut all_left_matched: std::collections::HashSet<usize> =
296 std::collections::HashSet::new();
297 for binding in &results {
298 for (idx, left_binding) in left.iter().enumerate() {
300 if bindings_match(binding, left_binding, &left_keys) {
301 all_left_matched.insert(idx);
302 }
303 }
304 }
305 for (idx, binding) in left.iter().enumerate() {
306 if !all_left_matched.contains(&idx) {
307 results.push(binding.clone());
308 }
309 }
310 }
311 (JoinType::Right, false) => {
312 let mut all_right_matched: std::collections::HashSet<usize> =
314 std::collections::HashSet::new();
315 for binding in &results {
316 for (idx, right_binding) in right.iter().enumerate() {
317 if bindings_match(binding, right_binding, &right_keys) {
318 all_right_matched.insert(idx);
319 }
320 }
321 }
322 for (idx, binding) in right.iter().enumerate() {
323 if !all_right_matched.contains(&idx) {
324 results.push(binding.clone());
325 }
326 }
327 }
328 _ => {}
329 }
330
331 results
332}
333
334fn extract_key(binding: &Binding, vars: &[Var]) -> HashKey {
335 HashKey(vars.iter().map(|v| binding.get(v).cloned()).collect())
336}
337
338fn bindings_match(a: &Binding, b: &Binding, keys: &[Var]) -> bool {
339 keys.iter().all(|k| match (a.get(k), b.get(k)) {
340 (Some(v1), Some(v2)) => v1 == v2,
341 _ => false,
342 })
343}
344
345pub fn nested_loop_join(
351 left: Vec<Binding>,
352 right: Vec<Binding>,
353 condition: &JoinCondition,
354 join_type: JoinType,
355) -> Vec<Binding> {
356 let mut results = Vec::new();
357 let mut left_matched = vec![false; left.len()];
358 let mut right_matched = vec![false; right.len()];
359
360 for (left_idx, left_binding) in left.iter().enumerate() {
361 let mut found_match = false;
362
363 for (right_idx, right_binding) in right.iter().enumerate() {
364 if check_condition(left_binding, right_binding, condition) {
365 found_match = true;
366 left_matched[left_idx] = true;
367 right_matched[right_idx] = true;
368
369 let merged = merge_bindings(left_binding, right_binding);
370 results.push(merged);
371 }
372 }
373
374 if !found_match && matches!(join_type, JoinType::Left | JoinType::FullOuter) {
376 results.push(left_binding.clone());
377 }
378 }
379
380 if matches!(join_type, JoinType::Right | JoinType::FullOuter) {
382 for (right_idx, right_binding) in right.iter().enumerate() {
383 if !right_matched[right_idx] {
384 results.push(right_binding.clone());
385 }
386 }
387 }
388
389 results
390}
391
392fn check_condition(left: &Binding, right: &Binding, condition: &JoinCondition) -> bool {
393 match condition {
394 JoinCondition::Eq(left_var, right_var) => {
395 match (left.get(left_var), right.get(right_var)) {
396 (Some(l), Some(r)) => l == r,
397 _ => false,
398 }
399 }
400 JoinCondition::And(conditions) => {
401 conditions.iter().all(|c| check_condition(left, right, c))
402 }
403 JoinCondition::None => true,
404 }
405}
406
407pub fn merge_join(
413 left: Vec<Binding>,
414 right: Vec<Binding>,
415 condition: &JoinCondition,
416 join_type: JoinType,
417) -> Vec<Binding> {
418 let left_keys = condition.left_vars();
421 let right_keys = condition.right_vars();
422
423 if left_keys.is_empty() || right_keys.is_empty() {
424 return nested_loop_join(left, right, condition, join_type);
425 }
426
427 let mut left_sorted = left;
429 let mut right_sorted = right;
430
431 left_sorted.sort_by(|a, b| compare_by_keys(a, b, &left_keys));
432 right_sorted.sort_by(|a, b| compare_by_keys(a, b, &right_keys));
433
434 let mut results = Vec::new();
435 let mut left_idx = 0;
436 let mut right_idx = 0;
437 let mut left_matched = vec![false; left_sorted.len()];
438 let mut right_matched = vec![false; right_sorted.len()];
439
440 while left_idx < left_sorted.len() && right_idx < right_sorted.len() {
441 let left_key = extract_key(&left_sorted[left_idx], &left_keys);
442 let right_key = extract_key(&right_sorted[right_idx], &right_keys);
443
444 match compare_keys(&left_key, &right_key) {
445 std::cmp::Ordering::Less => {
446 if matches!(join_type, JoinType::Left | JoinType::FullOuter)
448 && !left_matched[left_idx]
449 {
450 results.push(left_sorted[left_idx].clone());
451 }
452 left_idx += 1;
453 }
454 std::cmp::Ordering::Greater => {
455 if matches!(join_type, JoinType::Right | JoinType::FullOuter)
457 && !right_matched[right_idx]
458 {
459 results.push(right_sorted[right_idx].clone());
460 }
461 right_idx += 1;
462 }
463 std::cmp::Ordering::Equal => {
464 let match_start_right = right_idx;
466
467 while right_idx < right_sorted.len() {
469 let current_right_key = extract_key(&right_sorted[right_idx], &right_keys);
470 if compare_keys(&left_key, ¤t_right_key) != std::cmp::Ordering::Equal {
471 break;
472 }
473
474 left_matched[left_idx] = true;
475 right_matched[right_idx] = true;
476
477 let merged = merge_bindings(&left_sorted[left_idx], &right_sorted[right_idx]);
478 results.push(merged);
479 right_idx += 1;
480 }
481
482 left_idx += 1;
484 while left_idx < left_sorted.len() {
485 let current_left_key = extract_key(&left_sorted[left_idx], &left_keys);
486 if compare_keys(¤t_left_key, &left_key) != std::cmp::Ordering::Equal {
487 break;
488 }
489
490 for right_row in right_sorted.iter().take(right_idx).skip(match_start_right) {
492 left_matched[left_idx] = true;
493 let merged = merge_bindings(&left_sorted[left_idx], right_row);
494 results.push(merged);
495 }
496 left_idx += 1;
497 }
498
499 right_idx = match_start_right;
501 if left_idx >= left_sorted.len() || {
502 let next_left_key = extract_key(
503 &left_sorted[left_idx.min(left_sorted.len() - 1)],
504 &left_keys,
505 );
506 compare_keys(&next_left_key, &left_key) != std::cmp::Ordering::Equal
507 } {
508 while right_idx < right_sorted.len() {
510 let current_right_key = extract_key(&right_sorted[right_idx], &right_keys);
511 if compare_keys(&left_key, ¤t_right_key) != std::cmp::Ordering::Equal
512 {
513 break;
514 }
515 right_idx += 1;
516 }
517 }
518 }
519 }
520 }
521
522 while left_idx < left_sorted.len() {
524 if matches!(join_type, JoinType::Left | JoinType::FullOuter) && !left_matched[left_idx] {
525 results.push(left_sorted[left_idx].clone());
526 }
527 left_idx += 1;
528 }
529
530 while right_idx < right_sorted.len() {
531 if matches!(join_type, JoinType::Right | JoinType::FullOuter) && !right_matched[right_idx] {
532 results.push(right_sorted[right_idx].clone());
533 }
534 right_idx += 1;
535 }
536
537 results
538}
539
540fn compare_by_keys(a: &Binding, b: &Binding, keys: &[Var]) -> std::cmp::Ordering {
541 for key in keys {
542 match (a.get(key), b.get(key)) {
543 (Some(av), Some(bv)) => {
544 let cmp = total_compare_values(av, bv);
545 if cmp != std::cmp::Ordering::Equal {
546 return cmp;
547 }
548 }
549 (Some(_), None) => return std::cmp::Ordering::Less,
550 (None, Some(_)) => return std::cmp::Ordering::Greater,
551 (None, None) => {}
552 }
553 }
554 std::cmp::Ordering::Equal
555}
556
557fn compare_keys(a: &HashKey, b: &HashKey) -> std::cmp::Ordering {
558 for (av, bv) in a.0.iter().zip(b.0.iter()) {
559 match (av, bv) {
560 (Some(av), Some(bv)) => {
561 let cmp = total_compare_values(av, bv);
562 if cmp != std::cmp::Ordering::Equal {
563 return cmp;
564 }
565 }
566 (Some(_), None) => return std::cmp::Ordering::Less,
567 (None, Some(_)) => return std::cmp::Ordering::Greater,
568 (None, None) => {}
569 }
570 }
571 std::cmp::Ordering::Equal
572}
573
574fn merge_bindings(left: &Binding, right: &Binding) -> Binding {
580 if let Some(merged) = left.merge(right) {
583 merged
584 } else {
585 left.clone()
587 }
588}
589
590pub fn execute_join(
596 left: Vec<Binding>,
597 right: Vec<Binding>,
598 condition: JoinCondition,
599 join_type: JoinType,
600 stats: Option<JoinStats>,
601) -> Vec<Binding> {
602 let actual_stats = stats.unwrap_or(JoinStats {
604 left_cardinality: left.len(),
605 right_cardinality: right.len(),
606 left_sorted: false,
607 right_sorted: false,
608 condition_selectivity: 1.0,
609 });
610
611 let strategy = choose_strategy(&actual_stats, &condition);
612
613 match strategy {
614 JoinStrategy::Hash => hash_join(left, right, &condition, join_type),
615 JoinStrategy::NestedLoop => nested_loop_join(left, right, &condition, join_type),
616 JoinStrategy::Merge => merge_join(left, right, &condition, join_type),
617 }
618}
619
620#[cfg(test)]
625mod tests {
626 use super::*;
627
628 fn make_binding(pairs: &[(&str, &str)]) -> Binding {
629 if pairs.is_empty() {
631 return Binding::empty();
632 }
633
634 let mut result = Binding::one(Var::new(pairs[0].0), Value::String(pairs[0].1.to_string()));
635
636 for (k, v) in pairs.iter().skip(1) {
637 let next = Binding::one(Var::new(k), Value::String(v.to_string()));
638 result = result.merge(&next).unwrap_or(result);
639 }
640
641 result
642 }
643
644 #[test]
645 fn test_inner_join() {
646 let left = vec![
647 make_binding(&[("id", "1"), ("name", "Alice")]),
648 make_binding(&[("id", "2"), ("name", "Bob")]),
649 make_binding(&[("id", "3"), ("name", "Charlie")]),
650 ];
651
652 let right = vec![
653 make_binding(&[("user_id", "1"), ("score", "100")]),
654 make_binding(&[("user_id", "2"), ("score", "90")]),
655 make_binding(&[("user_id", "4"), ("score", "80")]),
656 ];
657
658 let condition = JoinCondition::Eq(Var::new("id"), Var::new("user_id"));
659 let results = execute_join(left, right, condition, JoinType::Inner, None);
660
661 assert_eq!(results.len(), 2);
662 assert!(results
663 .iter()
664 .any(|b| b.get(&Var::new("name")) == Some(&Value::String("Alice".to_string()))));
665 assert!(results
666 .iter()
667 .any(|b| b.get(&Var::new("name")) == Some(&Value::String("Bob".to_string()))));
668 }
669
670 #[test]
671 fn test_left_join() {
672 let left = vec![
673 make_binding(&[("id", "1"), ("name", "Alice")]),
674 make_binding(&[("id", "2"), ("name", "Bob")]),
675 make_binding(&[("id", "3"), ("name", "Charlie")]),
676 ];
677
678 let right = vec![make_binding(&[("user_id", "1"), ("score", "100")])];
679
680 let condition = JoinCondition::Eq(Var::new("id"), Var::new("user_id"));
681 let results = execute_join(left, right, condition, JoinType::Left, None);
682
683 assert_eq!(results.len(), 3); assert!(results
685 .iter()
686 .any(|b| b.get(&Var::new("name")) == Some(&Value::String("Charlie".to_string()))));
687 }
688
689 #[test]
690 fn test_right_join() {
691 let left = vec![make_binding(&[("id", "1"), ("name", "Alice")])];
692
693 let right = vec![
694 make_binding(&[("user_id", "1"), ("score", "100")]),
695 make_binding(&[("user_id", "2"), ("score", "90")]),
696 make_binding(&[("user_id", "3"), ("score", "80")]),
697 ];
698
699 let condition = JoinCondition::Eq(Var::new("id"), Var::new("user_id"));
700 let results = execute_join(left, right, condition, JoinType::Right, None);
701
702 assert_eq!(results.len(), 3); }
704
705 #[test]
706 fn test_cross_join() {
707 let left = vec![make_binding(&[("a", "1")]), make_binding(&[("a", "2")])];
708
709 let right = vec![
710 make_binding(&[("b", "x")]),
711 make_binding(&[("b", "y")]),
712 make_binding(&[("b", "z")]),
713 ];
714
715 let results = execute_join(left, right, JoinCondition::None, JoinType::Cross, None);
716
717 assert_eq!(results.len(), 6); }
719
720 #[test]
721 fn test_merge_join() {
722 let left = vec![
723 make_binding(&[("id", "1"), ("name", "Alice")]),
724 make_binding(&[("id", "2"), ("name", "Bob")]),
725 ];
726
727 let right = vec![
728 make_binding(&[("id", "1"), ("dept", "Eng")]),
729 make_binding(&[("id", "2"), ("dept", "Sales")]),
730 ];
731
732 let condition = JoinCondition::Eq(Var::new("id"), Var::new("id"));
733 let stats = JoinStats {
734 left_cardinality: 2,
735 right_cardinality: 2,
736 left_sorted: true,
737 right_sorted: true,
738 condition_selectivity: 1.0,
739 };
740
741 let results = execute_join(left, right, condition, JoinType::Inner, Some(stats));
742 assert_eq!(results.len(), 2);
743 }
744
745 #[test]
746 fn test_strategy_selection() {
747 let stats = JoinStats {
749 left_cardinality: 10,
750 right_cardinality: 10,
751 left_sorted: false,
752 right_sorted: false,
753 condition_selectivity: 1.0,
754 };
755 assert_eq!(
756 choose_strategy(&stats, &JoinCondition::Eq(Var::new("a"), Var::new("b"))),
757 JoinStrategy::NestedLoop
758 );
759
760 let stats = JoinStats {
762 left_cardinality: 10000,
763 right_cardinality: 10000,
764 left_sorted: false,
765 right_sorted: false,
766 condition_selectivity: 1.0,
767 };
768 assert_eq!(
769 choose_strategy(&stats, &JoinCondition::Eq(Var::new("a"), Var::new("b"))),
770 JoinStrategy::Hash
771 );
772
773 let stats = JoinStats {
775 left_cardinality: 1000,
776 right_cardinality: 1000,
777 left_sorted: true,
778 right_sorted: true,
779 condition_selectivity: 1.0,
780 };
781 assert_eq!(
782 choose_strategy(&stats, &JoinCondition::Eq(Var::new("a"), Var::new("b"))),
783 JoinStrategy::Merge
784 );
785 }
786}