1use crate::{CompareTypes, TypeRelation, TypeSet};
2use serde::{Deserialize, Serialize};
3use std::fmt::Display;
4
5#[allow(unused_imports)]
6use crate::SyntaxShape;
7
8#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Ord, PartialOrd)]
44#[serde(transparent)]
45pub struct CollectionColumns<T> {
46 fields: Box<[(String, T)]>,
47}
48
49impl<T> CollectionColumns<T> {
50 pub fn map<U>(&self, f: impl Fn(&T) -> U) -> CollectionColumns<U> {
51 self.iter().map(|(k, v)| (k.clone(), f(v))).collect()
52 }
53
54 pub fn iter(&self) -> impl Iterator<Item = &(String, T)> {
55 self.into_iter()
56 }
57
58 pub fn is_empty(&self) -> bool {
59 self.fields.is_empty()
60 }
61
62 pub fn len(&self) -> usize {
63 self.fields.len()
64 }
65}
66
67impl<T> CollectionColumns<T> {
68 pub fn new(fields: Box<[(String, T)]>) -> Self {
69 Self { fields }
70 }
71
72 pub fn get<'s>(&'s self, key: &'_ str) -> Option<&'s T> {
73 self.iter()
74 .find(|(name, _)| name == key)
75 .map(|(_, val)| val)
76 }
77}
78
79impl<T> IntoIterator for CollectionColumns<T> {
80 type Item = (String, T);
81 type IntoIter = std::vec::IntoIter<Self::Item>;
82
83 fn into_iter(self) -> Self::IntoIter {
84 self.fields.into_iter()
85 }
86}
87
88impl<'a, T> IntoIterator for &'a CollectionColumns<T> {
89 type Item = &'a (String, T);
90 type IntoIter = std::slice::Iter<'a, (String, T)>;
91
92 fn into_iter(self) -> Self::IntoIter {
93 self.fields.iter()
94 }
95}
96
97impl<T> FromIterator<(String, T)> for CollectionColumns<T> {
98 fn from_iter<I: IntoIterator<Item = (String, T)>>(iter: I) -> Self {
99 Self {
100 fields: iter.into_iter().collect(),
101 }
102 }
103}
104
105impl<T> From<Vec<(String, T)>> for CollectionColumns<T> {
106 fn from(value: Vec<(String, T)>) -> Self {
107 value.into_boxed_slice().into()
108 }
109}
110
111impl<T> From<Box<[(String, T)]>> for CollectionColumns<T> {
112 fn from(value: Box<[(String, T)]>) -> Self {
113 Self { fields: value }
114 }
115}
116
117impl<T> CollectionColumns<T>
118where
119 T: TypeSet + Clone,
120{
121 fn widen_fields(lhs: Box<[(String, T)]>, rhs: Box<[(String, T)]>) -> Box<[(String, T)]> {
122 if lhs.is_empty() || rhs.is_empty() {
123 return [].into();
124 }
125
126 let (small, big) = if lhs.len() <= rhs.len() {
128 (lhs, rhs)
129 } else {
130 (rhs, lhs)
131 };
132
133 const MAP_THRESH: usize = 16;
134 if big.len() > MAP_THRESH {
135 use std::collections::HashMap;
136 let mut big_map: HashMap<String, T> = big.into_iter().collect();
137 small
138 .into_iter()
139 .filter_map(|(col, typ)| big_map.remove(&col).map(|b_typ| (col, typ.union(b_typ))))
140 .collect()
141 } else {
142 small
143 .into_iter()
144 .filter_map(|(col, typ)| {
145 big.iter()
146 .find_map(|(b_col, b_typ)| (&col == b_col).then(|| b_typ.clone()))
147 .map(|b_typ| (col, typ.union(b_typ)))
148 })
149 .collect()
150 }
151 }
152}
153
154fn element_comparison_helper<T, F, O>(
155 lhs: &CollectionColumns<T>,
156 rhs: &CollectionColumns<T>,
157 f: F,
158) -> impl Iterator<Item = Option<O>>
159where
160 T: CompareTypes,
161 F: Fn(&T, &T) -> Option<O>,
162{
163 lhs.iter()
164 .map(move |(lhs_key, lhs_ty)| match rhs.get(lhs_key) {
165 Some(rhs_ty) => f(lhs_ty, rhs_ty),
166 None => None,
170 })
171}
172
173impl<T> CompareTypes for CollectionColumns<T>
174where
175 T: CompareTypes,
176{
177 fn compare_types(&self, other: &Self) -> Option<TypeRelation> {
178 match (self.is_empty(), other.is_empty()) {
187 (true, true) => return Some(TypeRelation::Equal),
188 (true, false) => return Some(TypeRelation::Supertype),
189 (false, true) => return Some(TypeRelation::Subtype),
190 (false, false) => (),
191 }
192
193 let (flipped, eq, (lhs, rhs)) = match self.fields.len().cmp(&other.fields.len()) {
195 std::cmp::Ordering::Less => (false, false, (self, other)),
196 std::cmp::Ordering::Equal => (false, true, (self, other)),
197 std::cmp::Ordering::Greater => (true, false, (other, self)),
198 };
199
200 let start = match eq {
201 true => TypeRelation::Equal,
202 false => TypeRelation::Supertype,
203 };
204
205 let out = element_comparison_helper(lhs, rhs, |lhs_ty, rhs_ty| {
206 if lhs_ty.is_any() || rhs_ty.is_any() {
207 Some(TypeRelation::Equal)
209 } else {
210 lhs_ty.compare_types(rhs_ty)
211 }
212 })
213 .try_fold(start, |acc, e| acc.combine(e?))?;
214
215 Some(match flipped {
216 true => out.reverse(),
217 false => out,
218 })
219 }
220
221 fn is_any(&self) -> bool {
223 self.fields.is_empty()
224 }
225
226 fn is_assignable_to(&self, dst: &Self) -> bool {
227 let src = self;
228
229 (src.is_any() || dst.is_any())
230 || element_comparison_helper(dst, src, |dst_ty, src_ty| {
231 Some(src_ty.is_assignable_to(dst_ty))
232 })
233 .try_fold(true, |acc, e| Some(acc && (e?)))
234 .unwrap_or(false)
235 }
236}
237
238impl<T> TypeSet for CollectionColumns<T>
239where
240 T: TypeSet + Clone,
241{
242 fn union(self, other: Self) -> Self {
243 let Self {
244 fields: self_fields,
245 } = self;
246 let Self {
247 fields: other_fields,
248 } = other;
249
250 Self {
251 fields: Self::widen_fields(self_fields, other_fields),
252 }
253 }
254}
255
256impl<T> Default for CollectionColumns<T> {
257 fn default() -> Self {
258 Self {
259 fields: Default::default(),
260 }
261 }
262}
263
264impl<T> Display for CollectionColumns<T>
265where
266 T: Display,
267{
268 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
269 match self.fields.as_ref() {
270 [] => Ok(()),
271 [(name, shape), tail @ ..] => {
272 write!(f, "<{name}: {shape}")?;
273 for (name, shape) in tail {
274 write!(f, ", {name}: {shape}")?;
275 }
276
277 write!(f, ">")?;
278 Ok(())
279 }
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use pretty_assertions::assert_eq;
287 use rstest::rstest;
288
289 use super::*;
290 use crate::Type;
291
292 #[rstest]
293 #[case(Some(TypeRelation::Equal), [], [])]
294 #[case(Some(TypeRelation::Equal),
295 [("a", Type::Int)],
296 [("a", Type::Int)],
297 )]
298 #[case(None,
299 [("a", Type::Int)],
300 [("b", Type::Int)],
301 )]
302 #[case(Some(TypeRelation::Supertype),
303 [("a", Type::Int), ("b", Type::Int)],
304 [("a", Type::Int), ("b", Type::Int), ("c", Type::Int)],
305 )]
306 #[case(None,
307 [("name", Type::String), ("attrs", Type::list(Type::Any)), ("desc", Type::String)],
308 [("attrs", Type::list(Type::String)), ("desc", Type::String)],
309 )]
310 fn relations(
311 #[case] expected: Option<TypeRelation>,
312 #[case] lhs: impl IntoIterator<Item = (&'static str, Type)>,
313 #[case] rhs: impl IntoIterator<Item = (&'static str, Type)>,
314 ) {
315 let lhs = lhs
316 .into_iter()
317 .map(|(k, ty)| (k.to_owned(), ty))
318 .collect::<CollectionColumns<Type>>();
319 let rhs = rhs
320 .into_iter()
321 .map(|(k, ty)| (k.to_owned(), ty))
322 .collect::<CollectionColumns<Type>>();
323
324 assert_eq!(lhs.compare_types(&rhs), expected);
325 assert_eq!(rhs.compare_types(&lhs), expected.map(TypeRelation::reverse));
326 }
327
328 #[rstest]
329 #[case(true,
330 [("name", Type::String), ("attrs", Type::list(Type::Any)), ("desc", Type::String)],
331 [("attrs", Type::list(Type::String)), ("desc", Type::String)],
332 )]
333 fn is_assignable_to(
334 #[case] expected: bool,
335 #[case] src: impl IntoIterator<Item = (&'static str, Type)>,
336 #[case] dst: impl IntoIterator<Item = (&'static str, Type)>,
337 ) {
338 let src = src
339 .into_iter()
340 .map(|(k, ty)| (k.to_owned(), ty))
341 .collect::<CollectionColumns<Type>>();
342 let dst = dst
343 .into_iter()
344 .map(|(k, ty)| (k.to_owned(), ty))
345 .collect::<CollectionColumns<Type>>();
346
347 assert_eq!(src.is_assignable_to(&dst), expected)
348 }
349}