Skip to main content

graphix_compiler/typ/
cast.rs

1use crate::{env::Env, errf, typ::Type, CAST_ERR_TAG};
2use anyhow::{anyhow, bail, Result};
3use arcstr::ArcStr;
4use fxhash::FxHashSet;
5use immutable_chunkmap::map::Map;
6use netidx::publisher::{Typ, Value};
7use netidx_value::ValArray;
8use poolshark::local::LPooled;
9use std::iter;
10use triomphe::Arc;
11
12impl Type {
13    fn check_cast_int(&self, env: &Env, hist: &mut FxHashSet<usize>) -> Result<()> {
14        match self {
15            Type::Primitive(_) | Type::Any => Ok(()),
16            Type::Fn(_) => bail!("can't cast a value to a function"),
17            Type::Bottom => bail!("can't cast a value to bottom"),
18            Type::Set(s) | Type::Abstract { id: _, params: s } => Ok(for t in s.iter() {
19                t.check_cast_int(env, hist)?
20            }),
21            Type::TVar(tv) => match &*tv.read().typ.read() {
22                Some(t) => t.check_cast_int(env, hist),
23                None => bail!("can't cast a value to a free type variable"),
24            },
25            Type::Error(e) => e.check_cast_int(env, hist),
26            Type::Array(et) => et.check_cast_int(env, hist),
27            Type::Map { key, value } => {
28                key.check_cast_int(env, hist)?;
29                value.check_cast_int(env, hist)
30            }
31            Type::ByRef(_) => bail!("can't cast a reference"),
32            Type::Tuple(ts) => Ok(for t in ts.iter() {
33                t.check_cast_int(env, hist)?
34            }),
35            Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
36                t.check_cast_int(env, hist)?
37            }),
38            Type::Variant(_, ts) => Ok(for t in ts.iter() {
39                t.check_cast_int(env, hist)?
40            }),
41            Type::Ref { .. } => {
42                let t = self.lookup_ref(env)?;
43                let t_addr = (t as *const Type).addr();
44                if hist.contains(&t_addr) {
45                    Ok(())
46                } else {
47                    hist.insert(t_addr);
48                    t.check_cast_int(env, hist)
49                }
50            }
51        }
52    }
53
54    pub fn check_cast(&self, env: &Env) -> Result<()> {
55        self.check_cast_int(env, &mut LPooled::take())
56    }
57
58    fn cast_value_int(
59        &self,
60        env: &Env,
61        hist: &mut FxHashSet<(usize, usize)>,
62        v: Value,
63    ) -> Result<Value> {
64        if self.is_a_int(env, hist, &v) {
65            return Ok(v);
66        }
67        match self {
68            Type::Bottom => bail!("can't cast {v} to Bottom"),
69            Type::Fn(_) => bail!("can't cast {v} to a function"),
70            Type::Abstract { id: _, params: _ } => {
71                bail!("can't cast {v} to an abstract type")
72            }
73            Type::ByRef(_) => bail!("can't cast {v} to a reference"),
74            Type::Primitive(s) => s
75                .iter()
76                .find_map(|t| v.clone().cast(t))
77                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
78            Type::Any => Ok(v),
79            Type::Error(e) => {
80                let v = match v {
81                    Value::Error(v) => (*v).clone(),
82                    v => v,
83                };
84                Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
85            }
86            Type::Array(et) => match v {
87                Value::Array(elts) => {
88                    let mut va = elts
89                        .iter()
90                        .map(|el| et.cast_value_int(env, hist, el.clone()))
91                        .collect::<Result<LPooled<Vec<Value>>>>()?;
92                    Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
93                }
94                v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
95            },
96            Type::Map { key, value } => match v {
97                Value::Map(m) => {
98                    let mut m = m
99                        .into_iter()
100                        .map(|(k, v)| {
101                            Ok((
102                                key.cast_value_int(env, hist, k.clone())?,
103                                value.cast_value_int(env, hist, v.clone())?,
104                            ))
105                        })
106                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
107                    Ok(Value::Map(Map::from_iter(m.drain(..))))
108                }
109                Value::Array(a) => {
110                    let mut m = a
111                        .iter()
112                        .map(|a| match a {
113                            Value::Array(a) if a.len() == 2 => Ok((
114                                key.cast_value_int(env, hist, a[0].clone())?,
115                                value.cast_value_int(env, hist, a[1].clone())?,
116                            )),
117                            _ => bail!("expected an array of pairs"),
118                        })
119                        .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
120                    Ok(Value::Map(Map::from_iter(m.drain(..))))
121                }
122                _ => bail!("can't cast {v} to {self}"),
123            },
124            Type::Tuple(ts) => match v {
125                Value::Array(elts) => {
126                    if elts.len() != ts.len() {
127                        bail!("tuple size mismatch {self} with {}", Value::Array(elts))
128                    }
129                    let mut a = ts
130                        .iter()
131                        .zip(elts.iter())
132                        .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
133                        .collect::<Result<LPooled<Vec<Value>>>>()?;
134                    Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
135                }
136                v => bail!("can't cast {v} to {self}"),
137            },
138            Type::Struct(ts) => match v {
139                Value::Array(elts) => {
140                    if elts.len() != ts.len() {
141                        bail!("struct size mismatch {self} with {}", Value::Array(elts))
142                    }
143                    let is_pairs = elts.iter().all(|v| match v {
144                        Value::Array(a) if a.len() == 2 => match &a[0] {
145                            Value::String(_) => true,
146                            _ => false,
147                        },
148                        _ => false,
149                    });
150                    if !is_pairs {
151                        bail!("expected array of pairs, got {}", Value::Array(elts))
152                    }
153                    let mut elts_s: LPooled<Vec<&Value>> = elts.iter().collect();
154                    elts_s.sort_by_key(|v| match v {
155                        Value::Array(a) => match &a[0] {
156                            Value::String(s) => s,
157                            _ => unreachable!(),
158                        },
159                        _ => unreachable!(),
160                    });
161                    let keys_ok = ts.iter().zip(elts_s.iter()).fold(
162                        Ok(true),
163                        |acc: Result<_>, ((fname, t), v)| {
164                            let kok = acc?;
165                            let (name, v) = match v {
166                                Value::Array(a) => match (&a[0], &a[1]) {
167                                    (Value::String(n), v) => (n, v),
168                                    _ => unreachable!(),
169                                },
170                                _ => unreachable!(),
171                            };
172                            Ok(kok
173                                && name == fname
174                                && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
175                        },
176                    )?;
177                    if keys_ok {
178                        let mut elts = ts
179                            .iter()
180                            .zip(elts_s.iter())
181                            .map(|((n, t), v)| match v {
182                                Value::Array(a) => {
183                                    let a = [
184                                        Value::String(n.clone()),
185                                        t.cast_value_int(env, hist, a[1].clone())?,
186                                    ];
187                                    Ok(Value::Array(ValArray::from_iter_exact(
188                                        a.into_iter(),
189                                    )))
190                                }
191                                _ => unreachable!(),
192                            })
193                            .collect::<Result<LPooled<Vec<Value>>>>()?;
194                        Ok(Value::Array(ValArray::from_iter_exact(elts.drain(..))))
195                    } else {
196                        drop(elts_s);
197                        bail!("struct fields mismatch {self}, {}", Value::Array(elts))
198                    }
199                }
200                v => bail!("can't cast {v} to {self}"),
201            },
202            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
203                Value::String(s) if s == tag => Ok(v),
204                _ => bail!("variant tag mismatch expected {tag} got {v}"),
205            },
206            Type::Variant(tag, ts) => match &v {
207                Value::Array(elts) => {
208                    if ts.len() + 1 == elts.len() {
209                        match &elts[0] {
210                            Value::String(s) if s == tag => (),
211                            v => bail!("variant tag mismatch expected {tag} got {v}"),
212                        }
213                        let mut a = iter::once(&Type::Primitive(Typ::String.into()))
214                            .chain(ts.iter())
215                            .zip(elts.iter())
216                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
217                            .collect::<Result<LPooled<Vec<Value>>>>()?;
218                        Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
219                    } else if ts.len() == elts.len() {
220                        let mut a = ts
221                            .iter()
222                            .zip(elts.iter())
223                            .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
224                            .collect::<Result<LPooled<Vec<Value>>>>()?;
225                        a.insert(0, Value::String(tag.clone()));
226                        Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
227                    } else {
228                        bail!("variant length mismatch")
229                    }
230                }
231                v => bail!("can't cast {v} to {self}"),
232            },
233            Type::Ref { .. } => self.lookup_ref(env)?.cast_value_int(env, hist, v),
234            Type::Set(ts) => ts
235                .iter()
236                .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
237                .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
238            Type::TVar(tv) => match &*tv.read().typ.read() {
239                Some(t) => t.cast_value_int(env, hist, v.clone()),
240                None => Ok(v),
241            },
242        }
243    }
244
245    pub fn cast_value(&self, env: &Env, v: Value) -> Value {
246        match self.cast_value_int(env, &mut LPooled::take(), v) {
247            Ok(v) => v,
248            Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
249        }
250    }
251
252    fn is_a_int(
253        &self,
254        env: &Env,
255        hist: &mut FxHashSet<(usize, usize)>,
256        v: &Value,
257    ) -> bool {
258        match self {
259            Type::Ref { .. } => match self.lookup_ref(env) {
260                Err(_) => false,
261                Ok(t) => {
262                    let t_addr = (t as *const Type).addr();
263                    let v_addr = (v as *const Value).addr();
264                    !hist.contains(&(t_addr, v_addr)) && {
265                        hist.insert((t_addr, v_addr));
266                        t.is_a_int(env, hist, v)
267                    }
268                }
269            },
270            Type::Primitive(t) => t.contains(Typ::get(&v)),
271            Type::Abstract { .. } => false,
272            Type::Any => true,
273            Type::Array(et) => match v {
274                Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, v)),
275                _ => false,
276            },
277            Type::Map { key, value } => match v {
278                Value::Map(m) => m.into_iter().all(|(k, v)| {
279                    key.is_a_int(env, hist, k) && value.is_a_int(env, hist, v)
280                }),
281                _ => false,
282            },
283            Type::Error(e) => match v {
284                Value::Error(v) => e.is_a_int(env, hist, v),
285                _ => false,
286            },
287            Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
288            Type::Tuple(ts) => match v {
289                Value::Array(elts) => {
290                    elts.len() == ts.len()
291                        && ts
292                            .iter()
293                            .zip(elts.iter())
294                            .all(|(t, v)| t.is_a_int(env, hist, v))
295                }
296                _ => false,
297            },
298            Type::Struct(ts) => match v {
299                Value::Array(elts) => {
300                    elts.len() == ts.len()
301                        && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
302                            Value::Array(a) if a.len() == 2 => match &a[..] {
303                                [Value::String(key), v] => {
304                                    n == key && t.is_a_int(env, hist, v)
305                                }
306                                _ => false,
307                            },
308                            _ => false,
309                        })
310                }
311                _ => false,
312            },
313            Type::Variant(tag, ts) if ts.len() == 0 => match &v {
314                Value::String(s) => s == tag,
315                _ => false,
316            },
317            Type::Variant(tag, ts) => match &v {
318                Value::Array(elts) => {
319                    ts.len() + 1 == elts.len()
320                        && match &elts[0] {
321                            Value::String(s) => s == tag,
322                            _ => false,
323                        }
324                        && ts
325                            .iter()
326                            .zip(elts[1..].iter())
327                            .all(|(t, v)| t.is_a_int(env, hist, v))
328                }
329                _ => false,
330            },
331            Type::TVar(tv) => match &*tv.read().typ.read() {
332                None => true,
333                Some(t) => t.is_a_int(env, hist, v),
334            },
335            Type::Fn(_) => match v {
336                Value::U64(_) => true,
337                _ => false,
338            },
339            Type::Bottom => true,
340            Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, v)),
341        }
342    }
343
344    /// return true if v is structurally compatible with the type
345    pub fn is_a(&self, env: &Env, v: &Value) -> bool {
346        self.is_a_int(env, &mut LPooled::take(), v)
347    }
348}