use std::collections::HashMap;
use polars::prelude::*;
use crate::error::{Error, Result};
use crate::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
Right,
Outer,
Cross,
Semi,
Anti,
}
impl JoinType {
pub fn to_polars(&self) -> Result<polars::prelude::JoinType> {
match self {
JoinType::Inner => Ok(polars::prelude::JoinType::Inner),
JoinType::Left => Ok(polars::prelude::JoinType::Left),
JoinType::Right => Err(Error::operation(
"Right join not supported in this Polars version, so cannot convert to Polars",
)),
JoinType::Outer => Ok(polars::prelude::JoinType::Full),
JoinType::Cross => Ok(polars::prelude::JoinType::Cross),
JoinType::Semi => Err(Error::operation(
"Semi join not supported in this Polars version, so cannot convert to Polars",
)),
JoinType::Anti => Err(Error::operation(
"Anti join not supported in this Polars version, so cannot convert to Polars",
)),
}
}
#[must_use]
pub fn as_str(&self) -> &'static str {
match self {
JoinType::Inner => "inner",
JoinType::Left => "left",
JoinType::Right => "right",
JoinType::Outer => "outer",
JoinType::Cross => "cross",
JoinType::Semi => "semi",
JoinType::Anti => "anti",
}
}
#[allow(clippy::should_implement_trait)]
pub fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"inner" => Ok(JoinType::Inner),
"left" | "left_outer" => Ok(JoinType::Left),
"right" | "right_outer" => Ok(JoinType::Right),
"outer" | "full" | "full_outer" => Ok(JoinType::Outer),
"cross" => Ok(JoinType::Cross),
"semi" => Ok(JoinType::Semi),
"anti" => Ok(JoinType::Anti),
_ => Err(Error::operation(format!("Unknown join type: {s}"))),
}
}
}
#[derive(Debug, Clone)]
pub struct JoinOptions {
pub join_type: JoinType,
pub suffix: String,
pub validate: JoinValidation,
pub sort: bool,
pub coalesce: polars::prelude::JoinCoalesce,
}
impl Default for JoinOptions {
fn default() -> Self {
Self {
join_type: JoinType::Inner,
suffix: "_right".to_string(),
validate: JoinValidation::None,
sort: false,
coalesce: polars::prelude::JoinCoalesce::JoinSpecific,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinValidation {
None,
OneToMany,
ManyToOne,
OneToOne,
}
impl JoinValidation {
#[must_use]
pub fn to_polars(&self) -> polars::prelude::JoinValidation {
match self {
JoinValidation::OneToMany => polars::prelude::JoinValidation::OneToMany,
JoinValidation::None | JoinValidation::ManyToOne => {
polars::prelude::JoinValidation::ManyToOne
}
JoinValidation::OneToOne => polars::prelude::JoinValidation::OneToOne,
}
}
}
#[derive(Debug, Clone)]
pub enum JoinKeys {
On(Vec<String>),
LeftRight {
left: Vec<String>,
right: Vec<String>,
},
}
impl JoinKeys {
#[must_use]
pub fn on(columns: Vec<String>) -> Self {
JoinKeys::On(columns)
}
#[must_use]
pub fn left_right(left: Vec<String>, right: Vec<String>) -> Self {
JoinKeys::LeftRight { left, right }
}
#[must_use]
pub fn left_columns(&self) -> &[String] {
match self {
JoinKeys::On(cols) => cols,
JoinKeys::LeftRight { left, .. } => left,
}
}
#[must_use]
pub fn right_columns(&self) -> &[String] {
match self {
JoinKeys::On(cols) => cols,
JoinKeys::LeftRight { right, .. } => right,
}
}
}
pub fn join(left: &Value, right: &Value, keys: &JoinKeys, options: &JoinOptions) -> Result<Value> {
match (left, right) {
(Value::DataFrame(left_df), Value::DataFrame(right_df)) => {
join_dataframes(left_df, right_df, keys, options)
}
(Value::LazyFrame(left_lf), Value::LazyFrame(right_lf)) => {
join_lazy_frames(left_lf, right_lf, keys, options)
}
(Value::DataFrame(left_df), Value::LazyFrame(right_lf)) => {
let right_df = right_lf.clone().collect().map_err(Error::from)?;
join_dataframes(left_df, &right_df, keys, options)
}
(Value::LazyFrame(left_lf), Value::DataFrame(right_df)) => {
let left_df = left_lf.clone().collect().map_err(Error::from)?;
join_dataframes(&left_df, right_df, keys, options)
}
(Value::Array(left_arr), Value::Array(right_arr)) => {
join_arrays(left_arr, right_arr, keys, options)
}
(left_val, right_val) => {
let left_df = left_val.to_dataframe()?;
let right_df = right_val.to_dataframe()?;
join_dataframes(&left_df, &right_df, keys, options)
}
}
}
fn join_dataframes(
left_df: &DataFrame,
right_df: &DataFrame,
keys: &JoinKeys,
options: &JoinOptions,
) -> Result<Value> {
let left_on: Vec<Expr> = keys.left_columns().iter().map(col).collect();
let right_on: Vec<Expr> = keys.right_columns().iter().map(col).collect();
let join_args = JoinArgs::new(options.join_type.to_polars()?);
let join_builder =
left_df
.clone()
.lazy()
.join(right_df.clone().lazy(), left_on, right_on, join_args);
let result_df = join_builder.collect().map_err(Error::from)?;
Ok(Value::DataFrame(result_df))
}
fn join_lazy_frames(
left_lf: &LazyFrame,
right_lf: &LazyFrame,
keys: &JoinKeys,
options: &JoinOptions,
) -> Result<Value> {
let left_on: Vec<Expr> = keys.left_columns().iter().map(col).collect();
let right_on: Vec<Expr> = keys.right_columns().iter().map(col).collect();
let join_args = JoinArgs::new(options.join_type.to_polars()?);
let mut join_builder = left_lf
.clone()
.join(right_lf.clone(), left_on, right_on, join_args);
if options.sort {
let sort_exprs: Vec<Expr> = keys.left_columns().iter().map(col).collect();
join_builder = join_builder.sort_by_exprs(sort_exprs, SortMultipleOptions::default());
}
Ok(Value::LazyFrame(Box::new(join_builder)))
}
fn join_arrays(
left_arr: &[Value],
right_arr: &[Value],
keys: &JoinKeys,
options: &JoinOptions,
) -> Result<Value> {
let mut result = match options.join_type {
JoinType::Inner => inner_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
JoinType::Left => left_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
JoinType::Right => right_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
JoinType::Outer => outer_join_arrays(left_arr, right_arr, keys, &options.suffix)?,
JoinType::Cross => cross_join_arrays(left_arr, right_arr, &options.suffix)?,
JoinType::Semi => semi_join_arrays(left_arr, right_arr, keys)?,
JoinType::Anti => anti_join_arrays(left_arr, right_arr, keys)?,
};
if options.sort {
if let Some(first_key) = keys.left_columns().first() {
result.sort_by(|a, b| {
let a_val = match a {
Value::Object(obj) => obj.get(first_key).unwrap_or(&Value::Null),
_ => &Value::Null,
};
let b_val = match b {
Value::Object(obj) => obj.get(first_key).unwrap_or(&Value::Null),
_ => &Value::Null,
};
compare_values_for_sorting(a_val, b_val)
});
}
}
Ok(Value::Array(result))
}
fn inner_join_arrays(
left_arr: &[Value],
right_arr: &[Value],
keys: &JoinKeys,
suffix: &str,
) -> Result<Vec<Value>> {
let mut result = Vec::new();
for left_item in left_arr {
if let Value::Object(left_obj) = left_item {
for right_item in right_arr {
if let Value::Object(right_obj) = right_item {
if objects_match_on_keys(left_obj, right_obj, keys)? {
let joined = merge_objects(
left_obj,
right_obj,
suffix,
false,
&std::collections::HashSet::new(),
)?;
result.push(Value::Object(joined));
}
}
}
}
}
Ok(result)
}
fn left_join_arrays(
left_arr: &[Value],
right_arr: &[Value],
keys: &JoinKeys,
suffix: &str,
) -> Result<Vec<Value>> {
let right_keys: std::collections::HashSet<String> = right_arr
.iter()
.filter_map(|v| {
if let Value::Object(o) = v {
Some(o.keys().cloned().collect::<Vec<_>>())
} else {
None
}
})
.flatten()
.collect();
let mut result = Vec::new();
for left_item in left_arr {
if let Value::Object(left_obj) = left_item {
let mut found_match = false;
for right_item in right_arr {
if let Value::Object(right_obj) = right_item {
if objects_match_on_keys(left_obj, right_obj, keys)? {
let joined = merge_objects(
left_obj,
right_obj,
suffix,
false,
&std::collections::HashSet::new(),
)?;
result.push(Value::Object(joined));
found_match = true;
}
}
}
if !found_match {
let joined = merge_objects(left_obj, &HashMap::new(), suffix, true, &right_keys)?;
result.push(Value::Object(joined));
}
}
}
Ok(result)
}
fn right_join_arrays(
left_arr: &[Value],
right_arr: &[Value],
keys: &JoinKeys,
suffix: &str,
) -> Result<Vec<Value>> {
let left_keys: std::collections::HashSet<String> = left_arr
.iter()
.filter_map(|v| {
if let Value::Object(o) = v {
Some(o.keys().cloned().collect::<Vec<_>>())
} else {
None
}
})
.flatten()
.collect();
let mut result = Vec::new();
for right_item in right_arr {
if let Value::Object(right_obj) = right_item {
let mut found_match = false;
for left_item in left_arr {
if let Value::Object(left_obj) = left_item {
if objects_match_on_keys(left_obj, right_obj, keys)? {
let joined = merge_objects(
left_obj,
right_obj,
suffix,
false,
&std::collections::HashSet::new(),
)?;
result.push(Value::Object(joined));
found_match = true;
}
}
}
if !found_match {
let joined = merge_objects(&HashMap::new(), right_obj, suffix, true, &left_keys)?;
result.push(Value::Object(joined));
}
}
}
Ok(result)
}
fn outer_join_arrays(
left_arr: &[Value],
right_arr: &[Value],
keys: &JoinKeys,
suffix: &str,
) -> Result<Vec<Value>> {
let left_keys: std::collections::HashSet<String> = left_arr
.iter()
.filter_map(|v| {
if let Value::Object(o) = v {
Some(o.keys().cloned().collect::<Vec<_>>())
} else {
None
}
})
.flatten()
.collect();
let right_keys: std::collections::HashSet<String> = right_arr
.iter()
.filter_map(|v| {
if let Value::Object(o) = v {
Some(o.keys().cloned().collect::<Vec<_>>())
} else {
None
}
})
.flatten()
.collect();
let mut result = Vec::new();
let mut right_matched = vec![false; right_arr.len()];
for left_item in left_arr {
if let Value::Object(left_obj) = left_item {
let mut found_match = false;
for (right_idx, right_item) in right_arr.iter().enumerate() {
if let Value::Object(right_obj) = right_item {
if objects_match_on_keys(left_obj, right_obj, keys)? {
let joined = merge_objects(
left_obj,
right_obj,
suffix,
false,
&std::collections::HashSet::new(),
)?;
result.push(Value::Object(joined));
right_matched[right_idx] = true;
found_match = true;
}
}
}
if !found_match {
let joined = merge_objects(left_obj, &HashMap::new(), suffix, true, &right_keys)?;
result.push(Value::Object(joined));
}
}
}
for (right_idx, right_item) in right_arr.iter().enumerate() {
if !right_matched[right_idx] {
if let Value::Object(right_obj) = right_item {
let joined = merge_objects(&HashMap::new(), right_obj, suffix, true, &left_keys)?;
result.push(Value::Object(joined));
}
}
}
Ok(result)
}
fn cross_join_arrays(left_arr: &[Value], right_arr: &[Value], suffix: &str) -> Result<Vec<Value>> {
let mut result = Vec::new();
for left_item in left_arr {
if let Value::Object(left_obj) = left_item {
for right_item in right_arr {
if let Value::Object(right_obj) = right_item {
let joined = merge_objects(
left_obj,
right_obj,
suffix,
false,
&std::collections::HashSet::new(),
)?;
result.push(Value::Object(joined));
}
}
}
}
Ok(result)
}
fn semi_join_arrays(
left_arr: &[Value],
right_arr: &[Value],
keys: &JoinKeys,
) -> Result<Vec<Value>> {
let mut result = Vec::new();
for left_item in left_arr {
if let Value::Object(left_obj) = left_item {
for right_item in right_arr {
if let Value::Object(right_obj) = right_item {
if objects_match_on_keys(left_obj, right_obj, keys)? {
result.push(left_item.clone());
break; }
}
}
}
}
Ok(result)
}
fn anti_join_arrays(
left_arr: &[Value],
right_arr: &[Value],
keys: &JoinKeys,
) -> Result<Vec<Value>> {
let mut result = Vec::new();
for left_item in left_arr {
if let Value::Object(left_obj) = left_item {
let mut found_match = false;
for right_item in right_arr {
if let Value::Object(right_obj) = right_item {
if objects_match_on_keys(left_obj, right_obj, keys)? {
found_match = true;
break;
}
}
}
if !found_match {
result.push(left_item.clone());
}
}
}
Ok(result)
}
fn objects_match_on_keys(
left_obj: &HashMap<String, Value>,
right_obj: &HashMap<String, Value>,
keys: &JoinKeys,
) -> Result<bool> {
let left_keys = keys.left_columns();
let right_keys = keys.right_columns();
if left_keys.len() != right_keys.len() {
return Err(Error::operation(
"Left and right join keys must have the same length",
));
}
for (left_key, right_key) in left_keys.iter().zip(right_keys.iter()) {
let left_val = left_obj.get(left_key).unwrap_or(&Value::Null);
let right_val = right_obj.get(right_key).unwrap_or(&Value::Null);
if !values_equal_for_join(left_val, right_val) {
return Ok(false);
}
}
Ok(true)
}
fn values_equal_for_join(left: &Value, right: &Value) -> bool {
match (left, right) {
(Value::Null, Value::Null) => true,
(Value::Bool(a), Value::Bool(b)) => a == b,
(Value::Int(a), Value::Int(b)) => a == b,
(Value::Float(a), Value::Float(b)) => (a - b).abs() < f64::EPSILON,
(Value::String(a), Value::String(b)) => a == b,
#[allow(clippy::cast_precision_loss)]
(Value::Int(a), Value::Float(b)) => (*a as f64 - b).abs() < f64::EPSILON,
#[allow(clippy::cast_precision_loss)]
(Value::Float(a), Value::Int(b)) => (a - *b as f64).abs() < f64::EPSILON,
_ => false,
}
}
#[allow(clippy::unnecessary_wraps)]
fn merge_objects(
left_obj: &HashMap<String, Value>,
right_obj: &HashMap<String, Value>,
suffix: &str,
fill_nulls: bool,
null_keys: &std::collections::HashSet<String>,
) -> Result<HashMap<String, Value>> {
let mut result = left_obj.clone();
for (right_key, right_val) in right_obj {
let key = if result.contains_key(right_key) {
format!("{right_key}{suffix}")
} else {
right_key.clone()
};
result.insert(key, right_val.clone());
}
if fill_nulls {
for key in null_keys {
if result.contains_key(key) {
let suffixed = format!("{key}{suffix}");
result.entry(suffixed).or_insert(Value::Null);
} else {
result.insert(key.clone(), Value::Null);
}
}
}
Ok(result)
}
fn compare_values_for_sorting(a: &Value, b: &Value) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (a, b) {
(Value::Null, Value::Null) => Ordering::Equal,
(Value::Null, _) => Ordering::Less,
(_, Value::Null) => Ordering::Greater,
(Value::Bool(a), Value::Bool(b)) => a.cmp(b),
(Value::Int(a), Value::Int(b)) => a.cmp(b),
(Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
(Value::String(a), Value::String(b)) => a.cmp(b),
#[allow(clippy::cast_precision_loss)]
(Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b).unwrap_or(Ordering::Equal),
#[allow(clippy::cast_precision_loss)]
(Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)).unwrap_or(Ordering::Equal),
_ => a.to_string().cmp(&b.to_string()),
}
}
pub fn inner_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
let options = JoinOptions {
join_type: JoinType::Inner,
..Default::default()
};
join(left, right, keys, &options)
}
pub fn left_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
let options = JoinOptions {
join_type: JoinType::Left,
..Default::default()
};
join(left, right, keys, &options)
}
pub fn right_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
let options = JoinOptions {
join_type: JoinType::Right,
..Default::default()
};
join(left, right, keys, &options)
}
pub fn outer_join(left: &Value, right: &Value, keys: &JoinKeys) -> Result<Value> {
let options = JoinOptions {
join_type: JoinType::Outer,
..Default::default()
};
join(left, right, keys, &options)
}
pub fn join_multiple(
dataframes: &[Value],
keys: &JoinKeys,
options: &JoinOptions,
) -> Result<Value> {
if dataframes.is_empty() {
return Err(Error::operation("No DataFrames provided for join"));
}
if dataframes.len() == 1 {
return Ok(dataframes[0].clone());
}
let mut result = dataframes[0].clone();
for (i, df) in dataframes.iter().enumerate().skip(1) {
let mut join_options = options.clone();
join_options.suffix = format!("_right_{i}");
result = join(&result, df, keys, &join_options)?;
}
Ok(result)
}
#[allow(clippy::used_underscore_binding)]
pub fn join_with_condition(
left: &Value,
right: &Value,
condition: Expr,
_join_type: JoinType,
) -> Result<Value> {
match (left, right) {
(Value::DataFrame(left_df), Value::DataFrame(right_df)) => {
let how = JoinType::Cross.to_polars()?;
let join_args = JoinArgs::new(how);
let cross_joined =
left_df
.clone()
.lazy()
.join(right_df.clone().lazy(), vec![], vec![], join_args);
let filtered = cross_joined.filter(condition);
let result_df = filtered.collect().map_err(Error::from)?;
Ok(Value::DataFrame(result_df))
}
(Value::LazyFrame(left_lf), Value::LazyFrame(right_lf)) => {
let how = JoinType::Cross.to_polars()?;
let join_args = JoinArgs::new(how);
let cross_joined = left_lf
.clone()
.join(*right_lf.clone(), vec![], vec![], join_args);
let filtered = cross_joined.filter(condition);
Ok(Value::LazyFrame(Box::new(filtered)))
}
_ => {
let left_df = left.to_dataframe()?;
let right_df = right.to_dataframe()?;
join_with_condition(
&Value::DataFrame(left_df),
&Value::DataFrame(right_df),
condition,
_join_type,
)
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
fn create_left_dataframe() -> DataFrame {
let id = Column::new("id".into(), &[1, 2, 3, 4]);
let name = Column::new("name".into(), &["Alice", "Bob", "Charlie", "Dave"]);
let dept_id = Column::new("dept_id".into(), &[10, 20, 10, 30]);
DataFrame::new(vec![id, name, dept_id]).unwrap()
}
fn create_right_dataframe() -> DataFrame {
let id = Column::new("id".into(), &[10, 20, 40]);
let dept_name = Column::new("dept_name".into(), &["Engineering", "Sales", "Marketing"]);
let budget = Column::new("budget".into(), &[100000, 50000, 75000]);
DataFrame::new(vec![id, dept_name, budget]).unwrap()
}
#[test]
fn test_inner_join() {
let left_df = create_left_dataframe();
let right_df = create_right_dataframe();
let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
let result = inner_join(
&Value::DataFrame(left_df),
&Value::DataFrame(right_df),
&keys,
)
.unwrap();
match result {
Value::DataFrame(df) => {
assert_eq!(df.shape().0, 3); assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
assert!(df
.get_column_names()
.contains(&&PlSmallStr::from("dept_name")));
}
_ => panic!("Expected DataFrame"),
}
}
#[test]
fn test_left_join() {
let left_df = create_left_dataframe();
let right_df = create_right_dataframe();
let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
let result = left_join(
&Value::DataFrame(left_df),
&Value::DataFrame(right_df),
&keys,
)
.unwrap();
match result {
Value::DataFrame(df) => {
assert_eq!(df.height(), 4); assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
assert!(df
.get_column_names()
.contains(&&PlSmallStr::from("dept_name")));
}
_ => panic!("Expected DataFrame"),
}
}
#[test]
#[ignore = "Right join not supported in this Polars version"]
fn test_right_join() {
let left_df = create_left_dataframe();
let right_df = create_right_dataframe();
let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
let result = right_join(
&Value::DataFrame(left_df),
&Value::DataFrame(right_df),
&keys,
);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Right join not supported"));
}
#[test]
fn test_array_join() {
let left_array = Value::Array(vec![
Value::Object(HashMap::from([
("id".to_string(), Value::Int(1)),
("name".to_string(), Value::String("Alice".to_string())),
])),
Value::Object(HashMap::from([
("id".to_string(), Value::Int(2)),
("name".to_string(), Value::String("Bob".to_string())),
])),
]);
let right_array = Value::Array(vec![
Value::Object(HashMap::from([
("id".to_string(), Value::Int(1)),
("age".to_string(), Value::Int(30)),
])),
Value::Object(HashMap::from([
("id".to_string(), Value::Int(3)),
("age".to_string(), Value::Int(25)),
])),
]);
let keys = JoinKeys::on(vec!["id".to_string()]);
let result = inner_join(&left_array, &right_array, &keys).unwrap();
match result {
Value::Array(arr) => {
assert_eq!(arr.len(), 1); if let Value::Object(obj) = &arr[0] {
assert_eq!(obj.get("name"), Some(&Value::String("Alice".to_string())));
assert_eq!(obj.get("age"), Some(&Value::Int(30)));
}
}
_ => panic!("Expected Array"),
}
}
#[test]
fn test_join_types() {
assert_eq!(JoinType::from_str("inner").unwrap(), JoinType::Inner);
assert_eq!(JoinType::from_str("left_outer").unwrap(), JoinType::Left);
assert_eq!(JoinType::from_str("full").unwrap(), JoinType::Outer);
assert_eq!(JoinType::from_str("cross").unwrap(), JoinType::Cross);
assert!(JoinType::from_str("invalid").is_err());
}
#[test]
fn test_join_keys() {
let keys = JoinKeys::on(vec!["id".to_string(), "name".to_string()]);
assert_eq!(keys.left_columns(), &["id", "name"]);
assert_eq!(keys.right_columns(), &["id", "name"]);
let keys = JoinKeys::left_right(vec!["left_id".to_string()], vec!["right_id".to_string()]);
assert_eq!(keys.left_columns(), &["left_id"]);
assert_eq!(keys.right_columns(), &["right_id"]);
}
#[test]
fn test_semi_join() {
let left_array = Value::Array(vec![
Value::Object(HashMap::from([
("id".to_string(), Value::Int(1)),
("name".to_string(), Value::String("Alice".to_string())),
])),
Value::Object(HashMap::from([
("id".to_string(), Value::Int(2)),
("name".to_string(), Value::String("Bob".to_string())),
])),
Value::Object(HashMap::from([
("id".to_string(), Value::Int(3)),
("name".to_string(), Value::String("Charlie".to_string())),
])),
]);
let right_array = Value::Array(vec![
Value::Object(HashMap::from([("id".to_string(), Value::Int(1))])),
Value::Object(HashMap::from([("id".to_string(), Value::Int(3))])),
]);
let keys = JoinKeys::on(vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Semi,
..Default::default()
};
let result = join(&left_array, &right_array, &keys, &options).unwrap();
match result {
Value::Array(arr) => {
assert_eq!(arr.len(), 2); if let Value::Object(obj) = &arr[0] {
assert!(obj.contains_key("name"));
assert!(!obj.contains_key("age")); }
}
_ => panic!("Expected Array"),
}
}
#[test]
fn test_anti_join() {
let left_array = Value::Array(vec![
Value::Object(HashMap::from([
("id".to_string(), Value::Int(1)),
("name".to_string(), Value::String("Alice".to_string())),
])),
Value::Object(HashMap::from([
("id".to_string(), Value::Int(2)),
("name".to_string(), Value::String("Bob".to_string())),
])),
]);
let right_array = Value::Array(vec![Value::Object(HashMap::from([(
"id".to_string(),
Value::Int(1),
)]))]);
let keys = JoinKeys::on(vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Anti,
..Default::default()
};
let result = join(&left_array, &right_array, &keys, &options).unwrap();
match result {
Value::Array(arr) => {
assert_eq!(arr.len(), 1); if let Value::Object(obj) = &arr[0] {
assert_eq!(obj.get("name"), Some(&Value::String("Bob".to_string())));
}
}
_ => panic!("Expected Array"),
}
}
#[test]
fn test_cross_join() {
let left_array = Value::Array(vec![
Value::Object(HashMap::from([(
"name".to_string(),
Value::String("Alice".to_string()),
)])),
Value::Object(HashMap::from([(
"name".to_string(),
Value::String("Bob".to_string()),
)])),
]);
let right_array = Value::Array(vec![
Value::Object(HashMap::from([(
"color".to_string(),
Value::String("Red".to_string()),
)])),
Value::Object(HashMap::from([(
"color".to_string(),
Value::String("Blue".to_string()),
)])),
]);
let keys = JoinKeys::on(vec![]); let options = JoinOptions {
join_type: JoinType::Cross,
..Default::default()
};
let result = join(&left_array, &right_array, &keys, &options).unwrap();
match result {
Value::Array(arr) => {
assert_eq!(arr.len(), 4); for item in &arr {
if let Value::Object(obj) = item {
assert!(obj.contains_key("name"));
assert!(obj.contains_key("color"));
}
}
}
_ => panic!("Expected Array"),
}
}
#[test]
fn test_join_multiple() {
let df1 = DataFrame::new(vec![
Column::new("id".into(), &[1, 2]),
Column::new("name".into(), &["Alice", "Bob"]),
])
.unwrap();
let df2 = DataFrame::new(vec![
Column::new("id".into(), &[1, 2]),
Column::new("age".into(), &[30, 25]),
])
.unwrap();
let df3 = DataFrame::new(vec![
Column::new("id".into(), &[1, 2]),
Column::new("city".into(), &["NYC", "LA"]),
])
.unwrap();
let dataframes = vec![
Value::DataFrame(df1),
Value::DataFrame(df2),
Value::DataFrame(df3),
];
let keys = JoinKeys::on(vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Inner,
..Default::default()
};
let result = join_multiple(&dataframes, &keys, &options).unwrap();
match result {
Value::DataFrame(df) => {
assert_eq!(df.height(), 2);
assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
assert!(df.get_column_names().contains(&&PlSmallStr::from("age")));
assert!(df.get_column_names().contains(&&PlSmallStr::from("city")));
}
_ => panic!("Expected DataFrame"),
}
}
#[test]
fn test_join_with_options() {
let left_df = create_left_dataframe();
let right_df = create_right_dataframe();
let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Inner,
suffix: "_right".to_string(),
validate: JoinValidation::None,
sort: false,
coalesce: polars::prelude::JoinCoalesce::JoinSpecific,
};
let result = join(
&Value::DataFrame(left_df),
&Value::DataFrame(right_df),
&keys,
&options,
)
.unwrap();
match result {
Value::DataFrame(df) => {
assert_eq!(df.height(), 3);
assert!(df.get_column_names().contains(&&PlSmallStr::from("name")));
assert!(df
.get_column_names()
.contains(&&PlSmallStr::from("dept_name")));
}
_ => panic!("Expected DataFrame"),
}
}
#[test]
fn test_join_lazy_frames() {
let left_df = create_left_dataframe();
let right_df = create_right_dataframe();
let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Inner,
..Default::default()
};
let result = join(
&Value::LazyFrame(Box::new(left_df.lazy())),
&Value::LazyFrame(Box::new(right_df.lazy())),
&keys,
&options,
)
.unwrap();
match result {
Value::LazyFrame(_) => {
}
_ => panic!("Expected LazyFrame"),
}
}
#[test]
fn test_join_mixed_types() {
let left_df = create_left_dataframe();
let right_lf = create_right_dataframe().lazy();
let keys = JoinKeys::left_right(vec!["dept_id".to_string()], vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Inner,
..Default::default()
};
let result = join(
&Value::DataFrame(left_df),
&Value::LazyFrame(Box::new(right_lf)),
&keys,
&options,
)
.unwrap();
match result {
Value::DataFrame(df) => {
assert_eq!(df.height(), 3);
}
_ => panic!("Expected DataFrame"),
}
}
#[test]
fn test_join_with_suffix() {
let left_array = Value::Array(vec![Value::Object(HashMap::from([
("id".to_string(), Value::Int(1)),
("name".to_string(), Value::String("Alice".to_string())),
]))]);
let right_array = Value::Array(vec![Value::Object(HashMap::from([
("id".to_string(), Value::Int(1)),
("name".to_string(), Value::String("Bob".to_string())), ]))]);
let keys = JoinKeys::on(vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Inner,
..Default::default()
};
let result = join(&left_array, &right_array, &keys, &options).unwrap();
match result {
Value::Array(arr) => {
assert_eq!(arr.len(), 1);
if let Value::Object(obj) = &arr[0] {
assert!(obj.contains_key("name"));
assert!(obj.contains_key("name_right"));
}
}
_ => panic!("Expected Array"),
}
}
#[test]
fn test_join_empty_arrays() {
let left_array = Value::Array(vec![]);
let right_array = Value::Array(vec![]);
let keys = JoinKeys::on(vec!["id".to_string()]);
let options = JoinOptions {
join_type: JoinType::Inner,
..Default::default()
};
let result = join(&left_array, &right_array, &keys, &options).unwrap();
match result {
Value::Array(arr) => {
assert_eq!(arr.len(), 0);
}
_ => panic!("Expected Array"),
}
}
#[test]
fn test_join_invalid_keys() {
let left_df = create_left_dataframe();
let right_df = create_right_dataframe();
let keys = JoinKeys::on(vec!["nonexistent".to_string()]);
let result = join(
&Value::DataFrame(left_df),
&Value::DataFrame(right_df),
&keys,
&JoinOptions::default(),
);
assert!(result.is_err()); }
#[test]
fn test_join_type_parsing() {
assert_eq!(JoinType::from_str("inner").unwrap(), JoinType::Inner);
assert_eq!(JoinType::from_str("left").unwrap(), JoinType::Left);
assert_eq!(JoinType::from_str("right").unwrap(), JoinType::Right);
assert_eq!(JoinType::from_str("outer").unwrap(), JoinType::Outer);
assert_eq!(JoinType::from_str("full").unwrap(), JoinType::Outer);
assert_eq!(JoinType::from_str("cross").unwrap(), JoinType::Cross);
assert_eq!(JoinType::from_str("semi").unwrap(), JoinType::Semi);
assert_eq!(JoinType::from_str("anti").unwrap(), JoinType::Anti);
assert!(JoinType::from_str("invalid").is_err());
}
#[test]
fn test_join_validation() {
assert_eq!(
JoinValidation::None.to_polars(),
polars::prelude::JoinValidation::ManyToOne
);
assert_eq!(
JoinValidation::OneToMany.to_polars(),
polars::prelude::JoinValidation::OneToMany
);
assert_eq!(
JoinValidation::ManyToOne.to_polars(),
polars::prelude::JoinValidation::ManyToOne
);
assert_eq!(
JoinValidation::OneToOne.to_polars(),
polars::prelude::JoinValidation::OneToOne
);
}
#[test]
fn test_join_keys_methods() {
let keys = JoinKeys::on(vec!["a".to_string(), "b".to_string()]);
assert_eq!(keys.left_columns(), &["a", "b"]);
assert_eq!(keys.right_columns(), &["a", "b"]);
let keys = JoinKeys::left_right(vec!["la".to_string()], vec!["ra".to_string()]);
assert_eq!(keys.left_columns(), &["la"]);
assert_eq!(keys.right_columns(), &["ra"]);
}
#[test]
fn test_join_options_default() {
let options = JoinOptions::default();
assert_eq!(options.join_type, JoinType::Inner);
assert_eq!(options.suffix, "_right");
assert_eq!(options.validate, JoinValidation::None);
assert!(!options.sort);
}
}