use crate::error::{DataFusionError, Result};
use arrow::datatypes::{Field, Schema};
use std::collections::HashSet;
#[derive(Clone, Copy, Debug)]
pub enum JoinType {
Inner,
Left,
Right,
}
pub type JoinOn = [(String, String)];
pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result<()> {
let left: HashSet<String> = left.fields().iter().map(|f| f.name().clone()).collect();
let right: HashSet<String> =
right.fields().iter().map(|f| f.name().clone()).collect();
check_join_set_is_valid(&left, &right, on)
}
fn check_join_set_is_valid(
left: &HashSet<String>,
right: &HashSet<String>,
on: &JoinOn,
) -> Result<()> {
if on.is_empty() {
return Err(DataFusionError::Plan(
"The 'on' clause of a join cannot be empty".to_string(),
));
}
let on_left = &on.iter().map(|on| on.0.to_string()).collect::<HashSet<_>>();
let left_missing = on_left.difference(left).collect::<HashSet<_>>();
let on_right = &on.iter().map(|on| on.1.to_string()).collect::<HashSet<_>>();
let right_missing = on_right.difference(right).collect::<HashSet<_>>();
if !left_missing.is_empty() | !right_missing.is_empty() {
return Err(DataFusionError::Plan(format!(
"The left or right side of the join does not have all columns on \"on\": \nMissing on the left: {:?}\nMissing on the right: {:?}",
left_missing,
right_missing,
)));
};
let remaining = right
.difference(on_right)
.cloned()
.collect::<HashSet<String>>();
let collisions = left.intersection(&remaining).collect::<HashSet<_>>();
if !collisions.is_empty() {
return Err(DataFusionError::Plan(format!(
"The left schema and the right schema have the following columns with the same name without being on the ON statement: {:?}. Consider aliasing them.",
collisions,
)));
};
Ok(())
}
pub fn build_join_schema(
left: &Schema,
right: &Schema,
on: &JoinOn,
join_type: &JoinType,
) -> Schema {
let fields: Vec<Field> = match join_type {
JoinType::Inner | JoinType::Left => {
let duplicate_keys = &on
.iter()
.filter(|(l, r)| l == r)
.map(|on| on.1.to_string())
.collect::<HashSet<_>>();
let left_fields = left.fields().iter();
let right_fields = right
.fields()
.iter()
.filter(|f| !duplicate_keys.contains(f.name()));
left_fields.chain(right_fields).cloned().collect()
}
JoinType::Right => {
let duplicate_keys = &on
.iter()
.filter(|(l, r)| l == r)
.map(|on| on.1.to_string())
.collect::<HashSet<_>>();
let left_fields = left
.fields()
.iter()
.filter(|f| !duplicate_keys.contains(f.name()));
let right_fields = right.fields().iter();
left_fields.chain(right_fields).cloned().collect()
}
};
Schema::new(fields)
}
#[cfg(test)]
mod tests {
use super::*;
fn check(left: &[&str], right: &[&str], on: &[(&str, &str)]) -> Result<()> {
let left = left.iter().map(|x| x.to_string()).collect::<HashSet<_>>();
let right = right.iter().map(|x| x.to_string()).collect::<HashSet<_>>();
let on: Vec<_> = on
.iter()
.map(|(l, r)| (l.to_string(), r.to_string()))
.collect();
check_join_set_is_valid(&left, &right, &on)
}
#[test]
fn check_valid() -> Result<()> {
let left = vec!["a", "b1"];
let right = vec!["a", "b2"];
let on = &[("a", "a")];
check(&left, &right, on)?;
Ok(())
}
#[test]
fn check_not_in_right() {
let left = vec!["a", "b"];
let right = vec!["b"];
let on = &[("a", "a")];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_not_in_left() {
let left = vec!["b"];
let right = vec!["a"];
let on = &[("a", "a")];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_collision() {
let left = vec!["a", "c"];
let right = vec!["a", "b"];
let on = &[("a", "b")];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_in_right() {
let left = vec!["a", "c"];
let right = vec!["b"];
let on = &[("a", "b")];
assert!(check(&left, &right, on).is_ok());
}
}