Skip to main content

graphix_compiler/typ/
cast.rs

1use crate::{
2    env::Env,
3    errf,
4    typ::{RefHist, Type, TypeRef},
5    AbstractTypeRegistry, CAST_ERR_TAG,
6};
7use ahash::AHashSet;
8use anyhow::{anyhow, bail, Result};
9use arcstr::ArcStr;
10use enumflags2::{bitflags, BitFlags};
11use immutable_chunkmap::map::Map;
12use netidx::publisher::{Typ, Value};
13use netidx_value::ValArray;
14use poolshark::local::LPooled;
15use std::iter;
16use triomphe::Arc;
17
18#[derive(Debug, Clone, Copy)]
19#[bitflags]
20#[repr(u8)]
21pub enum IsAFlags {
22    /// When set, Type::Abstract matches any Value::Abstract
23    MatchAbstract,
24}
25
26impl Type {
27    fn check_cast_int(
28        &self,
29        env: &Env,
30        hist: &mut RefHist<AHashSet<Option<usize>>>,
31    ) -> Result<()> {
32        match self {
33            Type::Primitive(_) | Type::Any => Ok(()),
34            Type::Fn(_) => bail!("can't cast a value to a function"),
35            Type::Bottom => bail!("can't cast a value to bottom"),
36            Type::Set(s) | Type::Abstract { id: _, params: s } => Ok(for t in s.iter() {
37                t.check_cast_int(env, hist)?
38            }),
39            Type::TVar(tv) => match &*tv.read().typ.read() {
40                Some(t) => t.check_cast_int(env, hist),
41                None => bail!("can't cast a value to a free type variable"),
42            },
43            Type::Error(e) => e.check_cast_int(env, hist),
44            Type::Array(et) => et.check_cast_int(env, hist),
45            Type::Map { key, value } => {
46                key.check_cast_int(env, hist)?;
47                value.check_cast_int(env, hist)
48            }
49            Type::ByRef(_) => bail!("can't cast a reference"),
50            Type::Tuple(ts) => Ok(for t in ts.iter() {
51                t.check_cast_int(env, hist)?
52            }),
53            Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
54                t.check_cast_int(env, hist)?
55            }),
56            Type::Variant(_, ts) => Ok(for t in ts.iter() {
57                t.check_cast_int(env, hist)?
58            }),
59            Type::Ref(TypeRef { .. }) => {
60                let id = hist.ref_id(self, env);
61                let t = self.lookup_ref(env)?;
62                if hist.contains(&id) {
63                    Ok(())
64                } else {
65                    hist.insert(id);
66                    t.check_cast_int(env, hist)
67                }
68            }
69        }
70    }
71
72    pub fn check_cast(&self, env: &Env) -> Result<()> {
73        self.check_cast_int(env, &mut RefHist::new(LPooled::take()))
74    }
75
76    fn cast_value_int(
77        &self,
78        env: &Env,
79        hist: &mut AHashSet<(usize, usize)>,
80        v: Value,
81    ) -> Result<Value> {
82        if self.is_a_int(env, hist, BitFlags::empty(), &v) {
83            return Ok(v);
84        }
85        match self {
86            Type::Bottom => bail!("can't cast {v} to Bottom"),
87            Type::Fn(_) => bail!("can't cast {v} to a function"),
88            Type::Abstract { id: _, params: _ } => {
89                bail!("can't cast {v} to an abstract type")
90            }
91            Type::ByRef(_) => bail!("can't cast {v} to a reference"),
92            Type::Primitive(s) => s
93                .iter()
94                .find_map(|t| v.clone().cast(t))
95                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
96            Type::Any => Ok(v),
97            Type::Error(e) => {
98                let v = match v {
99                    Value::Error(v) => (*v).clone(),
100                    v => v,
101                };
102                Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
103            }
104            Type::Array(et) => match v {
105                Value::Array(elts) => {
106                    let mut va = elts
107                        .iter()
108                        .map(|el| et.cast_value_int(env, hist, el.clone()))
109                        .collect::<Result<LPooled<Vec<Value>>>>()?;
110                    Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
111                }
112                v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
113            },
114            Type::Map { key, value } => match v {
115                Value::Map(m) => {
116                    let mut m = m
117                        .into_iter()
118                        .map(|(k, v)| {
119                            Ok((
120                                key.cast_value_int(env, hist, k.clone())?,
121                                value.cast_value_int(env, hist, v.clone())?,
122                            ))
123                        })
124                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
125                    Ok(Value::Map(Map::from_iter(m.drain(..))))
126                }
127                Value::Array(a) => {
128                    let mut m = a
129                        .iter()
130                        .map(|a| match a {
131                            Value::Array(a) if a.len() == 2 => Ok((
132                                key.cast_value_int(env, hist, a[0].clone())?,
133                                value.cast_value_int(env, hist, a[1].clone())?,
134                            )),
135                            _ => bail!("expected an array of pairs"),
136                        })
137                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
138                    Ok(Value::Map(Map::from_iter(m.drain(..))))
139                }
140                _ => bail!("can't cast {v} to {self}"),
141            },
142            Type::Tuple(ts) => match v {
143                Value::Array(elts) => {
144                    if elts.len() != ts.len() {
145                        bail!("tuple size mismatch {self} with {}", Value::Array(elts))
146                    }
147                    let mut a = ts
148                        .iter()
149                        .zip(elts.iter())
150                        .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
151                        .collect::<Result<LPooled<Vec<Value>>>>()?;
152                    Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
153                }
154                v => bail!("can't cast {v} to {self}"),
155            },
156            Type::Struct(ts) => match v {
157                Value::Array(elts) => {
158                    if elts.len() != ts.len() {
159                        bail!("struct size mismatch {self} with {}", Value::Array(elts))
160                    }
161                    let is_pairs = elts.iter().all(|v| match v {
162                        Value::Array(a) if a.len() == 2 => match &a[0] {
163                            Value::String(_) => true,
164                            _ => false,
165                        },
166                        _ => false,
167                    });
168                    if !is_pairs {
169                        bail!("expected array of pairs, got {}", Value::Array(elts))
170                    }
171                    let mut elts_s: LPooled<Vec<&Value>> = elts.iter().collect();
172                    elts_s.sort_by_key(|v| match v {
173                        Value::Array(a) => match &a[0] {
174                            Value::String(s) => s,
175                            _ => unreachable!(),
176                        },
177                        _ => unreachable!(),
178                    });
179                    let keys_ok = ts.iter().zip(elts_s.iter()).fold(
180                        Ok(true),
181                        |acc: Result<_>, ((fname, t), v)| {
182                            let kok = acc?;
183                            let (name, v) = match v {
184                                Value::Array(a) => match (&a[0], &a[1]) {
185                                    (Value::String(n), v) => (n, v),
186                                    _ => unreachable!(),
187                                },
188                                _ => unreachable!(),
189                            };
190                            Ok(kok
191                                && name == fname
192                                && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
193                        },
194                    )?;
195                    if keys_ok {
196                        let mut elts = ts
197                            .iter()
198                            .zip(elts_s.iter())
199                            .map(|((n, t), v)| match v {
200                                Value::Array(a) => {
201                                    let a = [
202                                        Value::String(n.clone()),
203                                        t.cast_value_int(env, hist, a[1].clone())?,
204                                    ];
205                                    Ok(Value::Array(ValArray::from_iter_exact(
206                                        a.into_iter(),
207                                    )))
208                                }
209                                _ => unreachable!(),
210                            })
211                            .collect::<Result<LPooled<Vec<Value>>>>()?;
212                        Ok(Value::Array(ValArray::from_iter_exact(elts.drain(..))))
213                    } else {
214                        drop(elts_s);
215                        bail!("struct fields mismatch {self}, {}", Value::Array(elts))
216                    }
217                }
218                v => bail!("can't cast {v} to {self}"),
219            },
220            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
221                Value::String(s) if s == tag => Ok(v),
222                _ => bail!("variant tag mismatch expected {tag} got {v}"),
223            },
224            Type::Variant(tag, ts) => match &v {
225                Value::Array(elts) => {
226                    if ts.len() + 1 == elts.len() {
227                        match &elts[0] {
228                            Value::String(s) if s == tag => (),
229                            v => bail!("variant tag mismatch expected {tag} got {v}"),
230                        }
231                        let mut a = iter::once(&Type::Primitive(Typ::String.into()))
232                            .chain(ts.iter())
233                            .zip(elts.iter())
234                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
235                            .collect::<Result<LPooled<Vec<Value>>>>()?;
236                        Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
237                    } else if ts.len() == elts.len() {
238                        let mut a = ts
239                            .iter()
240                            .zip(elts.iter())
241                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
242                            .collect::<Result<LPooled<Vec<Value>>>>()?;
243                        a.insert(0, Value::String(tag.clone()));
244                        Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
245                    } else {
246                        bail!("variant length mismatch")
247                    }
248                }
249                v => bail!("can't cast {v} to {self}"),
250            },
251            Type::Ref(TypeRef { .. }) => {
252                let t = self.lookup_ref(env)?;
253                t.cast_value_int(env, hist, v)
254            }
255            Type::Set(ts) => ts
256                .iter()
257                .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
258                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
259            Type::TVar(tv) => match &*tv.read().typ.read() {
260                Some(t) => t.cast_value_int(env, hist, v.clone()),
261                None => Ok(v),
262            },
263        }
264    }
265
266    pub fn cast_value(&self, env: &Env, v: Value) -> Value {
267        match self.cast_value_int(env, &mut LPooled::take(), v) {
268            Ok(v) => v,
269            Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
270        }
271    }
272
273    fn is_a_int(
274        &self,
275        env: &Env,
276        hist: &mut AHashSet<(usize, usize)>,
277        flags: BitFlags<IsAFlags>,
278        v: &Value,
279    ) -> bool {
280        match self {
281            Type::Ref(TypeRef { scope, name, .. }) => match self.lookup_ref(env) {
282                Err(_) => false,
283                Ok(t) => {
284                    let t_addr = (scope.as_ref() as *const _ as *const u8).addr()
285                        ^ (name.as_ref() as *const _ as *const u8).addr();
286                    let v_addr = (v as *const Value).addr();
287                    !hist.contains(&(t_addr, v_addr)) && {
288                        hist.insert((t_addr, v_addr));
289                        t.is_a_int(env, hist, flags, v)
290                    }
291                }
292            },
293            Type::Primitive(t) => t.contains(Typ::get(&v)),
294            Type::Abstract { .. } => {
295                flags.contains(IsAFlags::MatchAbstract) && matches!(v, Value::Abstract(_))
296            }
297            Type::Any => true,
298            Type::Array(et) => match v {
299                Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, flags, v)),
300                _ => false,
301            },
302            Type::Map { key, value } => match v {
303                Value::Map(m) => m.into_iter().all(|(k, v)| {
304                    key.is_a_int(env, hist, flags, k)
305                        && value.is_a_int(env, hist, flags, v)
306                }),
307                _ => false,
308            },
309            Type::Error(e) => match v {
310                Value::Error(v) => e.is_a_int(env, hist, flags, v),
311                _ => false,
312            },
313            Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
314            Type::Tuple(ts) => match v {
315                Value::Array(elts) => {
316                    elts.len() == ts.len()
317                        && ts
318                            .iter()
319                            .zip(elts.iter())
320                            .all(|(t, v)| t.is_a_int(env, hist, flags, v))
321                }
322                _ => false,
323            },
324            Type::Struct(ts) => match v {
325                Value::Array(elts) => {
326                    elts.len() == ts.len()
327                        && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
328                            Value::Array(a) if a.len() == 2 => match &a[..] {
329                                [Value::String(key), v] => {
330                                    n == key && t.is_a_int(env, hist, flags, v)
331                                }
332                                _ => false,
333                            },
334                            _ => false,
335                        })
336                }
337                _ => false,
338            },
339            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
340                Value::String(s) => s == tag,
341                _ => false,
342            },
343            Type::Variant(tag, ts) => match &v {
344                Value::Array(elts) => {
345                    ts.len() + 1 == elts.len()
346                        && match &elts[0] {
347                            Value::String(s) => s == tag,
348                            _ => false,
349                        }
350                        && ts
351                            .iter()
352                            .zip(elts[1..].iter())
353                            .all(|(t, v)| t.is_a_int(env, hist, flags, v))
354                }
355                _ => false,
356            },
357            Type::TVar(tv) => match &*tv.read().typ.read() {
358                None => true,
359                Some(t) => t.is_a_int(env, hist, flags, v),
360            },
361            Type::Fn(_) => match v {
362                Value::Abstract(a) if AbstractTypeRegistry::is_a(a, "lambda") => true,
363                _ => false,
364            },
365            Type::Bottom => true,
366            Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, flags, v)),
367        }
368    }
369
370    /// return true if v is structurally compatible with the type
371    pub fn is_a(&self, env: &Env, v: &Value) -> bool {
372        self.is_a_int(env, &mut LPooled::take(), BitFlags::empty(), v)
373    }
374
375    /// return true if v is structurally compatible with the type, with flags
376    pub fn is_a_with(&self, env: &Env, flags: BitFlags<IsAFlags>, v: &Value) -> bool {
377        self.is_a_int(env, &mut LPooled::take(), flags, v)
378    }
379}