1use crate::{env::Env, errf, typ::Type, CAST_ERR_TAG};
2use anyhow::{anyhow, bail, Result};
3use arcstr::ArcStr;
4use fxhash::FxHashSet;
5use immutable_chunkmap::map::Map;
6use netidx::publisher::{Typ, Value};
7use netidx_value::ValArray;
8use poolshark::local::LPooled;
9use std::iter;
10use triomphe::Arc;
11
12impl Type {
13 fn check_cast_int(&self, env: &Env, hist: &mut FxHashSet<usize>) -> Result<()> {
14 match self {
15 Type::Primitive(_) | Type::Any => Ok(()),
16 Type::Fn(_) => bail!("can't cast a value to a function"),
17 Type::Bottom => bail!("can't cast a value to bottom"),
18 Type::Set(s) | Type::Abstract { id: _, params: s } => Ok(for t in s.iter() {
19 t.check_cast_int(env, hist)?
20 }),
21 Type::TVar(tv) => match &*tv.read().typ.read() {
22 Some(t) => t.check_cast_int(env, hist),
23 None => bail!("can't cast a value to a free type variable"),
24 },
25 Type::Error(e) => e.check_cast_int(env, hist),
26 Type::Array(et) => et.check_cast_int(env, hist),
27 Type::Map { key, value } => {
28 key.check_cast_int(env, hist)?;
29 value.check_cast_int(env, hist)
30 }
31 Type::ByRef(_) => bail!("can't cast a reference"),
32 Type::Tuple(ts) => Ok(for t in ts.iter() {
33 t.check_cast_int(env, hist)?
34 }),
35 Type::Struct(ts) => Ok(for (_, t) in ts.iter() {
36 t.check_cast_int(env, hist)?
37 }),
38 Type::Variant(_, ts) => Ok(for t in ts.iter() {
39 t.check_cast_int(env, hist)?
40 }),
41 Type::Ref { .. } => {
42 let t = self.lookup_ref(env)?;
43 let t_addr = (t as *const Type).addr();
44 if hist.contains(&t_addr) {
45 Ok(())
46 } else {
47 hist.insert(t_addr);
48 t.check_cast_int(env, hist)
49 }
50 }
51 }
52 }
53
54 pub fn check_cast(&self, env: &Env) -> Result<()> {
55 self.check_cast_int(env, &mut LPooled::take())
56 }
57
58 fn cast_value_int(
59 &self,
60 env: &Env,
61 hist: &mut FxHashSet<(usize, usize)>,
62 v: Value,
63 ) -> Result<Value> {
64 if self.is_a_int(env, hist, &v) {
65 return Ok(v);
66 }
67 match self {
68 Type::Bottom => bail!("can't cast {v} to Bottom"),
69 Type::Fn(_) => bail!("can't cast {v} to a function"),
70 Type::Abstract { id: _, params: _ } => {
71 bail!("can't cast {v} to an abstract type")
72 }
73 Type::ByRef(_) => bail!("can't cast {v} to a reference"),
74 Type::Primitive(s) => s
75 .iter()
76 .find_map(|t| v.clone().cast(t))
77 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
78 Type::Any => Ok(v),
79 Type::Error(e) => {
80 let v = match v {
81 Value::Error(v) => (*v).clone(),
82 v => v,
83 };
84 Ok(Value::Error(Arc::new(e.cast_value_int(env, hist, v)?)))
85 }
86 Type::Array(et) => match v {
87 Value::Array(elts) => {
88 let mut va = elts
89 .iter()
90 .map(|el| et.cast_value_int(env, hist, el.clone()))
91 .collect::<Result<LPooled<Vec<Value>>>>()?;
92 Ok(Value::Array(ValArray::from_iter_exact(va.drain(..))))
93 }
94 v => Ok(Value::Array([et.cast_value_int(env, hist, v)?].into())),
95 },
96 Type::Map { key, value } => match v {
97 Value::Map(m) => {
98 let mut m = m
99 .into_iter()
100 .map(|(k, v)| {
101 Ok((
102 key.cast_value_int(env, hist, k.clone())?,
103 value.cast_value_int(env, hist, v.clone())?,
104 ))
105 })
106 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
107 Ok(Value::Map(Map::from_iter(m.drain(..))))
108 }
109 Value::Array(a) => {
110 let mut m = a
111 .iter()
112 .map(|a| match a {
113 Value::Array(a) if a.len() == 2 => Ok((
114 key.cast_value_int(env, hist, a[0].clone())?,
115 value.cast_value_int(env, hist, a[1].clone())?,
116 )),
117 _ => bail!("expected an array of pairs"),
118 })
119 .collect::<Result<LPooled<Vec<(Value, Value)>>>>()?;
120 Ok(Value::Map(Map::from_iter(m.drain(..))))
121 }
122 _ => bail!("can't cast {v} to {self}"),
123 },
124 Type::Tuple(ts) => match v {
125 Value::Array(elts) => {
126 if elts.len() != ts.len() {
127 bail!("tuple size mismatch {self} with {}", Value::Array(elts))
128 }
129 let mut a = ts
130 .iter()
131 .zip(elts.iter())
132 .map(|(t, el)| t.cast_value_int(env, hist, el.clone()))
133 .collect::<Result<LPooled<Vec<Value>>>>()?;
134 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
135 }
136 v => bail!("can't cast {v} to {self}"),
137 },
138 Type::Struct(ts) => match v {
139 Value::Array(elts) => {
140 if elts.len() != ts.len() {
141 bail!("struct size mismatch {self} with {}", Value::Array(elts))
142 }
143 let is_pairs = elts.iter().all(|v| match v {
144 Value::Array(a) if a.len() == 2 => match &a[0] {
145 Value::String(_) => true,
146 _ => false,
147 },
148 _ => false,
149 });
150 if !is_pairs {
151 bail!("expected array of pairs, got {}", Value::Array(elts))
152 }
153 let mut elts_s: LPooled<Vec<&Value>> = elts.iter().collect();
154 elts_s.sort_by_key(|v| match v {
155 Value::Array(a) => match &a[0] {
156 Value::String(s) => s,
157 _ => unreachable!(),
158 },
159 _ => unreachable!(),
160 });
161 let keys_ok = ts.iter().zip(elts_s.iter()).fold(
162 Ok(true),
163 |acc: Result<_>, ((fname, t), v)| {
164 let kok = acc?;
165 let (name, v) = match v {
166 Value::Array(a) => match (&a[0], &a[1]) {
167 (Value::String(n), v) => (n, v),
168 _ => unreachable!(),
169 },
170 _ => unreachable!(),
171 };
172 Ok(kok
173 && name == fname
174 && t.contains(env, &Type::Primitive(Typ::get(v).into()))?)
175 },
176 )?;
177 if keys_ok {
178 let mut elts = ts
179 .iter()
180 .zip(elts_s.iter())
181 .map(|((n, t), v)| match v {
182 Value::Array(a) => {
183 let a = [
184 Value::String(n.clone()),
185 t.cast_value_int(env, hist, a[1].clone())?,
186 ];
187 Ok(Value::Array(ValArray::from_iter_exact(
188 a.into_iter(),
189 )))
190 }
191 _ => unreachable!(),
192 })
193 .collect::<Result<LPooled<Vec<Value>>>>()?;
194 Ok(Value::Array(ValArray::from_iter_exact(elts.drain(..))))
195 } else {
196 drop(elts_s);
197 bail!("struct fields mismatch {self}, {}", Value::Array(elts))
198 }
199 }
200 v => bail!("can't cast {v} to {self}"),
201 },
202 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
203 Value::String(s) if s == tag => Ok(v),
204 _ => bail!("variant tag mismatch expected {tag} got {v}"),
205 },
206 Type::Variant(tag, ts) => match &v {
207 Value::Array(elts) => {
208 if ts.len() + 1 == elts.len() {
209 match &elts[0] {
210 Value::String(s) if s == tag => (),
211 v => bail!("variant tag mismatch expected {tag} got {v}"),
212 }
213 let mut a = iter::once(&Type::Primitive(Typ::String.into()))
214 .chain(ts.iter())
215 .zip(elts.iter())
216 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
217 .collect::<Result<LPooled<Vec<Value>>>>()?;
218 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
219 } else if ts.len() == elts.len() {
220 let mut a = ts
221 .iter()
222 .zip(elts.iter())
223 .map(|(t, v)| t.cast_value_int(env, hist, v.clone()))
224 .collect::<Result<LPooled<Vec<Value>>>>()?;
225 a.insert(0, Value::String(tag.clone()));
226 Ok(Value::Array(ValArray::from_iter_exact(a.drain(..))))
227 } else {
228 bail!("variant length mismatch")
229 }
230 }
231 v => bail!("can't cast {v} to {self}"),
232 },
233 Type::Ref { .. } => self.lookup_ref(env)?.cast_value_int(env, hist, v),
234 Type::Set(ts) => ts
235 .iter()
236 .find_map(|t| t.cast_value_int(env, hist, v.clone()).ok())
237 .ok_or_else(|| anyhow!("can't cast {v} to {self}")),
238 Type::TVar(tv) => match &*tv.read().typ.read() {
239 Some(t) => t.cast_value_int(env, hist, v.clone()),
240 None => Ok(v),
241 },
242 }
243 }
244
245 pub fn cast_value(&self, env: &Env, v: Value) -> Value {
246 match self.cast_value_int(env, &mut LPooled::take(), v) {
247 Ok(v) => v,
248 Err(e) => errf!(CAST_ERR_TAG, "{e:?}"),
249 }
250 }
251
252 fn is_a_int(
253 &self,
254 env: &Env,
255 hist: &mut FxHashSet<(usize, usize)>,
256 v: &Value,
257 ) -> bool {
258 match self {
259 Type::Ref { .. } => match self.lookup_ref(env) {
260 Err(_) => false,
261 Ok(t) => {
262 let t_addr = (t as *const Type).addr();
263 let v_addr = (v as *const Value).addr();
264 !hist.contains(&(t_addr, v_addr)) && {
265 hist.insert((t_addr, v_addr));
266 t.is_a_int(env, hist, v)
267 }
268 }
269 },
270 Type::Primitive(t) => t.contains(Typ::get(&v)),
271 Type::Abstract { .. } => false,
272 Type::Any => true,
273 Type::Array(et) => match v {
274 Value::Array(a) => a.iter().all(|v| et.is_a_int(env, hist, v)),
275 _ => false,
276 },
277 Type::Map { key, value } => match v {
278 Value::Map(m) => m.into_iter().all(|(k, v)| {
279 key.is_a_int(env, hist, k) && value.is_a_int(env, hist, v)
280 }),
281 _ => false,
282 },
283 Type::Error(e) => match v {
284 Value::Error(v) => e.is_a_int(env, hist, v),
285 _ => false,
286 },
287 Type::ByRef(_) => matches!(v, Value::U64(_) | Value::V64(_)),
288 Type::Tuple(ts) => match v {
289 Value::Array(elts) => {
290 elts.len() == ts.len()
291 && ts
292 .iter()
293 .zip(elts.iter())
294 .all(|(t, v)| t.is_a_int(env, hist, v))
295 }
296 _ => false,
297 },
298 Type::Struct(ts) => match v {
299 Value::Array(elts) => {
300 elts.len() == ts.len()
301 && ts.iter().zip(elts.iter()).all(|((n, t), v)| match v {
302 Value::Array(a) if a.len() == 2 => match &a[..] {
303 [Value::String(key), v] => {
304 n == key && t.is_a_int(env, hist, v)
305 }
306 _ => false,
307 },
308 _ => false,
309 })
310 }
311 _ => false,
312 },
313 Type::Variant(tag, ts) if ts.len() == 0 => match &v {
314 Value::String(s) => s == tag,
315 _ => false,
316 },
317 Type::Variant(tag, ts) => match &v {
318 Value::Array(elts) => {
319 ts.len() + 1 == elts.len()
320 && match &elts[0] {
321 Value::String(s) => s == tag,
322 _ => false,
323 }
324 && ts
325 .iter()
326 .zip(elts[1..].iter())
327 .all(|(t, v)| t.is_a_int(env, hist, v))
328 }
329 _ => false,
330 },
331 Type::TVar(tv) => match &*tv.read().typ.read() {
332 None => true,
333 Some(t) => t.is_a_int(env, hist, v),
334 },
335 Type::Fn(_) => match v {
336 Value::U64(_) => true,
337 _ => false,
338 },
339 Type::Bottom => true,
340 Type::Set(ts) => ts.iter().any(|t| t.is_a_int(env, hist, v)),
341 }
342 }
343
344 pub fn is_a(&self, env: &Env, v: &Value) -> bool {
346 self.is_a_int(env, &mut LPooled::take(), v)
347 }
348}