1use std::collections::HashMap;
16
17use polars::prelude::*;
18
19use crate::error::{Error, Result};
20use crate::Value;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum JoinType {
25 Inner,
27 Left,
29 Right,
31 Outer,
33 Cross,
35 Semi,
37 Anti,
39}
40
41impl JoinType {
42 pub fn to_polars(&self) -> Result<polars::prelude::JoinType> {
44 match self {
45 JoinType::Inner => Ok(polars::prelude::JoinType::Inner),
46 JoinType::Left => Ok(polars::prelude::JoinType::Left),
47 JoinType::Right => Err(Error::operation(
48 "Right join not supported in this Polars version, so cannot convert to Polars",
49 )),
50 JoinType::Outer => Ok(polars::prelude::JoinType::Full),
51 JoinType::Cross => Ok(polars::prelude::JoinType::Cross),
52 JoinType::Semi => Err(Error::operation(
53 "Semi join not supported in this Polars version, so cannot convert to Polars",
54 )),
55 JoinType::Anti => Err(Error::operation(
56 "Anti join not supported in this Polars version, so cannot convert to Polars",
57 )),
58 }
59 }
60
61 #[must_use]
63 pub fn as_str(&self) -> &'static str {
64 match self {
65 JoinType::Inner => "inner",
66 JoinType::Left => "left",
67 JoinType::Right => "right",
68 JoinType::Outer => "outer",
69 JoinType::Cross => "cross",
70 JoinType::Semi => "semi",
71 JoinType::Anti => "anti",
72 }
73 }
74
75 #[allow(clippy::should_implement_trait)]
77 pub fn from_str(s: &str) -> Result<Self> {
78 match s.to_lowercase().as_str() {
79 "inner" => Ok(JoinType::Inner),
80 "left" | "left_outer" => Ok(JoinType::Left),
81 "right" | "right_outer" => Ok(JoinType::Right),
82 "outer" | "full" | "full_outer" => Ok(JoinType::Outer),
83 "cross" => Ok(JoinType::Cross),
84 "semi" => Ok(JoinType::Semi),
85 "anti" => Ok(JoinType::Anti),
86 _ => Err(Error::operation(format!("Unknown join type: {s}"))),
87 }
88 }
89}
90
91#[derive(Debug, Clone)]
93pub struct JoinOptions {
94 pub join_type: JoinType,
96 pub suffix: String,
98 pub validate: JoinValidation,
100 pub sort: bool,
102 pub coalesce: polars::prelude::JoinCoalesce,
104}
105
106impl Default for JoinOptions {
107 fn default() -> Self {
108 Self {
109 join_type: JoinType::Inner,
110 suffix: "_right".to_string(),
111 validate: JoinValidation::None,
112 sort: false,
113 coalesce: polars::prelude::JoinCoalesce::JoinSpecific,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum JoinValidation {
121 None,
123 OneToMany,
125 ManyToOne,
127 OneToOne,
129}
130
131impl JoinValidation {
132 #[must_use]
134 pub fn to_polars(&self) -> polars::prelude::JoinValidation {
135 match self {
136 JoinValidation::OneToMany => polars::prelude::JoinValidation::OneToMany,
138 JoinValidation::None | JoinValidation::ManyToOne => {
139 polars::prelude::JoinValidation::ManyToOne
140 }
141 JoinValidation::OneToOne => polars::prelude::JoinValidation::OneToOne,
142 }
143 }
144}
145
146#[derive(Debug, Clone)]
148pub enum JoinKeys {
149 On(Vec<String>),
151 LeftRight {
153 left: Vec<String>,
155 right: Vec<String>,
157 },
158}
159
160impl JoinKeys {
161 #[must_use]
163 pub fn on(columns: Vec<String>) -> Self {
164 JoinKeys::On(columns)
165 }
166
167 #[must_use]
169 pub fn left_right(left: Vec<String>, right: Vec<String>) -> Self {
170 JoinKeys::LeftRight { left, right }
171 }
172
173 #[must_use]
175 pub fn left_columns(&self) -> &[String] {
176 match self {
177 JoinKeys::On(cols) => cols,
178 JoinKeys::LeftRight { left, .. } => left,
179 }
180 }
181
182 #[must_use]
184 pub fn right_columns(&self) -> &[String] {
185 match self {
186 JoinKeys::On(cols) => cols,
187 JoinKeys::LeftRight { right, .. } => right,
188 }
189 }
190}
191
192pub fn join(left: &Value, right: &Value, keys: &JoinKeys, options: &JoinOptions) -> Result<Value> {
208 match (left, right) {
209 (Value::DataFrame(left_df), Value::DataFrame(right_df)) => {
210 join_dataframes(left_df, right_df, keys, options)
211 }
212 (Value::LazyFrame(left_lf), Value::LazyFrame(right_lf)) => {
213 join_lazy_frames(left_lf, right_lf, keys, options)
214 }
215 (Value::DataFrame(left_df), Value::LazyFrame(right_lf)) => {
216 let right_df = right_lf.clone().collect().map_err(Error::from)?;
217 join_dataframes(left_df, &right_df, keys, options)
218 }
219 (Value::LazyFrame(left_lf), Value::DataFrame(right_df)) => {
220 let left_df = left_lf.clone().collect().map_err(Error::from)?;
221 join_dataframes(&left_df, right_df, keys, options)
222 }
223 (Value::Array(left_arr), Value::Array(right_arr)) => {
224 join_arrays(left_arr, right_arr, keys, options)
225 }
226 (left_val, right_val) => {
227 let left_df = left_val.to_dataframe()?;
229 let right_df = right_val.to_dataframe()?;
230 join_dataframes(&left_df, &right_df, keys, options)
231 }
232 }
233}
234
235fn join_dataframes(
237 left_df: &DataFrame,
238 right_df: &DataFrame,
239 keys: &JoinKeys,
240 options: &JoinOptions,
241) -> Result<Value> {
242 let left_on: Vec<Expr> = keys.left_columns().iter().map(col).collect();
243 let right_on: Vec<Expr> = keys.right_columns().iter().map(col).collect();
244
245 let join_args = JoinArgs::new(options.join_type.to_polars()?);
246
247 let join_builder =
248 left_df
249 .clone()
250 .lazy()
251 .join(right_df.clone().lazy(), left_on, right_on, join_args);
252
253 let result_df = join_builder.collect().map_err(Error::from)?;
254 Ok(Value::DataFrame(result_df))
255}
256
257fn join_lazy_frames(
259 left_lf: &LazyFrame,
260 right_lf: &LazyFrame,
261 keys: &JoinKeys,
262 options: &JoinOptions,
263) -> Result<Value> {
264 let left_on: Vec<Expr> = keys.left_columns().iter().map(col).collect();
265 let right_on: Vec<Expr> = keys.right_columns().iter().map(col).collect();
266
267 let join_args = JoinArgs::new(options.join_type.to_polars()?);
268
269 let mut join_builder = left_lf
270 .clone()
271 .join(right_lf.clone(), left_on, right_on, join_args);
272
273 if options.sort {
274 let sort_exprs: Vec<Expr> = keys.left_columns().iter().map(col).collect();
276 join_builder = join_builder.sort_by_exprs(sort_exprs, SortMultipleOptions::default());
277 }
278
279 Ok(Value::LazyFrame(Box::new(join_builder)))
280}
281
282fn join_arrays(
284 left_arr: &[Value],
285 right_arr: &[Value],
286 keys: &JoinKeys,
287 options: &JoinOptions,
288) -> Result<Value> {
289 let mut result = match options.join_type {
290 JoinType::Inner => inner_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
291 JoinType::Left => left_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
292 JoinType::Right => right_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
293 JoinType::Outer => outer_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
294 JoinType::Cross => cross_join_arrays(left_arr, right_arr, &options.suffix)?,
295 JoinType::Semi => semi_join_arrays(left_arr, right_arr, keys)?,
296 JoinType::Anti => anti_join_arrays(left_arr, right_arr, keys)?,
297 };
298
299 if options.sort {
300 if let Some(first_key) = keys.left_columns().first() {
302 result.sort_by(|a, b| {
303 let a_val = match a {
304 Value::Object(obj) => obj.get(first_key).unwrap_or(&Value::Null),
305 _ => &Value::Null,
306 };
307 let b_val = match b {
308 Value::Object(obj) => obj.get(first_key).unwrap_or(&Value::Null),
309 _ => &Value::Null,
310 };
311 compare_values_for_sorting(a_val, b_val)
312 });
313 }
314 }
315
316 Ok(Value::Array(result))
317}
318
319fn inner_join_arrays(
321 left_arr: &[Value],
322 right_arr: &[Value],
323 keys: &JoinKeys,
324 suffix: &str,
325) -> Result<Vec<Value>> {
326 let mut result = Vec::new();
327
328 for left_item in left_arr {
329 if let Value::Object(left_obj) = left_item {
330 for right_item in right_arr {
331 if let Value::Object(right_obj) = right_item {
332 if objects_match_on_keys(left_obj, right_obj, keys)? {
333 let joined = merge_objects(
334 left_obj,
335 right_obj,
336 suffix,
337 false,
338 &std::collections::HashSet::new(),
339 )?;
340 result.push(Value::Object(joined));
341 }
342 }
343 }
344 }
345 }
346
347 Ok(result)
348}
349
350fn left_join_arrays(
352 left_arr: &[Value],
353 right_arr: &[Value],
354 keys: &JoinKeys,
355 suffix: &str,
356) -> Result<Vec<Value>> {
357 let right_keys: std::collections::HashSet<String> = right_arr
358 .iter()
359 .filter_map(|v| {
360 if let Value::Object(o) = v {
361 Some(o.keys().cloned().collect::<Vec<_>>())
362 } else {
363 None
364 }
365 })
366 .flatten()
367 .collect();
368
369 let mut result = Vec::new();
370
371 for left_item in left_arr {
372 if let Value::Object(left_obj) = left_item {
373 let mut found_match = false;
374
375 for right_item in right_arr {
376 if let Value::Object(right_obj) = right_item {
377 if objects_match_on_keys(left_obj, right_obj, keys)? {
378 let joined = merge_objects(
379 left_obj,
380 right_obj,
381 suffix,
382 false,
383 &std::collections::HashSet::new(),
384 )?;
385 result.push(Value::Object(joined));
386 found_match = true;
387 }
388 }
389 }
390
391 if !found_match {
392 let joined = merge_objects(left_obj, &HashMap::new(), suffix, true, &right_keys)?;
394 result.push(Value::Object(joined));
395 }
396 }
397 }
398
399 Ok(result)
400}
401
402fn right_join_arrays(
404 left_arr: &[Value],
405 right_arr: &[Value],
406 keys: &JoinKeys,
407 suffix: &str,
408) -> Result<Vec<Value>> {
409 let left_keys: std::collections::HashSet<String> = left_arr
410 .iter()
411 .filter_map(|v| {
412 if let Value::Object(o) = v {
413 Some(o.keys().cloned().collect::<Vec<_>>())
414 } else {
415 None
416 }
417 })
418 .flatten()
419 .collect();
420
421 let mut result = Vec::new();
422
423 for right_item in right_arr {
424 if let Value::Object(right_obj) = right_item {
425 let mut found_match = false;
426
427 for left_item in left_arr {
428 if let Value::Object(left_obj) = left_item {
429 if objects_match_on_keys(left_obj, right_obj, keys)? {
430 let joined = merge_objects(
431 left_obj,
432 right_obj,
433 suffix,
434 false,
435 &std::collections::HashSet::new(),
436 )?;
437 result.push(Value::Object(joined));
438 found_match = true;
439 }
440 }
441 }
442
443 if !found_match {
444 let joined = merge_objects(&HashMap::new(), right_obj, suffix, true, &left_keys)?;
446 result.push(Value::Object(joined));
447 }
448 }
449 }
450
451 Ok(result)
452}
453
454fn outer_join_arrays(
456 left_arr: &[Value],
457 right_arr: &[Value],
458 keys: &JoinKeys,
459 suffix: &str,
460) -> Result<Vec<Value>> {
461 let left_keys: std::collections::HashSet<String> = left_arr
462 .iter()
463 .filter_map(|v| {
464 if let Value::Object(o) = v {
465 Some(o.keys().cloned().collect::<Vec<_>>())
466 } else {
467 None
468 }
469 })
470 .flatten()
471 .collect();
472 let right_keys: std::collections::HashSet<String> = right_arr
473 .iter()
474 .filter_map(|v| {
475 if let Value::Object(o) = v {
476 Some(o.keys().cloned().collect::<Vec<_>>())
477 } else {
478 None
479 }
480 })
481 .flatten()
482 .collect();
483
484 let mut result = Vec::new();
485 let mut right_matched = vec![false; right_arr.len()];
486
487 for left_item in left_arr {
489 if let Value::Object(left_obj) = left_item {
490 let mut found_match = false;
491
492 for (right_idx, right_item) in right_arr.iter().enumerate() {
493 if let Value::Object(right_obj) = right_item {
494 if objects_match_on_keys(left_obj, right_obj, keys)? {
495 let joined = merge_objects(
496 left_obj,
497 right_obj,
498 suffix,
499 false,
500 &std::collections::HashSet::new(),
501 )?;
502 result.push(Value::Object(joined));
503 right_matched[right_idx] = true;
504 found_match = true;
505 }
506 }
507 }
508
509 if !found_match {
510 let joined = merge_objects(left_obj, &HashMap::new(), suffix, true, &right_keys)?;
512 result.push(Value::Object(joined));
513 }
514 }
515 }
516
517 for (right_idx, right_item) in right_arr.iter().enumerate() {
519 if !right_matched[right_idx] {
520 if let Value::Object(right_obj) = right_item {
521 let joined = merge_objects(&HashMap::new(), right_obj, suffix, true, &left_keys)?;
522 result.push(Value::Object(joined));
523 }
524 }
525 }
526
527 Ok(result)
528}
529
530fn cross_join_arrays(left_arr: &[Value], right_arr: &[Value], suffix: &str) -> Result<Vec<Value>> {
532 let mut result = Vec::new();
533
534 for left_item in left_arr {
535 if let Value::Object(left_obj) = left_item {
536 for right_item in right_arr {
537 if let Value::Object(right_obj) = right_item {
538 let joined = merge_objects(
539 left_obj,
540 right_obj,
541 suffix,
542 false,
543 &std::collections::HashSet::new(),
544 )?;
545 result.push(Value::Object(joined));
546 }
547 }
548 }
549 }
550
551 Ok(result)
552}
553
554fn semi_join_arrays(
556 left_arr: &[Value],
557 right_arr: &[Value],
558 keys: &JoinKeys,
559) -> Result<Vec<Value>> {
560 let mut result = Vec::new();
561
562 for left_item in left_arr {
563 if let Value::Object(left_obj) = left_item {
564 for right_item in right_arr {
565 if let Value::Object(right_obj) = right_item {
566 if objects_match_on_keys(left_obj, right_obj, keys)? {
567 result.push(left_item.clone());
568 break; }
570 }
571 }
572 }
573 }
574
575 Ok(result)
576}
577
578fn anti_join_arrays(
580 left_arr: &[Value],
581 right_arr: &[Value],
582 keys: &JoinKeys,
583) -> Result<Vec<Value>> {
584 let mut result = Vec::new();
585
586 for left_item in left_arr {
587 if let Value::Object(left_obj) = left_item {
588 let mut found_match = false;
589
590 for right_item in right_arr {
591 if let Value::Object(right_obj) = right_item {
592 if objects_match_on_keys(left_obj, right_obj, keys)? {
593 found_match = true;
594 break;
595 }
596 }
597 }
598
599 if !found_match {
600 result.push(left_item.clone());
601 }
602 }
603 }
604
605 Ok(result)
606}
607
608fn objects_match_on_keys(
610 left_obj: &HashMap<String, Value>,
611 right_obj: &HashMap<String, Value>,
612 keys: &JoinKeys,
613) -> Result<bool> {
614 let left_keys = keys.left_columns();
615 let right_keys = keys.right_columns();
616
617 if left_keys.len() != right_keys.len() {
618 return Err(Error::operation(
619 "Left and right join keys must have the same length",
620 ));
621 }
622
623 for (left_key, right_key) in left_keys.iter().zip(right_keys.iter()) {
624 let left_val = left_obj.get(left_key).unwrap_or(&Value::Null);
625 let right_val = right_obj.get(right_key).unwrap_or(&Value::Null);
626
627 if !values_equal_for_join(left_val, right_val) {
628 return Ok(false);
629 }
630 }
631
632 Ok(true)
633}
634
635fn values_equal_for_join(left: &Value, right: &Value) -> bool {
637 match (left, right) {
638 (Value::Null, Value::Null) => true,
639 (Value::Bool(a), Value::Bool(b)) => a == b,
640 (Value::Int(a), Value::Int(b)) => a == b,
641 (Value::Float(a), Value::Float(b)) => (a - b).abs() < f64::EPSILON,
642 (Value::String(a), Value::String(b)) => a == b,
643 #[allow(clippy::cast_precision_loss)]
645 (Value::Int(a), Value::Float(b)) => (*a as f64 - b).abs() < f64::EPSILON,
646 #[allow(clippy::cast_precision_loss)]
647 (Value::Float(a), Value::Int(b)) => (a - *b as f64).abs() < f64::EPSILON,
648 _ => false,
649 }
650}
651
652#[allow(clippy::unnecessary_wraps)]
654fn merge_objects(
655 left_obj: &HashMap<String, Value>,
656 right_obj: &HashMap<String, Value>,
657 suffix: &str,
658 fill_nulls: bool,
659 null_keys: &std::collections::HashSet<String>,
660) -> Result<HashMap<String, Value>> {
661 let mut result = left_obj.clone();
662
663 for (right_key, right_val) in right_obj {
664 let key = if result.contains_key(right_key) {
665 format!("{right_key}{suffix}")
667 } else {
668 right_key.clone()
669 };
670 result.insert(key, right_val.clone());
671 }
672
673 if fill_nulls {
674 for key in null_keys {
675 if result.contains_key(key) {
676 let suffixed = format!("{key}{suffix}");
678 result.entry(suffixed).or_insert(Value::Null);
679 } else {
680 result.insert(key.clone(), Value::Null);
681 }
682 }
683 }
684
685 Ok(result)
686}
687
688fn compare_values_for_sorting(a: &Value, b: &Value) -> std::cmp::Ordering {
690 use std::cmp::Ordering;
691
692 match (a, b) {
693 (Value::Null, Value::Null) => Ordering::Equal,
694 (Value::Null, _) => Ordering::Less,
695 (_, Value::Null) => Ordering::Greater,
696
697 (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
698 (Value::Int(a), Value::Int(b)) => a.cmp(b),
699 (Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
700 (Value::String(a), Value::String(b)) => a.cmp(b),
701
702 #[allow(clippy::cast_precision_loss)]
704 (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal),
705 #[allow(clippy::cast_precision_loss)]
706 (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal),
707
708 _ => a.to_string().cmp(&b.to_string()),
710 }
711}
712
713pub fn inner_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
715 let options = JoinOptions {
716 join_type: JoinType::Inner,
717 ..Default::default()
718 };
719 join(left, right, keys, &options)
720}
721
722pub fn left_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
724 let options = JoinOptions {
725 join_type: JoinType::Left,
726 ..Default::default()
727 };
728 join(left, right, keys, &options)
729}
730
731pub fn right_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
733 let options = JoinOptions {
734 join_type: JoinType::Right,
735 ..Default::default()
736 };
737 join(left, right, keys, &options)
738}
739
740pub fn outer_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
742 let options = JoinOptions {
743 join_type: JoinType::Outer,
744 ..Default::default()
745 };
746 join(left, right, keys, &options)
747}
748
749pub fn join_multiple(
768 dataframes: &[Value],
769 keys: &JoinKeys,
770 options: &JoinOptions,
771) -> Result<Value> {
772 if dataframes.is_empty() {
773 return Err(Error::operation("No DataFrames provided for join"));
774 }
775
776 if dataframes.len() == 1 {
777 return Ok(dataframes[0].clone());
778 }
779
780 let mut result = dataframes[0].clone();
781
782 for (i, df) in dataframes.iter().enumerate().skip(1) {
783 let mut join_options = options.clone();
785 join_options.suffix = format!("_right_{i}");
786
787 result = join(&result, df, keys, &join_options)?;
788 }
789
790 Ok(result)
791}
792
793#[allow(clippy::used_underscore_binding)]
815pub fn join_with_condition(
816 left: &Value,
817 right: &Value,
818 condition: Expr,
819 _join_type: JoinType,
820) -> Result<Value> {
821 match (left, right) {
822 (Value::DataFrame(left_df), Value::DataFrame(right_df)) => {
823 let how = JoinType::Cross.to_polars()?;
827 let join_args = JoinArgs::new(how);
828
829 let cross_joined =
830 left_df
831 .clone()
832 .lazy()
833 .join(right_df.clone().lazy(), vec![], vec![], join_args);
834
835 let filtered = cross_joined.filter(condition);
836
837 let result_df = filtered.collect().map_err(Error::from)?;
838 Ok(Value::DataFrame(result_df))
839 }
840 (Value::LazyFrame(left_lf), Value::LazyFrame(right_lf)) => {
841 let how = JoinType::Cross.to_polars()?;
842 let join_args = JoinArgs::new(how);
843
844 let cross_joined = left_lf
845 .clone()
846 .join(*right_lf.clone(), vec![], vec![], join_args);
847
848 let filtered = cross_joined.filter(condition);
849 Ok(Value::LazyFrame(Box::new(filtered)))
850 }
851 _ => {
852 let left_df = left.to_dataframe()?;
854 let right_df = right.to_dataframe()?;
855 join_with_condition(
856 &Value::DataFrame(left_df),
857 &Value::DataFrame(right_df),
858 condition,
859 _join_type,
860 )
861 }
862 }
863}
864
865#[cfg(test)]
866mod tests {
867 use std::collections::HashMap;
868
869 use super::*;
870
871 fn create_left_dataframe() -> DataFrame {
872 let id = Column::new("id".into(), &[1, 2, 3, 4]);
873 let name = Column::new("name".into(), &["Alice", "Bob", "Charlie", "Dave"]);
874 let dept_id = Column::new("dept_id".into(), &[10, 20, 10, 30]);
875 DataFrame::new(vec![id, name, dept_id]).unwrap()
876 }
877
878 fn create_right_dataframe() -> DataFrame {
879 let id = Column::new("id".into(), &[10, 20, 40]);
880 let dept_name = Column::new("dept_name".into(), &["Engineering", "Sales", "Marketing"]);
881 let budget = Column::new("budget".into(), &[100000, 50000, 75000]);
882 DataFrame::new(vec![id, dept_name, budget]).unwrap()
883 }
884
885 #[test]
886 fn test_inner_join() {
887 let left_df = create_left_dataframe();
888 let right_df = create_right_dataframe();
889
890 let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
891
892 let result = inner_join(
893 &Value::DataFrame(left_df),
894 &Value::DataFrame(right_df),
895 &keys,
896 )
897 .unwrap();
898
899 match result {
900 Value::DataFrame(df) => {
901 assert_eq!(df.shape().0, 3); assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
903 assert!(df
904 .get_column_names()
905 .contains(&&PlSmallStr::from("dept_name")));
906 }
907 _ => panic!("Expected DataFrame"),
908 }
909 }
910
911 #[test]
912 fn test_left_join() {
913 let left_df = create_left_dataframe();
914 let right_df = create_right_dataframe();
915
916 let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
917
918 let result = left_join(
919 &Value::DataFrame(left_df),
920 &Value::DataFrame(right_df),
921 &keys,
922 )
923 .unwrap();
924
925 match result {
926 Value::DataFrame(df) => {
927 assert_eq!(df.height(), 4); assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
929 assert!(df
930 .get_column_names()
931 .contains(&&PlSmallStr::from("dept_name")));
932 }
933 _ => panic!("Expected DataFrame"),
934 }
935 }
936
937 #[test]
938 #[ignore = "Right join not supported in this Polars version"]
939 fn test_right_join() {
940 let left_df = create_left_dataframe();
941 let right_df = create_right_dataframe();
942
943 let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
944
945 let result = right_join(
946 &Value::DataFrame(left_df),
947 &Value::DataFrame(right_df),
948 &keys,
949 );
950
951 assert!(result.is_err());
952 assert!(result
953 .unwrap_err()
954 .to_string()
955 .contains("Right join not supported"));
956 }
957
958 #[test]
959 fn test_array_join() {
960 let left_array = Value::Array(vec![
961 Value::Object(HashMap::from([
962 ("id".to_string(), Value::Int(1)),
963 ("name".to_string(), Value::String("Alice".to_string())),
964 ])),
965 Value::Object(HashMap::from([
966 ("id".to_string(), Value::Int(2)),
967 ("name".to_string(), Value::String("Bob".to_string())),
968 ])),
969 ]);
970
971 let right_array = Value::Array(vec![
972 Value::Object(HashMap::from([
973 ("id".to_string(), Value::Int(1)),
974 ("age".to_string(), Value::Int(30)),
975 ])),
976 Value::Object(HashMap::from([
977 ("id".to_string(), Value::Int(3)),
978 ("age".to_string(), Value::Int(25)),
979 ])),
980 ]);
981
982 let keys = JoinKeys::on(vec!["id".to_string()]);
983 let result = inner_join(&left_array, &right_array, &keys).unwrap();
984
985 match result {
986 Value::Array(arr) => {
987 assert_eq!(arr.len(), 1); if let Value::Object(obj) = &arr[0] {
989 assert_eq!(obj.get("name"), Some(&Value::String("Alice".to_string())));
990 assert_eq!(obj.get("age"), Some(&Value::Int(30)));
991 }
992 }
993 _ => panic!("Expected Array"),
994 }
995 }
996
997 #[test]
998 fn test_join_types() {
999 assert_eq!(JoinType::from_str("inner").unwrap(), JoinType::Inner);
1000 assert_eq!(JoinType::from_str("left_outer").unwrap(), JoinType::Left);
1001 assert_eq!(JoinType::from_str("full").unwrap(), JoinType::Outer);
1002 assert_eq!(JoinType::from_str("cross").unwrap(), JoinType::Cross);
1003
1004 assert!(JoinType::from_str("invalid").is_err());
1005 }
1006
1007 #[test]
1008 fn test_join_keys() {
1009 let keys = JoinKeys::on(vec!["id".to_string(), "name".to_string()]);
1010 assert_eq!(keys.left_columns(), &["id", "name"]);
1011 assert_eq!(keys.right_columns(), &["id", "name"]);
1012
1013 let keys = JoinKeys::left_right(vec!["left_id".to_string()], vec!["right_id".to_string()]);
1014 assert_eq!(keys.left_columns(), &["left_id"]);
1015 assert_eq!(keys.right_columns(), &["right_id"]);
1016 }
1017
1018 #[test]
1019 fn test_semi_join() {
1020 let left_array = Value::Array(vec![
1021 Value::Object(HashMap::from([
1022 ("id".to_string(), Value::Int(1)),
1023 ("name".to_string(), Value::String("Alice".to_string())),
1024 ])),
1025 Value::Object(HashMap::from([
1026 ("id".to_string(), Value::Int(2)),
1027 ("name".to_string(), Value::String("Bob".to_string())),
1028 ])),
1029 Value::Object(HashMap::from([
1030 ("id".to_string(), Value::Int(3)),
1031 ("name".to_string(), Value::String("Charlie".to_string())),
1032 ])),
1033 ]);
1034
1035 let right_array = Value::Array(vec![
1036 Value::Object(HashMap::from([("id".to_string(), Value::Int(1))])),
1037 Value::Object(HashMap::from([("id".to_string(), Value::Int(3))])),
1038 ]);
1039
1040 let keys = JoinKeys::on(vec!["id".to_string()]);
1041 let options = JoinOptions {
1042 join_type: JoinType::Semi,
1043 ..Default::default()
1044 };
1045
1046 let result = join(&left_array, &right_array, &keys, &options).unwrap();
1047
1048 match result {
1049 Value::Array(arr) => {
1050 assert_eq!(arr.len(), 2); if let Value::Object(obj) = &arr[0] {
1053 assert!(obj.contains_key("name"));
1054 assert!(!obj.contains_key("age")); }
1056 }
1057 _ => panic!("Expected Array"),
1058 }
1059 }
1060
1061 #[test]
1062 fn test_anti_join() {
1063 let left_array = Value::Array(vec![
1064 Value::Object(HashMap::from([
1065 ("id".to_string(), Value::Int(1)),
1066 ("name".to_string(), Value::String("Alice".to_string())),
1067 ])),
1068 Value::Object(HashMap::from([
1069 ("id".to_string(), Value::Int(2)),
1070 ("name".to_string(), Value::String("Bob".to_string())),
1071 ])),
1072 ]);
1073
1074 let right_array = Value::Array(vec![Value::Object(HashMap::from([(
1075 "id".to_string(),
1076 Value::Int(1),
1077 )]))]);
1078
1079 let keys = JoinKeys::on(vec!["id".to_string()]);
1080 let options = JoinOptions {
1081 join_type: JoinType::Anti,
1082 ..Default::default()
1083 };
1084
1085 let result = join(&left_array, &right_array, &keys, &options).unwrap();
1086
1087 match result {
1088 Value::Array(arr) => {
1089 assert_eq!(arr.len(), 1); if let Value::Object(obj) = &arr[0] {
1091 assert_eq!(obj.get("name"), Some(&Value::String("Bob".to_string())));
1092 }
1093 }
1094 _ => panic!("Expected Array"),
1095 }
1096 }
1097
1098 #[test]
1099 fn test_cross_join() {
1100 let left_array = Value::Array(vec![
1101 Value::Object(HashMap::from([(
1102 "name".to_string(),
1103 Value::String("Alice".to_string()),
1104 )])),
1105 Value::Object(HashMap::from([(
1106 "name".to_string(),
1107 Value::String("Bob".to_string()),
1108 )])),
1109 ]);
1110
1111 let right_array = Value::Array(vec![
1112 Value::Object(HashMap::from([(
1113 "color".to_string(),
1114 Value::String("Red".to_string()),
1115 )])),
1116 Value::Object(HashMap::from([(
1117 "color".to_string(),
1118 Value::String("Blue".to_string()),
1119 )])),
1120 ]);
1121
1122 let keys = JoinKeys::on(vec![]); let options = JoinOptions {
1124 join_type: JoinType::Cross,
1125 ..Default::default()
1126 };
1127
1128 let result = join(&left_array, &right_array, &keys, &options).unwrap();
1129
1130 match result {
1131 Value::Array(arr) => {
1132 assert_eq!(arr.len(), 4); for item in &arr {
1135 if let Value::Object(obj) = item {
1136 assert!(obj.contains_key("name"));
1137 assert!(obj.contains_key("color"));
1138 }
1139 }
1140 }
1141 _ => panic!("Expected Array"),
1142 }
1143 }
1144
1145 #[test]
1146 fn test_join_multiple() {
1147 let df1 = DataFrame::new(vec![
1148 Column::new("id".into(), &[1, 2]),
1149 Column::new("name".into(), &["Alice", "Bob"]),
1150 ])
1151 .unwrap();
1152
1153 let df2 = DataFrame::new(vec![
1154 Column::new("id".into(), &[1, 2]),
1155 Column::new("age".into(), &[30, 25]),
1156 ])
1157 .unwrap();
1158
1159 let df3 = DataFrame::new(vec![
1160 Column::new("id".into(), &[1, 2]),
1161 Column::new("city".into(), &["NYC", "LA"]),
1162 ])
1163 .unwrap();
1164
1165 let dataframes = vec![
1166 Value::DataFrame(df1),
1167 Value::DataFrame(df2),
1168 Value::DataFrame(df3),
1169 ];
1170
1171 let keys = JoinKeys::on(vec!["id".to_string()]);
1172 let options = JoinOptions {
1173 join_type: JoinType::Inner,
1174 ..Default::default()
1175 };
1176
1177 let result = join_multiple(&dataframes, &keys, &options).unwrap();
1178
1179 match result {
1180 Value::DataFrame(df) => {
1181 assert_eq!(df.height(), 2);
1182 assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
1183 assert!(df.get_column_names().contains(&&PlSmallStr::from("age")));
1184 assert!(df.get_column_names().contains(&&PlSmallStr::from("city")));
1185 }
1186 _ => panic!("Expected DataFrame"),
1187 }
1188 }
1189
1190 #[test]
1191 fn test_join_with_options() {
1192 let left_df = create_left_dataframe();
1193 let right_df = create_right_dataframe();
1194
1195 let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
1196
1197 let options = JoinOptions {
1198 join_type: JoinType::Inner,
1199 suffix: "_right".to_string(),
1200 validate: JoinValidation::None,
1201 sort: false,
1202 coalesce: polars::prelude::JoinCoalesce::JoinSpecific,
1203 };
1204
1205 let result = join(
1206 &Value::DataFrame(left_df),
1207 &Value::DataFrame(right_df),
1208 &keys,
1209 &options,
1210 )
1211 .unwrap();
1212
1213 match result {
1214 Value::DataFrame(df) => {
1215 assert_eq!(df.height(), 3);
1216 assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
1218 assert!(df
1219 .get_column_names()
1220 .contains(&&PlSmallStr::from("dept_name")));
1221 }
1222 _ => panic!("Expected DataFrame"),
1223 }
1224 }
1225
1226 #[test]
1227 fn test_join_lazy_frames() {
1228 let left_df = create_left_dataframe();
1229 let right_df = create_right_dataframe();
1230
1231 let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
1232
1233 let options = JoinOptions {
1234 join_type: JoinType::Inner,
1235 ..Default::default()
1236 };
1237
1238 let result = join(
1239 &Value::LazyFrame(Box::new(left_df.lazy())),
1240 &Value::LazyFrame(Box::new(right_df.lazy())),
1241 &keys,
1242 &options,
1243 )
1244 .unwrap();
1245
1246 match result {
1247 Value::LazyFrame(_) => {
1248 }
1250 _ => panic!("Expected LazyFrame"),
1251 }
1252 }
1253
1254 #[test]
1255 fn test_join_mixed_types() {
1256 let left_df = create_left_dataframe();
1257 let right_lf = create_right_dataframe().lazy();
1258
1259 let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
1260
1261 let options = JoinOptions {
1262 join_type: JoinType::Inner,
1263 ..Default::default()
1264 };
1265
1266 let result = join(
1267 &Value::DataFrame(left_df),
1268 &Value::LazyFrame(Box::new(right_lf)),
1269 &keys,
1270 &options,
1271 )
1272 .unwrap();
1273
1274 match result {
1275 Value::DataFrame(df) => {
1276 assert_eq!(df.height(), 3);
1277 }
1278 _ => panic!("Expected DataFrame"),
1279 }
1280 }
1281
1282 #[test]
1283 fn test_join_with_suffix() {
1284 let left_array = Value::Array(vec![Value::Object(HashMap::from([
1285 ("id".to_string(), Value::Int(1)),
1286 ("name".to_string(), Value::String("Alice".to_string())),
1287 ]))]);
1288
1289 let right_array = Value::Array(vec![Value::Object(HashMap::from([
1290 ("id".to_string(), Value::Int(1)),
1291 ("name".to_string(), Value::String("Bob".to_string())), ]))]);
1293
1294 let keys = JoinKeys::on(vec!["id".to_string()]);
1295 let options = JoinOptions {
1296 join_type: JoinType::Inner,
1297 ..Default::default()
1298 };
1299
1300 let result = join(&left_array, &right_array, &keys, &options).unwrap();
1301
1302 match result {
1303 Value::Array(arr) => {
1304 assert_eq!(arr.len(), 1);
1305 if let Value::Object(obj) = &arr[0] {
1306 assert!(obj.contains_key("name"));
1307 assert!(obj.contains_key("name_right"));
1308 }
1309 }
1310 _ => panic!("Expected Array"),
1311 }
1312 }
1313
1314 #[test]
1315 fn test_join_empty_arrays() {
1316 let left_array = Value::Array(vec![]);
1317 let right_array = Value::Array(vec![]);
1318
1319 let keys = JoinKeys::on(vec!["id".to_string()]);
1320 let options = JoinOptions {
1321 join_type: JoinType::Inner,
1322 ..Default::default()
1323 };
1324
1325 let result = join(&left_array, &right_array, &keys, &options).unwrap();
1326
1327 match result {
1328 Value::Array(arr) => {
1329 assert_eq!(arr.len(), 0);
1330 }
1331 _ => panic!("Expected Array"),
1332 }
1333 }
1334
1335 #[test]
1336 fn test_join_invalid_keys() {
1337 let left_df = create_left_dataframe();
1338 let right_df = create_right_dataframe();
1339
1340 let keys = JoinKeys::on(vec!["nonexistent".to_string()]);
1341
1342 let result = join(
1343 &Value::DataFrame(left_df),
1344 &Value::DataFrame(right_df),
1345 &keys,
1346 &JoinOptions::default(),
1347 );
1348
1349 assert!(result.is_err()); }
1351
1352 #[test]
1353 fn test_join_type_parsing() {
1354 assert_eq!(JoinType::from_str("inner").unwrap(), JoinType::Inner);
1355 assert_eq!(JoinType::from_str("left").unwrap(), JoinType::Left);
1356 assert_eq!(JoinType::from_str("right").unwrap(), JoinType::Right);
1357 assert_eq!(JoinType::from_str("outer").unwrap(), JoinType::Outer);
1358 assert_eq!(JoinType::from_str("full").unwrap(), JoinType::Outer);
1359 assert_eq!(JoinType::from_str("cross").unwrap(), JoinType::Cross);
1360 assert_eq!(JoinType::from_str("semi").unwrap(), JoinType::Semi);
1361 assert_eq!(JoinType::from_str("anti").unwrap(), JoinType::Anti);
1362 assert!(JoinType::from_str("invalid").is_err());
1363 }
1364
1365 #[test]
1366 fn test_join_validation() {
1367 assert_eq!(
1368 JoinValidation::None.to_polars(),
1369 polars::prelude::JoinValidation::ManyToOne
1370 );
1371 assert_eq!(
1372 JoinValidation::OneToMany.to_polars(),
1373 polars::prelude::JoinValidation::OneToMany
1374 );
1375 assert_eq!(
1376 JoinValidation::ManyToOne.to_polars(),
1377 polars::prelude::JoinValidation::ManyToOne
1378 );
1379 assert_eq!(
1380 JoinValidation::OneToOne.to_polars(),
1381 polars::prelude::JoinValidation::OneToOne
1382 );
1383 }
1384
1385 #[test]
1386 fn test_join_keys_methods() {
1387 let keys = JoinKeys::on(vec!["a".to_string(), "b".to_string()]);
1388 assert_eq!(keys.left_columns(), &["a", "b"]);
1389 assert_eq!(keys.right_columns(), &["a", "b"]);
1390
1391 let keys = JoinKeys::left_right(vec!["la".to_string()], vec!["ra".to_string()]);
1392 assert_eq!(keys.left_columns(), &["la"]);
1393 assert_eq!(keys.right_columns(), &["ra"]);
1394 }
1395
1396 #[test]
1397 fn test_join_options_default() {
1398 let options = JoinOptions::default();
1399 assert_eq!(options.join_type, JoinType::Inner);
1400 assert_eq!(options.suffix, "_right");
1401 assert_eq!(options.validate, JoinValidation::None);
1402 assert!(!options.sort);
1403 }
1404}