1use crate::{
2 error::{AdtConflict, CollectAdtsError, TypeError},
3 typesystem::TypeVarSupply,
4 unification::{Subst, subst_is_empty},
5};
6use chrono::{DateTime, Utc};
7use rex_ast::{Pattern, Symbol};
8use rpds::HashTrieMapSync;
9use std::{
10 cmp::Ordering,
11 collections::{BTreeMap, BTreeSet},
12 fmt::{self, Display, Formatter},
13 mem,
14 sync::Arc,
15};
16use uuid::Uuid;
17
18pub type TypeVarId = usize;
19
20#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
21pub enum BuiltinTypeId {
22 U8,
23 U16,
24 U32,
25 U64,
26 I8,
27 I16,
28 I32,
29 I64,
30 F32,
31 F64,
32 Bool,
33 String,
34 Uuid,
35 DateTime,
36 List,
37 Array,
38 Dict,
39 Option,
40 Promise,
41 Result,
42}
43
44impl BuiltinTypeId {
45 pub fn as_symbol(self) -> Symbol {
46 Symbol::intern(self.as_str())
47 }
48
49 pub fn as_str(self) -> &'static str {
50 match self {
51 Self::U8 => "u8",
52 Self::U16 => "u16",
53 Self::U32 => "u32",
54 Self::U64 => "u64",
55 Self::I8 => "i8",
56 Self::I16 => "i16",
57 Self::I32 => "i32",
58 Self::I64 => "i64",
59 Self::F32 => "f32",
60 Self::F64 => "f64",
61 Self::Bool => "bool",
62 Self::String => "string",
63 Self::Uuid => "uuid",
64 Self::DateTime => "datetime",
65 Self::List => "List",
66 Self::Array => "Array",
67 Self::Dict => "Dict",
68 Self::Option => "Option",
69 Self::Promise => "Promise",
70 Self::Result => "Result",
71 }
72 }
73
74 pub fn arity(self) -> usize {
75 match self {
76 Self::List | Self::Array | Self::Dict | Self::Option | Self::Promise => 1,
77 Self::Result => 2,
78 _ => 0,
79 }
80 }
81
82 pub fn from_symbol(name: &Symbol) -> Option<Self> {
83 Self::from_name(name.as_ref())
84 }
85
86 pub fn from_name(name: &str) -> Option<Self> {
87 match name {
88 "u8" => Some(Self::U8),
89 "u16" => Some(Self::U16),
90 "u32" => Some(Self::U32),
91 "u64" => Some(Self::U64),
92 "i8" => Some(Self::I8),
93 "i16" => Some(Self::I16),
94 "i32" => Some(Self::I32),
95 "i64" => Some(Self::I64),
96 "f32" => Some(Self::F32),
97 "f64" => Some(Self::F64),
98 "bool" => Some(Self::Bool),
99 "string" => Some(Self::String),
100 "uuid" => Some(Self::Uuid),
101 "datetime" => Some(Self::DateTime),
102 "List" => Some(Self::List),
103 "Array" => Some(Self::Array),
104 "Dict" => Some(Self::Dict),
105 "Option" => Some(Self::Option),
106 "Promise" => Some(Self::Promise),
107 "Result" => Some(Self::Result),
108 _ => None,
109 }
110 }
111}
112
113#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
114pub struct TypeVar {
115 pub id: TypeVarId,
116 pub name: Option<Symbol>,
117}
118
119impl TypeVar {
120 pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
121 Self {
122 id,
123 name: name.into(),
124 }
125 }
126}
127
128#[derive(Clone, Debug, Hash, Eq, PartialEq)]
129pub enum TypeConst {
130 Builtin(BuiltinTypeId),
131 User { name: Symbol, arity: usize },
132}
133
134impl TypeConst {
135 pub fn builtin_id(&self) -> Option<BuiltinTypeId> {
136 match self {
137 Self::Builtin(id) => Some(*id),
138 Self::User { .. } => None,
139 }
140 }
141
142 pub fn is_builtin(&self, id: BuiltinTypeId) -> bool {
143 self.builtin_id() == Some(id)
144 }
145
146 pub fn name(&self) -> Symbol {
147 match self {
148 Self::Builtin(id) => id.as_symbol(),
149 Self::User { name, .. } => name.clone(),
150 }
151 }
152
153 pub fn name_str(&self) -> &str {
154 match self {
155 Self::Builtin(id) => id.as_str(),
156 Self::User { name, .. } => name.as_ref(),
157 }
158 }
159
160 pub fn user_name(&self) -> Option<&Symbol> {
161 match self {
162 Self::Builtin(_) => None,
163 Self::User { name, .. } => Some(name),
164 }
165 }
166
167 pub fn arity(&self) -> usize {
168 match self {
169 Self::Builtin(id) => id.arity(),
170 Self::User { arity, .. } => *arity,
171 }
172 }
173}
174
175impl Ord for TypeConst {
176 fn cmp(&self, other: &Self) -> Ordering {
177 self.name_str()
178 .cmp(other.name_str())
179 .then_with(|| self.arity().cmp(&other.arity()))
180 .then_with(|| self.builtin_id().cmp(&other.builtin_id()))
181 }
182}
183
184impl PartialOrd for TypeConst {
185 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
186 Some(self.cmp(other))
187 }
188}
189
190#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
191pub struct Type(Arc<TypeKind>);
192
193#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
194pub enum TypeKind {
195 Var(TypeVar),
196 Con(TypeConst),
197 App(Type, Type),
198 Fun(Type, Type),
199 Tuple(Vec<Type>),
200 Record(Vec<(Symbol, Type)>),
205}
206
207impl Type {
208 pub fn new(kind: TypeKind) -> Self {
209 Type(Arc::new(kind))
210 }
211
212 pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
213 if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
214 && id.arity() == arity
215 {
216 return Self::builtin(id);
217 }
218 Self::user_con(name, arity)
219 }
220
221 pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
222 Type::new(TypeKind::Con(TypeConst::User {
223 name: Symbol::intern(name.as_ref()),
224 arity,
225 }))
226 }
227
228 pub fn builtin(id: BuiltinTypeId) -> Self {
229 Type::new(TypeKind::Con(TypeConst::Builtin(id)))
230 }
231
232 pub fn var(tv: TypeVar) -> Self {
233 Type::new(TypeKind::Var(tv))
234 }
235
236 pub fn fun(a: Type, b: Type) -> Self {
237 Type::new(TypeKind::Fun(a, b))
238 }
239
240 pub fn app(f: Type, arg: Type) -> Self {
241 Type::new(TypeKind::App(f, arg))
242 }
243
244 pub fn tuple(elems: Vec<Type>) -> Self {
245 Type::new(TypeKind::Tuple(elems))
246 }
247
248 pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
249 fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
252 Type::new(TypeKind::Record(fields))
253 }
254
255 pub fn list(elem: Type) -> Type {
256 Type::app(Type::builtin(BuiltinTypeId::List), elem)
257 }
258
259 pub fn array(elem: Type) -> Type {
260 Type::app(Type::builtin(BuiltinTypeId::Array), elem)
261 }
262
263 pub fn dict(elem: Type) -> Type {
264 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
265 }
266
267 pub fn option(elem: Type) -> Type {
268 Type::app(Type::builtin(BuiltinTypeId::Option), elem)
269 }
270
271 pub fn promise(elem: Type) -> Type {
272 Type::app(Type::builtin(BuiltinTypeId::Promise), elem)
273 }
274
275 pub fn result(ok: Type, err: Type) -> Type {
276 Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
277 }
278
279 fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
280 match self.as_ref() {
281 TypeKind::Var(tv) => match s.get(&tv.id) {
282 Some(ty) => (ty.clone(), true),
283 None => (self.clone(), false),
284 },
285 TypeKind::Con(_) => (self.clone(), false),
286 TypeKind::App(l, r) => {
287 let (l_new, l_changed) = l.apply_with_change(s);
288 let (r_new, r_changed) = r.apply_with_change(s);
289 if l_changed || r_changed {
290 (Type::app(l_new, r_new), true)
291 } else {
292 (self.clone(), false)
293 }
294 }
295 TypeKind::Fun(_, _) => {
296 let mut args = Vec::new();
299 let mut changed = false;
300 let mut cur: &Type = self;
301 while let TypeKind::Fun(a, b) = cur.as_ref() {
302 let (a_new, a_changed) = a.apply_with_change(s);
303 changed |= a_changed;
304 args.push(a_new);
305 cur = b;
306 }
307 let (ret_new, ret_changed) = cur.apply_with_change(s);
308 changed |= ret_changed;
309 if !changed {
310 return (self.clone(), false);
311 }
312 let mut out = ret_new;
313 for a_new in args.into_iter().rev() {
314 out = Type::fun(a_new, out);
315 }
316 (out, true)
317 }
318 TypeKind::Tuple(ts) => {
319 let mut changed = false;
320 let mut out = Vec::with_capacity(ts.len());
321 for t in ts {
322 let (t_new, t_changed) = t.apply_with_change(s);
323 changed |= t_changed;
324 out.push(t_new);
325 }
326 if changed {
327 (Type::new(TypeKind::Tuple(out)), true)
328 } else {
329 (self.clone(), false)
330 }
331 }
332 TypeKind::Record(fields) => {
333 let mut changed = false;
334 let mut out = Vec::with_capacity(fields.len());
335 for (k, v) in fields {
336 let (v_new, v_changed) = v.apply_with_change(s);
337 changed |= v_changed;
338 out.push((k.clone(), v_new));
339 }
340 if changed {
341 (Type::new(TypeKind::Record(out)), true)
342 } else {
343 (self.clone(), false)
344 }
345 }
346 }
347 }
348
349 pub fn for_each<F>(&self, mut f: F) -> Type
350 where
351 F: FnMut(&Type),
352 {
353 self.transform(|t| {
354 f(t);
355 None
356 })
357 }
358
359 pub fn transform<F>(&self, mut f: F) -> Type
360 where
361 F: FnMut(&Type) -> Option<Type>,
362 {
363 self.transform_ref(&mut f)
364 }
365
366 fn transform_ref<F>(&self, f: &mut F) -> Type
367 where
368 F: FnMut(&Type) -> Option<Type>,
369 {
370 if let Some(repl) = f(self) {
371 return repl;
372 }
373
374 match self.as_ref() {
375 TypeKind::Var(type_var) => Type(Arc::new(TypeKind::Var(type_var.clone()))),
376 TypeKind::Con(type_const) => Type(Arc::new(TypeKind::Con(type_const.clone()))),
377 TypeKind::App(fun, arg) => Type(Arc::new(TypeKind::App(
378 fun.transform_ref(f),
379 arg.transform_ref(f),
380 ))),
381 TypeKind::Fun(arg, res) => Type(Arc::new(TypeKind::Fun(
382 arg.transform_ref(f),
383 res.transform_ref(f),
384 ))),
385 TypeKind::Tuple(ts) => Type(Arc::new(TypeKind::Tuple(
386 ts.iter().map(|t| t.transform_ref(f)).collect(),
387 ))),
388 TypeKind::Record(fields) => Type(Arc::new(TypeKind::Record(
389 fields
390 .iter()
391 .map(|(s, t)| (s.clone(), t.transform_ref(f)))
392 .collect(),
393 ))),
394 }
395 }
396}
397
398impl AsRef<TypeKind> for Type {
399 fn as_ref(&self) -> &TypeKind {
400 self.0.as_ref()
401 }
402}
403
404impl std::ops::Deref for Type {
405 type Target = TypeKind;
406
407 fn deref(&self) -> &Self::Target {
408 &self.0
409 }
410}
411
412impl Display for Type {
413 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
414 match self.as_ref() {
415 TypeKind::Var(tv) => match &tv.name {
416 Some(name) => write!(f, "'{}", name),
417 None => write!(f, "t{}", tv.id),
418 },
419 TypeKind::Con(c) => write!(f, "{}", c.name_str()),
420 TypeKind::App(l, r) => {
421 if let TypeKind::App(head, err) = l.as_ref()
427 && matches!(
428 head.as_ref(),
429 TypeKind::Con(c)
430 if c.is_builtin(BuiltinTypeId::Result) && c.arity() == 2
431 )
432 {
433 return write!(f, "(Result {} {})", r, err);
434 }
435 write!(f, "({} {})", l, r)
436 }
437 TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
438 TypeKind::Tuple(elems) => {
439 write!(f, "(")?;
440 for (i, t) in elems.iter().enumerate() {
441 write!(f, "{}", t)?;
442 if i + 1 < elems.len() {
443 write!(f, ", ")?;
444 }
445 }
446 write!(f, ")")
447 }
448 TypeKind::Record(fields) => {
449 write!(f, "{{")?;
450 for (i, (name, ty)) in fields.iter().enumerate() {
451 write!(f, "{}: {}", name, ty)?;
452 if i + 1 < fields.len() {
453 write!(f, ", ")?;
454 }
455 }
456 write!(f, "}}")
457 }
458 }
459 }
460}
461
462#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
463pub struct Predicate {
464 pub class: Symbol,
465 pub typ: Type,
466}
467
468impl Predicate {
469 pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
470 Self {
471 class: Symbol::intern(class.as_ref()),
472 typ,
473 }
474 }
475}
476
477#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
478pub struct Scheme {
479 pub vars: Vec<TypeVar>,
480 pub preds: Vec<Predicate>,
481 pub typ: Type,
482}
483
484impl Scheme {
485 pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
486 Self { vars, preds, typ }
487 }
488}
489
490pub trait Types: Sized {
491 fn apply(&self, s: &Subst) -> Self;
492 fn ftv(&self) -> BTreeSet<TypeVarId>;
493}
494
495impl Types for Type {
496 fn apply(&self, s: &Subst) -> Self {
497 self.apply_with_change(s).0
498 }
499
500 fn ftv(&self) -> BTreeSet<TypeVarId> {
501 let mut out = BTreeSet::new();
502 let mut stack: Vec<&Type> = vec![self];
503 while let Some(t) = stack.pop() {
504 match t.as_ref() {
505 TypeKind::Var(tv) => {
506 out.insert(tv.id);
507 }
508 TypeKind::Con(_) => {}
509 TypeKind::App(l, r) => {
510 stack.push(l);
511 stack.push(r);
512 }
513 TypeKind::Fun(a, b) => {
514 stack.push(a);
515 stack.push(b);
516 }
517 TypeKind::Tuple(ts) => {
518 for t in ts {
519 stack.push(t);
520 }
521 }
522 TypeKind::Record(fields) => {
523 for (_, ty) in fields {
524 stack.push(ty);
525 }
526 }
527 }
528 }
529 out
530 }
531}
532
533impl Types for Predicate {
534 fn apply(&self, s: &Subst) -> Self {
535 Predicate {
536 class: self.class.clone(),
537 typ: self.typ.apply(s),
538 }
539 }
540
541 fn ftv(&self) -> BTreeSet<TypeVarId> {
542 self.typ.ftv()
543 }
544}
545
546impl Types for Scheme {
547 fn apply(&self, s: &Subst) -> Self {
548 let mut s_pruned = Subst::new_sync();
549 for (k, v) in s.iter() {
550 if !self.vars.iter().any(|var| var.id == *k) {
551 s_pruned = s_pruned.insert(*k, v.clone());
552 }
553 }
554 Scheme::new(
555 self.vars.clone(),
556 self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
557 self.typ.apply(&s_pruned),
558 )
559 }
560
561 fn ftv(&self) -> BTreeSet<TypeVarId> {
562 let mut ftv = self.typ.ftv();
563 for p in &self.preds {
564 ftv.extend(p.ftv());
565 }
566 for v in &self.vars {
567 ftv.remove(&v.id);
568 }
569 ftv
570 }
571}
572
573impl<T: Types> Types for Vec<T> {
574 fn apply(&self, s: &Subst) -> Self {
575 self.iter().map(|t| t.apply(s)).collect()
576 }
577
578 fn ftv(&self) -> BTreeSet<TypeVarId> {
579 self.iter().flat_map(Types::ftv).collect()
580 }
581}
582
583#[derive(Clone, Debug, PartialEq)]
584pub struct TypedExpr {
585 pub typ: Type,
586 pub kind: Arc<TypedExprKind>,
587}
588
589struct TypedTailAppFrame {
590 head: Arc<TypedExpr>,
591 prefix_args: Vec<(Type, Arc<TypedExpr>)>,
592 tail_result_type: Type,
593}
594
595fn collect_typed_app_chain(expr: &TypedExpr) -> (Arc<TypedExpr>, Vec<(Type, Arc<TypedExpr>)>) {
596 let mut args = Vec::new();
597 let mut cur = expr;
598 while let TypedExprKind::App(f, x) = cur.kind.as_ref() {
599 args.push((cur.typ.clone(), Arc::clone(x)));
600 cur = f.as_ref();
601 }
602 args.reverse();
603 (Arc::new(cur.clone()), args)
604}
605
606fn collect_typed_tail_app_chain(
607 expr: &TypedExpr,
608) -> Option<(Arc<TypedExpr>, Vec<TypedTailAppFrame>)> {
609 let mut frames = Vec::new();
610 let mut cur = Arc::new(expr.clone());
611 while matches!(cur.kind.as_ref(), TypedExprKind::App(..)) {
612 let (head, mut args) = collect_typed_app_chain(cur.as_ref());
613 let Some((tail_result_type, tail)) = args.pop() else {
614 break;
615 };
616 if !matches!(tail.kind.as_ref(), TypedExprKind::App(..)) {
617 break;
618 }
619 frames.push(TypedTailAppFrame {
620 head,
621 prefix_args: args,
622 tail_result_type,
623 });
624 cur = tail;
625 }
626 (!frames.is_empty()).then_some((cur, frames))
627}
628
629fn typed_drop_placeholder() -> Arc<TypedExpr> {
630 Arc::new(TypedExpr::new(Type::tuple(vec![]), TypedExprKind::Hole))
631}
632
633fn drain_typed_expr_kind(kind: &mut TypedExprKind, stack: &mut Vec<Arc<TypedExpr>>) {
634 match kind {
635 TypedExprKind::Tuple(elems) | TypedExprKind::List(elems) => {
636 stack.extend(mem::take(elems));
637 }
638 TypedExprKind::Dict(kvs) => {
639 stack.extend(mem::take(kvs).into_values());
640 }
641 TypedExprKind::RecordUpdate { base, updates } => {
642 stack.push(mem::replace(base, typed_drop_placeholder()));
643 stack.extend(mem::take(updates).into_values());
644 }
645 TypedExprKind::App(f, x) => {
646 stack.push(mem::replace(f, typed_drop_placeholder()));
647 stack.push(mem::replace(x, typed_drop_placeholder()));
648 }
649 TypedExprKind::Project { expr, .. } => {
650 stack.push(mem::replace(expr, typed_drop_placeholder()));
651 }
652 TypedExprKind::Lam { body, .. } => {
653 stack.push(mem::replace(body, typed_drop_placeholder()));
654 }
655 TypedExprKind::Let { def, body, .. } => {
656 stack.push(mem::replace(def, typed_drop_placeholder()));
657 stack.push(mem::replace(body, typed_drop_placeholder()));
658 }
659 TypedExprKind::LetRec { bindings, body } => {
660 for (_name, def) in mem::take(bindings) {
661 stack.push(def);
662 }
663 stack.push(mem::replace(body, typed_drop_placeholder()));
664 }
665 TypedExprKind::Ite {
666 cond,
667 then_expr,
668 else_expr,
669 } => {
670 stack.push(mem::replace(cond, typed_drop_placeholder()));
671 stack.push(mem::replace(then_expr, typed_drop_placeholder()));
672 stack.push(mem::replace(else_expr, typed_drop_placeholder()));
673 }
674 TypedExprKind::Match { scrutinee, arms } => {
675 stack.push(mem::replace(scrutinee, typed_drop_placeholder()));
676 for (_pat, arm) in mem::take(arms) {
677 stack.push(arm);
678 }
679 }
680 TypedExprKind::Bool(..)
681 | TypedExprKind::Uint(..)
682 | TypedExprKind::Int(..)
683 | TypedExprKind::Float(..)
684 | TypedExprKind::String(..)
685 | TypedExprKind::Uuid(..)
686 | TypedExprKind::DateTime(..)
687 | TypedExprKind::Hole
688 | TypedExprKind::Var { .. } => {}
689 }
690}
691
692impl Drop for TypedExpr {
693 fn drop(&mut self) {
694 let Some(kind) = Arc::get_mut(&mut self.kind) else {
695 return;
696 };
697 let mut stack = Vec::new();
698 drain_typed_expr_kind(kind, &mut stack);
699 while let Some(mut expr) = stack.pop() {
700 let Some(expr) = Arc::get_mut(&mut expr) else {
701 continue;
702 };
703 let Some(kind) = Arc::get_mut(&mut expr.kind) else {
704 continue;
705 };
706 drain_typed_expr_kind(kind, &mut stack);
707 }
708 }
709}
710
711impl TypedExpr {
712 pub fn new(typ: Type, kind: TypedExprKind) -> Self {
713 Self {
714 typ,
715 kind: Arc::new(kind),
716 }
717 }
718
719 pub fn apply(&self, s: &Subst) -> Self {
720 match self.kind.as_ref() {
724 TypedExprKind::Lam { .. } => {
725 let mut params: Vec<(Symbol, Type)> = Vec::new();
726 let mut cur = self;
727 while let TypedExprKind::Lam { param, body } = cur.kind.as_ref() {
728 params.push((param.clone(), cur.typ.apply(s)));
729 cur = body.as_ref();
730 }
731 let mut out = cur.apply(s);
732 for (param, typ) in params.into_iter().rev() {
733 out = TypedExpr::new(
734 typ,
735 TypedExprKind::Lam {
736 param,
737 body: Arc::new(out),
738 },
739 );
740 }
741 return out;
742 }
743 TypedExprKind::App(..) => {
744 if let Some((leaf, frames)) = collect_typed_tail_app_chain(self) {
745 let mut out = leaf.apply(s);
746 for frame in frames.into_iter().rev() {
747 let mut typed = frame.head.apply(s);
748 for (typ, arg) in frame.prefix_args {
749 typed = TypedExpr::new(
750 typ.apply(s),
751 TypedExprKind::App(Arc::new(typed), Arc::new(arg.apply(s))),
752 );
753 }
754 out = TypedExpr::new(
755 frame.tail_result_type.apply(s),
756 TypedExprKind::App(Arc::new(typed), Arc::new(out)),
757 );
758 }
759 return out;
760 }
761
762 let mut apps: Vec<(Type, Arc<TypedExpr>)> = Vec::new();
763 let mut cur = self;
764 while let TypedExprKind::App(f, x) = cur.kind.as_ref() {
765 apps.push((cur.typ.apply(s), Arc::clone(x)));
766 cur = f.as_ref();
767 }
768 let mut out = cur.apply(s);
769 for (typ, arg) in apps.into_iter().rev() {
770 out = TypedExpr::new(
771 typ,
772 TypedExprKind::App(Arc::new(out), Arc::new(arg.apply(s))),
773 );
774 }
775 return out;
776 }
777 _ => {}
778 }
779
780 let typ = self.typ.apply(s);
781 let kind = match self.kind.as_ref() {
782 TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
783 TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
784 TypedExprKind::Int(v) => TypedExprKind::Int(*v),
785 TypedExprKind::Float(v) => TypedExprKind::Float(*v),
786 TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
787 TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
788 TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
789 TypedExprKind::Hole => TypedExprKind::Hole,
790 TypedExprKind::Tuple(elems) => {
791 TypedExprKind::Tuple(elems.iter().map(|e| Arc::new(e.apply(s))).collect())
792 }
793 TypedExprKind::List(elems) => {
794 TypedExprKind::List(elems.iter().map(|e| Arc::new(e.apply(s))).collect())
795 }
796 TypedExprKind::Dict(kvs) => {
797 let mut out = BTreeMap::new();
798 for (k, v) in kvs {
799 out.insert(k.clone(), Arc::new(v.apply(s)));
800 }
801 TypedExprKind::Dict(out)
802 }
803 TypedExprKind::RecordUpdate { base, updates } => {
804 let mut out = BTreeMap::new();
805 for (k, v) in updates {
806 out.insert(k.clone(), Arc::new(v.apply(s)));
807 }
808 TypedExprKind::RecordUpdate {
809 base: Arc::new(base.apply(s)),
810 updates: out,
811 }
812 }
813 TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
814 name: name.clone(),
815 overloads: overloads.iter().map(|t| t.apply(s)).collect(),
816 },
817 TypedExprKind::App(f, x) => {
818 TypedExprKind::App(Arc::new(f.apply(s)), Arc::new(x.apply(s)))
819 }
820 TypedExprKind::Project { expr, field } => TypedExprKind::Project {
821 expr: Arc::new(expr.apply(s)),
822 field: field.clone(),
823 },
824 TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
825 param: param.clone(),
826 body: Arc::new(body.apply(s)),
827 },
828 TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
829 name: name.clone(),
830 def: Arc::new(def.apply(s)),
831 body: Arc::new(body.apply(s)),
832 },
833 TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
834 bindings: bindings
835 .iter()
836 .map(|(name, def)| (name.clone(), Arc::new(def.apply(s))))
837 .collect(),
838 body: Arc::new(body.apply(s)),
839 },
840 TypedExprKind::Ite {
841 cond,
842 then_expr,
843 else_expr,
844 } => TypedExprKind::Ite {
845 cond: Arc::new(cond.apply(s)),
846 then_expr: Arc::new(then_expr.apply(s)),
847 else_expr: Arc::new(else_expr.apply(s)),
848 },
849 TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
850 scrutinee: Arc::new(scrutinee.apply(s)),
851 arms: arms
852 .iter()
853 .map(|(p, e)| (p.clone(), Arc::new(e.apply(s))))
854 .collect(),
855 },
856 };
857 TypedExpr::new(typ, kind)
858 }
859}
860
861#[derive(Clone, Debug, PartialEq)]
862pub enum TypedExprKind {
863 Bool(bool),
864 Uint(u64),
865 Int(i64),
866 Float(f64),
867 String(String),
868 Uuid(Uuid),
869 DateTime(DateTime<Utc>),
870 Hole,
871 Tuple(Vec<Arc<TypedExpr>>),
872 List(Vec<Arc<TypedExpr>>),
873 Dict(BTreeMap<Symbol, Arc<TypedExpr>>),
874 RecordUpdate {
875 base: Arc<TypedExpr>,
876 updates: BTreeMap<Symbol, Arc<TypedExpr>>,
877 },
878 Var {
879 name: Symbol,
880 overloads: Vec<Type>,
881 },
882 App(Arc<TypedExpr>, Arc<TypedExpr>),
883 Project {
884 expr: Arc<TypedExpr>,
885 field: Symbol,
886 },
887 Lam {
888 param: Symbol,
889 body: Arc<TypedExpr>,
890 },
891 Let {
892 name: Symbol,
893 def: Arc<TypedExpr>,
894 body: Arc<TypedExpr>,
895 },
896 LetRec {
897 bindings: Vec<(Symbol, Arc<TypedExpr>)>,
898 body: Arc<TypedExpr>,
899 },
900 Ite {
901 cond: Arc<TypedExpr>,
902 then_expr: Arc<TypedExpr>,
903 else_expr: Arc<TypedExpr>,
904 },
905 Match {
906 scrutinee: Arc<TypedExpr>,
907 arms: Vec<(Pattern, Arc<TypedExpr>)>,
908 },
909}
910
911#[derive(Default, Debug, Clone)]
912pub struct TypeEnv {
913 pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
914}
915
916impl TypeEnv {
917 pub fn new() -> Self {
918 Self {
919 values: HashTrieMapSync::new_sync(),
920 }
921 }
922
923 pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
924 self.values = self.values.insert(name, vec![scheme]);
925 }
926
927 pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
928 let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
929 schemes.push(scheme);
930 self.values = self.values.insert(name, schemes);
931 }
932
933 pub fn remove(&mut self, name: &Symbol) {
934 self.values = self.values.remove(name);
935 }
936
937 pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
938 self.values.get(name).map(|schemes| schemes.as_slice())
939 }
940}
941
942impl Types for TypeEnv {
943 fn apply(&self, s: &Subst) -> Self {
944 let mut values = HashTrieMapSync::new_sync();
945 for (k, v) in self.values.iter() {
946 let updated = v
947 .iter()
948 .map(|scheme| {
949 if scheme.vars.is_empty() && !subst_is_empty(s) {
952 scheme.apply(s)
953 } else {
954 scheme.clone()
955 }
956 })
957 .collect();
958 values = values.insert(k.clone(), updated);
959 }
960 TypeEnv { values }
961 }
962
963 fn ftv(&self) -> BTreeSet<TypeVarId> {
964 self.values
965 .iter()
966 .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
967 .collect()
968 }
969}
970
971#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
973pub struct AdtParam {
974 pub name: Symbol,
975 pub var: TypeVar,
976}
977
978#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
980pub struct AdtVariant {
981 pub name: Symbol,
982 pub args: Vec<Type>,
983}
984
985#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
991pub struct AdtDecl {
992 pub name: Symbol,
993 pub params: Vec<AdtParam>,
994 pub variants: Vec<AdtVariant>,
995}
996
997impl AdtDecl {
998 pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
999 let params = param_names
1000 .iter()
1001 .map(|p| AdtParam {
1002 name: p.clone(),
1003 var: supply.fresh(Some(p.clone())),
1004 })
1005 .collect();
1006 Self {
1007 name: name.clone(),
1008 params,
1009 variants: Vec::new(),
1010 }
1011 }
1012
1013 pub fn param_type(&self, name: &Symbol) -> Option<Type> {
1014 self.params
1015 .iter()
1016 .find(|p| &p.name == name)
1017 .map(|p| Type::var(p.var.clone()))
1018 }
1019
1020 pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
1021 self.variants.push(AdtVariant { name, args });
1022 }
1023
1024 pub fn result_type(&self) -> Type {
1025 let mut ty = Type::con(&self.name, self.params.len());
1026 for param in &self.params {
1027 ty = Type::app(ty, Type::var(param.var.clone()));
1028 }
1029 ty
1030 }
1031
1032 pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
1035 let result_ty = self.result_type();
1036 let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
1037 let mut out = Vec::new();
1038 for variant in &self.variants {
1039 let mut typ = result_ty.clone();
1040 for arg in variant.args.iter().rev() {
1041 typ = Type::fun(arg.clone(), typ);
1042 }
1043 out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
1044 }
1045 out
1046 }
1047}
1048
1049pub trait RexType {
1061 fn rex_type() -> Type;
1067
1068 fn collect_rex_family(_out: &mut Vec<AdtDecl>) -> Result<(), TypeError> {
1078 Ok(())
1079 }
1080}
1081
1082pub trait RexAdt: RexType {
1093 fn rex_adt_decl() -> Result<AdtDecl, TypeError>;
1098
1099 fn rex_adt_family() -> Result<Vec<AdtDecl>, TypeError> {
1106 let mut out = Vec::new();
1107 <Self as RexType>::collect_rex_family(&mut out)?;
1108 Ok(out)
1109 }
1110}
1111
1112impl RexType for bool {
1113 fn rex_type() -> Type {
1114 Type::builtin(BuiltinTypeId::Bool)
1115 }
1116}
1117
1118impl RexType for u8 {
1119 fn rex_type() -> Type {
1120 Type::builtin(BuiltinTypeId::U8)
1121 }
1122}
1123
1124impl RexType for u16 {
1125 fn rex_type() -> Type {
1126 Type::builtin(BuiltinTypeId::U16)
1127 }
1128}
1129
1130impl RexType for u32 {
1131 fn rex_type() -> Type {
1132 Type::builtin(BuiltinTypeId::U32)
1133 }
1134}
1135
1136impl RexType for u64 {
1137 fn rex_type() -> Type {
1138 Type::builtin(BuiltinTypeId::U64)
1139 }
1140}
1141
1142impl RexType for i8 {
1143 fn rex_type() -> Type {
1144 Type::builtin(BuiltinTypeId::I8)
1145 }
1146}
1147
1148impl RexType for i16 {
1149 fn rex_type() -> Type {
1150 Type::builtin(BuiltinTypeId::I16)
1151 }
1152}
1153
1154impl RexType for i32 {
1155 fn rex_type() -> Type {
1156 Type::builtin(BuiltinTypeId::I32)
1157 }
1158}
1159
1160impl RexType for i64 {
1161 fn rex_type() -> Type {
1162 Type::builtin(BuiltinTypeId::I64)
1163 }
1164}
1165
1166impl RexType for f32 {
1167 fn rex_type() -> Type {
1168 Type::builtin(BuiltinTypeId::F32)
1169 }
1170}
1171
1172impl RexType for f64 {
1173 fn rex_type() -> Type {
1174 Type::builtin(BuiltinTypeId::F64)
1175 }
1176}
1177
1178impl RexType for String {
1179 fn rex_type() -> Type {
1180 Type::builtin(BuiltinTypeId::String)
1181 }
1182}
1183
1184impl RexType for &str {
1185 fn rex_type() -> Type {
1186 Type::builtin(BuiltinTypeId::String)
1187 }
1188}
1189
1190impl RexType for Uuid {
1191 fn rex_type() -> Type {
1192 Type::builtin(BuiltinTypeId::Uuid)
1193 }
1194}
1195
1196impl RexType for DateTime<Utc> {
1197 fn rex_type() -> Type {
1198 Type::builtin(BuiltinTypeId::DateTime)
1199 }
1200}
1201
1202impl<T: RexType> RexType for Vec<T> {
1203 fn rex_type() -> Type {
1204 Type::app(Type::builtin(BuiltinTypeId::Array), T::rex_type())
1205 }
1206}
1207
1208impl<T: RexType> RexType for Option<T> {
1209 fn rex_type() -> Type {
1210 Type::app(Type::builtin(BuiltinTypeId::Option), T::rex_type())
1211 }
1212}
1213
1214impl<T: RexType, E: RexType> RexType for Result<T, E> {
1215 fn rex_type() -> Type {
1216 Type::app(
1217 Type::app(Type::builtin(BuiltinTypeId::Result), E::rex_type()),
1218 T::rex_type(),
1219 )
1220 }
1221}
1222
1223impl RexType for () {
1224 fn rex_type() -> Type {
1225 Type::tuple(vec![])
1226 }
1227}
1228
1229macro_rules! impl_tuple_rex_type {
1230 ($($name:ident),+) => {
1231 impl<$($name: RexType),+> RexType for ($($name,)+) {
1232 fn rex_type() -> Type {
1233 Type::tuple(vec![$($name::rex_type()),+])
1234 }
1235 }
1236 };
1237}
1238
1239impl_tuple_rex_type!(A0);
1240impl_tuple_rex_type!(A0, A1);
1241impl_tuple_rex_type!(A0, A1, A2);
1242impl_tuple_rex_type!(A0, A1, A2, A3);
1243impl_tuple_rex_type!(A0, A1, A2, A3, A4);
1244impl_tuple_rex_type!(A0, A1, A2, A3, A4, A5);
1245impl_tuple_rex_type!(A0, A1, A2, A3, A4, A5, A6);
1246impl_tuple_rex_type!(A0, A1, A2, A3, A4, A5, A6, A7);
1247
1248#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
1249pub struct Class {
1250 pub supers: Vec<Symbol>,
1251}
1252
1253impl Class {
1254 pub fn new(supers: Vec<Symbol>) -> Self {
1255 Self { supers }
1256 }
1257}
1258
1259#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
1260pub struct Instance {
1261 pub context: Vec<Predicate>,
1262 pub head: Predicate,
1263}
1264
1265impl Instance {
1266 pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
1267 Self { context, head }
1268 }
1269}
1270
1271#[derive(Default, Debug, Clone)]
1272pub struct ClassEnv {
1273 pub classes: BTreeMap<Symbol, Class>,
1274 pub instances: BTreeMap<Symbol, Vec<Instance>>,
1275}
1276
1277impl ClassEnv {
1278 pub fn new() -> Self {
1279 Self {
1280 classes: BTreeMap::new(),
1281 instances: BTreeMap::new(),
1282 }
1283 }
1284
1285 pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
1286 self.classes.insert(name, Class::new(supers));
1287 }
1288
1289 pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
1290 self.instances.entry(class).or_default().push(inst);
1291 }
1292
1293 pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
1294 self.classes
1295 .get(class)
1296 .map(|c| c.supers.clone())
1297 .unwrap_or_default()
1298 }
1299}
1300
1301pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
1338 let mut out = Vec::new();
1339 let mut seen = BTreeSet::new();
1340 let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
1341 for typ in &types {
1342 typ.for_each(|t| {
1343 if let TypeKind::Con(tc) = t.as_ref() {
1344 if let Some(name) = tc.user_name() {
1346 let adt = Type::new(TypeKind::Con(tc.clone()));
1347 if seen.insert(adt.clone()) {
1348 out.push(adt.clone());
1349 }
1350 let defs = defs_by_name.entry(name.clone()).or_default();
1351 if !defs.contains(&adt) {
1352 defs.push(adt);
1353 }
1354 }
1355 }
1356 });
1357 }
1358
1359 let conflicts: Vec<AdtConflict> = defs_by_name
1360 .into_iter()
1361 .filter_map(|(name, definitions)| {
1362 (definitions.len() > 1).then_some(AdtConflict { name, definitions })
1363 })
1364 .collect();
1365 if !conflicts.is_empty() {
1366 return Err(CollectAdtsError { conflicts });
1367 }
1368
1369 Ok(out)
1370}
1371
1372fn collect_adts_error_to_type(err: CollectAdtsError) -> TypeError {
1373 let details = err
1374 .conflicts
1375 .into_iter()
1376 .map(|conflict| {
1377 let defs = conflict
1378 .definitions
1379 .iter()
1380 .map(ToString::to_string)
1381 .collect::<Vec<_>>()
1382 .join(", ");
1383 format!("{}: [{defs}]", conflict.name)
1384 })
1385 .collect::<Vec<_>>()
1386 .join("; ");
1387 TypeError::Internal(format!(
1388 "conflicting ADT definitions discovered in input types: {details}"
1389 ))
1390}
1391
1392fn type_head_and_args_for_adt_family(typ: &Type) -> Result<(Symbol, usize, Vec<Type>), TypeError> {
1393 let mut args = Vec::new();
1394 let mut head = typ;
1395 while let TypeKind::App(f, arg) = head.as_ref() {
1396 args.push(arg.clone());
1397 head = f;
1398 }
1399 args.reverse();
1400
1401 let TypeKind::Con(con) = head.as_ref() else {
1402 return Err(TypeError::Internal(format!(
1403 "cannot build ADT declaration from non-constructor type `{typ}`"
1404 )));
1405 };
1406 if !args.is_empty() && args.len() != con.arity() {
1407 return Err(TypeError::Internal(format!(
1408 "constructor `{}` expected {} type arguments but got {} in `{typ}`",
1409 con.name_str(),
1410 con.arity(),
1411 args.len()
1412 )));
1413 }
1414 Ok((con.name(), con.arity(), args))
1415}
1416
1417fn type_head_for_adt_family(typ: &Type) -> Result<Type, TypeError> {
1418 let (name, arity, _args) = type_head_and_args_for_adt_family(typ)?;
1419 Ok(Type::con(name.as_ref(), arity))
1420}
1421
1422fn adt_shape(adt: &AdtDecl) -> String {
1423 let param_names: BTreeMap<_, _> = adt
1424 .params
1425 .iter()
1426 .enumerate()
1427 .map(|(idx, param)| (param.var.id, format!("t{idx}")))
1428 .collect();
1429 let mut variants = adt
1430 .variants
1431 .iter()
1432 .map(|variant| {
1433 let args = variant
1434 .args
1435 .iter()
1436 .map(|arg| normalize_type_for_shape(arg, ¶m_names))
1437 .collect::<Vec<_>>()
1438 .join(", ");
1439 format!("{}({args})", variant.name)
1440 })
1441 .collect::<Vec<_>>();
1442 variants.sort();
1443 format!("{}[{}]", adt.name, variants.join(" | "))
1444}
1445
1446fn normalize_type_for_shape(typ: &Type, param_names: &BTreeMap<usize, String>) -> String {
1447 match typ.as_ref() {
1448 TypeKind::Var(tv) => param_names
1449 .get(&tv.id)
1450 .cloned()
1451 .unwrap_or_else(|| format!("v{}", tv.id)),
1452 TypeKind::Con(con) => con.name_str().to_string(),
1453 TypeKind::App(fun, arg) => format!(
1454 "({} {})",
1455 normalize_type_for_shape(fun, param_names),
1456 normalize_type_for_shape(arg, param_names)
1457 ),
1458 TypeKind::Fun(arg, ret) => format!(
1459 "({} -> {})",
1460 normalize_type_for_shape(arg, param_names),
1461 normalize_type_for_shape(ret, param_names)
1462 ),
1463 TypeKind::Tuple(elems) => format!(
1464 "({})",
1465 elems
1466 .iter()
1467 .map(|elem| normalize_type_for_shape(elem, param_names))
1468 .collect::<Vec<_>>()
1469 .join(", ")
1470 ),
1471 TypeKind::Record(fields) => format!(
1472 "{{{}}}",
1473 fields
1474 .iter()
1475 .map(|(name, typ)| format!(
1476 "{name}: {}",
1477 normalize_type_for_shape(typ, param_names)
1478 ))
1479 .collect::<Vec<_>>()
1480 .join(", ")
1481 ),
1482 }
1483}
1484
1485fn adt_shape_eq(left: &AdtDecl, right: &AdtDecl) -> bool {
1486 adt_shape(left) == adt_shape(right)
1487}
1488
1489fn adt_direct_dependencies(adt: &AdtDecl) -> Result<Vec<Type>, TypeError> {
1490 let types = adt
1491 .variants
1492 .iter()
1493 .flat_map(|variant| variant.args.iter().cloned())
1494 .collect::<Vec<_>>();
1495 let deps = collect_adts_in_types(types).map_err(collect_adts_error_to_type)?;
1496 deps.into_iter()
1497 .map(|typ| type_head_for_adt_family(&typ))
1498 .collect()
1499}
1500
1501pub fn order_adt_family(adts: Vec<AdtDecl>) -> Result<Vec<AdtDecl>, TypeError> {
1514 let mut unique = BTreeMap::new();
1515 for adt in adts {
1516 match unique.get(&adt.name) {
1517 Some(existing) if adt_shape_eq(existing, &adt) => {}
1518 Some(existing) => {
1519 return Err(TypeError::Internal(format!(
1520 "conflicting ADT family definitions for `{}`: {} vs {}",
1521 adt.name,
1522 adt_shape(existing),
1523 adt_shape(&adt)
1524 )));
1525 }
1526 None => {
1527 unique.insert(adt.name.clone(), adt);
1528 }
1529 }
1530 }
1531
1532 let mut visiting = Vec::<Symbol>::new();
1533 let mut visited = BTreeSet::<Symbol>::new();
1534 let mut ordered = Vec::<AdtDecl>::new();
1535
1536 fn visit(
1537 name: &Symbol,
1538 unique: &BTreeMap<Symbol, AdtDecl>,
1539 visiting: &mut Vec<Symbol>,
1540 visited: &mut BTreeSet<Symbol>,
1541 ordered: &mut Vec<AdtDecl>,
1542 ) -> Result<(), TypeError> {
1543 if visited.contains(name) {
1544 return Ok(());
1545 }
1546 if let Some(idx) = visiting.iter().position(|current| current == name) {
1547 let mut cycle = visiting[idx..]
1548 .iter()
1549 .map(ToString::to_string)
1550 .collect::<Vec<_>>();
1551 cycle.push(name.to_string());
1552 return Err(TypeError::Internal(format!(
1553 "cyclic ADT auto-registration is not supported yet: {}",
1554 cycle.join(" -> ")
1555 )));
1556 }
1557
1558 let adt = unique
1559 .get(name)
1560 .ok_or_else(|| TypeError::Internal(format!("missing ADT `{name}` during ordering")))?;
1561 visiting.push(name.clone());
1562 for dep in adt_direct_dependencies(adt)? {
1563 let dep_head = type_head_for_adt_family(&dep)?;
1564 let TypeKind::Con(dep_con) = dep_head.as_ref() else {
1565 return Err(TypeError::Internal(format!(
1566 "dependency head for `{name}` was not a constructor"
1567 )));
1568 };
1569 if let Some(name) = dep_con.user_name()
1570 && unique.contains_key(name)
1571 {
1572 visit(name, unique, visiting, visited, ordered)?;
1573 }
1574 }
1575 visiting.pop();
1576 visited.insert(name.clone());
1577 ordered.push(adt.clone());
1578 Ok(())
1579 }
1580
1581 let mut names = unique.keys().cloned().collect::<Vec<_>>();
1582 names.sort();
1583 for name in names {
1584 visit(&name, &unique, &mut visiting, &mut visited, &mut ordered)?;
1585 }
1586 Ok(ordered)
1587}