1use crate::{
2 env::Env,
3 errf,
4 typ::{RefHist, Type},
5 AbstractTypeRegistry, CAST_ERR_TAG,
6};
7use anyhow::{anyhow, bail, Result};
8use arcstr::ArcStr;
9use enumflags2::{bitflags, BitFlags};
10use fxhash::FxHashSet;
11use immutable_chunkmap::map::Map;
12use netidx::publisher::{Typ, Value};
13use netidx_value::ValArray;
14use poolshark::local::LPooled;
15use std::iter;
16use triomphe::Arc;
17
18#[derive(Debug, Clone, Copy)]
19#[bitflags]
20#[repr(u8)]
21pub enum IsAFlags {
22 MatchAbstract,
24}
25
26impl Type {
27 fn check_cast_int(
28 &self,
29 env: &Env,
30 hist: &mut RefHist<FxHashSet<Option<usize>>>,
31 ) -> Result<()> {
32 match self {
33 Type::Primitive(_) | Type::Any => Ok(()),
34 Type::Fn(_) => bail!("can't cast a value to a function"),
35 Type::Bottom => bail!("can't cast a value to bottom"),
36 Type::Set(s) | Type::Abstract { id: _, params: s } => Ok(for t in s.iter() {
37 t.check_cast_int(env, hist)?
38 }),
39 Type::TVar(tv) => match &*tv.read().typ.read() {
40 Some(t) => t.check_cast_int(env, hist),
41 None => bail!("can't cast a value to a free type variable"),
42 },
43 Type::Error(e) => e.check_cast_int(env, hist),
44 Type::Array(et) => et.check_cast_int(env, hist),
45 Type::Map { key, value } => {
46 key.check_cast_int(env, hist)?;
47 value.check_cast_int(env, hist)
48 }
49 Type::ByRef(_) => bail!("can't cast a reference"),
50 Type::Tuple(ts) => Ok(for t in ts.iter() {
51 t.check_cast_int(env, hist)?
52 }),
53 Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
54 t.check_cast_int(env, hist)?
55 }),
56 Type::Variant(_, ts) => Ok(for t in ts.iter() {
57 t.check_cast_int(env, hist)?
58 }),
59 Type::Ref { .. } => {
60 let id = hist.ref_id(self, env);
61 let t = self.lookup_ref(env)?;
62 if hist.contains(&id) {
63 Ok(())
64 } else {
65 hist.insert(id);
66 t.check_cast_int(env, hist)
67 }
68 }
69 }
70 }
71
72 pub fn check_cast(&self, env: &Env) -> Result<()> {
73 self.check_cast_int(env, &mut RefHist::new(LPooled::take()))
74 }
75
76 fn cast_value_int(
77 &self,
78 env: &Env,
79 hist: &mut FxHashSet<(usize, usize)>,
80 v: Value,
81 ) -> Result<Value> {
82 if self.is_a_int(env, hist, BitFlags::empty(), &v) {
83 return Ok(v);
84 }
85 match self {
86 Type::Bottom => bail!("can't cast {v} to Bottom"),
87 Type::Fn(_) => bail!("can't cast {v} to a function"),
88 Type::Abstract { id: _, params: _ } => {
89 bail!("can't cast {v} to an abstract type")
90 }
91 Type::ByRef(_) => bail!("can't cast {v} to a reference"),
92 Type::Primitive(s) => s
93 .iter()
94 .find_map(|t| v.clone().cast(t))
95 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
96 Type::Any => Ok(v),
97 Type::Error(e) => {
98 let v = match v {
99 Value::Error(v) => (*v).clone(),
100 v => v,
101 };
102 Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
103 }
104 Type::Array(et) => match v {
105 Value::Array(elts) => {
106 let mut va = elts
107 .iter()
108 .map(|el| et.cast_value_int(env, hist, el.clone()))
109 .collect::<Result<LPooled<Vec<Value>>>>()?;
110 Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
111 }
112 v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
113 },
114 Type::Map { key, value } => match v {
115 Value::Map(m) => {
116 let mut m = m
117 .into_iter()
118 .map(|(k, v)| {
119 Ok((
120 key.cast_value_int(env, hist, k.clone())?,
121 value.cast_value_int(env, hist, v.clone())?,
122 ))
123 })
124 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
125 Ok(Value::Map(Map::from_iter(m.drain(..))))
126 }
127 Value::Array(a) => {
128 let mut m = a
129 .iter()
130 .map(|a| match a {
131 Value::Array(a) if a.len() == 2 => Ok((
132 key.cast_value_int(env, hist, a[0].clone())?,
133 value.cast_value_int(env, hist, a[1].clone())?,
134 )),
135 _ => bail!("expected an array of pairs"),
136 })
137 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
138 Ok(Value::Map(Map::from_iter(m.drain(..))))
139 }
140 _ => bail!("can't cast {v} to {self}"),
141 },
142 Type::Tuple(ts) => match v {
143 Value::Array(elts) => {
144 if elts.len() != ts.len() {
145 bail!("tuple size mismatch {self} with {}", Value::Array(elts))
146 }
147 let mut a = ts
148 .iter()
149 .zip(elts.iter())
150 .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
151 .collect::<Result<LPooled<Vec<Value>>>>()?;
152 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
153 }
154 v => bail!("can't cast {v} to {self}"),
155 },
156 Type::Struct(ts) => match v {
157 Value::Array(elts) => {
158 if elts.len() != ts.len() {
159 bail!("struct size mismatch {self} with {}", Value::Array(elts))
160 }
161 let is_pairs = elts.iter().all(|v| match v {
162 Value::Array(a) if a.len() == 2 => match &a[0] {
163 Value::String(_) => true,
164 _ => false,
165 },
166 _ => false,
167 });
168 if !is_pairs {
169 bail!("expected array of pairs, got {}", Value::Array(elts))
170 }
171 let mut elts_s: LPooled<Vec<&Value>> = elts.iter().collect();
172 elts_s.sort_by_key(|v| match v {
173 Value::Array(a) => match &a[0] {
174 Value::String(s) => s,
175 _ => unreachable!(),
176 },
177 _ => unreachable!(),
178 });
179 let keys_ok = ts.iter().zip(elts_s.iter()).fold(
180 Ok(true),
181 |acc: Result<_>, ((fname, t), v)| {
182 let kok = acc?;
183 let (name, v) = match v {
184 Value::Array(a) => match (&a[0], &a[1]) {
185 (Value::String(n), v) => (n, v),
186 _ => unreachable!(),
187 },
188 _ => unreachable!(),
189 };
190 Ok(kok
191 && name == fname
192 && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
193 },
194 )?;
195 if keys_ok {
196 let mut elts = ts
197 .iter()
198 .zip(elts_s.iter())
199 .map(|((n, t), v)| match v {
200 Value::Array(a) => {
201 let a = [
202 Value::String(n.clone()),
203 t.cast_value_int(env, hist, a[1].clone())?,
204 ];
205 Ok(Value::Array(ValArray::from_iter_exact(
206 a.into_iter(),
207 )))
208 }
209 _ => unreachable!(),
210 })
211 .collect::<Result<LPooled<Vec<Value>>>>()?;
212 Ok(Value::Array(ValArray::from_iter_exact(elts.drain(..))))
213 } else {
214 drop(elts_s);
215 bail!("struct fields mismatch {self}, {}", Value::Array(elts))
216 }
217 }
218 v => bail!("can't cast {v} to {self}"),
219 },
220 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
221 Value::String(s) if s == tag => Ok(v),
222 _ => bail!("variant tag mismatch expected {tag} got {v}"),
223 },
224 Type::Variant(tag, ts) => match &v {
225 Value::Array(elts) => {
226 if ts.len() + 1 == elts.len() {
227 match &elts[0] {
228 Value::String(s) if s == tag => (),
229 v => bail!("variant tag mismatch expected {tag} got {v}"),
230 }
231 let mut a = iter::once(&Type::Primitive(Typ::String.into()))
232 .chain(ts.iter())
233 .zip(elts.iter())
234 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
235 .collect::<Result<LPooled<Vec<Value>>>>()?;
236 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
237 } else if ts.len() == elts.len() {
238 let mut a = ts
239 .iter()
240 .zip(elts.iter())
241 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
242 .collect::<Result<LPooled<Vec<Value>>>>()?;
243 a.insert(0, Value::String(tag.clone()));
244 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
245 } else {
246 bail!("variant length mismatch")
247 }
248 }
249 v => bail!("can't cast {v} to {self}"),
250 },
251 Type::Ref { .. } => {
252 let t = self.lookup_ref(env)?;
253 t.cast_value_int(env, hist, v)
254 }
255 Type::Set(ts) => ts
256 .iter()
257 .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
258 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
259 Type::TVar(tv) => match &*tv.read().typ.read() {
260 Some(t) => t.cast_value_int(env, hist, v.clone()),
261 None => Ok(v),
262 },
263 }
264 }
265
266 pub fn cast_value(&self, env: &Env, v: Value) -> Value {
267 match self.cast_value_int(env, &mut LPooled::take(), v) {
268 Ok(v) => v,
269 Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
270 }
271 }
272
273 fn is_a_int(
274 &self,
275 env: &Env,
276 hist: &mut FxHashSet<(usize, usize)>,
277 flags: BitFlags<IsAFlags>,
278 v: &Value,
279 ) -> bool {
280 match self {
281 Type::Ref { .. } => match self.lookup_ref(env) {
282 Err(_) => false,
283 Ok(t) => {
284 let t_addr = (&t as *const Type).addr();
285 let v_addr = (v as *const Value).addr();
286 !hist.contains(&(t_addr, v_addr)) && {
287 hist.insert((t_addr, v_addr));
288 t.is_a_int(env, hist, flags, v)
289 }
290 }
291 },
292 Type::Primitive(t) => t.contains(Typ::get(&v)),
293 Type::Abstract { .. } => {
294 flags.contains(IsAFlags::MatchAbstract) && matches!(v, Value::Abstract(_))
295 }
296 Type::Any => true,
297 Type::Array(et) => match v {
298 Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, flags, v)),
299 _ => false,
300 },
301 Type::Map { key, value } => match v {
302 Value::Map(m) => m.into_iter().all(|(k, v)| {
303 key.is_a_int(env, hist, flags, k)
304 && value.is_a_int(env, hist, flags, v)
305 }),
306 _ => false,
307 },
308 Type::Error(e) => match v {
309 Value::Error(v) => e.is_a_int(env, hist, flags, v),
310 _ => false,
311 },
312 Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
313 Type::Tuple(ts) => match v {
314 Value::Array(elts) => {
315 elts.len() == ts.len()
316 && ts
317 .iter()
318 .zip(elts.iter())
319 .all(|(t, v)| t.is_a_int(env, hist, flags, v))
320 }
321 _ => false,
322 },
323 Type::Struct(ts) => match v {
324 Value::Array(elts) => {
325 elts.len() == ts.len()
326 && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
327 Value::Array(a) if a.len() == 2 => match &a[..] {
328 [Value::String(key), v] => {
329 n == key && t.is_a_int(env, hist, flags, v)
330 }
331 _ => false,
332 },
333 _ => false,
334 })
335 }
336 _ => false,
337 },
338 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
339 Value::String(s) => s == tag,
340 _ => false,
341 },
342 Type::Variant(tag, ts) => match &v {
343 Value::Array(elts) => {
344 ts.len() + 1 == elts.len()
345 && match &elts[0] {
346 Value::String(s) => s == tag,
347 _ => false,
348 }
349 && ts
350 .iter()
351 .zip(elts[1..].iter())
352 .all(|(t, v)| t.is_a_int(env, hist, flags, v))
353 }
354 _ => false,
355 },
356 Type::TVar(tv) => match &*tv.read().typ.read() {
357 None => true,
358 Some(t) => t.is_a_int(env, hist, flags, v),
359 },
360 Type::Fn(_) => match v {
361 Value::Abstract(a) if AbstractTypeRegistry::is_a(a, "lambda") => true,
362 _ => false,
363 },
364 Type::Bottom => true,
365 Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, flags, v)),
366 }
367 }
368
369 pub fn is_a(&self, env: &Env, v: &Value) -> bool {
371 self.is_a_int(env, &mut LPooled::take(), BitFlags::empty(), v)
372 }
373
374 pub fn is_a_with(&self, env: &Env, flags: BitFlags<IsAFlags>, v: &Value) -> bool {
376 self.is_a_int(env, &mut LPooled::take(), flags, v)
377 }
378}