use crate::core::error::{Error, Result};
use crate::dataframe::base::DataFrame;
use crate::series::Series;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JoinType {
Inner,
Left,
Right,
Outer,
}
pub fn merge(
left: &DataFrame,
right: &DataFrame,
on: &str,
how: JoinType,
suffixes: (&str, &str),
) -> Result<DataFrame> {
if !left.contains_column(on) {
return Err(Error::InvalidValue(format!(
"Join column '{}' not found in left DataFrame",
on
)));
}
if !right.contains_column(on) {
return Err(Error::InvalidValue(format!(
"Join column '{}' not found in right DataFrame",
on
)));
}
let left_join_values = if let Ok(vals) = left.get_column_numeric_values(on) {
vals.into_iter()
.map(|v| v.to_bits().to_string())
.collect::<Vec<_>>()
} else if let Ok(vals) = left.get_column_string_values(on) {
vals
} else {
return Err(Error::InvalidValue(format!(
"Cannot read join column '{}' from left DataFrame",
on
)));
};
let right_join_values = if let Ok(vals) = right.get_column_numeric_values(on) {
vals.into_iter()
.map(|v| v.to_bits().to_string())
.collect::<Vec<_>>()
} else if let Ok(vals) = right.get_column_string_values(on) {
vals
} else {
return Err(Error::InvalidValue(format!(
"Cannot read join column '{}' from right DataFrame",
on
)));
};
let mut right_index: HashMap<String, Vec<usize>> = HashMap::new();
for (i, val) in right_join_values.iter().enumerate() {
right_index
.entry(val.clone())
.or_insert_with(Vec::new)
.push(i);
}
let mut matched_pairs: Vec<(Option<usize>, Option<usize>)> = Vec::new();
let mut left_matched = vec![false; left_join_values.len()];
let mut right_matched = vec![false; right_join_values.len()];
for (left_idx, left_val) in left_join_values.iter().enumerate() {
if let Some(right_indices) = right_index.get(left_val) {
for &right_idx in right_indices {
matched_pairs.push((Some(left_idx), Some(right_idx)));
left_matched[left_idx] = true;
right_matched[right_idx] = true;
}
} else if matches!(how, JoinType::Left | JoinType::Outer) {
matched_pairs.push((Some(left_idx), None));
left_matched[left_idx] = true;
}
}
if matches!(how, JoinType::Right | JoinType::Outer) {
for (right_idx, matched) in right_matched.iter().enumerate() {
if !matched {
matched_pairs.push((None, Some(right_idx)));
}
}
}
let mut result = DataFrame::new();
let left_cols = left.column_names();
let right_cols = right.column_names();
let mut overlapping: Vec<String> = Vec::new();
for col in &right_cols {
if col != on && left_cols.contains(col) {
overlapping.push(col.clone());
}
}
for col_name in &left_cols {
if let Ok(values) = left.get_column_numeric_values(col_name) {
let merged: Vec<f64> = if col_name == on {
let right_values = right
.get_column_numeric_values(on)
.expect("test should succeed");
matched_pairs
.iter()
.map(|(left_idx, right_idx)| {
left_idx
.map(|i| values.get(i).copied().unwrap_or(f64::NAN))
.or_else(|| {
right_idx.map(|i| right_values.get(i).copied().unwrap_or(f64::NAN))
})
.unwrap_or(f64::NAN)
})
.collect()
} else {
matched_pairs
.iter()
.map(|(left_idx, _)| {
left_idx
.map(|i| values.get(i).copied().unwrap_or(f64::NAN))
.unwrap_or(f64::NAN)
})
.collect()
};
result.add_column(
col_name.clone(),
Series::new(merged, Some(col_name.clone()))?,
)?;
} else if let Ok(values) = left.get_column_string_values(col_name) {
let merged: Vec<String> = if col_name == on {
let right_values = right
.get_column_string_values(on)
.expect("test should succeed");
matched_pairs
.iter()
.map(|(left_idx, right_idx)| {
left_idx
.and_then(|i| values.get(i).cloned())
.or_else(|| right_idx.and_then(|i| right_values.get(i).cloned()))
.unwrap_or_else(|| "".to_string())
})
.collect()
} else {
matched_pairs
.iter()
.map(|(left_idx, _)| {
left_idx
.and_then(|i| values.get(i).cloned())
.unwrap_or_else(|| "".to_string())
})
.collect()
};
result.add_column(
col_name.clone(),
Series::new(merged, Some(col_name.clone()))?,
)?;
}
}
for col_name in &right_cols {
if col_name == on {
continue;
}
let final_name = if overlapping.contains(col_name) {
format!("{}{}", col_name, suffixes.1)
} else {
col_name.clone()
};
if let Ok(values) = right.get_column_numeric_values(col_name) {
let merged: Vec<f64> = matched_pairs
.iter()
.map(|(_, right_idx)| {
right_idx
.map(|i| values.get(i).copied().unwrap_or(f64::NAN))
.unwrap_or(f64::NAN)
})
.collect();
result.add_column(
final_name.clone(),
Series::new(merged, Some(final_name.clone()))?,
)?;
} else if let Ok(values) = right.get_column_string_values(col_name) {
let merged: Vec<String> = matched_pairs
.iter()
.map(|(_, right_idx)| {
right_idx
.and_then(|i| values.get(i).cloned())
.unwrap_or_else(|| "".to_string())
})
.collect();
result.add_column(
final_name.clone(),
Series::new(merged, Some(final_name.clone()))?,
)?;
}
}
if !overlapping.is_empty() {
let mut rename_map = HashMap::new();
for col in &overlapping {
rename_map.insert(col.clone(), format!("{}{}", col, suffixes.0));
}
let mut renamed_df = DataFrame::new();
for col_name in result.column_names() {
let new_name = rename_map.get(&col_name).unwrap_or(&col_name).clone();
if let Ok(vals) = result.get_column_numeric_values(&col_name) {
renamed_df
.add_column(new_name.clone(), Series::new(vals, Some(new_name.clone()))?)?;
} else if let Ok(vals) = result.get_column_string_values(&col_name) {
renamed_df
.add_column(new_name.clone(), Series::new(vals, Some(new_name.clone()))?)?;
}
}
result = renamed_df;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_left_df() -> DataFrame {
let mut df = DataFrame::new();
df.add_column(
"key".to_string(),
Series::new(
vec![
"A".to_string(),
"B".to_string(),
"C".to_string(),
"D".to_string(),
],
Some("key".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
df.add_column(
"value1".to_string(),
Series::new(vec![1.0, 2.0, 3.0, 4.0], Some("value1".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
df
}
fn create_right_df() -> DataFrame {
let mut df = DataFrame::new();
df.add_column(
"key".to_string(),
Series::new(
vec![
"B".to_string(),
"C".to_string(),
"D".to_string(),
"E".to_string(),
],
Some("key".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
df.add_column(
"value2".to_string(),
Series::new(vec![20.0, 30.0, 40.0, 50.0], Some("value2".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
df
}
#[test]
fn test_merge_inner() {
let left = create_left_df();
let right = create_right_df();
let result = merge(&left, &right, "key", JoinType::Inner, ("_x", "_y"))
.expect("test should succeed");
assert_eq!(result.row_count(), 3);
let keys = result
.get_column_string_values("key")
.expect("test should succeed");
assert_eq!(keys, vec!["B", "C", "D"]);
let val1 = result
.get_column_numeric_values("value1")
.expect("test should succeed");
assert_eq!(val1, vec![2.0, 3.0, 4.0]);
let val2 = result
.get_column_numeric_values("value2")
.expect("test should succeed");
assert_eq!(val2, vec![20.0, 30.0, 40.0]);
}
#[test]
fn test_merge_left() {
let left = create_left_df();
let right = create_right_df();
let result =
merge(&left, &right, "key", JoinType::Left, ("_x", "_y")).expect("test should succeed");
assert_eq!(result.row_count(), 4);
let keys = result
.get_column_string_values("key")
.expect("test should succeed");
assert_eq!(keys, vec!["A", "B", "C", "D"]);
let val1 = result
.get_column_numeric_values("value1")
.expect("test should succeed");
assert_eq!(val1, vec![1.0, 2.0, 3.0, 4.0]);
let val2 = result
.get_column_numeric_values("value2")
.expect("test should succeed");
assert!(val2[0].is_nan()); assert_eq!(val2[1], 20.0);
assert_eq!(val2[2], 30.0);
assert_eq!(val2[3], 40.0);
}
#[test]
fn test_merge_right() {
let left = create_left_df();
let right = create_right_df();
let result = merge(&left, &right, "key", JoinType::Right, ("_x", "_y"))
.expect("test should succeed");
assert_eq!(result.row_count(), 4);
let keys = result
.get_column_string_values("key")
.expect("test should succeed");
assert_eq!(keys, vec!["B", "C", "D", "E"]);
let val1 = result
.get_column_numeric_values("value1")
.expect("test should succeed");
assert_eq!(val1[0], 2.0);
assert_eq!(val1[1], 3.0);
assert_eq!(val1[2], 4.0);
assert!(val1[3].is_nan());
let val2 = result
.get_column_numeric_values("value2")
.expect("test should succeed");
assert_eq!(val2, vec![20.0, 30.0, 40.0, 50.0]);
}
#[test]
fn test_merge_outer() {
let left = create_left_df();
let right = create_right_df();
let result = merge(&left, &right, "key", JoinType::Outer, ("_x", "_y"))
.expect("test should succeed");
assert_eq!(result.row_count(), 5);
let keys = result
.get_column_string_values("key")
.expect("test should succeed");
assert_eq!(keys, vec!["A", "B", "C", "D", "E"]);
let val1 = result
.get_column_numeric_values("value1")
.expect("test should succeed");
assert_eq!(val1[0], 1.0); assert_eq!(val1[1], 2.0); assert_eq!(val1[2], 3.0); assert_eq!(val1[3], 4.0); assert!(val1[4].is_nan());
let val2 = result
.get_column_numeric_values("value2")
.expect("test should succeed");
assert!(val2[0].is_nan()); assert_eq!(val2[1], 20.0); assert_eq!(val2[2], 30.0); assert_eq!(val2[3], 40.0); assert_eq!(val2[4], 50.0); }
#[test]
fn test_merge_with_overlapping_columns() {
let mut left = DataFrame::new();
left.add_column(
"key".to_string(),
Series::new(
vec!["A".to_string(), "B".to_string()],
Some("key".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
left.add_column(
"value".to_string(),
Series::new(vec![1.0, 2.0], Some("value".to_string())).expect("test should succeed"),
)
.expect("test should succeed");
let mut right = DataFrame::new();
right
.add_column(
"key".to_string(),
Series::new(
vec!["A".to_string(), "B".to_string()],
Some("key".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
right
.add_column(
"value".to_string(),
Series::new(vec![10.0, 20.0], Some("value".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
let result = merge(&left, &right, "key", JoinType::Inner, ("_left", "_right"))
.expect("test should succeed");
assert!(result.contains_column("value_left"));
assert!(result.contains_column("value_right"));
let val_left = result
.get_column_numeric_values("value_left")
.expect("test should succeed");
assert_eq!(val_left, vec![1.0, 2.0]);
let val_right = result
.get_column_numeric_values("value_right")
.expect("test should succeed");
assert_eq!(val_right, vec![10.0, 20.0]);
}
#[test]
fn test_merge_numeric_key() {
let mut left = DataFrame::new();
left.add_column(
"id".to_string(),
Series::new(vec![1.0, 2.0, 3.0], Some("id".to_string())).expect("test should succeed"),
)
.expect("test should succeed");
left.add_column(
"name".to_string(),
Series::new(
vec![
"Alice".to_string(),
"Bob".to_string(),
"Charlie".to_string(),
],
Some("name".to_string()),
)
.expect("test should succeed"),
)
.expect("test should succeed");
let mut right = DataFrame::new();
right
.add_column(
"id".to_string(),
Series::new(vec![2.0, 3.0, 4.0], Some("id".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
right
.add_column(
"score".to_string(),
Series::new(vec![85.0, 90.0, 95.0], Some("score".to_string()))
.expect("test should succeed"),
)
.expect("test should succeed");
let result =
merge(&left, &right, "id", JoinType::Inner, ("_x", "_y")).expect("test should succeed");
assert_eq!(result.row_count(), 2);
let names = result
.get_column_string_values("name")
.expect("test should succeed");
assert_eq!(names, vec!["Bob", "Charlie"]);
let scores = result
.get_column_numeric_values("score")
.expect("test should succeed");
assert_eq!(scores, vec![85.0, 90.0]);
}
}