Skip to main content

graphix_compiler/typ/
normalize.rs

1use crate::typ::{TVar, Type};
2use arcstr::ArcStr;
3use enumflags2::BitFlags;
4use netidx::publisher::Typ;
5use poolshark::local::LPooled;
6use smallvec::SmallVec;
7use std::iter;
8use triomphe::Arc;
9
10impl Type {
11    pub(crate) fn flatten_set(set: impl IntoIterator<Item = Self>) -> Self {
12        let init: Box<dyn Iterator<Item = Self>> = Box::new(set.into_iter());
13        let mut iters: LPooled<Vec<Box<dyn Iterator<Item = Self>>>> =
14            LPooled::from_iter([init]);
15        let mut acc: LPooled<Vec<Self>> = LPooled::take();
16        loop {
17            match iters.last_mut() {
18                None => break,
19                Some(iter) => match iter.next() {
20                    None => {
21                        iters.pop();
22                    }
23                    Some(Type::Set(s)) => {
24                        let v: SmallVec<[Self; 16]> =
25                            s.iter().map(|t| t.clone()).collect();
26                        iters.push(Box::new(v.into_iter()))
27                    }
28                    Some(Type::Any) => return Type::Any,
29                    Some(t) => {
30                        acc.push(t);
31                        let mut i = 0;
32                        let mut j = 0;
33                        while i < acc.len() {
34                            while j < acc.len() {
35                                if j == i {
36                                    j += 1;
37                                    continue;
38                                }
39                                match acc[i].merge(&acc[j]) {
40                                    None => j += 1,
41                                    Some(t) => {
42                                        acc[i] = t;
43                                        acc.remove(j);
44                                        i = 0;
45                                        j = 0;
46                                    }
47                                }
48                            }
49                            i += 1;
50                            j = 0;
51                        }
52                    }
53                },
54            }
55        }
56        acc.sort();
57        match &**acc {
58            [] => Type::Primitive(BitFlags::empty()),
59            [t] => t.clone(),
60            _ => Type::Set(Arc::from_iter(acc.drain(..))),
61        }
62    }
63
64    /// Deep-clone the type tree, replacing every bound TVar with its
65    /// concrete binding (recursively). Unbound TVars are kept as fresh
66    /// named TVars. This produces a snapshot that is independent of the
67    /// original TVar cells.
68    pub fn resolve_tvars(&self) -> Self {
69        match self {
70            Type::Bottom | Type::Any | Type::Primitive(_) => self.clone(),
71            Type::Abstract { id, params } => Type::Abstract {
72                id: *id,
73                params: Arc::from_iter(params.iter().map(|t| t.resolve_tvars())),
74            },
75            Type::Ref { scope, name, params } => Type::Ref {
76                scope: scope.clone(),
77                name: name.clone(),
78                params: Arc::from_iter(params.iter().map(|t| t.resolve_tvars())),
79            },
80            Type::TVar(tv) => match tv.read().typ.read().as_ref() {
81                Some(t) => t.resolve_tvars(),
82                None => Type::TVar(TVar::empty_named(tv.name.clone())),
83            },
84            Type::Set(s) => {
85                Type::Set(Arc::from_iter(s.iter().map(|t| t.resolve_tvars())))
86            }
87            Type::Error(t) => Type::Error(Arc::new(t.resolve_tvars())),
88            Type::Array(t) => Type::Array(Arc::new(t.resolve_tvars())),
89            Type::Map { key, value } => Type::Map {
90                key: Arc::new(key.resolve_tvars()),
91                value: Arc::new(value.resolve_tvars()),
92            },
93            Type::ByRef(t) => Type::ByRef(Arc::new(t.resolve_tvars())),
94            Type::Tuple(t) => {
95                Type::Tuple(Arc::from_iter(t.iter().map(|t| t.resolve_tvars())))
96            }
97            Type::Struct(t) => Type::Struct(Arc::from_iter(
98                t.iter().map(|(n, t)| (n.clone(), t.resolve_tvars())),
99            )),
100            Type::Variant(tag, t) => Type::Variant(
101                tag.clone(),
102                Arc::from_iter(t.iter().map(|t| t.resolve_tvars())),
103            ),
104            Type::Fn(ft) => Type::Fn(Arc::new(ft.resolve_tvars())),
105        }
106    }
107
108    pub(crate) fn normalize(&self) -> Self {
109        match self {
110            Type::Bottom | Type::Any | Type::Abstract { .. } | Type::Primitive(_) => {
111                self.clone()
112            }
113            Type::Ref { scope, name, params } => {
114                let params = Arc::from_iter(params.iter().map(|t| t.normalize()));
115                Type::Ref { scope: scope.clone(), name: name.clone(), params }
116            }
117            Type::TVar(tv) => Type::TVar(tv.normalize()),
118            Type::Set(s) => Self::flatten_set(s.iter().map(|t| t.normalize())),
119            Type::Error(t) => Type::Error(Arc::new(t.normalize())),
120            Type::Array(t) => Type::Array(Arc::new(t.normalize())),
121            Type::Map { key, value } => {
122                let key = Arc::new(key.normalize());
123                let value = Arc::new(value.normalize());
124                Type::Map { key, value }
125            }
126            Type::ByRef(t) => Type::ByRef(Arc::new(t.normalize())),
127            Type::Tuple(t) => {
128                Type::Tuple(Arc::from_iter(t.iter().map(|t| t.normalize())))
129            }
130            Type::Struct(t) => Type::Struct(Arc::from_iter(
131                t.iter().map(|(n, t)| (n.clone(), t.normalize())),
132            )),
133            Type::Variant(tag, t) => Type::Variant(
134                tag.clone(),
135                Arc::from_iter(t.iter().map(|t| t.normalize())),
136            ),
137            Type::Fn(ft) => Type::Fn(Arc::new(ft.normalize())),
138        }
139    }
140
141    fn merge(&self, t: &Self) -> Option<Self> {
142        macro_rules! flatten {
143            ($t:expr) => {
144                match $t {
145                    Type::Set(et) => Self::flatten_set(et.iter().cloned()),
146                    t => t.clone(),
147                }
148            };
149        }
150        match (self, t) {
151            (
152                Type::Ref { scope: s0, name: r0, params: a0 },
153                Type::Ref { scope: s1, name: r1, params: a1 },
154            ) => {
155                if s0 == s1 && r0 == r1 && a0 == a1 {
156                    Some(Type::Ref {
157                        scope: s0.clone(),
158                        name: r0.clone(),
159                        params: a0.clone(),
160                    })
161                } else {
162                    None
163                }
164            }
165            (Type::Ref { .. }, _) | (_, Type::Ref { .. }) => None,
166            (Type::Bottom, t) | (t, Type::Bottom) => Some(t.clone()),
167            (Type::Any, _) | (_, Type::Any) => Some(Type::Any),
168            (Type::Primitive(s0), Type::Primitive(s1)) => {
169                Some(Type::Primitive(*s0 | *s1))
170            }
171            (Type::Primitive(p), t) | (t, Type::Primitive(p)) if p.is_empty() => {
172                Some(t.clone())
173            }
174            (
175                Type::Abstract { id: id0, params: p0 },
176                Type::Abstract { id: id1, params: p1 },
177            ) => {
178                if id0 == id1 && p0 == p1 {
179                    Some(self.clone())
180                } else {
181                    None
182                }
183            }
184            (Type::Fn(f0), Type::Fn(f1)) => {
185                if f0 == f1 {
186                    Some(Type::Fn(f0.clone()))
187                } else {
188                    None
189                }
190            }
191            (Type::Array(t0), Type::Array(t1)) => {
192                if flatten!(&**t0) == flatten!(&**t1) {
193                    Some(Type::Array(t0.clone()))
194                } else {
195                    None
196                }
197            }
198            (Type::Primitive(p), Type::Array(_))
199            | (Type::Array(_), Type::Primitive(p)) => {
200                if p.contains(Typ::Array) {
201                    Some(Type::Primitive(*p))
202                } else {
203                    None
204                }
205            }
206            (Type::Map { key: k0, value: v0 }, Type::Map { key: k1, value: v1 }) => {
207                if flatten!(&**k0) == flatten!(&**k1)
208                    && flatten!(&**v0) == flatten!(&**v1)
209                {
210                    Some(Type::Map { key: k0.clone(), value: v0.clone() })
211                } else {
212                    None
213                }
214            }
215            (Type::Error(t0), Type::Error(t1)) => {
216                if flatten!(&**t0) == flatten!(&**t1) {
217                    Some(Type::Error(t0.clone()))
218                } else {
219                    None
220                }
221            }
222            (Type::ByRef(t0), Type::ByRef(t1)) => {
223                t0.merge(t1).map(|t| Type::ByRef(Arc::new(t)))
224            }
225            (Type::Set(s0), Type::Set(s1)) => {
226                Some(Self::flatten_set(s0.iter().cloned().chain(s1.iter().cloned())))
227            }
228            (Type::Set(s), Type::Primitive(p)) | (Type::Primitive(p), Type::Set(s))
229                if p.is_empty() =>
230            {
231                Some(Type::Set(s.clone()))
232            }
233            (Type::Set(s), t) | (t, Type::Set(s)) => {
234                Some(Self::flatten_set(s.iter().cloned().chain(iter::once(t.clone()))))
235            }
236            (Type::Tuple(t0), Type::Tuple(t1)) => {
237                if t0.len() == t1.len() {
238                    let mut t = t0
239                        .iter()
240                        .zip(t1.iter())
241                        .map(|(t0, t1)| t0.merge(t1))
242                        .collect::<Option<LPooled<Vec<Type>>>>()?;
243                    Some(Type::Tuple(Arc::from_iter(t.drain(..))))
244                } else {
245                    None
246                }
247            }
248            (Type::Variant(tag0, t0), Type::Variant(tag1, t1)) => {
249                if tag0 == tag1 && t0.len() == t1.len() {
250                    let t = t0
251                        .iter()
252                        .zip(t1.iter())
253                        .map(|(t0, t1)| t0.merge(t1))
254                        .collect::<Option<SmallVec<[Type; 8]>>>()?;
255                    Some(Type::Variant(tag0.clone(), Arc::from_iter(t)))
256                } else {
257                    None
258                }
259            }
260            (Type::Struct(t0), Type::Struct(t1)) => {
261                if t0.len() == t1.len() {
262                    let t = t0
263                        .iter()
264                        .zip(t1.iter())
265                        .map(|((n0, t0), (n1, t1))| {
266                            if n0 != n1 {
267                                None
268                            } else {
269                                t0.merge(t1).map(|t| (n0.clone(), t))
270                            }
271                        })
272                        .collect::<Option<SmallVec<[(ArcStr, Type); 8]>>>()?;
273                    Some(Type::Struct(Arc::from_iter(t)))
274                } else {
275                    None
276                }
277            }
278            (Type::TVar(tv0), Type::TVar(tv1)) if tv0.name == tv1.name && tv0 == tv1 => {
279                Some(Type::TVar(tv0.clone()))
280            }
281            (Type::TVar(tv), t) => {
282                tv.read().typ.read().as_ref().and_then(|tv| tv.merge(t))
283            }
284            (t, Type::TVar(tv)) => {
285                tv.read().typ.read().as_ref().and_then(|tv| t.merge(tv))
286            }
287            (Type::ByRef(_), _)
288            | (_, Type::ByRef(_))
289            | (Type::Abstract { .. }, _)
290            | (_, Type::Abstract { .. })
291            | (Type::Array(_), _)
292            | (_, Type::Array(_))
293            | (_, Type::Map { .. })
294            | (Type::Map { .. }, _)
295            | (Type::Tuple(_), _)
296            | (_, Type::Tuple(_))
297            | (Type::Struct(_), _)
298            | (_, Type::Struct(_))
299            | (Type::Variant(_, _), _)
300            | (_, Type::Variant(_, _))
301            | (_, Type::Fn(_))
302            | (Type::Fn(_), _)
303            | (Type::Error(_), _)
304            | (_, Type::Error(_)) => None,
305        }
306    }
307}