Skip to main content

graphix_compiler/typ/
cast.rs

1use crate::{env::Env, errf, typ::{RefHist, Type}, AbstractTypeRegistry, CAST_ERR_TAG};
2use anyhow::{anyhow, bail, Result};
3use arcstr::ArcStr;
4use fxhash::FxHashSet;
5use immutable_chunkmap::map::Map;
6use netidx::publisher::{Typ, Value};
7use netidx_value::ValArray;
8use poolshark::local::LPooled;
9use std::iter;
10use triomphe::Arc;
11
12impl Type {
13    fn check_cast_int(&self, env: &Env, hist: &mut RefHist<FxHashSet<Option<usize>>>) -> Result<()> {
14        match self {
15            Type::Primitive(_) | Type::Any => Ok(()),
16            Type::Fn(_) => bail!("can't cast a value to a function"),
17            Type::Bottom => bail!("can't cast a value to bottom"),
18            Type::Set(s) | Type::Abstract { id: _, params: s } => Ok(for t in s.iter() {
19                t.check_cast_int(env, hist)?
20            }),
21            Type::TVar(tv) => match &*tv.read().typ.read() {
22                Some(t) => t.check_cast_int(env, hist),
23                None => bail!("can't cast a value to a free type variable"),
24            },
25            Type::Error(e) => e.check_cast_int(env, hist),
26            Type::Array(et) => et.check_cast_int(env, hist),
27            Type::Map { key, value } => {
28                key.check_cast_int(env, hist)?;
29                value.check_cast_int(env, hist)
30            }
31            Type::ByRef(_) => bail!("can't cast a reference"),
32            Type::Tuple(ts) => Ok(for t in ts.iter() {
33                t.check_cast_int(env, hist)?
34            }),
35            Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
36                t.check_cast_int(env, hist)?
37            }),
38            Type::Variant(_, ts) => Ok(for t in ts.iter() {
39                t.check_cast_int(env, hist)?
40            }),
41            Type::Ref { .. } => {
42                let id = hist.ref_id(self, env);
43                let t = self.lookup_ref(env)?;
44                if hist.contains(&id) {
45                    Ok(())
46                } else {
47                    hist.insert(id);
48                    t.check_cast_int(env, hist)
49                }
50            }
51        }
52    }
53
54    pub fn check_cast(&self, env: &Env) -> Result<()> {
55        self.check_cast_int(env, &mut RefHist::new(LPooled::take()))
56    }
57
58    fn cast_value_int(
59        &self,
60        env: &Env,
61        hist: &mut FxHashSet<(usize, usize)>,
62        v: Value,
63    ) -> Result<Value> {
64        if self.is_a_int(env, hist, &v) {
65            return Ok(v);
66        }
67        match self {
68            Type::Bottom => bail!("can't cast {v} to Bottom"),
69            Type::Fn(_) => bail!("can't cast {v} to a function"),
70            Type::Abstract { id: _, params: _ } => {
71                bail!("can't cast {v} to an abstract type")
72            }
73            Type::ByRef(_) => bail!("can't cast {v} to a reference"),
74            Type::Primitive(s) => s
75                .iter()
76                .find_map(|t| v.clone().cast(t))
77                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
78            Type::Any => Ok(v),
79            Type::Error(e) => {
80                let v = match v {
81                    Value::Error(v) => (*v).clone(),
82                    v => v,
83                };
84                Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
85            }
86            Type::Array(et) => match v {
87                Value::Array(elts) => {
88                    let mut va = elts
89                        .iter()
90                        .map(|el| et.cast_value_int(env, hist, el.clone()))
91                        .collect::<Result<LPooled<Vec<Value>>>>()?;
92                    Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
93                }
94                v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
95            },
96            Type::Map { key, value } => match v {
97                Value::Map(m) => {
98                    let mut m = m
99                        .into_iter()
100                        .map(|(k, v)| {
101                            Ok((
102                                key.cast_value_int(env, hist, k.clone())?,
103                                value.cast_value_int(env, hist, v.clone())?,
104                            ))
105                        })
106                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
107                    Ok(Value::Map(Map::from_iter(m.drain(..))))
108                }
109                Value::Array(a) => {
110                    let mut m = a
111                        .iter()
112                        .map(|a| match a {
113                            Value::Array(a) if a.len() == 2 => Ok((
114                                key.cast_value_int(env, hist, a[0].clone())?,
115                                value.cast_value_int(env, hist, a[1].clone())?,
116                            )),
117                            _ => bail!("expected an array of pairs"),
118                        })
119                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
120                    Ok(Value::Map(Map::from_iter(m.drain(..))))
121                }
122                _ => bail!("can't cast {v} to {self}"),
123            },
124            Type::Tuple(ts) => match v {
125                Value::Array(elts) => {
126                    if elts.len() != ts.len() {
127                        bail!("tuple size mismatch {self} with {}", Value::Array(elts))
128                    }
129                    let mut a = ts
130                        .iter()
131                        .zip(elts.iter())
132                        .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
133                        .collect::<Result<LPooled<Vec<Value>>>>()?;
134                    Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
135                }
136                v => bail!("can't cast {v} to {self}"),
137            },
138            Type::Struct(ts) => match v {
139                Value::Array(elts) => {
140                    if elts.len() != ts.len() {
141                        bail!("struct size mismatch {self} with {}", Value::Array(elts))
142                    }
143                    let is_pairs = elts.iter().all(|v| match v {
144                        Value::Array(a) if a.len() == 2 => match &a[0] {
145                            Value::String(_) => true,
146                            _ => false,
147                        },
148                        _ => false,
149                    });
150                    if !is_pairs {
151                        bail!("expected array of pairs, got {}", Value::Array(elts))
152                    }
153                    let mut elts_s: LPooled<Vec<&Value>> = elts.iter().collect();
154                    elts_s.sort_by_key(|v| match v {
155                        Value::Array(a) => match &a[0] {
156                            Value::String(s) => s,
157                            _ => unreachable!(),
158                        },
159                        _ => unreachable!(),
160                    });
161                    let keys_ok = ts.iter().zip(elts_s.iter()).fold(
162                        Ok(true),
163                        |acc: Result<_>, ((fname, t), v)| {
164                            let kok = acc?;
165                            let (name, v) = match v {
166                                Value::Array(a) => match (&a[0], &a[1]) {
167                                    (Value::String(n), v) => (n, v),
168                                    _ => unreachable!(),
169                                },
170                                _ => unreachable!(),
171                            };
172                            Ok(kok
173                                && name == fname
174                                && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
175                        },
176                    )?;
177                    if keys_ok {
178                        let mut elts = ts
179                            .iter()
180                            .zip(elts_s.iter())
181                            .map(|((n, t), v)| match v {
182                                Value::Array(a) => {
183                                    let a = [
184                                        Value::String(n.clone()),
185                                        t.cast_value_int(env, hist, a[1].clone())?,
186                                    ];
187                                    Ok(Value::Array(ValArray::from_iter_exact(
188                                        a.into_iter(),
189                                    )))
190                                }
191                                _ => unreachable!(),
192                            })
193                            .collect::<Result<LPooled<Vec<Value>>>>()?;
194                        Ok(Value::Array(ValArray::from_iter_exact(elts.drain(..))))
195                    } else {
196                        drop(elts_s);
197                        bail!("struct fields mismatch {self}, {}", Value::Array(elts))
198                    }
199                }
200                v => bail!("can't cast {v} to {self}"),
201            },
202            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
203                Value::String(s) if s == tag => Ok(v),
204                _ => bail!("variant tag mismatch expected {tag} got {v}"),
205            },
206            Type::Variant(tag, ts) => match &v {
207                Value::Array(elts) => {
208                    if ts.len() + 1 == elts.len() {
209                        match &elts[0] {
210                            Value::String(s) if s == tag => (),
211                            v => bail!("variant tag mismatch expected {tag} got {v}"),
212                        }
213                        let mut a = iter::once(&Type::Primitive(Typ::String.into()))
214                            .chain(ts.iter())
215                            .zip(elts.iter())
216                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
217                            .collect::<Result<LPooled<Vec<Value>>>>()?;
218                        Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
219                    } else if ts.len() == elts.len() {
220                        let mut a = ts
221                            .iter()
222                            .zip(elts.iter())
223                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
224                            .collect::<Result<LPooled<Vec<Value>>>>()?;
225                        a.insert(0, Value::String(tag.clone()));
226                        Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
227                    } else {
228                        bail!("variant length mismatch")
229                    }
230                }
231                v => bail!("can't cast {v} to {self}"),
232            },
233            Type::Ref { .. } => {
234                let t = self.lookup_ref(env)?;
235                t.cast_value_int(env, hist, v)
236            }
237            Type::Set(ts) => ts
238                .iter()
239                .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
240                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
241            Type::TVar(tv) => match &*tv.read().typ.read() {
242                Some(t) => t.cast_value_int(env, hist, v.clone()),
243                None => Ok(v),
244            },
245        }
246    }
247
248    pub fn cast_value(&self, env: &Env, v: Value) -> Value {
249        match self.cast_value_int(env, &mut LPooled::take(), v) {
250            Ok(v) => v,
251            Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
252        }
253    }
254
255    fn is_a_int(
256        &self,
257        env: &Env,
258        hist: &mut FxHashSet<(usize, usize)>,
259        v: &Value,
260    ) -> bool {
261        match self {
262            Type::Ref { .. } => match self.lookup_ref(env) {
263                Err(_) => false,
264                Ok(t) => {
265                    let t_addr = (&t as *const Type).addr();
266                    let v_addr = (v as *const Value).addr();
267                    !hist.contains(&(t_addr, v_addr)) && {
268                        hist.insert((t_addr, v_addr));
269                        t.is_a_int(env, hist, v)
270                    }
271                }
272            },
273            Type::Primitive(t) => t.contains(Typ::get(&v)),
274            Type::Abstract { .. } => false,
275            Type::Any => true,
276            Type::Array(et) => match v {
277                Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, v)),
278                _ => false,
279            },
280            Type::Map { key, value } => match v {
281                Value::Map(m) => m.into_iter().all(|(k, v)| {
282                    key.is_a_int(env, hist, k) && value.is_a_int(env, hist, v)
283                }),
284                _ => false,
285            },
286            Type::Error(e) => match v {
287                Value::Error(v) => e.is_a_int(env, hist, v),
288                _ => false,
289            },
290            Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
291            Type::Tuple(ts) => match v {
292                Value::Array(elts) => {
293                    elts.len() == ts.len()
294                        && ts
295                            .iter()
296                            .zip(elts.iter())
297                            .all(|(t, v)| t.is_a_int(env, hist, v))
298                }
299                _ => false,
300            },
301            Type::Struct(ts) => match v {
302                Value::Array(elts) => {
303                    elts.len() == ts.len()
304                        && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
305                            Value::Array(a) if a.len() == 2 => match &a[..] {
306                                [Value::String(key), v] => {
307                                    n == key && t.is_a_int(env, hist, v)
308                                }
309                                _ => false,
310                            },
311                            _ => false,
312                        })
313                }
314                _ => false,
315            },
316            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
317                Value::String(s) => s == tag,
318                _ => false,
319            },
320            Type::Variant(tag, ts) => match &v {
321                Value::Array(elts) => {
322                    ts.len() + 1 == elts.len()
323                        && match &elts[0] {
324                            Value::String(s) => s == tag,
325                            _ => false,
326                        }
327                        && ts
328                            .iter()
329                            .zip(elts[1..].iter())
330                            .all(|(t, v)| t.is_a_int(env, hist, v))
331                }
332                _ => false,
333            },
334            Type::TVar(tv) => match &*tv.read().typ.read() {
335                None => true,
336                Some(t) => t.is_a_int(env, hist, v),
337            },
338            Type::Fn(_) => match v {
339                Value::Abstract(a) if AbstractTypeRegistry::is_a(a, "lambda") => true,
340                _ => false,
341            },
342            Type::Bottom => true,
343            Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, v)),
344        }
345    }
346
347    /// return true if v is structurally compatible with the type
348    pub fn is_a(&self, env: &Env, v: &Value) -> bool {
349        self.is_a_int(env, &mut LPooled::take(), v)
350    }
351}