1use crate::{env::Env, errf, typ::{RefHist, Type}, AbstractTypeRegistry, CAST_ERR_TAG};
2use anyhow::{anyhow, bail, Result};
3use arcstr::ArcStr;
4use fxhash::FxHashSet;
5use immutable_chunkmap::map::Map;
6use netidx::publisher::{Typ, Value};
7use netidx_value::ValArray;
8use poolshark::local::LPooled;
9use std::iter;
10use triomphe::Arc;
11
12impl Type {
13 fn check_cast_int(&self, env: &Env, hist: &mut RefHist<FxHashSet<Option<usize>>>) -> Result<()> {
14 match self {
15 Type::Primitive(_) | Type::Any => Ok(()),
16 Type::Fn(_) => bail!("can't cast a value to a function"),
17 Type::Bottom => bail!("can't cast a value to bottom"),
18 Type::Set(s) | Type::Abstract { id: _, params: s } => Ok(for t in s.iter() {
19 t.check_cast_int(env, hist)?
20 }),
21 Type::TVar(tv) => match &*tv.read().typ.read() {
22 Some(t) => t.check_cast_int(env, hist),
23 None => bail!("can't cast a value to a free type variable"),
24 },
25 Type::Error(e) => e.check_cast_int(env, hist),
26 Type::Array(et) => et.check_cast_int(env, hist),
27 Type::Map { key, value } => {
28 key.check_cast_int(env, hist)?;
29 value.check_cast_int(env, hist)
30 }
31 Type::ByRef(_) => bail!("can't cast a reference"),
32 Type::Tuple(ts) => Ok(for t in ts.iter() {
33 t.check_cast_int(env, hist)?
34 }),
35 Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
36 t.check_cast_int(env, hist)?
37 }),
38 Type::Variant(_, ts) => Ok(for t in ts.iter() {
39 t.check_cast_int(env, hist)?
40 }),
41 Type::Ref { .. } => {
42 let id = hist.ref_id(self, env);
43 let t = self.lookup_ref(env)?;
44 if hist.contains(&id) {
45 Ok(())
46 } else {
47 hist.insert(id);
48 t.check_cast_int(env, hist)
49 }
50 }
51 }
52 }
53
54 pub fn check_cast(&self, env: &Env) -> Result<()> {
55 self.check_cast_int(env, &mut RefHist::new(LPooled::take()))
56 }
57
58 fn cast_value_int(
59 &self,
60 env: &Env,
61 hist: &mut FxHashSet<(usize, usize)>,
62 v: Value,
63 ) -> Result<Value> {
64 if self.is_a_int(env, hist, &v) {
65 return Ok(v);
66 }
67 match self {
68 Type::Bottom => bail!("can't cast {v} to Bottom"),
69 Type::Fn(_) => bail!("can't cast {v} to a function"),
70 Type::Abstract { id: _, params: _ } => {
71 bail!("can't cast {v} to an abstract type")
72 }
73 Type::ByRef(_) => bail!("can't cast {v} to a reference"),
74 Type::Primitive(s) => s
75 .iter()
76 .find_map(|t| v.clone().cast(t))
77 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
78 Type::Any => Ok(v),
79 Type::Error(e) => {
80 let v = match v {
81 Value::Error(v) => (*v).clone(),
82 v => v,
83 };
84 Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
85 }
86 Type::Array(et) => match v {
87 Value::Array(elts) => {
88 let mut va = elts
89 .iter()
90 .map(|el| et.cast_value_int(env, hist, el.clone()))
91 .collect::<Result<LPooled<Vec<Value>>>>()?;
92 Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
93 }
94 v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
95 },
96 Type::Map { key, value } => match v {
97 Value::Map(m) => {
98 let mut m = m
99 .into_iter()
100 .map(|(k, v)| {
101 Ok((
102 key.cast_value_int(env, hist, k.clone())?,
103 value.cast_value_int(env, hist, v.clone())?,
104 ))
105 })
106 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
107 Ok(Value::Map(Map::from_iter(m.drain(..))))
108 }
109 Value::Array(a) => {
110 let mut m = a
111 .iter()
112 .map(|a| match a {
113 Value::Array(a) if a.len() == 2 => Ok((
114 key.cast_value_int(env, hist, a[0].clone())?,
115 value.cast_value_int(env, hist, a[1].clone())?,
116 )),
117 _ => bail!("expected an array of pairs"),
118 })
119 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
120 Ok(Value::Map(Map::from_iter(m.drain(..))))
121 }
122 _ => bail!("can't cast {v} to {self}"),
123 },
124 Type::Tuple(ts) => match v {
125 Value::Array(elts) => {
126 if elts.len() != ts.len() {
127 bail!("tuple size mismatch {self} with {}", Value::Array(elts))
128 }
129 let mut a = ts
130 .iter()
131 .zip(elts.iter())
132 .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
133 .collect::<Result<LPooled<Vec<Value>>>>()?;
134 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
135 }
136 v => bail!("can't cast {v} to {self}"),
137 },
138 Type::Struct(ts) => match v {
139 Value::Array(elts) => {
140 if elts.len() != ts.len() {
141 bail!("struct size mismatch {self} with {}", Value::Array(elts))
142 }
143 let is_pairs = elts.iter().all(|v| match v {
144 Value::Array(a) if a.len() == 2 => match &a[0] {
145 Value::String(_) => true,
146 _ => false,
147 },
148 _ => false,
149 });
150 if !is_pairs {
151 bail!("expected array of pairs, got {}", Value::Array(elts))
152 }
153 let mut elts_s: LPooled<Vec<&Value>> = elts.iter().collect();
154 elts_s.sort_by_key(|v| match v {
155 Value::Array(a) => match &a[0] {
156 Value::String(s) => s,
157 _ => unreachable!(),
158 },
159 _ => unreachable!(),
160 });
161 let keys_ok = ts.iter().zip(elts_s.iter()).fold(
162 Ok(true),
163 |acc: Result<_>, ((fname, t), v)| {
164 let kok = acc?;
165 let (name, v) = match v {
166 Value::Array(a) => match (&a[0], &a[1]) {
167 (Value::String(n), v) => (n, v),
168 _ => unreachable!(),
169 },
170 _ => unreachable!(),
171 };
172 Ok(kok
173 && name == fname
174 && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
175 },
176 )?;
177 if keys_ok {
178 let mut elts = ts
179 .iter()
180 .zip(elts_s.iter())
181 .map(|((n, t), v)| match v {
182 Value::Array(a) => {
183 let a = [
184 Value::String(n.clone()),
185 t.cast_value_int(env, hist, a[1].clone())?,
186 ];
187 Ok(Value::Array(ValArray::from_iter_exact(
188 a.into_iter(),
189 )))
190 }
191 _ => unreachable!(),
192 })
193 .collect::<Result<LPooled<Vec<Value>>>>()?;
194 Ok(Value::Array(ValArray::from_iter_exact(elts.drain(..))))
195 } else {
196 drop(elts_s);
197 bail!("struct fields mismatch {self}, {}", Value::Array(elts))
198 }
199 }
200 v => bail!("can't cast {v} to {self}"),
201 },
202 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
203 Value::String(s) if s == tag => Ok(v),
204 _ => bail!("variant tag mismatch expected {tag} got {v}"),
205 },
206 Type::Variant(tag, ts) => match &v {
207 Value::Array(elts) => {
208 if ts.len() + 1 == elts.len() {
209 match &elts[0] {
210 Value::String(s) if s == tag => (),
211 v => bail!("variant tag mismatch expected {tag} got {v}"),
212 }
213 let mut a = iter::once(&Type::Primitive(Typ::String.into()))
214 .chain(ts.iter())
215 .zip(elts.iter())
216 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
217 .collect::<Result<LPooled<Vec<Value>>>>()?;
218 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
219 } else if ts.len() == elts.len() {
220 let mut a = ts
221 .iter()
222 .zip(elts.iter())
223 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
224 .collect::<Result<LPooled<Vec<Value>>>>()?;
225 a.insert(0, Value::String(tag.clone()));
226 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
227 } else {
228 bail!("variant length mismatch")
229 }
230 }
231 v => bail!("can't cast {v} to {self}"),
232 },
233 Type::Ref { .. } => {
234 let t = self.lookup_ref(env)?;
235 t.cast_value_int(env, hist, v)
236 }
237 Type::Set(ts) => ts
238 .iter()
239 .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
240 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
241 Type::TVar(tv) => match &*tv.read().typ.read() {
242 Some(t) => t.cast_value_int(env, hist, v.clone()),
243 None => Ok(v),
244 },
245 }
246 }
247
248 pub fn cast_value(&self, env: &Env, v: Value) -> Value {
249 match self.cast_value_int(env, &mut LPooled::take(), v) {
250 Ok(v) => v,
251 Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
252 }
253 }
254
255 fn is_a_int(
256 &self,
257 env: &Env,
258 hist: &mut FxHashSet<(usize, usize)>,
259 v: &Value,
260 ) -> bool {
261 match self {
262 Type::Ref { .. } => match self.lookup_ref(env) {
263 Err(_) => false,
264 Ok(t) => {
265 let t_addr = (&t as *const Type).addr();
266 let v_addr = (v as *const Value).addr();
267 !hist.contains(&(t_addr, v_addr)) && {
268 hist.insert((t_addr, v_addr));
269 t.is_a_int(env, hist, v)
270 }
271 }
272 },
273 Type::Primitive(t) => t.contains(Typ::get(&v)),
274 Type::Abstract { .. } => false,
275 Type::Any => true,
276 Type::Array(et) => match v {
277 Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, v)),
278 _ => false,
279 },
280 Type::Map { key, value } => match v {
281 Value::Map(m) => m.into_iter().all(|(k, v)| {
282 key.is_a_int(env, hist, k) && value.is_a_int(env, hist, v)
283 }),
284 _ => false,
285 },
286 Type::Error(e) => match v {
287 Value::Error(v) => e.is_a_int(env, hist, v),
288 _ => false,
289 },
290 Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
291 Type::Tuple(ts) => match v {
292 Value::Array(elts) => {
293 elts.len() == ts.len()
294 && ts
295 .iter()
296 .zip(elts.iter())
297 .all(|(t, v)| t.is_a_int(env, hist, v))
298 }
299 _ => false,
300 },
301 Type::Struct(ts) => match v {
302 Value::Array(elts) => {
303 elts.len() == ts.len()
304 && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
305 Value::Array(a) if a.len() == 2 => match &a[..] {
306 [Value::String(key), v] => {
307 n == key && t.is_a_int(env, hist, v)
308 }
309 _ => false,
310 },
311 _ => false,
312 })
313 }
314 _ => false,
315 },
316 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
317 Value::String(s) => s == tag,
318 _ => false,
319 },
320 Type::Variant(tag, ts) => match &v {
321 Value::Array(elts) => {
322 ts.len() + 1 == elts.len()
323 && match &elts[0] {
324 Value::String(s) => s == tag,
325 _ => false,
326 }
327 && ts
328 .iter()
329 .zip(elts[1..].iter())
330 .all(|(t, v)| t.is_a_int(env, hist, v))
331 }
332 _ => false,
333 },
334 Type::TVar(tv) => match &*tv.read().typ.read() {
335 None => true,
336 Some(t) => t.is_a_int(env, hist, v),
337 },
338 Type::Fn(_) => match v {
339 Value::Abstract(a) if AbstractTypeRegistry::is_a(a, "lambda") => true,
340 _ => false,
341 },
342 Type::Bottom => true,
343 Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, v)),
344 }
345 }
346
347 pub fn is_a(&self, env: &Env, v: &Value) -> bool {
349 self.is_a_int(env, &mut LPooled::take(), v)
350 }
351}