use datafusion_expr::Expr;
use indexmap::{Equivalent, IndexSet};
#[derive(Debug)]
pub struct JoinKeySet {
inner: IndexSet<(Expr, Expr)>,
}
impl JoinKeySet {
pub fn new() -> Self {
Self {
inner: IndexSet::new(),
}
}
pub fn contains(&self, left: &Expr, right: &Expr) -> bool {
self.inner.contains(&ExprPair::new(left, right))
|| self.inner.contains(&ExprPair::new(right, left))
}
pub fn insert(&mut self, left: &Expr, right: &Expr) -> bool {
if self.contains(left, right) {
false
} else {
self.inner.insert((left.clone(), right.clone()));
true
}
}
pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool {
if self.contains(&left, &right) {
false
} else {
self.inner.insert((left, right));
true
}
}
pub fn insert_all<'a>(
&mut self,
iter: impl IntoIterator<Item = &'a (Expr, Expr)>,
) -> bool {
let mut inserted = false;
for (left, right) in iter.into_iter() {
inserted |= self.insert(left, right);
}
inserted
}
pub fn insert_all_owned(
&mut self,
iter: impl IntoIterator<Item = (Expr, Expr)>,
) -> bool {
let mut inserted = false;
for (left, right) in iter.into_iter() {
inserted |= self.insert_owned(left, right);
}
inserted
}
pub fn insert_intersection(&mut self, s1: &JoinKeySet, s2: &JoinKeySet) {
for (left, right) in s1.inner.iter() {
if s2.contains(left, right) {
self.insert(left, right);
}
}
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[cfg(test)]
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn iter(&self) -> impl Iterator<Item = (&Expr, &Expr)> {
self.inner.iter().map(|(l, r)| (l, r))
}
}
#[derive(Debug, Eq, PartialEq, Hash)]
struct ExprPair<'a>(&'a Expr, &'a Expr);
impl<'a> ExprPair<'a> {
fn new(left: &'a Expr, right: &'a Expr) -> Self {
Self(left, right)
}
}
impl Equivalent<(Expr, Expr)> for ExprPair<'_> {
fn equivalent(&self, other: &(Expr, Expr)) -> bool {
self.0 == &other.0 && self.1 == &other.1
}
}
#[cfg(test)]
mod test {
use crate::join_key_set::JoinKeySet;
use datafusion_expr::{col, Expr};
#[test]
fn test_insert() {
let mut set = JoinKeySet::new();
assert!(set.is_empty());
assert!(set.insert(&col("a"), &col("b")));
assert!(!set.is_empty());
assert!(!set.insert(&col("a"), &col("b")));
assert_eq!(set.len(), 1);
assert!(!set.insert(&col("b"), &col("a")));
assert_eq!(set.len(), 1);
assert!(set.insert(&col("a"), &col("c")));
assert_eq!(set.len(), 2);
}
#[test]
fn test_insert_owned() {
let mut set = JoinKeySet::new();
assert!(set.insert_owned(col("a"), col("b")));
assert!(set.contains(&col("a"), &col("b")));
assert!(set.contains(&col("b"), &col("a")));
assert!(!set.contains(&col("a"), &col("c")));
}
#[test]
fn test_contains() {
let mut set = JoinKeySet::new();
assert!(set.insert(&col("a"), &col("b")));
assert!(set.contains(&col("a"), &col("b")));
assert!(set.contains(&col("b"), &col("a")));
assert!(!set.contains(&col("a"), &col("c")));
assert!(set.insert(&col("a"), &col("c")));
assert!(set.contains(&col("a"), &col("c")));
assert!(set.contains(&col("c"), &col("a")));
}
#[test]
fn test_iterator() {
let mut set = JoinKeySet::new();
set.insert(&col("c"), &col("a"));
set.insert(&col("b"), &col("c"));
set.insert(&col("a"), &col("c"));
assert_contents(&set, vec![(&col("c"), &col("a")), (&col("b"), &col("c"))]);
}
#[test]
fn test_insert_intersection() {
let mut set1 = JoinKeySet::new();
set1.insert(&col("a"), &col("b"));
set1.insert(&col("b"), &col("c"));
set1.insert(&col("c"), &col("d"));
let mut set2 = JoinKeySet::new();
set2.insert(&col("a"), &col("a"));
set2.insert(&col("b"), &col("b"));
set2.insert(&col("b"), &col("c"));
set2.insert(&col("d"), &col("c"));
let mut set = JoinKeySet::new();
set.insert(&col("x"), &col("y"));
set.insert_intersection(&set1, &set2);
assert_contents(
&set,
vec![
(&col("x"), &col("y")),
(&col("b"), &col("c")),
(&col("c"), &col("d")),
],
);
}
fn assert_contents(set: &JoinKeySet, expected: Vec<(&Expr, &Expr)>) {
let contents: Vec<_> = set.iter().collect();
assert_eq!(contents, expected);
}
#[test]
fn test_insert_all() {
let mut set = JoinKeySet::new();
set.insert_all(vec![
&(col("a"), col("b")),
&(col("b"), col("c")),
&(col("b"), col("a")),
]);
assert_eq!(set.len(), 2);
assert!(set.contains(&col("a"), &col("b")));
assert!(set.contains(&col("b"), &col("c")));
assert!(set.contains(&col("b"), &col("a")));
assert!(!set.contains(&col("a"), &col("c")));
}
#[test]
fn test_insert_all_owned() {
let mut set = JoinKeySet::new();
set.insert_all_owned(vec![
(col("a"), col("b")),
(col("b"), col("c")),
(col("b"), col("a")),
]);
assert_eq!(set.len(), 2);
assert!(set.contains(&col("a"), &col("b")));
assert!(set.contains(&col("b"), &col("c")));
assert!(set.contains(&col("b"), &col("a")));
assert!(!set.contains(&col("a"), &col("c")));
}
}