1use std::{
2 fmt,
3 sync::{Arc, RwLock},
4};
5
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 format_vec,
10 interner::{Symbol, TypeNodeId, with_session_globals},
11 pattern::TypedId,
12 utils::metadata::Location,
13};
14
15#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
18pub enum PType {
19 Unit,
20 Int,
21 Numeric,
22 String,
23}
24
25#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
26pub struct IntermediateId(pub u64);
27
28#[derive(Clone, Debug, PartialEq)]
29pub struct TypeBound {
30 pub lower: TypeNodeId,
31 pub upper: TypeNodeId,
32}
33impl Default for TypeBound {
34 fn default() -> Self {
35 Self {
36 lower: Type::Failure.into_id(),
37 upper: Type::Any.into_id(),
38 }
39 }
40}
41
42#[derive(Clone, Debug, PartialEq)]
43pub struct TypeVar {
44 pub parent: Option<TypeNodeId>,
45 pub var: IntermediateId,
46 pub level: u64,
47 pub bound: TypeBound,
48}
49impl TypeVar {
50 pub fn new(var: IntermediateId, level: u64) -> Self {
51 Self {
52 parent: None,
53 var,
54 level,
55 bound: TypeBound::default(),
56 }
57 }
58}
59
60#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
61pub struct RecordTypeField {
62 pub key: Symbol,
63 pub ty: TypeNodeId,
64 pub has_default: bool,
65}
66impl RecordTypeField {
67 pub fn new(key: Symbol, ty: TypeNodeId, has_default: bool) -> Self {
68 Self {
69 key,
70 ty,
71 has_default,
72 }
73 }
74}
75impl From<TypedId> for RecordTypeField {
76 fn from(value: TypedId) -> Self {
77 Self {
78 key: value.id,
79 ty: value.ty,
80 has_default: value.default_value.is_some(),
81 }
82 }
83}
84#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
85pub struct TypeSchemeId(pub u64);
86
87#[derive(Clone, Debug)]
88pub enum Type {
89 Primitive(PType),
90 Array(TypeNodeId),
92 Tuple(Vec<TypeNodeId>),
93 Record(Vec<RecordTypeField>),
94 Function {
97 arg: TypeNodeId,
98 ret: TypeNodeId,
99 },
100 Ref(TypeNodeId),
101 Code(TypeNodeId),
103 Union(Vec<TypeNodeId>),
105 UserSum {
108 name: Symbol,
109 variants: Vec<(Symbol, Option<TypeNodeId>)>,
110 },
111 Boxed(TypeNodeId),
115 Intermediate(Arc<RwLock<TypeVar>>),
116 TypeScheme(TypeSchemeId),
117 TypeAlias(Symbol),
119 Any,
121 Failure,
123 Unknown,
124}
125impl PartialEq for Type {
126 fn eq(&self, other: &Self) -> bool {
127 match (self, other) {
128 (Type::Intermediate(a), Type::Intermediate(b)) => {
129 let a = a.read().unwrap();
130 let b = b.read().unwrap();
131 a.var == b.var
132 }
133 (Type::Primitive(a), Type::Primitive(b)) => a == b,
134 (Type::Array(a), Type::Array(b)) => a == b,
135 (Type::Tuple(a), Type::Tuple(b)) => a == b,
136 (Type::Record(a), Type::Record(b)) => a == b,
137 (Type::Function { arg: a1, ret: a2 }, Type::Function { arg: b1, ret: b2 }) => {
138 a1 == b1 && a2 == b2
139 }
140 (Type::Ref(a), Type::Ref(b)) => a == b,
141 (Type::Code(a), Type::Code(b)) => a == b,
142 (Type::Union(a), Type::Union(b)) => a == b,
143 (
144 Type::UserSum {
145 name: n1,
146 variants: v1,
147 },
148 Type::UserSum {
149 name: n2,
150 variants: v2,
151 },
152 ) => n1 == n2 && v1 == v2,
153 (Type::Boxed(a), Type::Boxed(b)) => a == b,
154 (Type::TypeScheme(a), Type::TypeScheme(b)) => a == b,
155 (Type::TypeAlias(a), Type::TypeAlias(b)) => a == b,
156 (Type::Any, Type::Any) => true,
157 (Type::Failure, Type::Failure) => true,
158 (Type::Unknown, Type::Unknown) => true,
159 _ => false,
160 }
161 }
162}
163
164pub type TypeSize = u16;
166
167impl Type {
168 pub fn contains_function(&self) -> bool {
171 match self {
172 Type::Function { arg: _, ret: _ } => true,
173 Type::Tuple(t) => t.iter().any(|t| t.to_type().contains_function()),
174 Type::Record(t) => t
175 .iter()
176 .any(|RecordTypeField { ty, .. }| ty.to_type().contains_function()),
177 Type::Union(t) => t.iter().any(|t| t.to_type().contains_function()),
178 Type::Boxed(t) => t.to_type().contains_function(),
179 Type::TypeAlias(_) => true,
182 _ => false,
183 }
184 }
185 pub fn contains_boxed(&self) -> bool {
188 match self {
189 Type::Boxed(_) => true,
190 Type::Tuple(t) => t.iter().any(|t| t.to_type().contains_boxed()),
191 Type::Record(t) => t
192 .iter()
193 .any(|RecordTypeField { ty, .. }| ty.to_type().contains_boxed()),
194 Type::Union(t) => t.iter().any(|t| t.to_type().contains_boxed()),
195 Type::UserSum { .. } => {
196 true
201 }
202 Type::TypeAlias(_) => true,
205 _ => false,
206 }
207 }
208 pub fn is_function(&self) -> bool {
209 matches!(self, Type::Function { arg: _, ret: _ })
210 }
211 pub fn contains_code(&self) -> bool {
212 match self {
213 Type::Code(_) => true,
214 Type::Array(t) => t.to_type().contains_code(),
215 Type::Tuple(t) => t.iter().any(|t| t.to_type().contains_code()),
216 Type::Record(t) => t
217 .iter()
218 .any(|RecordTypeField { ty, .. }| ty.to_type().contains_code()),
219 Type::Function { arg, ret } => {
220 arg.to_type().contains_code() || ret.to_type().contains_code()
221 }
222 Type::Ref(t) => t.to_type().contains_code(),
223 Type::Union(t) => t.iter().any(|t| t.to_type().contains_code()),
224 Type::Boxed(t) => t.to_type().contains_code(),
225 _ => false,
226 }
227 }
228
229 pub fn contains_type_scheme(&self) -> bool {
230 match self {
231 Type::TypeScheme(_) => true,
232 Type::Array(t) => t.to_type().contains_type_scheme(),
233 Type::Tuple(t) => t.iter().any(|t| t.to_type().contains_type_scheme()),
234 Type::Record(t) => t
235 .iter()
236 .any(|RecordTypeField { ty, .. }| ty.to_type().contains_type_scheme()),
237 Type::Function { arg, ret } => {
238 arg.to_type().contains_type_scheme() || ret.to_type().contains_type_scheme()
239 }
240 Type::Ref(t) => t.to_type().contains_type_scheme(),
241 Type::Code(t) => t.to_type().contains_type_scheme(),
242 Type::Union(t) => t.iter().any(|t| t.to_type().contains_type_scheme()),
243 Type::Boxed(t) => t.to_type().contains_type_scheme(),
244 _ => false,
245 }
246 }
247
248 pub fn contains_unresolved(&self) -> bool {
252 match self {
253 Type::Unknown | Type::TypeScheme(_) => true,
254 Type::Intermediate(cell) => {
255 let tv = cell.read().unwrap();
256 let fallback =
257 (!matches!(tv.bound.lower.to_type(), Type::Failure)).then_some(tv.bound.lower);
258 tv.parent
259 .or(fallback)
260 .is_none_or(|resolved| resolved.to_type().contains_unresolved())
261 }
262 Type::Array(t) | Type::Ref(t) | Type::Code(t) | Type::Boxed(t) => {
263 t.to_type().contains_unresolved()
264 }
265 Type::Tuple(t) => t.iter().any(|t| t.to_type().contains_unresolved()),
266 Type::Record(t) => t
267 .iter()
268 .any(|RecordTypeField { ty, .. }| ty.to_type().contains_unresolved()),
269 Type::Function { arg, ret } => {
270 arg.to_type().contains_unresolved() || ret.to_type().contains_unresolved()
271 }
272 Type::Union(t) => t.iter().any(|t| t.to_type().contains_unresolved()),
273 _ => false,
274 }
275 }
276
277 pub fn is_intermediate(&self) -> Option<Arc<RwLock<TypeVar>>> {
278 match self {
279 Type::Intermediate(tvar) => Some(tvar.clone()),
280 _ => None,
281 }
282 }
283
284 pub fn get_as_tuple(&self) -> Option<Vec<TypeNodeId>> {
285 match self {
286 Type::Tuple(types) => Some(types.to_vec()),
287 Type::Record(fields) => Some(
288 fields
289 .iter()
290 .map(|RecordTypeField { ty, .. }| *ty)
291 .collect::<Vec<_>>(),
292 ),
293 _ => None,
294 }
295 }
296 pub fn can_be_unpacked(&self) -> bool {
297 matches!(self, Type::Tuple(_) | Type::Record(_))
298 }
299
300 fn is_iochannel_scalar(&self) -> bool {
301 match self {
302 Type::Primitive(PType::Numeric) | Type::Unknown => true,
303 Type::Intermediate(cell) => {
304 let parent = cell.read().unwrap().parent;
305 parent.is_none_or(|resolved| resolved.to_type().is_iochannel_scalar())
306 }
307 _ => false,
308 }
309 }
310
311 pub fn get_iochannel_count(&self) -> Option<u32> {
312 match self {
313 Type::Tuple(ts) => {
314 if ts.iter().all(|t| t.to_type().is_iochannel_scalar()) {
315 Some(ts.len() as _)
316 } else {
317 None
318 }
319 }
320 Type::Record(kvs) => {
321 if kvs
322 .iter()
323 .all(|RecordTypeField { ty, .. }| ty.to_type().is_iochannel_scalar())
324 {
325 Some(kvs.len() as _)
326 } else {
327 None
328 }
329 }
330 t if t.is_iochannel_scalar() => Some(1),
331 Type::Primitive(PType::Unit) => Some(0),
332 _ => None,
333 }
334 }
335 pub fn into_id(self) -> TypeNodeId {
336 with_session_globals(|session_globals| session_globals.store_type(self))
337 }
338
339 pub fn into_id_with_location(self, loc: Location) -> TypeNodeId {
340 with_session_globals(|session_globals| session_globals.store_type_with_location(self, loc))
341 }
342
343 pub fn to_string_for_error(&self) -> String {
344 match self {
345 Type::Array(a) => {
346 format!("[{}, ...]", a.to_type().to_string_for_error())
347 }
348 Type::Tuple(v) => {
349 let vf = format_vec!(
350 v.iter()
351 .map(|x| x.to_type().to_string_for_error())
352 .collect::<Vec<_>>(),
353 ","
354 );
355 format!("({vf})")
356 }
357 Type::Record(v) => {
358 let vf = format_vec!(
359 v.iter()
360 .map(|RecordTypeField { key, ty, .. }| format!(
361 "{}: {}",
362 key.as_str(),
363 ty.to_type().to_string_for_error()
364 ))
365 .collect::<Vec<_>>(),
366 ","
367 );
368 format!("({vf})")
369 }
370 Type::Function { arg, ret } => {
371 format!(
372 "({})->{}",
373 arg.to_type().to_string_for_error(),
374 ret.to_type().to_string_for_error()
375 )
376 }
377 Type::Ref(x) => format!("&{}", x.to_type().to_string_for_error()),
378 Type::Boxed(x) => format!("boxed({})", x.to_type().to_string_for_error()),
379 Type::Code(c) => format!("`({})", c.to_type().to_string_for_error()),
380 Type::Intermediate(cell) => {
381 let tv = cell.read().unwrap();
382 match tv.parent {
383 Some(parent) => parent.to_type().to_string_for_error(),
384 None => format!("unresolved type variable ?{}", tv.var.0),
385 }
386 }
387 x => x.to_string(),
389 }
390 }
391
392 pub fn to_mangled_string(&self) -> String {
396 match self {
397 Type::Primitive(p) => match p {
398 PType::Unit => "unit".to_string(),
399 PType::Int => "int".to_string(),
400 PType::Numeric => "num".to_string(),
401 PType::String => "str".to_string(),
402 },
403 Type::Array(a) => {
404 format!("arr_{}", a.to_type().to_mangled_string())
405 }
406 Type::Tuple(v) => {
407 let mangled_types = v
408 .iter()
409 .map(|x| x.to_type().to_mangled_string())
410 .collect::<Vec<_>>()
411 .join("_");
412 format!("tup_{mangled_types}")
413 }
414 Type::Record(v) => {
415 let mangled_fields = v
416 .iter()
417 .map(|RecordTypeField { key, ty, .. }| {
418 format!("{}_{}", key.as_str(), ty.to_type().to_mangled_string())
419 })
420 .collect::<Vec<_>>()
421 .join("_");
422 format!("rec_{mangled_fields}")
423 }
424 Type::Function { arg, ret } => {
425 format!(
426 "fn_{}_{}",
427 arg.to_type().to_mangled_string(),
428 ret.to_type().to_mangled_string()
429 )
430 }
431 Type::Ref(x) => format!("ref_{}", x.to_type().to_mangled_string()),
432 Type::Boxed(x) => format!("boxed_{}", x.to_type().to_mangled_string()),
433 Type::Code(c) => format!("code_{}", c.to_type().to_mangled_string()),
434 Type::Intermediate(tvar) => {
435 let tv = tvar.read().unwrap();
436 tv.parent
437 .map(|p| p.to_type().to_mangled_string())
438 .unwrap_or_else(|| format!("ivar_{}", tv.var.0))
439 }
440 Type::TypeScheme(id) => format!("scheme_{}", id.0),
441 Type::TypeAlias(name) => format!("alias_{}", name.as_str()),
442 Type::Union(v) => {
443 let mangled_types = v
444 .iter()
445 .map(|x| x.to_type().to_mangled_string())
446 .collect::<Vec<_>>()
447 .join("_");
448 format!("union_{}", mangled_types)
449 }
450 Type::UserSum { name, variants } => {
451 let variant_str = variants
452 .iter()
453 .map(|(s, _)| s.as_str())
454 .collect::<Vec<_>>()
455 .join("_");
456 format!("{}_{}", name.as_str(), variant_str)
457 }
458 Type::Any => "any".to_string(),
459 Type::Failure => "fail".to_string(),
460 Type::Unknown => "unknown".to_string(),
461 }
462 }
463}
464
465impl TypeNodeId {
466 pub fn get_root(&self) -> TypeNodeId {
467 match self.to_type() {
468 Type::Intermediate(cell) => {
469 let tv = cell.read().unwrap();
470 tv.parent.map_or(*self, |t| t.get_root())
471 }
472 _ => *self,
473 }
474 }
475
476 pub fn to_mangled_string(&self) -> String {
478 self.to_type().to_mangled_string()
479 }
480
481 pub fn apply_fn<F>(&self, mut closure: F) -> Self
482 where
483 F: FnMut(Self) -> Self,
484 {
485 let apply_scalar = |a: Self, c: &mut F| -> Self { c(a) };
486 let apply_vec = |v: &[Self], c: &mut F| -> Vec<Self> { v.iter().map(|a| c(*a)).collect() };
487 let result = match self.to_type() {
488 Type::Array(a) => Type::Array(apply_scalar(a, &mut closure)),
489 Type::Tuple(v) => Type::Tuple(apply_vec(&v, &mut closure)),
490 Type::Record(s) => Type::Record(
491 s.iter()
492 .map(
493 |RecordTypeField {
494 key,
495 ty,
496 has_default,
497 }| {
498 RecordTypeField::new(
499 *key,
500 apply_scalar(*ty, &mut closure),
501 *has_default,
502 )
503 },
504 )
505 .collect(),
506 ),
507 Type::Function { arg, ret } => Type::Function {
508 arg: apply_scalar(arg, &mut closure),
509 ret: apply_scalar(ret, &mut closure),
510 },
511 Type::Ref(x) => Type::Ref(apply_scalar(x, &mut closure)),
512 Type::Boxed(x) => Type::Boxed(apply_scalar(x, &mut closure)),
513 Type::Code(c) => Type::Code(apply_scalar(c, &mut closure)),
514 Type::Intermediate(id) => Type::Intermediate(id.clone()),
515 _ => self.to_type(),
516 };
517
518 result.into_id()
519 }
520
521 pub fn fold<F, R>(&self, _closure: F) -> R
522 where
523 F: Fn(Self, Self) -> R,
524 {
525 todo!()
526 }
527
528 pub fn word_size(&self) -> TypeSize {
537 match self.to_type() {
538 Type::Primitive(PType::Unit) => 0,
539 Type::Primitive(PType::String) => 1,
540 Type::Primitive(_) => 1,
541 Type::Array(_) => 1, Type::Tuple(types) => types.iter().map(|t| t.word_size()).sum(),
543 Type::Record(types) => types
544 .iter()
545 .map(|RecordTypeField { ty, .. }| ty.word_size())
546 .sum(),
547 Type::Function { .. } => 1,
548 Type::Ref(_) => 1,
549 Type::Boxed(_) => 1, Type::Code(_) => 1,
551 Type::Union(variants) => {
552 let max_variant_size = variants.iter().map(|v| v.word_size()).max().unwrap_or(0);
554 1 + max_variant_size
555 }
556 Type::UserSum { variants, .. } => {
557 let max_variant_size = variants
559 .iter()
560 .filter_map(|(_, payload_ty)| *payload_ty)
561 .map(|t| t.word_size())
562 .max()
563 .unwrap_or(0);
564 1 + max_variant_size
565 }
566 Type::Intermediate(cell) => {
567 let tv = cell.read().unwrap();
568 tv.parent
569 .or_else(|| {
570 (!matches!(tv.bound.lower.to_type(), Type::Failure))
571 .then_some(tv.bound.lower)
572 })
573 .map_or(1, |resolved| resolved.word_size())
574 }
575 _ => 1, }
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583 use std::sync::{Arc, RwLock};
584
585 #[test]
586 fn resolved_intermediate_is_not_unresolved() {
587 let resolved = Type::Tuple(vec![
588 Type::Primitive(PType::Numeric).into_id(),
589 Type::Primitive(PType::Numeric).into_id(),
590 ])
591 .into_id();
592 let mut tvar = TypeVar::new(IntermediateId(1), 0);
593 tvar.parent = Some(resolved);
594 let intermediate = Type::Intermediate(Arc::new(RwLock::new(tvar))).into_id();
595
596 assert!(!intermediate.to_type().contains_unresolved());
597 assert_eq!(intermediate.word_size(), 2);
598 }
599}
600
601impl fmt::Display for PType {
602 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
603 match self {
604 PType::Unit => write!(f, "()"),
605 PType::Int => write!(f, "int"),
606 PType::Numeric => write!(f, "number"),
607 PType::String => write!(f, "string"),
608 }
609 }
610}
611impl fmt::Display for TypeVar {
612 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
613 write!(
614 f,
615 "?{}[{}]{}",
616 self.var.0,
617 self.level,
618 self.parent
619 .map_or_else(|| "".to_string(), |t| format!(":{}", t.to_type()))
620 )
621 }
622}
623impl fmt::Display for RecordTypeField {
624 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
625 let def = if self.has_default { "(default)" } else { "" };
626 write!(f, "{}:{}{def}", self.key, self.ty.to_type())
627 }
628}
629impl fmt::Display for Type {
630 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
631 match self {
632 Type::Primitive(p) => write!(f, "{p}"),
633 Type::Array(a) => write!(f, "[{}]", a.to_type()),
634 Type::Tuple(v) => {
635 let vf = format_vec!(
636 v.iter().map(|x| x.to_type().clone()).collect::<Vec<_>>(),
637 ","
638 );
639 write!(f, "({vf})")
640 }
641 Type::Record(v) => {
642 write!(f, "{{{}}}", format_vec!(v, ", "))
643 }
644 Type::Function { arg, ret } => {
645 write!(f, "({})->{}", arg.to_type(), ret.to_type())
646 }
647 Type::Ref(x) => write!(f, "&{}", x.to_type()),
648 Type::Boxed(x) => write!(f, "boxed({})", x.to_type()),
649
650 Type::Code(c) => write!(f, "<{}>", c.to_type()),
651 Type::Union(v) => {
652 let vf = format_vec!(
653 v.iter().map(|x| x.to_type().clone()).collect::<Vec<_>>(),
654 " | "
655 );
656 write!(f, "{vf}")
657 }
658 Type::UserSum { name, variants } => {
659 let variant_str = variants
660 .iter()
661 .map(|(s, _)| s.as_str())
662 .collect::<Vec<_>>()
663 .join(" | ");
664 write!(f, "{} = {}", name.as_str(), variant_str)
665 }
666 Type::Intermediate(id) => {
667 write!(f, "{}", id.read().unwrap())
668 }
669 Type::TypeScheme(id) => {
670 write!(f, "g({})", id.0)
671 }
672 Type::TypeAlias(name) => write!(f, "{}", name.as_str()),
673 Type::Any => write!(f, "any"),
674 Type::Failure => write!(f, "!"),
675 Type::Unknown => write!(f, "unknown"),
676 }
677 }
678}
679
680pub mod builder;
681mod serde_impl;
682
683