1#[allow(
11 clippy::disallowed_types,
12 reason = "HashMap needed for O(1) hash-join probe; output follows left-table order, which is deterministic"
13)]
14use std::collections::HashMap;
15
16use crate::column::Column;
17use crate::dataframe::DataFrame;
18use crate::error::DataFrameError;
19use crate::scalar::Scalar;
20
21const NULL_SENTINEL: &str = "\0NULL\0";
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27#[non_exhaustive]
28pub enum JoinType {
29 Inner,
31 Left,
33 Right,
35 Outer,
37 Semi,
39 Anti,
41}
42
43impl DataFrame {
44 pub fn join(
49 &self,
50 other: &DataFrame,
51 on: &[&str],
52 how: JoinType,
53 ) -> Result<DataFrame, DataFrameError> {
54 self.join_on(other, on, on, how)
55 }
56
57 pub fn join_on(
63 &self,
64 other: &DataFrame,
65 left_on: &[&str],
66 right_on: &[&str],
67 how: JoinType,
68 ) -> Result<DataFrame, DataFrameError> {
69 if left_on.len() != right_on.len() {
71 return Err(DataFrameError::JoinKeyMismatch {
72 left_count: left_on.len(),
73 right_count: right_on.len(),
74 });
75 }
76 if left_on.is_empty() {
77 return Err(DataFrameError::Other(
78 "join requires at least one key column".to_string(),
79 ));
80 }
81
82 for name in left_on {
84 self.column(name)?;
85 }
86 for name in right_on {
87 other.column(name)?;
88 }
89
90 #[allow(
92 clippy::disallowed_types,
93 reason = "HashMap for O(1) hash-join index; see module-level allow"
94 )]
95 let mut right_index: HashMap<Vec<String>, Vec<usize>> = HashMap::new();
96 for row_idx in 0..other.height() {
97 let key = extract_key(other, right_on, row_idx);
98 if key.iter().any(|k| k == NULL_SENTINEL) {
100 continue;
101 }
102 right_index.entry(key).or_default().push(row_idx);
103 }
104
105 match how {
107 JoinType::Semi => self.join_semi(left_on, &right_index),
108 JoinType::Anti => self.join_anti(left_on, &right_index),
109 JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Outer => {
110 self.join_matching(other, left_on, right_on, how, &right_index)
111 }
112 }
113 }
114
115 #[allow(
117 clippy::disallowed_types,
118 reason = "HashMap parameter from join_on caller"
119 )]
120 #[allow(
121 clippy::too_many_arguments,
122 reason = "Internal method — 6 args needed for left/right context + index; not part of public API"
123 )]
124 fn join_matching(
125 &self,
126 other: &DataFrame,
127 left_on: &[&str],
128 right_on: &[&str],
129 how: JoinType,
130 right_index: &HashMap<Vec<String>, Vec<usize>>,
131 ) -> Result<DataFrame, DataFrameError> {
132 let mut left_indices: Vec<Option<usize>> = Vec::new();
133 let mut right_indices: Vec<Option<usize>> = Vec::new();
134
135 let mut right_matched = vec![false; other.height()];
137
138 for left_row in 0..self.height() {
140 let key = extract_key(self, left_on, left_row);
141
142 if key.iter().any(|k| k == NULL_SENTINEL) {
144 match how {
145 JoinType::Left | JoinType::Outer => {
146 left_indices.push(Some(left_row));
147 right_indices.push(None);
148 }
149 JoinType::Inner | JoinType::Right | JoinType::Semi | JoinType::Anti => {}
150 }
151 continue;
152 }
153
154 match right_index.get(&key) {
155 Some(matches) => {
156 for &right_row in matches {
157 left_indices.push(Some(left_row));
158 right_indices.push(Some(right_row));
159 #[allow(
161 clippy::indexing_slicing,
162 reason = "right_row is from right_index which was built from 0..other.height()"
163 )]
164 {
165 right_matched[right_row] = true;
166 }
167 }
168 }
169 None => match how {
170 JoinType::Left | JoinType::Outer => {
171 left_indices.push(Some(left_row));
172 right_indices.push(None);
173 }
174 JoinType::Inner | JoinType::Right | JoinType::Semi | JoinType::Anti => {}
175 },
176 }
177 }
178
179 if matches!(how, JoinType::Right | JoinType::Outer) {
181 for (right_row, matched) in right_matched.iter().enumerate() {
182 if !matched {
183 left_indices.push(None);
184 right_indices.push(Some(right_row));
185 }
186 }
187 }
188
189 assemble_columns(
191 self,
192 other,
193 left_on,
194 right_on,
195 &left_indices,
196 &right_indices,
197 )
198 }
199
200 #[allow(
202 clippy::disallowed_types,
203 reason = "HashMap parameter from join_on caller"
204 )]
205 fn join_semi(
206 &self,
207 left_on: &[&str],
208 right_index: &HashMap<Vec<String>, Vec<usize>>,
209 ) -> Result<DataFrame, DataFrameError> {
210 let mut keep_indices: Vec<usize> = Vec::new();
211
212 for left_row in 0..self.height() {
213 let key = extract_key(self, left_on, left_row);
214 if key.iter().any(|k| k == NULL_SENTINEL) {
215 continue;
216 }
217 if right_index.contains_key(&key) {
218 keep_indices.push(left_row);
219 }
220 }
221
222 Ok(DataFrame::from_columns_unchecked(
223 self.columns()
224 .iter()
225 .map(|c| c.take(&keep_indices))
226 .collect(),
227 ))
228 }
229
230 #[allow(
232 clippy::disallowed_types,
233 reason = "HashMap parameter from join_on caller"
234 )]
235 fn join_anti(
236 &self,
237 left_on: &[&str],
238 right_index: &HashMap<Vec<String>, Vec<usize>>,
239 ) -> Result<DataFrame, DataFrameError> {
240 let mut keep_indices: Vec<usize> = Vec::new();
241
242 for left_row in 0..self.height() {
243 let key = extract_key(self, left_on, left_row);
244 if key.iter().any(|k| k == NULL_SENTINEL) {
246 keep_indices.push(left_row);
247 continue;
248 }
249 if !right_index.contains_key(&key) {
250 keep_indices.push(left_row);
251 }
252 }
253
254 Ok(DataFrame::from_columns_unchecked(
255 self.columns()
256 .iter()
257 .map(|c| c.take(&keep_indices))
258 .collect(),
259 ))
260 }
261}
262
263fn extract_key(df: &DataFrame, key_cols: &[&str], row_idx: usize) -> Vec<String> {
266 key_cols
267 .iter()
268 .map(|name| {
269 df.column(name)
270 .ok()
271 .and_then(|col| col.get(row_idx))
272 .map_or_else(
273 || NULL_SENTINEL.to_string(),
274 |s| {
275 if s.is_null() {
276 NULL_SENTINEL.to_string()
277 } else {
278 s.to_string()
279 }
280 },
281 )
282 })
283 .collect()
284}
285
286#[allow(
290 clippy::too_many_arguments,
291 reason = "Assembling columns requires both table refs, both key slices, and both index vecs — cannot reduce without wrapper struct"
292)]
293fn assemble_columns(
294 left: &DataFrame,
295 right: &DataFrame,
296 left_on: &[&str],
297 right_on: &[&str],
298 left_indices: &[Option<usize>],
299 right_indices: &[Option<usize>],
300) -> Result<DataFrame, DataFrameError> {
301 let mut result_columns: Vec<Column> = Vec::new();
302 let left_names: Vec<&str> = left.column_names();
303 let right_names: Vec<&str> = right.column_names();
304
305 let left_key_set: Vec<&str> = left_on.to_vec();
307 let right_key_set: Vec<&str> = right_on.to_vec();
308
309 let left_non_keys: Vec<&str> = left_names
311 .iter()
312 .filter(|n| !left_key_set.contains(n))
313 .copied()
314 .collect();
315 let right_non_keys: Vec<&str> = right_names
316 .iter()
317 .filter(|n| !right_key_set.contains(n))
318 .copied()
319 .collect();
320
321 let collisions: Vec<&str> = left_non_keys
322 .iter()
323 .filter(|n| right_non_keys.contains(n))
324 .copied()
325 .collect();
326
327 for (i, &left_key_name) in left_on.iter().enumerate() {
329 let left_col = left.column(left_key_name)?;
330 #[allow(
332 clippy::indexing_slicing,
333 reason = "i < left_on.len() == right_on.len() by JoinKeyMismatch validation in join_on"
334 )]
335 let right_key_name = right_on[i];
336 let right_col = right.column(right_key_name)?;
337
338 let merged = merge_key_column(left_col, right_col, left_indices, right_indices);
339 result_columns.push(merged.rename(left_key_name));
340 }
341
342 for &name in &left_non_keys {
344 let col = left.column(name)?;
345 let taken = col.take_optional(left_indices);
346 if collisions.contains(&name) {
347 result_columns.push(taken.rename(format!("{name}_left")));
348 } else {
349 result_columns.push(taken);
350 }
351 }
352
353 for &name in &right_non_keys {
355 let col = right.column(name)?;
356 let taken = col.take_optional(right_indices);
357 if collisions.contains(&name) {
358 result_columns.push(taken.rename(format!("{name}_right")));
359 } else {
360 result_columns.push(taken);
361 }
362 }
363
364 Ok(DataFrame::from_columns_unchecked(result_columns))
365}
366
367fn merge_key_column(
369 left_col: &Column,
370 right_col: &Column,
371 left_indices: &[Option<usize>],
372 right_indices: &[Option<usize>],
373) -> Column {
374 let data = left_col.data();
377
378 let len = left_indices.len();
380 match data {
381 crate::column::ColumnData::Bool(_) => {
382 let vals: Vec<Option<bool>> = (0..len)
383 .map(|i| {
384 #[allow(
386 clippy::indexing_slicing,
387 reason = "i iterates 0..len where len = left_indices.len()"
388 )]
389 match (left_indices[i], right_indices[i]) {
390 (Some(li), _) => left_col.get(li).and_then(|s| s.as_bool()),
391 (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_bool()),
392 (None, None) => None,
393 }
394 })
395 .collect();
396 Column::new_bool(left_col.name(), vals)
397 }
398 crate::column::ColumnData::Int64(_) => {
399 let vals: Vec<Option<i64>> = (0..len)
400 .map(|i| {
401 #[allow(
402 clippy::indexing_slicing,
403 reason = "i iterates 0..len where len = left_indices.len()"
404 )]
405 match (left_indices[i], right_indices[i]) {
406 (Some(li), _) => left_col.get(li).and_then(|s| s.as_i64()),
407 (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_i64()),
408 (None, None) => None,
409 }
410 })
411 .collect();
412 Column::new_i64(left_col.name(), vals)
413 }
414 crate::column::ColumnData::UInt64(_) => {
415 let vals: Vec<Option<u64>> = (0..len)
416 .map(|i| {
417 #[allow(
418 clippy::indexing_slicing,
419 reason = "i iterates 0..len where len = left_indices.len()"
420 )]
421 match (left_indices[i], right_indices[i]) {
422 (Some(li), _) => left_col.get(li).and_then(|s| s.as_u64()),
423 (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_u64()),
424 (None, None) => None,
425 }
426 })
427 .collect();
428 Column::new_u64(left_col.name(), vals)
429 }
430 crate::column::ColumnData::Float64(_) => {
431 let vals: Vec<Option<f64>> = (0..len)
432 .map(|i| {
433 #[allow(
434 clippy::indexing_slicing,
435 reason = "i iterates 0..len where len = left_indices.len()"
436 )]
437 match (left_indices[i], right_indices[i]) {
438 (Some(li), _) => left_col.get(li).and_then(|s| s.as_f64()),
439 (None, Some(ri)) => right_col.get(ri).and_then(|s| s.as_f64()),
440 (None, None) => None,
441 }
442 })
443 .collect();
444 Column::new_f64(left_col.name(), vals)
445 }
446 crate::column::ColumnData::String(_) => {
447 let vals: Vec<Option<String>> = (0..len)
448 .map(|i| {
449 #[allow(
450 clippy::indexing_slicing,
451 reason = "i iterates 0..len where len = left_indices.len()"
452 )]
453 match (left_indices[i], right_indices[i]) {
454 (Some(li), _) => left_col
455 .get(li)
456 .and_then(|s| s.as_str().map(|s| s.to_string())),
457 (None, Some(ri)) => right_col
458 .get(ri)
459 .and_then(|s| s.as_str().map(|s| s.to_string())),
460 (None, None) => None,
461 }
462 })
463 .collect();
464 Column::new_string(left_col.name(), vals)
465 }
466 }
467}
468
469#[cfg(test)]
474mod tests {
475 use super::*;
476
477 fn drugs() -> DataFrame {
478 DataFrame::new(vec![
479 Column::from_strs("drug_id", &["D1", "D2", "D3", "D4"]),
480 Column::from_strs(
481 "drug_name",
482 &["aspirin", "metformin", "ibuprofen", "lisinopril"],
483 ),
484 ])
485 .unwrap_or_else(|_| unreachable!())
486 }
487
488 fn events() -> DataFrame {
489 DataFrame::new(vec![
490 Column::from_strs("drug_id", &["D1", "D1", "D2", "D5"]),
491 Column::from_strs("event", &["headache", "nausea", "rash", "dizziness"]),
492 Column::from_i64s("count", vec![10, 5, 3, 7]),
493 ])
494 .unwrap_or_else(|_| unreachable!())
495 }
496
497 #[test]
502 fn inner_join_basic() {
503 let result = drugs()
504 .join(&events(), &["drug_id"], JoinType::Inner)
505 .unwrap_or_else(|_| unreachable!());
506 assert_eq!(result.height(), 3);
508 assert_eq!(result.width(), 4);
510 }
511
512 #[test]
513 fn inner_join_no_matches() {
514 let left = DataFrame::new(vec![
515 Column::from_strs("k", &["a", "b"]),
516 Column::from_i64s("v", vec![1, 2]),
517 ])
518 .unwrap_or_else(|_| unreachable!());
519 let right = DataFrame::new(vec![
520 Column::from_strs("k", &["c", "d"]),
521 Column::from_i64s("w", vec![3, 4]),
522 ])
523 .unwrap_or_else(|_| unreachable!());
524
525 let result = left
526 .join(&right, &["k"], JoinType::Inner)
527 .unwrap_or_else(|_| unreachable!());
528 assert_eq!(result.height(), 0);
529 }
530
531 #[test]
536 fn left_join_basic() {
537 let result = drugs()
538 .join(&events(), &["drug_id"], JoinType::Left)
539 .unwrap_or_else(|_| unreachable!());
540 assert_eq!(result.height(), 5);
542 assert_eq!(result.width(), 4);
543
544 let mut found_d3 = false;
546 for i in 0..result.height() {
547 let id = result
548 .column("drug_id")
549 .unwrap_or_else(|_| unreachable!())
550 .get(i);
551 if id == Some(Scalar::String("D3".into())) {
552 found_d3 = true;
553 assert_eq!(
554 result
555 .column("event")
556 .unwrap_or_else(|_| unreachable!())
557 .get(i),
558 Some(Scalar::Null)
559 );
560 }
561 }
562 assert!(found_d3);
563 }
564
565 #[test]
570 fn right_join_basic() {
571 let result = drugs()
572 .join(&events(), &["drug_id"], JoinType::Right)
573 .unwrap_or_else(|_| unreachable!());
574 assert_eq!(result.height(), 4);
576
577 let mut found_d5 = false;
579 for i in 0..result.height() {
580 let id = result
581 .column("drug_id")
582 .unwrap_or_else(|_| unreachable!())
583 .get(i);
584 if id == Some(Scalar::String("D5".into())) {
585 found_d5 = true;
586 assert_eq!(
587 result
588 .column("drug_name")
589 .unwrap_or_else(|_| unreachable!())
590 .get(i),
591 Some(Scalar::Null)
592 );
593 }
594 }
595 assert!(found_d5);
596 }
597
598 #[test]
603 fn outer_join_basic() {
604 let result = drugs()
605 .join(&events(), &["drug_id"], JoinType::Outer)
606 .unwrap_or_else(|_| unreachable!());
607 assert_eq!(result.height(), 6);
609 assert_eq!(result.width(), 4);
610 }
611
612 #[test]
617 fn semi_join_basic() {
618 let result = drugs()
619 .join(&events(), &["drug_id"], JoinType::Semi)
620 .unwrap_or_else(|_| unreachable!());
621 assert_eq!(result.height(), 2);
623 assert_eq!(result.width(), 2);
625 }
626
627 #[test]
632 fn anti_join_basic() {
633 let result = drugs()
634 .join(&events(), &["drug_id"], JoinType::Anti)
635 .unwrap_or_else(|_| unreachable!());
636 assert_eq!(result.height(), 2);
638 assert_eq!(result.width(), 2);
639 }
640
641 #[test]
646 fn multi_key_join() {
647 let left = DataFrame::new(vec![
648 Column::from_strs("drug", &["asp", "asp", "met"]),
649 Column::from_strs("event", &["ha", "na", "ha"]),
650 Column::from_i64s("left_val", vec![1, 2, 3]),
651 ])
652 .unwrap_or_else(|_| unreachable!());
653
654 let right = DataFrame::new(vec![
655 Column::from_strs("drug", &["asp", "met", "met"]),
656 Column::from_strs("event", &["ha", "ha", "na"]),
657 Column::from_i64s("right_val", vec![10, 20, 30]),
658 ])
659 .unwrap_or_else(|_| unreachable!());
660
661 let result = left
662 .join(&right, &["drug", "event"], JoinType::Inner)
663 .unwrap_or_else(|_| unreachable!());
664 assert_eq!(result.height(), 2);
666 assert_eq!(result.width(), 4); }
668
669 #[test]
674 fn join_on_different_key_names() {
675 let left = DataFrame::new(vec![
676 Column::from_strs("id", &["a", "b", "c"]),
677 Column::from_i64s("val", vec![1, 2, 3]),
678 ])
679 .unwrap_or_else(|_| unreachable!());
680
681 let right = DataFrame::new(vec![
682 Column::from_strs("key", &["b", "c", "d"]),
683 Column::from_i64s("score", vec![10, 20, 30]),
684 ])
685 .unwrap_or_else(|_| unreachable!());
686
687 let result = left
688 .join_on(&right, &["id"], &["key"], JoinType::Inner)
689 .unwrap_or_else(|_| unreachable!());
690 assert_eq!(result.height(), 2); assert_eq!(result.width(), 3); assert!(result.column("id").is_ok());
694 }
695
696 #[test]
701 fn null_keys_never_match() {
702 let left = DataFrame::new(vec![
703 Column::new_string("k", vec![Some("a".into()), None, Some("c".into())]),
704 Column::from_i64s("v", vec![1, 2, 3]),
705 ])
706 .unwrap_or_else(|_| unreachable!());
707
708 let right = DataFrame::new(vec![
709 Column::new_string("k", vec![Some("a".into()), None]),
710 Column::from_i64s("w", vec![10, 20]),
711 ])
712 .unwrap_or_else(|_| unreachable!());
713
714 let inner = left
716 .join(&right, &["k"], JoinType::Inner)
717 .unwrap_or_else(|_| unreachable!());
718 assert_eq!(inner.height(), 1);
719
720 let lj = left
722 .join(&right, &["k"], JoinType::Left)
723 .unwrap_or_else(|_| unreachable!());
724 assert_eq!(lj.height(), 3); }
726
727 #[test]
728 fn null_keys_kept_in_anti_join() {
729 let left = DataFrame::new(vec![
730 Column::new_string("k", vec![Some("a".into()), None, Some("c".into())]),
731 Column::from_i64s("v", vec![1, 2, 3]),
732 ])
733 .unwrap_or_else(|_| unreachable!());
734
735 let right = DataFrame::new(vec![
736 Column::from_strs("k", &["a"]),
737 Column::from_i64s("w", vec![10]),
738 ])
739 .unwrap_or_else(|_| unreachable!());
740
741 let result = left
742 .join(&right, &["k"], JoinType::Anti)
743 .unwrap_or_else(|_| unreachable!());
744 assert_eq!(result.height(), 2);
746 }
747
748 #[test]
753 fn name_collision_suffixes() {
754 let left = DataFrame::new(vec![
755 Column::from_strs("k", &["a", "b"]),
756 Column::from_i64s("value", vec![1, 2]),
757 ])
758 .unwrap_or_else(|_| unreachable!());
759
760 let right = DataFrame::new(vec![
761 Column::from_strs("k", &["a", "b"]),
762 Column::from_i64s("value", vec![10, 20]),
763 ])
764 .unwrap_or_else(|_| unreachable!());
765
766 let result = left
767 .join(&right, &["k"], JoinType::Inner)
768 .unwrap_or_else(|_| unreachable!());
769 assert_eq!(result.height(), 2);
770 assert!(result.column("value_left").is_ok());
771 assert!(result.column("value_right").is_ok());
772 }
773
774 #[test]
779 fn join_empty_left() {
780 let left = DataFrame::new(vec![
781 Column::from_strs("k", &[]),
782 Column::from_i64s("v", vec![]),
783 ])
784 .unwrap_or_else(|_| unreachable!());
785
786 let right = DataFrame::new(vec![
787 Column::from_strs("k", &["a"]),
788 Column::from_i64s("w", vec![1]),
789 ])
790 .unwrap_or_else(|_| unreachable!());
791
792 let result = left
793 .join(&right, &["k"], JoinType::Inner)
794 .unwrap_or_else(|_| unreachable!());
795 assert_eq!(result.height(), 0);
796
797 let result = left
798 .join(&right, &["k"], JoinType::Left)
799 .unwrap_or_else(|_| unreachable!());
800 assert_eq!(result.height(), 0);
801 }
802
803 #[test]
804 fn join_empty_right() {
805 let left = DataFrame::new(vec![
806 Column::from_strs("k", &["a"]),
807 Column::from_i64s("v", vec![1]),
808 ])
809 .unwrap_or_else(|_| unreachable!());
810
811 let right = DataFrame::new(vec![
812 Column::from_strs("k", &[]),
813 Column::from_i64s("w", vec![]),
814 ])
815 .unwrap_or_else(|_| unreachable!());
816
817 let inner = left
818 .join(&right, &["k"], JoinType::Inner)
819 .unwrap_or_else(|_| unreachable!());
820 assert_eq!(inner.height(), 0);
821
822 let lj = left
823 .join(&right, &["k"], JoinType::Left)
824 .unwrap_or_else(|_| unreachable!());
825 assert_eq!(lj.height(), 1); }
827
828 #[test]
833 fn error_key_mismatch() {
834 let left = DataFrame::new(vec![
835 Column::from_strs("a", &["x"]),
836 Column::from_strs("b", &["y"]),
837 ])
838 .unwrap_or_else(|_| unreachable!());
839
840 let right =
841 DataFrame::new(vec![Column::from_strs("c", &["x"])]).unwrap_or_else(|_| unreachable!());
842
843 let err = left.join_on(&right, &["a", "b"], &["c"], JoinType::Inner);
844 assert!(err.is_err());
845 }
846
847 #[test]
848 fn error_empty_keys() {
849 let left =
850 DataFrame::new(vec![Column::from_strs("a", &["x"])]).unwrap_or_else(|_| unreachable!());
851 let right =
852 DataFrame::new(vec![Column::from_strs("a", &["x"])]).unwrap_or_else(|_| unreachable!());
853
854 let err = left.join(&right, &[], JoinType::Inner);
855 assert!(err.is_err());
856 }
857
858 #[test]
859 fn error_missing_column() {
860 let left =
861 DataFrame::new(vec![Column::from_strs("a", &["x"])]).unwrap_or_else(|_| unreachable!());
862 let right =
863 DataFrame::new(vec![Column::from_strs("b", &["x"])]).unwrap_or_else(|_| unreachable!());
864
865 let err = left.join(&right, &["a"], JoinType::Inner);
866 assert!(err.is_err()); }
868
869 #[test]
874 fn type_preservation_numeric_keys() {
875 let left = DataFrame::new(vec![
876 Column::from_i64s("id", vec![1, 2, 3]),
877 Column::from_strs("name", &["a", "b", "c"]),
878 ])
879 .unwrap_or_else(|_| unreachable!());
880
881 let right = DataFrame::new(vec![
882 Column::from_i64s("id", vec![2, 3, 4]),
883 Column::from_f64s("score", vec![9.5, 8.0, 7.5]),
884 ])
885 .unwrap_or_else(|_| unreachable!());
886
887 let result = left
888 .join(&right, &["id"], JoinType::Inner)
889 .unwrap_or_else(|_| unreachable!());
890 assert_eq!(result.height(), 2); let id_col = result.column("id").unwrap_or_else(|_| unreachable!());
894 assert_eq!(id_col.dtype(), crate::column::DataType::Int64);
895
896 let score_col = result.column("score").unwrap_or_else(|_| unreachable!());
898 assert_eq!(score_col.dtype(), crate::column::DataType::Float64);
899 }
900
901 #[test]
902 fn many_to_many_join() {
903 let left = DataFrame::new(vec![
904 Column::from_strs("k", &["a", "a", "b"]),
905 Column::from_i64s("lv", vec![1, 2, 3]),
906 ])
907 .unwrap_or_else(|_| unreachable!());
908
909 let right = DataFrame::new(vec![
910 Column::from_strs("k", &["a", "a"]),
911 Column::from_i64s("rv", vec![10, 20]),
912 ])
913 .unwrap_or_else(|_| unreachable!());
914
915 let result = left
916 .join(&right, &["k"], JoinType::Inner)
917 .unwrap_or_else(|_| unreachable!());
918 assert_eq!(result.height(), 4);
920 }
921
922 #[test]
923 fn semi_join_deduplicates() {
924 let left = DataFrame::new(vec![
926 Column::from_strs("k", &["a", "b"]),
927 Column::from_i64s("v", vec![1, 2]),
928 ])
929 .unwrap_or_else(|_| unreachable!());
930
931 let right = DataFrame::new(vec![
932 Column::from_strs("k", &["a", "a", "a"]),
933 Column::from_i64s("w", vec![10, 20, 30]),
934 ])
935 .unwrap_or_else(|_| unreachable!());
936
937 let result = left
938 .join(&right, &["k"], JoinType::Semi)
939 .unwrap_or_else(|_| unreachable!());
940 assert_eq!(result.height(), 1);
942 }
943
944 #[test]
945 fn outer_join_preserves_all_keys() {
946 let left = DataFrame::new(vec![
947 Column::from_strs("k", &["a", "b"]),
948 Column::from_i64s("v", vec![1, 2]),
949 ])
950 .unwrap_or_else(|_| unreachable!());
951
952 let right = DataFrame::new(vec![
953 Column::from_strs("k", &["b", "c"]),
954 Column::from_i64s("w", vec![20, 30]),
955 ])
956 .unwrap_or_else(|_| unreachable!());
957
958 let result = left
959 .join(&right, &["k"], JoinType::Outer)
960 .unwrap_or_else(|_| unreachable!());
961 assert_eq!(result.height(), 3);
963
964 let keys: Vec<Scalar> = (0..result.height())
966 .filter_map(|i| result.column("k").ok().and_then(|c| c.get(i)))
967 .collect();
968 assert_eq!(keys.len(), 3);
969 }
970
971 #[test]
972 fn right_join_symmetric_to_left() {
973 let a = DataFrame::new(vec![
975 Column::from_strs("k", &["x", "y"]),
976 Column::from_i64s("a_val", vec![1, 2]),
977 ])
978 .unwrap_or_else(|_| unreachable!());
979
980 let b = DataFrame::new(vec![
981 Column::from_strs("k", &["y", "z"]),
982 Column::from_i64s("b_val", vec![10, 20]),
983 ])
984 .unwrap_or_else(|_| unreachable!());
985
986 let right_result = a
987 .join(&b, &["k"], JoinType::Right)
988 .unwrap_or_else(|_| unreachable!());
989 let left_result = b
990 .join(&a, &["k"], JoinType::Left)
991 .unwrap_or_else(|_| unreachable!());
992
993 assert_eq!(right_result.height(), left_result.height());
994 }
995}