1use crate::{
2 env::Env,
3 errf,
4 typ::{RefHist, Type, TypeRef},
5 AbstractTypeRegistry, CAST_ERR_TAG,
6};
7use ahash::AHashSet;
8use anyhow::{anyhow, bail, Result};
9use arcstr::ArcStr;
10use enumflags2::{bitflags, BitFlags};
11use immutable_chunkmap::map::Map;
12use netidx::publisher::{Typ, Value};
13use netidx_value::ValArray;
14use poolshark::local::LPooled;
15use std::iter;
16use triomphe::Arc;
17
18#[derive(Debug, Clone, Copy)]
19#[bitflags]
20#[repr(u8)]
21pub enum IsAFlags {
22 MatchAbstract,
24}
25
26impl Type {
27 fn check_cast_int(
28 &self,
29 env: &Env,
30 hist: &mut RefHist<AHashSet<Option<usize>>>,
31 ) -> Result<()> {
32 match self {
33 Type::Primitive(_) | Type::Any => Ok(()),
34 Type::Fn(_) => bail!("can't cast a value to a function"),
35 Type::Bottom => bail!("can't cast a value to bottom"),
36 Type::Set(s) | Type::Abstract { id: _, params: s } => Ok(for t in s.iter() {
37 t.check_cast_int(env, hist)?
38 }),
39 Type::TVar(tv) => match &*tv.read().typ.read() {
40 Some(t) => t.check_cast_int(env, hist),
41 None => bail!("can't cast a value to a free type variable"),
42 },
43 Type::Error(e) => e.check_cast_int(env, hist),
44 Type::Array(et) => et.check_cast_int(env, hist),
45 Type::Map { key, value } => {
46 key.check_cast_int(env, hist)?;
47 value.check_cast_int(env, hist)
48 }
49 Type::ByRef(_) => bail!("can't cast a reference"),
50 Type::Tuple(ts) => Ok(for t in ts.iter() {
51 t.check_cast_int(env, hist)?
52 }),
53 Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
54 t.check_cast_int(env, hist)?
55 }),
56 Type::Variant(_, ts) => Ok(for t in ts.iter() {
57 t.check_cast_int(env, hist)?
58 }),
59 Type::Ref(TypeRef { .. }) => {
60 let id = hist.ref_id(self, env);
61 let t = self.lookup_ref(env)?;
62 if hist.contains(&id) {
63 Ok(())
64 } else {
65 hist.insert(id);
66 t.check_cast_int(env, hist)
67 }
68 }
69 }
70 }
71
72 pub fn check_cast(&self, env: &Env) -> Result<()> {
73 self.check_cast_int(env, &mut RefHist::new(LPooled::take()))
74 }
75
76 fn cast_value_int(
77 &self,
78 env: &Env,
79 hist: &mut AHashSet<(usize, usize)>,
80 v: Value,
81 ) -> Result<Value> {
82 if self.is_a_int(env, hist, BitFlags::empty(), &v) {
83 return Ok(v);
84 }
85 match self {
86 Type::Bottom => bail!("can't cast {v} to Bottom"),
87 Type::Fn(_) => bail!("can't cast {v} to a function"),
88 Type::Abstract { id: _, params: _ } => {
89 bail!("can't cast {v} to an abstract type")
90 }
91 Type::ByRef(_) => bail!("can't cast {v} to a reference"),
92 Type::Primitive(s) => s
93 .iter()
94 .find_map(|t| v.clone().cast(t))
95 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
96 Type::Any => Ok(v),
97 Type::Error(e) => {
98 let v = match v {
99 Value::Error(v) => (*v).clone(),
100 v => v,
101 };
102 Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
103 }
104 Type::Array(et) => match v {
105 Value::Array(elts) => {
106 let mut va = elts
107 .iter()
108 .map(|el| et.cast_value_int(env, hist, el.clone()))
109 .collect::<Result<LPooled<Vec<Value>>>>()?;
110 Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
111 }
112 v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
113 },
114 Type::Map { key, value } => match v {
115 Value::Map(m) => {
116 let mut m = m
117 .into_iter()
118 .map(|(k, v)| {
119 Ok((
120 key.cast_value_int(env, hist, k.clone())?,
121 value.cast_value_int(env, hist, v.clone())?,
122 ))
123 })
124 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
125 Ok(Value::Map(Map::from_iter(m.drain(..))))
126 }
127 Value::Array(a) => {
128 let mut m = a
129 .iter()
130 .map(|a| match a {
131 Value::Array(a) if a.len() == 2 => Ok((
132 key.cast_value_int(env, hist, a[0].clone())?,
133 value.cast_value_int(env, hist, a[1].clone())?,
134 )),
135 _ => bail!("expected an array of pairs"),
136 })
137 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
138 Ok(Value::Map(Map::from_iter(m.drain(..))))
139 }
140 _ => bail!("can't cast {v} to {self}"),
141 },
142 Type::Tuple(ts) => match v {
143 Value::Array(elts) => {
144 if elts.len() != ts.len() {
145 bail!("tuple size mismatch {self} with {}", Value::Array(elts))
146 }
147 let mut a = ts
148 .iter()
149 .zip(elts.iter())
150 .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
151 .collect::<Result<LPooled<Vec<Value>>>>()?;
152 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
153 }
154 v => bail!("can't cast {v} to {self}"),
155 },
156 Type::Struct(ts) => match v {
157 Value::Array(elts) => {
158 if elts.len() != ts.len() {
159 bail!("struct size mismatch {self} with {}", Value::Array(elts))
160 }
161 let is_pairs = elts.iter().all(|v| match v {
162 Value::Array(a) if a.len() == 2 => match &a[0] {
163 Value::String(_) => true,
164 _ => false,
165 },
166 _ => false,
167 });
168 if !is_pairs {
169 bail!("expected array of pairs, got {}", Value::Array(elts))
170 }
171 let mut elts_s: LPooled<Vec<&Value>> = elts.iter().collect();
172 elts_s.sort_by_key(|v| match v {
173 Value::Array(a) => match &a[0] {
174 Value::String(s) => s,
175 _ => unreachable!(),
176 },
177 _ => unreachable!(),
178 });
179 let keys_ok = ts.iter().zip(elts_s.iter()).fold(
180 Ok(true),
181 |acc: Result<_>, ((fname, t), v)| {
182 let kok = acc?;
183 let (name, v) = match v {
184 Value::Array(a) => match (&a[0], &a[1]) {
185 (Value::String(n), v) => (n, v),
186 _ => unreachable!(),
187 },
188 _ => unreachable!(),
189 };
190 Ok(kok
191 && name == fname
192 && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
193 },
194 )?;
195 if keys_ok {
196 let mut elts = ts
197 .iter()
198 .zip(elts_s.iter())
199 .map(|((n, t), v)| match v {
200 Value::Array(a) => {
201 let a = [
202 Value::String(n.clone()),
203 t.cast_value_int(env, hist, a[1].clone())?,
204 ];
205 Ok(Value::Array(ValArray::from_iter_exact(
206 a.into_iter(),
207 )))
208 }
209 _ => unreachable!(),
210 })
211 .collect::<Result<LPooled<Vec<Value>>>>()?;
212 Ok(Value::Array(ValArray::from_iter_exact(elts.drain(..))))
213 } else {
214 drop(elts_s);
215 bail!("struct fields mismatch {self}, {}", Value::Array(elts))
216 }
217 }
218 v => bail!("can't cast {v} to {self}"),
219 },
220 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
221 Value::String(s) if s == tag => Ok(v),
222 _ => bail!("variant tag mismatch expected {tag} got {v}"),
223 },
224 Type::Variant(tag, ts) => match &v {
225 Value::Array(elts) => {
226 if ts.len() + 1 == elts.len() {
227 match &elts[0] {
228 Value::String(s) if s == tag => (),
229 v => bail!("variant tag mismatch expected {tag} got {v}"),
230 }
231 let mut a = iter::once(&Type::Primitive(Typ::String.into()))
232 .chain(ts.iter())
233 .zip(elts.iter())
234 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
235 .collect::<Result<LPooled<Vec<Value>>>>()?;
236 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
237 } else if ts.len() == elts.len() {
238 let mut a = ts
239 .iter()
240 .zip(elts.iter())
241 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
242 .collect::<Result<LPooled<Vec<Value>>>>()?;
243 a.insert(0, Value::String(tag.clone()));
244 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
245 } else {
246 bail!("variant length mismatch")
247 }
248 }
249 v => bail!("can't cast {v} to {self}"),
250 },
251 Type::Ref(TypeRef { .. }) => {
252 let t = self.lookup_ref(env)?;
253 t.cast_value_int(env, hist, v)
254 }
255 Type::Set(ts) => ts
256 .iter()
257 .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
258 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
259 Type::TVar(tv) => match &*tv.read().typ.read() {
260 Some(t) => t.cast_value_int(env, hist, v.clone()),
261 None => Ok(v),
262 },
263 }
264 }
265
266 pub fn cast_value(&self, env: &Env, v: Value) -> Value {
267 match self.cast_value_int(env, &mut LPooled::take(), v) {
268 Ok(v) => v,
269 Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
270 }
271 }
272
273 fn is_a_int(
274 &self,
275 env: &Env,
276 hist: &mut AHashSet<(usize, usize)>,
277 flags: BitFlags<IsAFlags>,
278 v: &Value,
279 ) -> bool {
280 match self {
281 Type::Ref(TypeRef { scope, name, .. }) => match self.lookup_ref(env) {
282 Err(_) => false,
283 Ok(t) => {
284 let t_addr = (scope.as_ref() as *const _ as *const u8).addr()
285 ^ (name.as_ref() as *const _ as *const u8).addr();
286 let v_addr = (v as *const Value).addr();
287 !hist.contains(&(t_addr, v_addr)) && {
288 hist.insert((t_addr, v_addr));
289 t.is_a_int(env, hist, flags, v)
290 }
291 }
292 },
293 Type::Primitive(t) => t.contains(Typ::get(&v)),
294 Type::Abstract { .. } => {
295 flags.contains(IsAFlags::MatchAbstract) && matches!(v, Value::Abstract(_))
296 }
297 Type::Any => true,
298 Type::Array(et) => match v {
299 Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, flags, v)),
300 _ => false,
301 },
302 Type::Map { key, value } => match v {
303 Value::Map(m) => m.into_iter().all(|(k, v)| {
304 key.is_a_int(env, hist, flags, k)
305 && value.is_a_int(env, hist, flags, v)
306 }),
307 _ => false,
308 },
309 Type::Error(e) => match v {
310 Value::Error(v) => e.is_a_int(env, hist, flags, v),
311 _ => false,
312 },
313 Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
314 Type::Tuple(ts) => match v {
315 Value::Array(elts) => {
316 elts.len() == ts.len()
317 && ts
318 .iter()
319 .zip(elts.iter())
320 .all(|(t, v)| t.is_a_int(env, hist, flags, v))
321 }
322 _ => false,
323 },
324 Type::Struct(ts) => match v {
325 Value::Array(elts) => {
326 elts.len() == ts.len()
327 && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
328 Value::Array(a) if a.len() == 2 => match &a[..] {
329 [Value::String(key), v] => {
330 n == key && t.is_a_int(env, hist, flags, v)
331 }
332 _ => false,
333 },
334 _ => false,
335 })
336 }
337 _ => false,
338 },
339 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
340 Value::String(s) => s == tag,
341 _ => false,
342 },
343 Type::Variant(tag, ts) => match &v {
344 Value::Array(elts) => {
345 ts.len() + 1 == elts.len()
346 && match &elts[0] {
347 Value::String(s) => s == tag,
348 _ => false,
349 }
350 && ts
351 .iter()
352 .zip(elts[1..].iter())
353 .all(|(t, v)| t.is_a_int(env, hist, flags, v))
354 }
355 _ => false,
356 },
357 Type::TVar(tv) => match &*tv.read().typ.read() {
358 None => true,
359 Some(t) => t.is_a_int(env, hist, flags, v),
360 },
361 Type::Fn(_) => match v {
362 Value::Abstract(a) if AbstractTypeRegistry::is_a(a, "lambda") => true,
363 _ => false,
364 },
365 Type::Bottom => true,
366 Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, flags, v)),
367 }
368 }
369
370 pub fn is_a(&self, env: &Env, v: &Value) -> bool {
372 self.is_a_int(env, &mut LPooled::take(), BitFlags::empty(), v)
373 }
374
375 pub fn is_a_with(&self, env: &Env, flags: BitFlags<IsAFlags>, v: &Value) -> bool {
377 self.is_a_int(env, &mut LPooled::take(), flags, v)
378 }
379}