1use linked_hash_map::LinkedHashMap;
2
3#[non_exhaustive]
9#[derive(Debug, PartialEq, Clone)]
10pub enum Shape {
11 Bottom,
13
14 Any,
17
18 Optional(Box<Shape>),
20
21 Null,
23
24 Bool,
25 StringT,
26 Integer,
27 Floating,
28 VecT {
29 elem_type: Box<Shape>,
30 },
31 Struct {
32 fields: LinkedHashMap<String, Shape>,
33 },
34 Tuple(Vec<Shape>, u64),
35 MapT {
36 val_type: Box<Shape>,
37 },
38 Opaque(String),
39}
40
41pub fn fold_shapes(shapes: Vec<Shape>) -> Shape {
42 shapes.into_iter().fold(Shape::Bottom, common_shape)
43}
44
45pub fn common_shape(a: Shape, b: Shape) -> Shape {
46 if a == b {
47 return a;
48 }
49 use self::Shape::*;
50 match (a, b) {
51 (a, Bottom) | (Bottom, a) => a,
52 (Integer, Floating) | (Floating, Integer) => Floating,
53 (a, Null) | (Null, a) => a.into_optional(),
54 (a, Optional(b)) | (Optional(b), a) => common_shape(a, *b).into_optional(),
55 (Tuple(shapes1, n1), Tuple(shapes2, n2)) => {
56 if shapes1.len() == shapes2.len() {
57 let shapes: Vec<_> = shapes1
58 .into_iter()
59 .zip(shapes2.into_iter())
60 .map(|(a, b)| common_shape(a, b))
61 .collect();
62 Tuple(shapes, n1 + n2)
63 } else {
64 VecT {
65 elem_type: Box::new(common_shape(fold_shapes(shapes1), fold_shapes(shapes2))),
66 }
67 }
68 }
69 (Tuple(shapes, _), VecT { elem_type: e1 }) | (VecT { elem_type: e1 }, Tuple(shapes, _)) => {
70 VecT {
71 elem_type: Box::new(common_shape(*e1, fold_shapes(shapes))),
72 }
73 }
74 (VecT { elem_type: e1 }, VecT { elem_type: e2 }) => VecT {
75 elem_type: Box::new(common_shape(*e1, *e2)),
76 },
77 (MapT { val_type: v1 }, MapT { val_type: v2 }) => MapT {
78 val_type: Box::new(common_shape(*v1, *v2)),
79 },
80 (Struct { fields: f1 }, Struct { fields: f2 }) => Struct {
81 fields: common_field_shapes(f1, f2),
82 },
83 (Opaque(t), _) | (_, Opaque(t)) => Opaque(t),
84 _ => Any,
85 }
86}
87
88fn common_field_shapes(
89 mut f1: LinkedHashMap<String, Shape>,
90 mut f2: LinkedHashMap<String, Shape>,
91) -> LinkedHashMap<String, Shape> {
92 if f1 == f2 {
93 return f1;
94 }
95 for (key, val) in f1.iter_mut() {
96 let temp = std::mem::replace(val, Shape::Bottom);
97 match f2.remove(key) {
98 Some(val2) => {
99 *val = common_shape(temp, val2);
100 }
101 None => {
102 *val = temp.into_optional();
103 }
104 };
105 }
106 for (key, val) in f2.into_iter() {
107 f1.insert(key, val.into_optional());
108 }
109 f1
110}
111
112impl Shape {
113 fn into_optional(self) -> Self {
114 use self::Shape::*;
115 match self {
116 Null | Any | Bottom | Optional(_) => self,
117 non_nullable => Optional(Box::new(non_nullable)),
118 }
119 }
120
121 pub(crate) fn is_acceptable_substitution_for(&self, other: &Shape) -> bool {
124 use self::Shape::*;
125 if self == other {
126 return true;
127 }
128 match (self, other) {
129 (_, Bottom) => true,
130 (Optional(_), Null) => true,
131 (Optional(a), Optional(b)) => a.is_acceptable_substitution_for(b),
132 (VecT { elem_type: e1 }, VecT { elem_type: e2 }) => {
133 e1.is_acceptable_substitution_for(e2)
134 }
135 (MapT { val_type: v1 }, MapT { val_type: v2 }) => v1.is_acceptable_substitution_for(v2),
136 (Tuple(a, _), Tuple(b, _)) => {
137 a.len() == b.len()
138 && a.iter()
139 .zip(b.iter())
140 .all(|(e1, e2)| e1.is_acceptable_substitution_for(e2))
141 }
142 (Struct { fields: f1 }, Struct { fields: f2 }) => {
143 f1.len() == f2.len() && f1.iter().all(|(key, shape1)| {
146 if let Some(shape2) = f2.get(key) {
147 shape1.is_acceptable_substitution_for(shape2)
148 } else {
149 false
150 }
151 })
152 }
153 _ => false,
154 }
155 }
156}
157
158pub fn collapse_option(typ: &Shape) -> (bool, &Shape) {
159 if let Shape::Optional(inner) = typ {
160 return (true, &**inner);
161 }
162 (false, typ)
163}
164
165#[test]
166fn test_unify() {
167 use self::Shape::*;
168 assert_eq!(common_shape(Bool, Bool), Bool);
169 assert_eq!(common_shape(Bool, Integer), Any);
170 assert_eq!(common_shape(Integer, Floating), Floating);
171 assert_eq!(common_shape(Null, Any), Any);
172 assert_eq!(common_shape(Null, Bool), Optional(Box::new(Bool)));
173 assert_eq!(
174 common_shape(Null, Optional(Box::new(Integer))),
175 Optional(Box::new(Integer))
176 );
177 assert_eq!(common_shape(Any, Optional(Box::new(Integer))), Any);
178 assert_eq!(common_shape(Any, Optional(Box::new(Integer))), Any);
179 assert_eq!(
180 common_shape(Optional(Box::new(Integer)), Optional(Box::new(Floating))),
181 Optional(Box::new(Floating))
182 );
183 assert_eq!(
184 common_shape(Optional(Box::new(StringT)), Optional(Box::new(Integer))),
185 Any
186 );
187}
188
189#[test]
190fn test_common_field_shapes() {
191 use self::Shape::*;
192 use crate::word_case::string_hashmap;
193 {
194 let f1 = string_hashmap! {
195 "a" => Integer,
196 "b" => Bool,
197 "c" => Integer,
198 "d" => StringT,
199 };
200 let f2 = string_hashmap! {
201 "a" => Integer,
202 "c" => Floating,
203 "d" => Null,
204 "e" => Any,
205 };
206 assert_eq!(
207 common_field_shapes(f1, f2),
208 string_hashmap! {
209 "a" => Integer,
210 "b" => Optional(Box::new(Bool)),
211 "c" => Floating,
212 "d" => Optional(Box::new(StringT)),
213 "e" => Any,
214 }
215 );
216 }
217}