use crate::{CompareTypes, TypeRelation, TypeSet};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
#[allow(unused_imports)]
use crate::SyntaxShape;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Ord, PartialOrd)]
#[serde(transparent)]
pub struct CollectionColumns<T> {
fields: Box<[(String, T)]>,
}
impl<T> CollectionColumns<T> {
pub fn map<U>(&self, f: impl Fn(&T) -> U) -> CollectionColumns<U> {
self.iter().map(|(k, v)| (k.clone(), f(v))).collect()
}
pub fn iter(&self) -> impl Iterator<Item = &(String, T)> {
self.into_iter()
}
pub fn is_empty(&self) -> bool {
self.fields.is_empty()
}
pub fn len(&self) -> usize {
self.fields.len()
}
}
impl<T> CollectionColumns<T> {
pub fn new(fields: Box<[(String, T)]>) -> Self {
Self { fields }
}
pub fn get<'s>(&'s self, key: &'_ str) -> Option<&'s T> {
self.iter()
.find(|(name, _)| name == key)
.map(|(_, val)| val)
}
}
impl<T> IntoIterator for CollectionColumns<T> {
type Item = (String, T);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.fields.into_iter()
}
}
impl<'a, T> IntoIterator for &'a CollectionColumns<T> {
type Item = &'a (String, T);
type IntoIter = std::slice::Iter<'a, (String, T)>;
fn into_iter(self) -> Self::IntoIter {
self.fields.iter()
}
}
impl<T> FromIterator<(String, T)> for CollectionColumns<T> {
fn from_iter<I: IntoIterator<Item = (String, T)>>(iter: I) -> Self {
Self {
fields: iter.into_iter().collect(),
}
}
}
impl<T> From<Vec<(String, T)>> for CollectionColumns<T> {
fn from(value: Vec<(String, T)>) -> Self {
value.into_boxed_slice().into()
}
}
impl<T> From<Box<[(String, T)]>> for CollectionColumns<T> {
fn from(value: Box<[(String, T)]>) -> Self {
Self { fields: value }
}
}
impl<T> CollectionColumns<T>
where
T: TypeSet + Clone,
{
fn widen_fields(lhs: Box<[(String, T)]>, rhs: Box<[(String, T)]>) -> Box<[(String, T)]> {
if lhs.is_empty() || rhs.is_empty() {
return [].into();
}
let (small, big) = if lhs.len() <= rhs.len() {
(lhs, rhs)
} else {
(rhs, lhs)
};
const MAP_THRESH: usize = 16;
if big.len() > MAP_THRESH {
use std::collections::HashMap;
let mut big_map: HashMap<String, T> = big.into_iter().collect();
small
.into_iter()
.filter_map(|(col, typ)| big_map.remove(&col).map(|b_typ| (col, typ.union(b_typ))))
.collect()
} else {
small
.into_iter()
.filter_map(|(col, typ)| {
big.iter()
.find_map(|(b_col, b_typ)| (&col == b_col).then(|| b_typ.clone()))
.map(|b_typ| (col, typ.union(b_typ)))
})
.collect()
}
}
}
fn element_comparison_helper<T, F, O>(
lhs: &CollectionColumns<T>,
rhs: &CollectionColumns<T>,
f: F,
) -> impl Iterator<Item = Option<O>>
where
T: CompareTypes,
F: Fn(&T, &T) -> Option<O>,
{
lhs.iter()
.map(move |(lhs_key, lhs_ty)| match rhs.get(lhs_key) {
Some(rhs_ty) => f(lhs_ty, rhs_ty),
None => None,
})
}
impl<T> CompareTypes for CollectionColumns<T>
where
T: CompareTypes,
{
fn compare_types(&self, other: &Self) -> Option<TypeRelation> {
match (self.is_empty(), other.is_empty()) {
(true, true) => return Some(TypeRelation::Equal),
(true, false) => return Some(TypeRelation::Supertype),
(false, true) => return Some(TypeRelation::Subtype),
(false, false) => (),
}
let (flipped, eq, (lhs, rhs)) = match self.fields.len().cmp(&other.fields.len()) {
std::cmp::Ordering::Less => (false, false, (self, other)),
std::cmp::Ordering::Equal => (false, true, (self, other)),
std::cmp::Ordering::Greater => (true, false, (other, self)),
};
let start = match eq {
true => TypeRelation::Equal,
false => TypeRelation::Supertype,
};
let out = element_comparison_helper(lhs, rhs, |lhs_ty, rhs_ty| {
if lhs_ty.is_any() || rhs_ty.is_any() {
Some(TypeRelation::Equal)
} else {
lhs_ty.compare_types(rhs_ty)
}
})
.try_fold(start, |acc, e| acc.combine(e?))?;
Some(match flipped {
true => out.reverse(),
false => out,
})
}
fn is_any(&self) -> bool {
self.fields.is_empty()
}
fn is_assignable_to(&self, dst: &Self) -> bool {
let src = self;
(src.is_any() || dst.is_any())
|| element_comparison_helper(dst, src, |dst_ty, src_ty| {
Some(src_ty.is_assignable_to(dst_ty))
})
.try_fold(true, |acc, e| Some(acc && (e?)))
.unwrap_or(false)
}
}
impl<T> TypeSet for CollectionColumns<T>
where
T: TypeSet + Clone,
{
fn union(self, other: Self) -> Self {
let Self {
fields: self_fields,
} = self;
let Self {
fields: other_fields,
} = other;
Self {
fields: Self::widen_fields(self_fields, other_fields),
}
}
}
impl<T> Default for CollectionColumns<T> {
fn default() -> Self {
Self {
fields: Default::default(),
}
}
}
impl<T> Display for CollectionColumns<T>
where
T: Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.fields.as_ref() {
[] => Ok(()),
[(name, shape), tail @ ..] => {
write!(f, "<{name}: {shape}")?;
for (name, shape) in tail {
write!(f, ", {name}: {shape}")?;
}
write!(f, ">")?;
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use rstest::rstest;
use super::*;
use crate::Type;
#[rstest]
#[case(Some(TypeRelation::Equal), [], [])]
#[case(Some(TypeRelation::Equal),
[("a", Type::Int)],
[("a", Type::Int)],
)]
#[case(None,
[("a", Type::Int)],
[("b", Type::Int)],
)]
#[case(Some(TypeRelation::Supertype),
[("a", Type::Int), ("b", Type::Int)],
[("a", Type::Int), ("b", Type::Int), ("c", Type::Int)],
)]
#[case(None,
[("name", Type::String), ("attrs", Type::list(Type::Any)), ("desc", Type::String)],
[("attrs", Type::list(Type::String)), ("desc", Type::String)],
)]
fn relations(
#[case] expected: Option<TypeRelation>,
#[case] lhs: impl IntoIterator<Item = (&'static str, Type)>,
#[case] rhs: impl IntoIterator<Item = (&'static str, Type)>,
) {
let lhs = lhs
.into_iter()
.map(|(k, ty)| (k.to_owned(), ty))
.collect::<CollectionColumns<Type>>();
let rhs = rhs
.into_iter()
.map(|(k, ty)| (k.to_owned(), ty))
.collect::<CollectionColumns<Type>>();
assert_eq!(lhs.compare_types(&rhs), expected);
assert_eq!(rhs.compare_types(&lhs), expected.map(TypeRelation::reverse));
}
#[rstest]
#[case(true,
[("name", Type::String), ("attrs", Type::list(Type::Any)), ("desc", Type::String)],
[("attrs", Type::list(Type::String)), ("desc", Type::String)],
)]
fn is_assignable_to(
#[case] expected: bool,
#[case] src: impl IntoIterator<Item = (&'static str, Type)>,
#[case] dst: impl IntoIterator<Item = (&'static str, Type)>,
) {
let src = src
.into_iter()
.map(|(k, ty)| (k.to_owned(), ty))
.collect::<CollectionColumns<Type>>();
let dst = dst
.into_iter()
.map(|(k, ty)| (k.to_owned(), ty))
.collect::<CollectionColumns<Type>>();
assert_eq!(src.is_assignable_to(&dst), expected)
}
}