Skip to main content

graphix_compiler/typ/
normalize.rs

1use crate::typ::{TVar, Type, TypeRef};
2use arcstr::ArcStr;
3use enumflags2::BitFlags;
4use netidx::publisher::Typ;
5use poolshark::local::LPooled;
6use smallvec::SmallVec;
7use std::iter;
8use triomphe::Arc;
9
10impl Type {
11    pub(crate) fn flatten_set(set: impl IntoIterator<Item = Self>) -> Self {
12        let init: Box<dyn Iterator<Item = Self>> = Box::new(set.into_iter());
13        let mut iters: LPooled<Vec<Box<dyn Iterator<Item = Self>>>> =
14            LPooled::from_iter([init]);
15        let mut acc: LPooled<Vec<Self>> = LPooled::take();
16        loop {
17            match iters.last_mut() {
18                None => break,
19                Some(iter) => match iter.next() {
20                    None => {
21                        iters.pop();
22                    }
23                    Some(Type::Set(s)) => {
24                        let v: SmallVec<[Self; 16]> =
25                            s.iter().map(|t| t.clone()).collect();
26                        iters.push(Box::new(v.into_iter()))
27                    }
28                    Some(Type::Any) => return Type::Any,
29                    Some(t) => {
30                        acc.push(t);
31                        let mut i = 0;
32                        let mut j = 0;
33                        while i < acc.len() {
34                            while j < acc.len() {
35                                if j == i {
36                                    j += 1;
37                                    continue;
38                                }
39                                match acc[i].merge(&acc[j]) {
40                                    None => j += 1,
41                                    Some(t) => {
42                                        acc[i] = t;
43                                        acc.remove(j);
44                                        i = 0;
45                                        j = 0;
46                                    }
47                                }
48                            }
49                            i += 1;
50                            j = 0;
51                        }
52                    }
53                },
54            }
55        }
56        acc.sort();
57        match &**acc {
58            [] => Type::Primitive(BitFlags::empty()),
59            [t] => t.clone(),
60            _ => Type::Set(Arc::from_iter(acc.drain(..))),
61        }
62    }
63
64    /// Deep-clone the type tree, replacing every bound TVar with its
65    /// concrete binding (recursively). Unbound TVars are kept as fresh
66    /// named TVars. This produces a snapshot that is independent of the
67    /// original TVar cells.
68    pub fn resolve_tvars(&self) -> Self {
69        match self {
70            Type::Bottom | Type::Any | Type::Primitive(_) => self.clone(),
71            Type::Abstract { id, params } => Type::Abstract {
72                id: *id,
73                params: Arc::from_iter(params.iter().map(|t| t.resolve_tvars())),
74            },
75            Type::Ref(tr) => Type::Ref(TypeRef {
76                params: Arc::from_iter(tr.params.iter().map(|t| t.resolve_tvars())),
77                ..tr.clone()
78            }),
79            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
80                Some(t) => t.resolve_tvars(),
81                None => Type::TVar(TVar::empty_named(tv.name.clone())),
82            },
83            Type::Set(s) => {
84                Type::Set(Arc::from_iter(s.iter().map(|t| t.resolve_tvars())))
85            }
86            Type::Error(t) => Type::Error(Arc::new(t.resolve_tvars())),
87            Type::Array(t) => Type::Array(Arc::new(t.resolve_tvars())),
88            Type::Map { key, value } => Type::Map {
89                key: Arc::new(key.resolve_tvars()),
90                value: Arc::new(value.resolve_tvars()),
91            },
92            Type::ByRef(t) => Type::ByRef(Arc::new(t.resolve_tvars())),
93            Type::Tuple(t) => {
94                Type::Tuple(Arc::from_iter(t.iter().map(|t| t.resolve_tvars())))
95            }
96            Type::Struct(t) => Type::Struct(Arc::from_iter(
97                t.iter().map(|(n, t)| (n.clone(), t.resolve_tvars())),
98            )),
99            Type::Variant(tag, t) => Type::Variant(
100                tag.clone(),
101                Arc::from_iter(t.iter().map(|t| t.resolve_tvars())),
102            ),
103            Type::Fn(ft) => Type::Fn(Arc::new(ft.resolve_tvars())),
104        }
105    }
106
107    pub(crate) fn normalize(&self) -> Self {
108        match self {
109            Type::Bottom | Type::Any | Type::Abstract { .. } | Type::Primitive(_) => {
110                self.clone()
111            }
112            Type::Ref(tr) => {
113                let params = Arc::from_iter(tr.params.iter().map(|t| t.normalize()));
114                Type::Ref(TypeRef { params, ..tr.clone() })
115            }
116            Type::TVar(tv) => Type::TVar(tv.normalize()),
117            Type::Set(s) => Self::flatten_set(s.iter().map(|t| t.normalize())),
118            Type::Error(t) => Type::Error(Arc::new(t.normalize())),
119            Type::Array(t) => Type::Array(Arc::new(t.normalize())),
120            Type::Map { key, value } => {
121                let key = Arc::new(key.normalize());
122                let value = Arc::new(value.normalize());
123                Type::Map { key, value }
124            }
125            Type::ByRef(t) => Type::ByRef(Arc::new(t.normalize())),
126            Type::Tuple(t) => {
127                Type::Tuple(Arc::from_iter(t.iter().map(|t| t.normalize())))
128            }
129            Type::Struct(t) => Type::Struct(Arc::from_iter(
130                t.iter().map(|(n, t)| (n.clone(), t.normalize())),
131            )),
132            Type::Variant(tag, t) => Type::Variant(
133                tag.clone(),
134                Arc::from_iter(t.iter().map(|t| t.normalize())),
135            ),
136            Type::Fn(ft) => Type::Fn(Arc::new(ft.normalize())),
137        }
138    }
139
140    fn merge(&self, t: &Self) -> Option<Self> {
141        macro_rules! flatten {
142            ($t:expr) => {
143                match $t {
144                    Type::Set(et) => Self::flatten_set(et.iter().cloned()),
145                    t => t.clone(),
146                }
147            };
148        }
149        match (self, t) {
150            (Type::Ref(t0), Type::Ref(t1)) => {
151                if t0 == t1 {
152                    Some(Type::Ref(t0.clone()))
153                } else {
154                    None
155                }
156            }
157            (Type::Ref (TypeRef { .. }), _) | (_, Type::Ref (TypeRef { .. })) => None,
158            (Type::Bottom, t) | (t, Type::Bottom) => Some(t.clone()),
159            (Type::Any, _) | (_, Type::Any) => Some(Type::Any),
160            (Type::Primitive(s0), Type::Primitive(s1)) => {
161                Some(Type::Primitive(*s0 | *s1))
162            }
163            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
164                Some(t.clone())
165            }
166            (
167                Type::Abstract { id: id0, params: p0 },
168                Type::Abstract { id: id1, params: p1 },
169            ) => {
170                if id0 == id1 && p0 == p1 {
171                    Some(self.clone())
172                } else {
173                    None
174                }
175            }
176            (Type::Fn(f0), Type::Fn(f1)) => {
177                if f0 == f1 {
178                    Some(Type::Fn(f0.clone()))
179                } else {
180                    None
181                }
182            }
183            (Type::Array(t0), Type::Array(t1)) => {
184                if flatten!(&**t0) == flatten!(&**t1) {
185                    Some(Type::Array(t0.clone()))
186                } else {
187                    None
188                }
189            }
190            (Type::Primitive(p), Type::Array(_))
191            | (Type::Array(_), Type::Primitive(p)) => {
192                if p.contains(Typ::Array) {
193                    Some(Type::Primitive(*p))
194                } else {
195                    None
196                }
197            }
198            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
199                if flatten!(&**k0) == flatten!(&**k1)
200                    && flatten!(&**v0) == flatten!(&**v1)
201                {
202                    Some(Type::Map { key: k0.clone(), value: v0.clone() })
203                } else {
204                    None
205                }
206            }
207            (Type::Error(t0), Type::Error(t1)) => {
208                if flatten!(&**t0) == flatten!(&**t1) {
209                    Some(Type::Error(t0.clone()))
210                } else {
211                    None
212                }
213            }
214            (Type::ByRef(t0), Type::ByRef(t1)) => {
215                t0.merge(t1).map(|t| Type::ByRef(Arc::new(t)))
216            }
217            (Type::Set(s0), Type::Set(s1)) => {
218                Some(Self::flatten_set(s0.iter().cloned().chain(s1.iter().cloned())))
219            }
220            (Type::Set(s), Type::Primitive(p)) | (Type::Primitive(p), Type::Set(s))
221                if p.is_empty() =>
222            {
223                Some(Type::Set(s.clone()))
224            }
225            (Type::Set(s), t) | (t, Type::Set(s)) => {
226                Some(Self::flatten_set(s.iter().cloned().chain(iter::once(t.clone()))))
227            }
228            (Type::Tuple(t0), Type::Tuple(t1)) => {
229                if t0.len() == t1.len() {
230                    let mut t = t0
231                        .iter()
232                        .zip(t1.iter())
233                        .map(|(t0, t1)| t0.merge(t1))
234                        .collect::<Option<LPooled<Vec<Type>>>>()?;
235                    Some(Type::Tuple(Arc::from_iter(t.drain(..))))
236                } else {
237                    None
238                }
239            }
240            (Type::Variant(tag0, t0), Type::Variant(tag1, t1)) => {
241                if tag0 == tag1 && t0.len() == t1.len() {
242                    let t = t0
243                        .iter()
244                        .zip(t1.iter())
245                        .map(|(t0, t1)| t0.merge(t1))
246                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
247                    Some(Type::Variant(tag0.clone(), Arc::from_iter(t)))
248                } else {
249                    None
250                }
251            }
252            (Type::Struct(t0), Type::Struct(t1)) => {
253                if t0.len() == t1.len() {
254                    let t = t0
255                        .iter()
256                        .zip(t1.iter())
257                        .map(|((n0, t0), (n1, t1))| {
258                            if n0 != n1 {
259                                None
260                            } else {
261                                t0.merge(t1).map(|t| (n0.clone(), t))
262                            }
263                        })
264                        .collect::<Option<SmallVec<[(ArcStr, Type); 8]>>>()?;
265                    Some(Type::Struct(Arc::from_iter(t)))
266                } else {
267                    None
268                }
269            }
270            (Type::TVar(tv0), Type::TVar(tv1)) if tv0.name == tv1.name && tv0 == tv1 => {
271                Some(Type::TVar(tv0.clone()))
272            }
273            (Type::TVar(tv), t) => {
274                tv.read().typ.read().as_ref().and_then(|tv| tv.merge(t))
275            }
276            (t, Type::TVar(tv)) => {
277                tv.read().typ.read().as_ref().and_then(|tv| t.merge(tv))
278            }
279            (Type::ByRef(_), _)
280            | (_, Type::ByRef(_))
281            | (Type::Abstract { .. }, _)
282            | (_, Type::Abstract { .. })
283            | (Type::Array(_), _)
284            | (_, Type::Array(_))
285            | (_, Type::Map { .. })
286            | (Type::Map { .. }, _)
287            | (Type::Tuple(_), _)
288            | (_, Type::Tuple(_))
289            | (Type::Struct(_), _)
290            | (_, Type::Struct(_))
291            | (Type::Variant(_, _), _)
292            | (_, Type::Variant(_, _))
293            | (_, Type::Fn(_))
294            | (Type::Fn(_), _)
295            | (Type::Error(_), _)
296            | (_, Type::Error(_)) => None,
297        }
298    }
299}