dsq_core/ops/
join.rs

1//! Join operations for dsq
2//!
3//! This module provides join operations for `DataFrames` including:
4//! - Inner joins
5//! - Left outer joins  
6//! - Right outer joins
7//! - Full outer joins
8//! - Cross joins
9//! - Semi joins
10//! - Anti joins
11//!
12//! These operations correspond to SQL JOIN operations and allow combining
13//! data from multiple `DataFrames` based on common keys.
14
15use std::collections::HashMap;
16
17use polars::prelude::*;
18
19use crate::error::{Error, Result};
20use crate::Value;
21
22/// Types of join operations supported
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum JoinType {
25    /// Inner join - returns only matching rows from both `DataFrames`
26    Inner,
27    /// Left outer join - returns all rows from left `DataFrame`, matching rows from right
28    Left,
29    /// Right outer join - returns all rows from right `DataFrame`, matching rows from left
30    Right,
31    /// Full outer join - returns all rows from both `DataFrames`
32    Outer,
33    /// Cross join - cartesian product of both `DataFrames`
34    Cross,
35    /// Semi join - returns rows from left `DataFrame` that have matches in right
36    Semi,
37    /// Anti join - returns rows from left `DataFrame` that have no matches in right
38    Anti,
39}
40
41impl JoinType {
42    /// Convert to Polars `JoinType`
43    pub fn to_polars(&self) -> Result<polars::prelude::JoinType> {
44        match self {
45            JoinType::Inner => Ok(polars::prelude::JoinType::Inner),
46            JoinType::Left => Ok(polars::prelude::JoinType::Left),
47            JoinType::Right => Err(Error::operation(
48                "Right join not supported in this Polars version, so cannot convert to Polars",
49            )),
50            JoinType::Outer => Ok(polars::prelude::JoinType::Full),
51            JoinType::Cross => Ok(polars::prelude::JoinType::Cross),
52            JoinType::Semi => Err(Error::operation(
53                "Semi join not supported in this Polars version, so cannot convert to Polars",
54            )),
55            JoinType::Anti => Err(Error::operation(
56                "Anti join not supported in this Polars version, so cannot convert to Polars",
57            )),
58        }
59    }
60
61    /// Get the string representation
62    #[must_use]
63    pub fn as_str(&self) -> &'static str {
64        match self {
65            JoinType::Inner => "inner",
66            JoinType::Left => "left",
67            JoinType::Right => "right",
68            JoinType::Outer => "outer",
69            JoinType::Cross => "cross",
70            JoinType::Semi => "semi",
71            JoinType::Anti => "anti",
72        }
73    }
74
75    /// Parse from string
76    #[allow(clippy::should_implement_trait)]
77    pub fn from_str(s: &str) -> Result<Self> {
78        match s.to_lowercase().as_str() {
79            "inner" => Ok(JoinType::Inner),
80            "left" | "left_outer" => Ok(JoinType::Left),
81            "right" | "right_outer" => Ok(JoinType::Right),
82            "outer" | "full" | "full_outer" => Ok(JoinType::Outer),
83            "cross" => Ok(JoinType::Cross),
84            "semi" => Ok(JoinType::Semi),
85            "anti" => Ok(JoinType::Anti),
86            _ => Err(Error::operation(format!("Unknown join type: {s}"))),
87        }
88    }
89}
90
91/// Options for join operations
92#[derive(Debug, Clone)]
93pub struct JoinOptions {
94    /// Type of join to perform
95    pub join_type: JoinType,
96    /// Suffix to add to duplicate column names from the right `DataFrame`
97    pub suffix: String,
98    /// Whether to validate that join keys are unique (for performance)
99    pub validate: JoinValidation,
100    /// Whether to sort the result by join keys
101    pub sort: bool,
102    /// How to coalesce join keys
103    pub coalesce: polars::prelude::JoinCoalesce,
104}
105
106impl Default for JoinOptions {
107    fn default() -> Self {
108        Self {
109            join_type: JoinType::Inner,
110            suffix: "_right".to_string(),
111            validate: JoinValidation::None,
112            sort: false,
113            coalesce: polars::prelude::JoinCoalesce::JoinSpecific,
114        }
115    }
116}
117
118/// Join validation options
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum JoinValidation {
121    /// No validation
122    None,
123    /// Validate that left keys are unique
124    OneToMany,
125    /// Validate that right keys are unique  
126    ManyToOne,
127    /// Validate that both left and right keys are unique
128    OneToOne,
129}
130
131impl JoinValidation {
132    /// Convert to Polars `JoinValidation`
133    #[must_use]
134    pub fn to_polars(&self) -> polars::prelude::JoinValidation {
135        match self {
136            // None not available in Polars 0.35, default to ManyToOne (allows duplicate left keys)
137            JoinValidation::OneToMany => polars::prelude::JoinValidation::OneToMany,
138            JoinValidation::None | JoinValidation::ManyToOne => {
139                polars::prelude::JoinValidation::ManyToOne
140            }
141            JoinValidation::OneToOne => polars::prelude::JoinValidation::OneToOne,
142        }
143    }
144}
145
146/// Join specification for columns
147#[derive(Debug, Clone)]
148pub enum JoinKeys {
149    /// Join on columns with the same name
150    On(Vec<String>),
151    /// Join with different column names for left and right
152    LeftRight {
153        /// Column names from left `DataFrame`
154        left: Vec<String>,
155        /// Column names from right `DataFrame`
156        right: Vec<String>,
157    },
158}
159
160impl JoinKeys {
161    /// Create join keys for columns with the same name
162    #[must_use]
163    pub fn on(columns: Vec<String>) -> Self {
164        JoinKeys::On(columns)
165    }
166
167    /// Create join keys with different left and right column names
168    #[must_use]
169    pub fn left_right(left: Vec<String>, right: Vec<String>) -> Self {
170        JoinKeys::LeftRight { left, right }
171    }
172
173    /// Get the left column names
174    #[must_use]
175    pub fn left_columns(&self) -> &[String] {
176        match self {
177            JoinKeys::On(cols) => cols,
178            JoinKeys::LeftRight { left, .. } => left,
179        }
180    }
181
182    /// Get the right column names
183    #[must_use]
184    pub fn right_columns(&self) -> &[String] {
185        match self {
186            JoinKeys::On(cols) => cols,
187            JoinKeys::LeftRight { right, .. } => right,
188        }
189    }
190}
191
192/// Join two `DataFrames`
193///
194/// # Examples
195///
196/// ```rust,ignore
197/// use dsq_core::ops::join::{join, JoinKeys, JoinOptions, JoinType};
198/// use dsq_core::value::Value;
199///
200/// let keys = JoinKeys::on(vec!["id".to_string()]);
201/// let options = JoinOptions {
202///     join_type: JoinType::Inner,
203///     ..Default::default()
204/// };
205/// let result = join(&left_df, &right_df, &keys, &options).unwrap();
206/// ```
207pub fn join(left: &Value, right: &Value, keys: &JoinKeys, options: &JoinOptions) -> Result<Value> {
208    match (left, right) {
209        (Value::DataFrame(left_df), Value::DataFrame(right_df)) => {
210            join_dataframes(left_df, right_df, keys, options)
211        }
212        (Value::LazyFrame(left_lf), Value::LazyFrame(right_lf)) => {
213            join_lazy_frames(left_lf, right_lf, keys, options)
214        }
215        (Value::DataFrame(left_df), Value::LazyFrame(right_lf)) => {
216            let right_df = right_lf.clone().collect().map_err(Error::from)?;
217            join_dataframes(left_df, &right_df, keys, options)
218        }
219        (Value::LazyFrame(left_lf), Value::DataFrame(right_df)) => {
220            let left_df = left_lf.clone().collect().map_err(Error::from)?;
221            join_dataframes(&left_df, right_df, keys, options)
222        }
223        (Value::Array(left_arr), Value::Array(right_arr)) => {
224            join_arrays(left_arr, right_arr, keys, options)
225        }
226        (left_val, right_val) => {
227            // Try to convert to DataFrames
228            let left_df = left_val.to_dataframe()?;
229            let right_df = right_val.to_dataframe()?;
230            join_dataframes(&left_df, &right_df, keys, options)
231        }
232    }
233}
234
235/// Join two `DataFrames` using Polars
236fn join_dataframes(
237    left_df: &DataFrame,
238    right_df: &DataFrame,
239    keys: &JoinKeys,
240    options: &JoinOptions,
241) -> Result<Value> {
242    let left_on: Vec<Expr> = keys.left_columns().iter().map(col).collect();
243    let right_on: Vec<Expr> = keys.right_columns().iter().map(col).collect();
244
245    let join_args = JoinArgs::new(options.join_type.to_polars()?);
246
247    let join_builder =
248        left_df
249            .clone()
250            .lazy()
251            .join(right_df.clone().lazy(), left_on, right_on, join_args);
252
253    let result_df = join_builder.collect().map_err(Error::from)?;
254    Ok(Value::DataFrame(result_df))
255}
256
257/// Join two `LazyFrames`
258fn join_lazy_frames(
259    left_lf: &LazyFrame,
260    right_lf: &LazyFrame,
261    keys: &JoinKeys,
262    options: &JoinOptions,
263) -> Result<Value> {
264    let left_on: Vec<Expr> = keys.left_columns().iter().map(col).collect();
265    let right_on: Vec<Expr> = keys.right_columns().iter().map(col).collect();
266
267    let join_args = JoinArgs::new(options.join_type.to_polars()?);
268
269    let mut join_builder = left_lf
270        .clone()
271        .join(right_lf.clone(), left_on, right_on, join_args);
272
273    if options.sort {
274        // Sort by the join keys
275        let sort_exprs: Vec<Expr> = keys.left_columns().iter().map(col).collect();
276        join_builder = join_builder.sort_by_exprs(sort_exprs, SortMultipleOptions::default());
277    }
278
279    Ok(Value::LazyFrame(Box::new(join_builder)))
280}
281
282/// Join two arrays of objects (jq-style)
283fn join_arrays(
284    left_arr: &[Value],
285    right_arr: &[Value],
286    keys: &JoinKeys,
287    options: &JoinOptions,
288) -> Result<Value> {
289    let mut result = match options.join_type {
290        JoinType::Inner => inner_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
291        JoinType::Left => left_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
292        JoinType::Right => right_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
293        JoinType::Outer => outer_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
294        JoinType::Cross => cross_join_arrays(left_arr, right_arr, &options.suffix)?,
295        JoinType::Semi => semi_join_arrays(left_arr, right_arr, keys)?,
296        JoinType::Anti => anti_join_arrays(left_arr, right_arr, keys)?,
297    };
298
299    if options.sort {
300        // Sort by the first join key
301        if let Some(first_key) = keys.left_columns().first() {
302            result.sort_by(|a, b| {
303                let a_val = match a {
304                    Value::Object(obj) => obj.get(first_key).unwrap_or(&Value::Null),
305                    _ => &Value::Null,
306                };
307                let b_val = match b {
308                    Value::Object(obj) => obj.get(first_key).unwrap_or(&Value::Null),
309                    _ => &Value::Null,
310                };
311                compare_values_for_sorting(a_val, b_val)
312            });
313        }
314    }
315
316    Ok(Value::Array(result))
317}
318
319/// Inner join for arrays
320fn inner_join_arrays(
321    left_arr: &[Value],
322    right_arr: &[Value],
323    keys: &JoinKeys,
324    suffix: &str,
325) -> Result<Vec<Value>> {
326    let mut result = Vec::new();
327
328    for left_item in left_arr {
329        if let Value::Object(left_obj) = left_item {
330            for right_item in right_arr {
331                if let Value::Object(right_obj) = right_item {
332                    if objects_match_on_keys(left_obj, right_obj, keys)? {
333                        let joined = merge_objects(
334                            left_obj,
335                            right_obj,
336                            suffix,
337                            false,
338                            &std::collections::HashSet::new(),
339                        )?;
340                        result.push(Value::Object(joined));
341                    }
342                }
343            }
344        }
345    }
346
347    Ok(result)
348}
349
350/// Left join for arrays
351fn left_join_arrays(
352    left_arr: &[Value],
353    right_arr: &[Value],
354    keys: &JoinKeys,
355    suffix: &str,
356) -> Result<Vec<Value>> {
357    let right_keys: std::collections::HashSet<String> = right_arr
358        .iter()
359        .filter_map(|v| {
360            if let Value::Object(o) = v {
361                Some(o.keys().cloned().collect::<Vec<_>>())
362            } else {
363                None
364            }
365        })
366        .flatten()
367        .collect();
368
369    let mut result = Vec::new();
370
371    for left_item in left_arr {
372        if let Value::Object(left_obj) = left_item {
373            let mut found_match = false;
374
375            for right_item in right_arr {
376                if let Value::Object(right_obj) = right_item {
377                    if objects_match_on_keys(left_obj, right_obj, keys)? {
378                        let joined = merge_objects(
379                            left_obj,
380                            right_obj,
381                            suffix,
382                            false,
383                            &std::collections::HashSet::new(),
384                        )?;
385                        result.push(Value::Object(joined));
386                        found_match = true;
387                    }
388                }
389            }
390
391            if !found_match {
392                // Add left row with nulls for right columns
393                let joined = merge_objects(left_obj, &HashMap::new(), suffix, true, &right_keys)?;
394                result.push(Value::Object(joined));
395            }
396        }
397    }
398
399    Ok(result)
400}
401
402/// Right join for arrays
403fn right_join_arrays(
404    left_arr: &[Value],
405    right_arr: &[Value],
406    keys: &JoinKeys,
407    suffix: &str,
408) -> Result<Vec<Value>> {
409    let left_keys: std::collections::HashSet<String> = left_arr
410        .iter()
411        .filter_map(|v| {
412            if let Value::Object(o) = v {
413                Some(o.keys().cloned().collect::<Vec<_>>())
414            } else {
415                None
416            }
417        })
418        .flatten()
419        .collect();
420
421    let mut result = Vec::new();
422
423    for right_item in right_arr {
424        if let Value::Object(right_obj) = right_item {
425            let mut found_match = false;
426
427            for left_item in left_arr {
428                if let Value::Object(left_obj) = left_item {
429                    if objects_match_on_keys(left_obj, right_obj, keys)? {
430                        let joined = merge_objects(
431                            left_obj,
432                            right_obj,
433                            suffix,
434                            false,
435                            &std::collections::HashSet::new(),
436                        )?;
437                        result.push(Value::Object(joined));
438                        found_match = true;
439                    }
440                }
441            }
442
443            if !found_match {
444                // Add right row with nulls for left columns
445                let joined = merge_objects(&HashMap::new(), right_obj, suffix, true, &left_keys)?;
446                result.push(Value::Object(joined));
447            }
448        }
449    }
450
451    Ok(result)
452}
453
454/// Full outer join for arrays
455fn outer_join_arrays(
456    left_arr: &[Value],
457    right_arr: &[Value],
458    keys: &JoinKeys,
459    suffix: &str,
460) -> Result<Vec<Value>> {
461    let left_keys: std::collections::HashSet<String> = left_arr
462        .iter()
463        .filter_map(|v| {
464            if let Value::Object(o) = v {
465                Some(o.keys().cloned().collect::<Vec<_>>())
466            } else {
467                None
468            }
469        })
470        .flatten()
471        .collect();
472    let right_keys: std::collections::HashSet<String> = right_arr
473        .iter()
474        .filter_map(|v| {
475            if let Value::Object(o) = v {
476                Some(o.keys().cloned().collect::<Vec<_>>())
477            } else {
478                None
479            }
480        })
481        .flatten()
482        .collect();
483
484    let mut result = Vec::new();
485    let mut right_matched = vec![false; right_arr.len()];
486
487    // First pass: left join
488    for left_item in left_arr {
489        if let Value::Object(left_obj) = left_item {
490            let mut found_match = false;
491
492            for (right_idx, right_item) in right_arr.iter().enumerate() {
493                if let Value::Object(right_obj) = right_item {
494                    if objects_match_on_keys(left_obj, right_obj, keys)? {
495                        let joined = merge_objects(
496                            left_obj,
497                            right_obj,
498                            suffix,
499                            false,
500                            &std::collections::HashSet::new(),
501                        )?;
502                        result.push(Value::Object(joined));
503                        right_matched[right_idx] = true;
504                        found_match = true;
505                    }
506                }
507            }
508
509            if !found_match {
510                // Add left row with nulls for right columns
511                let joined = merge_objects(left_obj, &HashMap::new(), suffix, true, &right_keys)?;
512                result.push(Value::Object(joined));
513            }
514        }
515    }
516
517    // Second pass: add unmatched right rows
518    for (right_idx, right_item) in right_arr.iter().enumerate() {
519        if !right_matched[right_idx] {
520            if let Value::Object(right_obj) = right_item {
521                let joined = merge_objects(&HashMap::new(), right_obj, suffix, true, &left_keys)?;
522                result.push(Value::Object(joined));
523            }
524        }
525    }
526
527    Ok(result)
528}
529
530/// Cross join for arrays
531fn cross_join_arrays(left_arr: &[Value], right_arr: &[Value], suffix: &str) -> Result<Vec<Value>> {
532    let mut result = Vec::new();
533
534    for left_item in left_arr {
535        if let Value::Object(left_obj) = left_item {
536            for right_item in right_arr {
537                if let Value::Object(right_obj) = right_item {
538                    let joined = merge_objects(
539                        left_obj,
540                        right_obj,
541                        suffix,
542                        false,
543                        &std::collections::HashSet::new(),
544                    )?;
545                    result.push(Value::Object(joined));
546                }
547            }
548        }
549    }
550
551    Ok(result)
552}
553
554/// Semi join for arrays - returns left rows that have matches in right
555fn semi_join_arrays(
556    left_arr: &[Value],
557    right_arr: &[Value],
558    keys: &JoinKeys,
559) -> Result<Vec<Value>> {
560    let mut result = Vec::new();
561
562    for left_item in left_arr {
563        if let Value::Object(left_obj) = left_item {
564            for right_item in right_arr {
565                if let Value::Object(right_obj) = right_item {
566                    if objects_match_on_keys(left_obj, right_obj, keys)? {
567                        result.push(left_item.clone());
568                        break; // Only add once per left row
569                    }
570                }
571            }
572        }
573    }
574
575    Ok(result)
576}
577
578/// Anti join for arrays - returns left rows that have no matches in right
579fn anti_join_arrays(
580    left_arr: &[Value],
581    right_arr: &[Value],
582    keys: &JoinKeys,
583) -> Result<Vec<Value>> {
584    let mut result = Vec::new();
585
586    for left_item in left_arr {
587        if let Value::Object(left_obj) = left_item {
588            let mut found_match = false;
589
590            for right_item in right_arr {
591                if let Value::Object(right_obj) = right_item {
592                    if objects_match_on_keys(left_obj, right_obj, keys)? {
593                        found_match = true;
594                        break;
595                    }
596                }
597            }
598
599            if !found_match {
600                result.push(left_item.clone());
601            }
602        }
603    }
604
605    Ok(result)
606}
607
608/// Check if two objects match on the specified join keys
609fn objects_match_on_keys(
610    left_obj: &HashMap<String, Value>,
611    right_obj: &HashMap<String, Value>,
612    keys: &JoinKeys,
613) -> Result<bool> {
614    let left_keys = keys.left_columns();
615    let right_keys = keys.right_columns();
616
617    if left_keys.len() != right_keys.len() {
618        return Err(Error::operation(
619            "Left and right join keys must have the same length",
620        ));
621    }
622
623    for (left_key, right_key) in left_keys.iter().zip(right_keys.iter()) {
624        let left_val = left_obj.get(left_key).unwrap_or(&Value::Null);
625        let right_val = right_obj.get(right_key).unwrap_or(&Value::Null);
626
627        if !values_equal_for_join(left_val, right_val) {
628            return Ok(false);
629        }
630    }
631
632    Ok(true)
633}
634
635/// Check if two values are equal for join purposes
636fn values_equal_for_join(left: &Value, right: &Value) -> bool {
637    match (left, right) {
638        (Value::Null, Value::Null) => true,
639        (Value::Bool(a), Value::Bool(b)) => a == b,
640        (Value::Int(a), Value::Int(b)) => a == b,
641        (Value::Float(a), Value::Float(b)) => (a - b).abs() < f64::EPSILON,
642        (Value::String(a), Value::String(b)) => a == b,
643        // Cross-type numeric comparisons
644        #[allow(clippy::cast_precision_loss)]
645        (Value::Int(a), Value::Float(b)) => (*a as f64 - b).abs() < f64::EPSILON,
646        #[allow(clippy::cast_precision_loss)]
647        (Value::Float(a), Value::Int(b)) => (a - *b as f64).abs() < f64::EPSILON,
648        _ => false,
649    }
650}
651
652/// Merge two objects, handling column name conflicts
653#[allow(clippy::unnecessary_wraps)]
654fn merge_objects(
655    left_obj: &HashMap<String, Value>,
656    right_obj: &HashMap<String, Value>,
657    suffix: &str,
658    fill_nulls: bool,
659    null_keys: &std::collections::HashSet<String>,
660) -> Result<HashMap<String, Value>> {
661    let mut result = left_obj.clone();
662
663    for (right_key, right_val) in right_obj {
664        let key = if result.contains_key(right_key) {
665            // Column name conflict - add suffix to right column
666            format!("{right_key}{suffix}")
667        } else {
668            right_key.clone()
669        };
670        result.insert(key, right_val.clone());
671    }
672
673    if fill_nulls {
674        for key in null_keys {
675            if result.contains_key(key) {
676                // If conflict, suffix
677                let suffixed = format!("{key}{suffix}");
678                result.entry(suffixed).or_insert(Value::Null);
679            } else {
680                result.insert(key.clone(), Value::Null);
681            }
682        }
683    }
684
685    Ok(result)
686}
687
688/// Compare values for sorting
689fn compare_values_for_sorting(a: &Value, b: &Value) -> std::cmp::Ordering {
690    use std::cmp::Ordering;
691
692    match (a, b) {
693        (Value::Null, Value::Null) => Ordering::Equal,
694        (Value::Null, _) => Ordering::Less,
695        (_, Value::Null) => Ordering::Greater,
696
697        (Value::Bool(a), Value::Bool(b)) => a.cmp(b),
698        (Value::Int(a), Value::Int(b)) => a.cmp(b),
699        (Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
700        (Value::String(a), Value::String(b)) => a.cmp(b),
701
702        // Cross-type numeric comparisons
703        #[allow(clippy::cast_precision_loss)]
704        (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal),
705        #[allow(clippy::cast_precision_loss)]
706        (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal),
707
708        // For complex types, compare string representations
709        _ => a.to_string().cmp(&b.to_string()),
710    }
711}
712
713/// Convenience function for inner join
714pub fn inner_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
715    let options = JoinOptions {
716        join_type: JoinType::Inner,
717        ..Default::default()
718    };
719    join(left, right, keys, &options)
720}
721
722/// Convenience function for left join
723pub fn left_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
724    let options = JoinOptions {
725        join_type: JoinType::Left,
726        ..Default::default()
727    };
728    join(left, right, keys, &options)
729}
730
731/// Convenience function for right join
732pub fn right_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
733    let options = JoinOptions {
734        join_type: JoinType::Right,
735        ..Default::default()
736    };
737    join(left, right, keys, &options)
738}
739
740/// Convenience function for outer join
741pub fn outer_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
742    let options = JoinOptions {
743        join_type: JoinType::Outer,
744        ..Default::default()
745    };
746    join(left, right, keys, &options)
747}
748
749/// Join multiple `DataFrames` in sequence
750///
751/// Performs a series of joins on multiple `DataFrames` using the same join keys.
752///
753/// # Examples
754///
755/// ```rust,ignore
756/// use dsq_core::ops::join::{join_multiple, JoinKeys, JoinOptions, JoinType};
757/// use dsq_core::value::Value;
758///
759/// let dataframes = vec![df1, df2, df3];
760/// let keys = JoinKeys::on(vec!["id".to_string()]);
761/// let options = JoinOptions {
762///     join_type: JoinType::Inner,
763///     ..Default::default()
764/// };
765/// let result = join_multiple(&dataframes, &keys, &options).unwrap();
766/// ```
767pub fn join_multiple(
768    dataframes: &[Value],
769    keys: &JoinKeys,
770    options: &JoinOptions,
771) -> Result<Value> {
772    if dataframes.is_empty() {
773        return Err(Error::operation("No DataFrames provided for join"));
774    }
775
776    if dataframes.len() == 1 {
777        return Ok(dataframes[0].clone());
778    }
779
780    let mut result = dataframes[0].clone();
781
782    for (i, df) in dataframes.iter().enumerate().skip(1) {
783        // Adjust suffix for each join to avoid conflicts
784        let mut join_options = options.clone();
785        join_options.suffix = format!("_right_{i}");
786
787        result = join(&result, df, keys, &join_options)?;
788    }
789
790    Ok(result)
791}
792
793/// Perform a join with a custom condition
794///
795/// This allows for more complex join conditions beyond simple equality.
796///
797/// # Examples
798///
799/// ```rust,ignore
800/// use dsq_core::ops::join::{join_with_condition, JoinType};
801/// use dsq_core::value::Value;
802/// use polars::prelude::*;
803///
804/// let condition = col("left.price").gt(col("right.min_price"))
805///     .and(col("left.price").lt(col("right.max_price")));
806/// let result = join_with_condition(
807///     &left_df,
808///     &right_df,
809///     condition,
810///     JoinType::Inner,
811///     "_right"
812/// ).unwrap();
813/// ```
814#[allow(clippy::used_underscore_binding)]
815pub fn join_with_condition(
816    left: &Value,
817    right: &Value,
818    condition: Expr,
819    _join_type: JoinType,
820) -> Result<Value> {
821    match (left, right) {
822        (Value::DataFrame(left_df), Value::DataFrame(right_df)) => {
823            // For complex conditions, we need to use a cross join followed by a filter
824            // This is less efficient but more flexible
825
826            let how = JoinType::Cross.to_polars()?;
827            let join_args = JoinArgs::new(how);
828
829            let cross_joined =
830                left_df
831                    .clone()
832                    .lazy()
833                    .join(right_df.clone().lazy(), vec![], vec![], join_args);
834
835            let filtered = cross_joined.filter(condition);
836
837            let result_df = filtered.collect().map_err(Error::from)?;
838            Ok(Value::DataFrame(result_df))
839        }
840        (Value::LazyFrame(left_lf), Value::LazyFrame(right_lf)) => {
841            let how = JoinType::Cross.to_polars()?;
842            let join_args = JoinArgs::new(how);
843
844            let cross_joined = left_lf
845                .clone()
846                .join(*right_lf.clone(), vec![], vec![], join_args);
847
848            let filtered = cross_joined.filter(condition);
849            Ok(Value::LazyFrame(Box::new(filtered)))
850        }
851        _ => {
852            // Convert to DataFrames and retry
853            let left_df = left.to_dataframe()?;
854            let right_df = right.to_dataframe()?;
855            join_with_condition(
856                &Value::DataFrame(left_df),
857                &Value::DataFrame(right_df),
858                condition,
859                _join_type,
860            )
861        }
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use std::collections::HashMap;
868
869    use super::*;
870
871    fn create_left_dataframe() -> DataFrame {
872        let id = Column::new("id".into(), &[1, 2, 3, 4]);
873        let name = Column::new("name".into(), &["Alice", "Bob", "Charlie", "Dave"]);
874        let dept_id = Column::new("dept_id".into(), &[10, 20, 10, 30]);
875        DataFrame::new(vec![id, name, dept_id]).unwrap()
876    }
877
878    fn create_right_dataframe() -> DataFrame {
879        let id = Column::new("id".into(), &[10, 20, 40]);
880        let dept_name = Column::new("dept_name".into(), &["Engineering", "Sales", "Marketing"]);
881        let budget = Column::new("budget".into(), &[100000, 50000, 75000]);
882        DataFrame::new(vec![id, dept_name, budget]).unwrap()
883    }
884
885    #[test]
886    fn test_inner_join() {
887        let left_df = create_left_dataframe();
888        let right_df = create_right_dataframe();
889
890        let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
891
892        let result = inner_join(
893            &Value::DataFrame(left_df),
894            &Value::DataFrame(right_df),
895            &keys,
896        )
897        .unwrap();
898
899        match result {
900            Value::DataFrame(df) => {
901                assert_eq!(df.shape().0, 3); // Alice, Bob, and Charlie should match
902                assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
903                assert!(df
904                    .get_column_names()
905                    .contains(&&PlSmallStr::from("dept_name")));
906            }
907            _ => panic!("Expected DataFrame"),
908        }
909    }
910
911    #[test]
912    fn test_left_join() {
913        let left_df = create_left_dataframe();
914        let right_df = create_right_dataframe();
915
916        let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
917
918        let result = left_join(
919            &Value::DataFrame(left_df),
920            &Value::DataFrame(right_df),
921            &keys,
922        )
923        .unwrap();
924
925        match result {
926            Value::DataFrame(df) => {
927                assert_eq!(df.height(), 4); // All left rows should be present
928                assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
929                assert!(df
930                    .get_column_names()
931                    .contains(&&PlSmallStr::from("dept_name")));
932            }
933            _ => panic!("Expected DataFrame"),
934        }
935    }
936
937    #[test]
938    #[ignore = "Right join not supported in this Polars version"]
939    fn test_right_join() {
940        let left_df = create_left_dataframe();
941        let right_df = create_right_dataframe();
942
943        let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
944
945        let result = right_join(
946            &Value::DataFrame(left_df),
947            &Value::DataFrame(right_df),
948            &keys,
949        );
950
951        assert!(result.is_err());
952        assert!(result
953            .unwrap_err()
954            .to_string()
955            .contains("Right join not supported"));
956    }
957
958    #[test]
959    fn test_array_join() {
960        let left_array = Value::Array(vec![
961            Value::Object(HashMap::from([
962                ("id".to_string(), Value::Int(1)),
963                ("name".to_string(), Value::String("Alice".to_string())),
964            ])),
965            Value::Object(HashMap::from([
966                ("id".to_string(), Value::Int(2)),
967                ("name".to_string(), Value::String("Bob".to_string())),
968            ])),
969        ]);
970
971        let right_array = Value::Array(vec![
972            Value::Object(HashMap::from([
973                ("id".to_string(), Value::Int(1)),
974                ("age".to_string(), Value::Int(30)),
975            ])),
976            Value::Object(HashMap::from([
977                ("id".to_string(), Value::Int(3)),
978                ("age".to_string(), Value::Int(25)),
979            ])),
980        ]);
981
982        let keys = JoinKeys::on(vec!["id".to_string()]);
983        let result = inner_join(&left_array, &right_array, &keys).unwrap();
984
985        match result {
986            Value::Array(arr) => {
987                assert_eq!(arr.len(), 1); // Only Alice should match
988                if let Value::Object(obj) = &arr[0] {
989                    assert_eq!(obj.get("name"), Some(&Value::String("Alice".to_string())));
990                    assert_eq!(obj.get("age"), Some(&Value::Int(30)));
991                }
992            }
993            _ => panic!("Expected Array"),
994        }
995    }
996
997    #[test]
998    fn test_join_types() {
999        assert_eq!(JoinType::from_str("inner").unwrap(), JoinType::Inner);
1000        assert_eq!(JoinType::from_str("left_outer").unwrap(), JoinType::Left);
1001        assert_eq!(JoinType::from_str("full").unwrap(), JoinType::Outer);
1002        assert_eq!(JoinType::from_str("cross").unwrap(), JoinType::Cross);
1003
1004        assert!(JoinType::from_str("invalid").is_err());
1005    }
1006
1007    #[test]
1008    fn test_join_keys() {
1009        let keys = JoinKeys::on(vec!["id".to_string(), "name".to_string()]);
1010        assert_eq!(keys.left_columns(), &["id", "name"]);
1011        assert_eq!(keys.right_columns(), &["id", "name"]);
1012
1013        let keys = JoinKeys::left_right(vec!["left_id".to_string()], vec!["right_id".to_string()]);
1014        assert_eq!(keys.left_columns(), &["left_id"]);
1015        assert_eq!(keys.right_columns(), &["right_id"]);
1016    }
1017
1018    #[test]
1019    fn test_semi_join() {
1020        let left_array = Value::Array(vec![
1021            Value::Object(HashMap::from([
1022                ("id".to_string(), Value::Int(1)),
1023                ("name".to_string(), Value::String("Alice".to_string())),
1024            ])),
1025            Value::Object(HashMap::from([
1026                ("id".to_string(), Value::Int(2)),
1027                ("name".to_string(), Value::String("Bob".to_string())),
1028            ])),
1029            Value::Object(HashMap::from([
1030                ("id".to_string(), Value::Int(3)),
1031                ("name".to_string(), Value::String("Charlie".to_string())),
1032            ])),
1033        ]);
1034
1035        let right_array = Value::Array(vec![
1036            Value::Object(HashMap::from([("id".to_string(), Value::Int(1))])),
1037            Value::Object(HashMap::from([("id".to_string(), Value::Int(3))])),
1038        ]);
1039
1040        let keys = JoinKeys::on(vec!["id".to_string()]);
1041        let options = JoinOptions {
1042            join_type: JoinType::Semi,
1043            ..Default::default()
1044        };
1045
1046        let result = join(&left_array, &right_array, &keys, &options).unwrap();
1047
1048        match result {
1049            Value::Array(arr) => {
1050                assert_eq!(arr.len(), 2); // Alice and Charlie should be returned
1051                                          // Should only contain left columns
1052                if let Value::Object(obj) = &arr[0] {
1053                    assert!(obj.contains_key("name"));
1054                    assert!(!obj.contains_key("age")); // No right columns
1055                }
1056            }
1057            _ => panic!("Expected Array"),
1058        }
1059    }
1060
1061    #[test]
1062    fn test_anti_join() {
1063        let left_array = Value::Array(vec![
1064            Value::Object(HashMap::from([
1065                ("id".to_string(), Value::Int(1)),
1066                ("name".to_string(), Value::String("Alice".to_string())),
1067            ])),
1068            Value::Object(HashMap::from([
1069                ("id".to_string(), Value::Int(2)),
1070                ("name".to_string(), Value::String("Bob".to_string())),
1071            ])),
1072        ]);
1073
1074        let right_array = Value::Array(vec![Value::Object(HashMap::from([(
1075            "id".to_string(),
1076            Value::Int(1),
1077        )]))]);
1078
1079        let keys = JoinKeys::on(vec!["id".to_string()]);
1080        let options = JoinOptions {
1081            join_type: JoinType::Anti,
1082            ..Default::default()
1083        };
1084
1085        let result = join(&left_array, &right_array, &keys, &options).unwrap();
1086
1087        match result {
1088            Value::Array(arr) => {
1089                assert_eq!(arr.len(), 1); // Only Bob should be returned
1090                if let Value::Object(obj) = &arr[0] {
1091                    assert_eq!(obj.get("name"), Some(&Value::String("Bob".to_string())));
1092                }
1093            }
1094            _ => panic!("Expected Array"),
1095        }
1096    }
1097
1098    #[test]
1099    fn test_cross_join() {
1100        let left_array = Value::Array(vec![
1101            Value::Object(HashMap::from([(
1102                "name".to_string(),
1103                Value::String("Alice".to_string()),
1104            )])),
1105            Value::Object(HashMap::from([(
1106                "name".to_string(),
1107                Value::String("Bob".to_string()),
1108            )])),
1109        ]);
1110
1111        let right_array = Value::Array(vec![
1112            Value::Object(HashMap::from([(
1113                "color".to_string(),
1114                Value::String("Red".to_string()),
1115            )])),
1116            Value::Object(HashMap::from([(
1117                "color".to_string(),
1118                Value::String("Blue".to_string()),
1119            )])),
1120        ]);
1121
1122        let keys = JoinKeys::on(vec![]); // No join keys for cross join
1123        let options = JoinOptions {
1124            join_type: JoinType::Cross,
1125            ..Default::default()
1126        };
1127
1128        let result = join(&left_array, &right_array, &keys, &options).unwrap();
1129
1130        match result {
1131            Value::Array(arr) => {
1132                assert_eq!(arr.len(), 4); // 2 x 2 = 4 combinations
1133                                          // Each result should have both name and color
1134                for item in &arr {
1135                    if let Value::Object(obj) = item {
1136                        assert!(obj.contains_key("name"));
1137                        assert!(obj.contains_key("color"));
1138                    }
1139                }
1140            }
1141            _ => panic!("Expected Array"),
1142        }
1143    }
1144
1145    #[test]
1146    fn test_join_multiple() {
1147        let df1 = DataFrame::new(vec![
1148            Column::new("id".into(), &[1, 2]),
1149            Column::new("name".into(), &["Alice", "Bob"]),
1150        ])
1151        .unwrap();
1152
1153        let df2 = DataFrame::new(vec![
1154            Column::new("id".into(), &[1, 2]),
1155            Column::new("age".into(), &[30, 25]),
1156        ])
1157        .unwrap();
1158
1159        let df3 = DataFrame::new(vec![
1160            Column::new("id".into(), &[1, 2]),
1161            Column::new("city".into(), &["NYC", "LA"]),
1162        ])
1163        .unwrap();
1164
1165        let dataframes = vec![
1166            Value::DataFrame(df1),
1167            Value::DataFrame(df2),
1168            Value::DataFrame(df3),
1169        ];
1170
1171        let keys = JoinKeys::on(vec!["id".to_string()]);
1172        let options = JoinOptions {
1173            join_type: JoinType::Inner,
1174            ..Default::default()
1175        };
1176
1177        let result = join_multiple(&dataframes, &keys, &options).unwrap();
1178
1179        match result {
1180            Value::DataFrame(df) => {
1181                assert_eq!(df.height(), 2);
1182                assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
1183                assert!(df.get_column_names().contains(&&PlSmallStr::from("age")));
1184                assert!(df.get_column_names().contains(&&PlSmallStr::from("city")));
1185            }
1186            _ => panic!("Expected DataFrame"),
1187        }
1188    }
1189
1190    #[test]
1191    fn test_join_with_options() {
1192        let left_df = create_left_dataframe();
1193        let right_df = create_right_dataframe();
1194
1195        let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
1196
1197        let options = JoinOptions {
1198            join_type: JoinType::Inner,
1199            suffix: "_right".to_string(),
1200            validate: JoinValidation::None,
1201            sort: false,
1202            coalesce: polars::prelude::JoinCoalesce::JoinSpecific,
1203        };
1204
1205        let result = join(
1206            &Value::DataFrame(left_df),
1207            &Value::DataFrame(right_df),
1208            &keys,
1209            &options,
1210        )
1211        .unwrap();
1212
1213        match result {
1214            Value::DataFrame(df) => {
1215                assert_eq!(df.height(), 3);
1216                // Check that columns are present
1217                assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
1218                assert!(df
1219                    .get_column_names()
1220                    .contains(&&PlSmallStr::from("dept_name")));
1221            }
1222            _ => panic!("Expected DataFrame"),
1223        }
1224    }
1225
1226    #[test]
1227    fn test_join_lazy_frames() {
1228        let left_df = create_left_dataframe();
1229        let right_df = create_right_dataframe();
1230
1231        let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
1232
1233        let options = JoinOptions {
1234            join_type: JoinType::Inner,
1235            ..Default::default()
1236        };
1237
1238        let result = join(
1239            &Value::LazyFrame(Box::new(left_df.lazy())),
1240            &Value::LazyFrame(Box::new(right_df.lazy())),
1241            &keys,
1242            &options,
1243        )
1244        .unwrap();
1245
1246        match result {
1247            Value::LazyFrame(_) => {
1248                // Just check it's a LazyFrame
1249            }
1250            _ => panic!("Expected LazyFrame"),
1251        }
1252    }
1253
1254    #[test]
1255    fn test_join_mixed_types() {
1256        let left_df = create_left_dataframe();
1257        let right_lf = create_right_dataframe().lazy();
1258
1259        let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
1260
1261        let options = JoinOptions {
1262            join_type: JoinType::Inner,
1263            ..Default::default()
1264        };
1265
1266        let result = join(
1267            &Value::DataFrame(left_df),
1268            &Value::LazyFrame(Box::new(right_lf)),
1269            &keys,
1270            &options,
1271        )
1272        .unwrap();
1273
1274        match result {
1275            Value::DataFrame(df) => {
1276                assert_eq!(df.height(), 3);
1277            }
1278            _ => panic!("Expected DataFrame"),
1279        }
1280    }
1281
1282    #[test]
1283    fn test_join_with_suffix() {
1284        let left_array = Value::Array(vec![Value::Object(HashMap::from([
1285            ("id".to_string(), Value::Int(1)),
1286            ("name".to_string(), Value::String("Alice".to_string())),
1287        ]))]);
1288
1289        let right_array = Value::Array(vec![Value::Object(HashMap::from([
1290            ("id".to_string(), Value::Int(1)),
1291            ("name".to_string(), Value::String("Bob".to_string())), // Conflicting column
1292        ]))]);
1293
1294        let keys = JoinKeys::on(vec!["id".to_string()]);
1295        let options = JoinOptions {
1296            join_type: JoinType::Inner,
1297            ..Default::default()
1298        };
1299
1300        let result = join(&left_array, &right_array, &keys, &options).unwrap();
1301
1302        match result {
1303            Value::Array(arr) => {
1304                assert_eq!(arr.len(), 1);
1305                if let Value::Object(obj) = &arr[0] {
1306                    assert!(obj.contains_key("name"));
1307                    assert!(obj.contains_key("name_right"));
1308                }
1309            }
1310            _ => panic!("Expected Array"),
1311        }
1312    }
1313
1314    #[test]
1315    fn test_join_empty_arrays() {
1316        let left_array = Value::Array(vec![]);
1317        let right_array = Value::Array(vec![]);
1318
1319        let keys = JoinKeys::on(vec!["id".to_string()]);
1320        let options = JoinOptions {
1321            join_type: JoinType::Inner,
1322            ..Default::default()
1323        };
1324
1325        let result = join(&left_array, &right_array, &keys, &options).unwrap();
1326
1327        match result {
1328            Value::Array(arr) => {
1329                assert_eq!(arr.len(), 0);
1330            }
1331            _ => panic!("Expected Array"),
1332        }
1333    }
1334
1335    #[test]
1336    fn test_join_invalid_keys() {
1337        let left_df = create_left_dataframe();
1338        let right_df = create_right_dataframe();
1339
1340        let keys = JoinKeys::on(vec!["nonexistent".to_string()]);
1341
1342        let result = join(
1343            &Value::DataFrame(left_df),
1344            &Value::DataFrame(right_df),
1345            &keys,
1346            &JoinOptions::default(),
1347        );
1348
1349        assert!(result.is_err()); // Should fail due to invalid column
1350    }
1351
1352    #[test]
1353    fn test_join_type_parsing() {
1354        assert_eq!(JoinType::from_str("inner").unwrap(), JoinType::Inner);
1355        assert_eq!(JoinType::from_str("left").unwrap(), JoinType::Left);
1356        assert_eq!(JoinType::from_str("right").unwrap(), JoinType::Right);
1357        assert_eq!(JoinType::from_str("outer").unwrap(), JoinType::Outer);
1358        assert_eq!(JoinType::from_str("full").unwrap(), JoinType::Outer);
1359        assert_eq!(JoinType::from_str("cross").unwrap(), JoinType::Cross);
1360        assert_eq!(JoinType::from_str("semi").unwrap(), JoinType::Semi);
1361        assert_eq!(JoinType::from_str("anti").unwrap(), JoinType::Anti);
1362        assert!(JoinType::from_str("invalid").is_err());
1363    }
1364
1365    #[test]
1366    fn test_join_validation() {
1367        assert_eq!(
1368            JoinValidation::None.to_polars(),
1369            polars::prelude::JoinValidation::ManyToOne
1370        );
1371        assert_eq!(
1372            JoinValidation::OneToMany.to_polars(),
1373            polars::prelude::JoinValidation::OneToMany
1374        );
1375        assert_eq!(
1376            JoinValidation::ManyToOne.to_polars(),
1377            polars::prelude::JoinValidation::ManyToOne
1378        );
1379        assert_eq!(
1380            JoinValidation::OneToOne.to_polars(),
1381            polars::prelude::JoinValidation::OneToOne
1382        );
1383    }
1384
1385    #[test]
1386    fn test_join_keys_methods() {
1387        let keys = JoinKeys::on(vec!["a".to_string(), "b".to_string()]);
1388        assert_eq!(keys.left_columns(), &["a", "b"]);
1389        assert_eq!(keys.right_columns(), &["a", "b"]);
1390
1391        let keys = JoinKeys::left_right(vec!["la".to_string()], vec!["ra".to_string()]);
1392        assert_eq!(keys.left_columns(), &["la"]);
1393        assert_eq!(keys.right_columns(), &["ra"]);
1394    }
1395
1396    #[test]
1397    fn test_join_options_default() {
1398        let options = JoinOptions::default();
1399        assert_eq!(options.join_type, JoinType::Inner);
1400        assert_eq!(options.suffix, "_right");
1401        assert_eq!(options.validate, JoinValidation::None);
1402        assert!(!options.sort);
1403    }
1404}