1use crate::{
2 error::{AdtConflict, CollectAdtsError},
3 typesystem::TypeVarSupply,
4 unification::{Subst, subst_is_empty},
5};
6use chrono::{DateTime, Utc};
7use rex_ast::expr::{Pattern, Symbol, intern, sym};
8use rpds::HashTrieMapSync;
9use std::{
10 collections::{BTreeMap, BTreeSet},
11 fmt::{self, Display, Formatter},
12 sync::Arc,
13};
14use uuid::Uuid;
15
16pub type TypeVarId = usize;
17
18#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
19pub enum BuiltinTypeId {
20 U8,
21 U16,
22 U32,
23 U64,
24 I8,
25 I16,
26 I32,
27 I64,
28 F32,
29 F64,
30 Bool,
31 String,
32 Uuid,
33 DateTime,
34 List,
35 Array,
36 Dict,
37 Option,
38 Promise,
39 Result,
40}
41
42impl BuiltinTypeId {
43 pub fn as_symbol(self) -> Symbol {
44 sym(self.as_str())
45 }
46
47 pub fn as_str(self) -> &'static str {
48 match self {
49 Self::U8 => "u8",
50 Self::U16 => "u16",
51 Self::U32 => "u32",
52 Self::U64 => "u64",
53 Self::I8 => "i8",
54 Self::I16 => "i16",
55 Self::I32 => "i32",
56 Self::I64 => "i64",
57 Self::F32 => "f32",
58 Self::F64 => "f64",
59 Self::Bool => "bool",
60 Self::String => "string",
61 Self::Uuid => "uuid",
62 Self::DateTime => "datetime",
63 Self::List => "List",
64 Self::Array => "Array",
65 Self::Dict => "Dict",
66 Self::Option => "Option",
67 Self::Promise => "Promise",
68 Self::Result => "Result",
69 }
70 }
71
72 pub fn arity(self) -> usize {
73 match self {
74 Self::List | Self::Array | Self::Dict | Self::Option | Self::Promise => 1,
75 Self::Result => 2,
76 _ => 0,
77 }
78 }
79
80 pub fn from_symbol(name: &Symbol) -> Option<Self> {
81 Self::from_name(name.as_ref())
82 }
83
84 pub fn from_name(name: &str) -> Option<Self> {
85 match name {
86 "u8" => Some(Self::U8),
87 "u16" => Some(Self::U16),
88 "u32" => Some(Self::U32),
89 "u64" => Some(Self::U64),
90 "i8" => Some(Self::I8),
91 "i16" => Some(Self::I16),
92 "i32" => Some(Self::I32),
93 "i64" => Some(Self::I64),
94 "f32" => Some(Self::F32),
95 "f64" => Some(Self::F64),
96 "bool" => Some(Self::Bool),
97 "string" => Some(Self::String),
98 "uuid" => Some(Self::Uuid),
99 "datetime" => Some(Self::DateTime),
100 "List" => Some(Self::List),
101 "Array" => Some(Self::Array),
102 "Dict" => Some(Self::Dict),
103 "Option" => Some(Self::Option),
104 "Promise" => Some(Self::Promise),
105 "Result" => Some(Self::Result),
106 _ => None,
107 }
108 }
109}
110
111#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
112pub struct TypeVar {
113 pub id: TypeVarId,
114 pub name: Option<Symbol>,
115}
116
117impl TypeVar {
118 pub fn new(id: TypeVarId, name: impl Into<Option<Symbol>>) -> Self {
119 Self {
120 id,
121 name: name.into(),
122 }
123 }
124}
125
126#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
127pub struct TypeConst {
128 pub name: Symbol,
129 pub arity: usize,
130 pub builtin_id: Option<BuiltinTypeId>,
131}
132
133#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
134pub struct Type(Arc<TypeKind>);
135
136#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
137pub enum TypeKind {
138 Var(TypeVar),
139 Con(TypeConst),
140 App(Type, Type),
141 Fun(Type, Type),
142 Tuple(Vec<Type>),
143 Record(Vec<(Symbol, Type)>),
148}
149
150impl Type {
151 pub fn new(kind: TypeKind) -> Self {
152 Type(Arc::new(kind))
153 }
154
155 pub fn con(name: impl AsRef<str>, arity: usize) -> Self {
156 if let Some(id) = BuiltinTypeId::from_name(name.as_ref())
157 && id.arity() == arity
158 {
159 return Self::builtin(id);
160 }
161 Self::user_con(name, arity)
162 }
163
164 pub fn user_con(name: impl AsRef<str>, arity: usize) -> Self {
165 Type::new(TypeKind::Con(TypeConst {
166 name: intern(name.as_ref()),
167 arity,
168 builtin_id: None,
169 }))
170 }
171
172 pub fn builtin(id: BuiltinTypeId) -> Self {
173 Type::new(TypeKind::Con(TypeConst {
174 name: id.as_symbol(),
175 arity: id.arity(),
176 builtin_id: Some(id),
177 }))
178 }
179
180 pub fn var(tv: TypeVar) -> Self {
181 Type::new(TypeKind::Var(tv))
182 }
183
184 pub fn fun(a: Type, b: Type) -> Self {
185 Type::new(TypeKind::Fun(a, b))
186 }
187
188 pub fn app(f: Type, arg: Type) -> Self {
189 Type::new(TypeKind::App(f, arg))
190 }
191
192 pub fn tuple(elems: Vec<Type>) -> Self {
193 Type::new(TypeKind::Tuple(elems))
194 }
195
196 pub fn record(mut fields: Vec<(Symbol, Type)>) -> Self {
197 fields.sort_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
200 Type::new(TypeKind::Record(fields))
201 }
202
203 pub fn list(elem: Type) -> Type {
204 Type::app(Type::builtin(BuiltinTypeId::List), elem)
205 }
206
207 pub fn array(elem: Type) -> Type {
208 Type::app(Type::builtin(BuiltinTypeId::Array), elem)
209 }
210
211 pub fn dict(elem: Type) -> Type {
212 Type::app(Type::builtin(BuiltinTypeId::Dict), elem)
213 }
214
215 pub fn option(elem: Type) -> Type {
216 Type::app(Type::builtin(BuiltinTypeId::Option), elem)
217 }
218
219 pub fn promise(elem: Type) -> Type {
220 Type::app(Type::builtin(BuiltinTypeId::Promise), elem)
221 }
222
223 pub fn result(ok: Type, err: Type) -> Type {
224 Type::app(Type::app(Type::builtin(BuiltinTypeId::Result), err), ok)
225 }
226
227 fn apply_with_change(&self, s: &Subst) -> (Type, bool) {
228 match self.as_ref() {
229 TypeKind::Var(tv) => match s.get(&tv.id) {
230 Some(ty) => (ty.clone(), true),
231 None => (self.clone(), false),
232 },
233 TypeKind::Con(_) => (self.clone(), false),
234 TypeKind::App(l, r) => {
235 let (l_new, l_changed) = l.apply_with_change(s);
236 let (r_new, r_changed) = r.apply_with_change(s);
237 if l_changed || r_changed {
238 (Type::app(l_new, r_new), true)
239 } else {
240 (self.clone(), false)
241 }
242 }
243 TypeKind::Fun(_, _) => {
244 let mut args = Vec::new();
247 let mut changed = false;
248 let mut cur: &Type = self;
249 while let TypeKind::Fun(a, b) = cur.as_ref() {
250 let (a_new, a_changed) = a.apply_with_change(s);
251 changed |= a_changed;
252 args.push(a_new);
253 cur = b;
254 }
255 let (ret_new, ret_changed) = cur.apply_with_change(s);
256 changed |= ret_changed;
257 if !changed {
258 return (self.clone(), false);
259 }
260 let mut out = ret_new;
261 for a_new in args.into_iter().rev() {
262 out = Type::fun(a_new, out);
263 }
264 (out, true)
265 }
266 TypeKind::Tuple(ts) => {
267 let mut changed = false;
268 let mut out = Vec::with_capacity(ts.len());
269 for t in ts {
270 let (t_new, t_changed) = t.apply_with_change(s);
271 changed |= t_changed;
272 out.push(t_new);
273 }
274 if changed {
275 (Type::new(TypeKind::Tuple(out)), true)
276 } else {
277 (self.clone(), false)
278 }
279 }
280 TypeKind::Record(fields) => {
281 let mut changed = false;
282 let mut out = Vec::with_capacity(fields.len());
283 for (k, v) in fields {
284 let (v_new, v_changed) = v.apply_with_change(s);
285 changed |= v_changed;
286 out.push((k.clone(), v_new));
287 }
288 if changed {
289 (Type::new(TypeKind::Record(out)), true)
290 } else {
291 (self.clone(), false)
292 }
293 }
294 }
295 }
296
297 pub fn for_each<F>(&self, mut f: F) -> Type
298 where
299 F: FnMut(&Type),
300 {
301 self.transform(|t| {
302 f(t);
303 None
304 })
305 }
306
307 pub fn transform<F>(&self, mut f: F) -> Type
308 where
309 F: FnMut(&Type) -> Option<Type>,
310 {
311 self.transform_ref(&mut f)
312 }
313
314 fn transform_ref<F>(&self, f: &mut F) -> Type
315 where
316 F: FnMut(&Type) -> Option<Type>,
317 {
318 if let Some(repl) = f(self) {
319 return repl;
320 }
321
322 match self.as_ref() {
323 TypeKind::Var(type_var) => Type(Arc::new(TypeKind::Var(type_var.clone()))),
324 TypeKind::Con(type_const) => Type(Arc::new(TypeKind::Con(type_const.clone()))),
325 TypeKind::App(fun, arg) => Type(Arc::new(TypeKind::App(
326 fun.transform_ref(f),
327 arg.transform_ref(f),
328 ))),
329 TypeKind::Fun(arg, res) => Type(Arc::new(TypeKind::Fun(
330 arg.transform_ref(f),
331 res.transform_ref(f),
332 ))),
333 TypeKind::Tuple(ts) => Type(Arc::new(TypeKind::Tuple(
334 ts.iter().map(|t| t.transform_ref(f)).collect(),
335 ))),
336 TypeKind::Record(fields) => Type(Arc::new(TypeKind::Record(
337 fields
338 .iter()
339 .map(|(s, t)| (s.clone(), t.transform_ref(f)))
340 .collect(),
341 ))),
342 }
343 }
344}
345
346impl AsRef<TypeKind> for Type {
347 fn as_ref(&self) -> &TypeKind {
348 self.0.as_ref()
349 }
350}
351
352impl std::ops::Deref for Type {
353 type Target = TypeKind;
354
355 fn deref(&self) -> &Self::Target {
356 &self.0
357 }
358}
359
360impl Display for Type {
361 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
362 match self.as_ref() {
363 TypeKind::Var(tv) => match &tv.name {
364 Some(name) => write!(f, "'{}", name),
365 None => write!(f, "t{}", tv.id),
366 },
367 TypeKind::Con(c) => write!(f, "{}", c.name),
368 TypeKind::App(l, r) => {
369 if let TypeKind::App(head, err) = l.as_ref()
375 && matches!(
376 head.as_ref(),
377 TypeKind::Con(c)
378 if c.builtin_id == Some(BuiltinTypeId::Result) && c.arity == 2
379 )
380 {
381 return write!(f, "(Result {} {})", r, err);
382 }
383 write!(f, "({} {})", l, r)
384 }
385 TypeKind::Fun(a, b) => write!(f, "({} -> {})", a, b),
386 TypeKind::Tuple(elems) => {
387 write!(f, "(")?;
388 for (i, t) in elems.iter().enumerate() {
389 write!(f, "{}", t)?;
390 if i + 1 < elems.len() {
391 write!(f, ", ")?;
392 }
393 }
394 write!(f, ")")
395 }
396 TypeKind::Record(fields) => {
397 write!(f, "{{")?;
398 for (i, (name, ty)) in fields.iter().enumerate() {
399 write!(f, "{}: {}", name, ty)?;
400 if i + 1 < fields.len() {
401 write!(f, ", ")?;
402 }
403 }
404 write!(f, "}}")
405 }
406 }
407 }
408}
409
410#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
411pub struct Predicate {
412 pub class: Symbol,
413 pub typ: Type,
414}
415
416impl Predicate {
417 pub fn new(class: impl AsRef<str>, typ: Type) -> Self {
418 Self {
419 class: intern(class.as_ref()),
420 typ,
421 }
422 }
423}
424
425#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
426pub struct Scheme {
427 pub vars: Vec<TypeVar>,
428 pub preds: Vec<Predicate>,
429 pub typ: Type,
430}
431
432impl Scheme {
433 pub fn new(vars: Vec<TypeVar>, preds: Vec<Predicate>, typ: Type) -> Self {
434 Self { vars, preds, typ }
435 }
436}
437
438pub trait Types: Sized {
439 fn apply(&self, s: &Subst) -> Self;
440 fn ftv(&self) -> BTreeSet<TypeVarId>;
441}
442
443impl Types for Type {
444 fn apply(&self, s: &Subst) -> Self {
445 self.apply_with_change(s).0
446 }
447
448 fn ftv(&self) -> BTreeSet<TypeVarId> {
449 let mut out = BTreeSet::new();
450 let mut stack: Vec<&Type> = vec![self];
451 while let Some(t) = stack.pop() {
452 match t.as_ref() {
453 TypeKind::Var(tv) => {
454 out.insert(tv.id);
455 }
456 TypeKind::Con(_) => {}
457 TypeKind::App(l, r) => {
458 stack.push(l);
459 stack.push(r);
460 }
461 TypeKind::Fun(a, b) => {
462 stack.push(a);
463 stack.push(b);
464 }
465 TypeKind::Tuple(ts) => {
466 for t in ts {
467 stack.push(t);
468 }
469 }
470 TypeKind::Record(fields) => {
471 for (_, ty) in fields {
472 stack.push(ty);
473 }
474 }
475 }
476 }
477 out
478 }
479}
480
481impl Types for Predicate {
482 fn apply(&self, s: &Subst) -> Self {
483 Predicate {
484 class: self.class.clone(),
485 typ: self.typ.apply(s),
486 }
487 }
488
489 fn ftv(&self) -> BTreeSet<TypeVarId> {
490 self.typ.ftv()
491 }
492}
493
494impl Types for Scheme {
495 fn apply(&self, s: &Subst) -> Self {
496 let mut s_pruned = Subst::new_sync();
497 for (k, v) in s.iter() {
498 if !self.vars.iter().any(|var| var.id == *k) {
499 s_pruned = s_pruned.insert(*k, v.clone());
500 }
501 }
502 Scheme::new(
503 self.vars.clone(),
504 self.preds.iter().map(|p| p.apply(&s_pruned)).collect(),
505 self.typ.apply(&s_pruned),
506 )
507 }
508
509 fn ftv(&self) -> BTreeSet<TypeVarId> {
510 let mut ftv = self.typ.ftv();
511 for p in &self.preds {
512 ftv.extend(p.ftv());
513 }
514 for v in &self.vars {
515 ftv.remove(&v.id);
516 }
517 ftv
518 }
519}
520
521impl<T: Types> Types for Vec<T> {
522 fn apply(&self, s: &Subst) -> Self {
523 self.iter().map(|t| t.apply(s)).collect()
524 }
525
526 fn ftv(&self) -> BTreeSet<TypeVarId> {
527 self.iter().flat_map(Types::ftv).collect()
528 }
529}
530
531#[derive(Clone, Debug, PartialEq)]
532pub struct TypedExpr {
533 pub typ: Type,
534 pub kind: TypedExprKind,
535}
536
537impl TypedExpr {
538 pub fn new(typ: Type, kind: TypedExprKind) -> Self {
539 Self { typ, kind }
540 }
541
542 pub fn apply(&self, s: &Subst) -> Self {
543 match &self.kind {
544 TypedExprKind::Lam { .. } => {
545 let mut params: Vec<(Symbol, Type)> = Vec::new();
546 let mut cur = self;
547 while let TypedExprKind::Lam { param, body } = &cur.kind {
548 params.push((param.clone(), cur.typ.apply(s)));
549 cur = body.as_ref();
550 }
551 let mut out = cur.apply(s);
552 for (param, typ) in params.into_iter().rev() {
553 out = TypedExpr {
554 typ,
555 kind: TypedExprKind::Lam {
556 param,
557 body: Box::new(out),
558 },
559 };
560 }
561 return out;
562 }
563 TypedExprKind::App(..) => {
564 let mut apps: Vec<(Type, &TypedExpr)> = Vec::new();
565 let mut cur = self;
566 while let TypedExprKind::App(f, x) = &cur.kind {
567 apps.push((cur.typ.apply(s), x.as_ref()));
568 cur = f.as_ref();
569 }
570 let mut out = cur.apply(s);
571 for (typ, arg) in apps.into_iter().rev() {
572 out = TypedExpr {
573 typ,
574 kind: TypedExprKind::App(Box::new(out), Box::new(arg.apply(s))),
575 };
576 }
577 return out;
578 }
579 _ => {}
580 }
581
582 let typ = self.typ.apply(s);
583 let kind = match &self.kind {
584 TypedExprKind::Bool(v) => TypedExprKind::Bool(*v),
585 TypedExprKind::Uint(v) => TypedExprKind::Uint(*v),
586 TypedExprKind::Int(v) => TypedExprKind::Int(*v),
587 TypedExprKind::Float(v) => TypedExprKind::Float(*v),
588 TypedExprKind::String(v) => TypedExprKind::String(v.clone()),
589 TypedExprKind::Uuid(v) => TypedExprKind::Uuid(*v),
590 TypedExprKind::DateTime(v) => TypedExprKind::DateTime(*v),
591 TypedExprKind::Hole => TypedExprKind::Hole,
592 TypedExprKind::Tuple(elems) => {
593 TypedExprKind::Tuple(elems.iter().map(|e| e.apply(s)).collect())
594 }
595 TypedExprKind::List(elems) => {
596 TypedExprKind::List(elems.iter().map(|e| e.apply(s)).collect())
597 }
598 TypedExprKind::Dict(kvs) => {
599 let mut out = BTreeMap::new();
600 for (k, v) in kvs {
601 out.insert(k.clone(), v.apply(s));
602 }
603 TypedExprKind::Dict(out)
604 }
605 TypedExprKind::RecordUpdate { base, updates } => {
606 let mut out = BTreeMap::new();
607 for (k, v) in updates {
608 out.insert(k.clone(), v.apply(s));
609 }
610 TypedExprKind::RecordUpdate {
611 base: Box::new(base.apply(s)),
612 updates: out,
613 }
614 }
615 TypedExprKind::Var { name, overloads } => TypedExprKind::Var {
616 name: name.clone(),
617 overloads: overloads.iter().map(|t| t.apply(s)).collect(),
618 },
619 TypedExprKind::App(f, x) => {
620 TypedExprKind::App(Box::new(f.apply(s)), Box::new(x.apply(s)))
621 }
622 TypedExprKind::Project { expr, field } => TypedExprKind::Project {
623 expr: Box::new(expr.apply(s)),
624 field: field.clone(),
625 },
626 TypedExprKind::Lam { param, body } => TypedExprKind::Lam {
627 param: param.clone(),
628 body: Box::new(body.apply(s)),
629 },
630 TypedExprKind::Let { name, def, body } => TypedExprKind::Let {
631 name: name.clone(),
632 def: Box::new(def.apply(s)),
633 body: Box::new(body.apply(s)),
634 },
635 TypedExprKind::LetRec { bindings, body } => TypedExprKind::LetRec {
636 bindings: bindings
637 .iter()
638 .map(|(name, def)| (name.clone(), def.apply(s)))
639 .collect(),
640 body: Box::new(body.apply(s)),
641 },
642 TypedExprKind::Ite {
643 cond,
644 then_expr,
645 else_expr,
646 } => TypedExprKind::Ite {
647 cond: Box::new(cond.apply(s)),
648 then_expr: Box::new(then_expr.apply(s)),
649 else_expr: Box::new(else_expr.apply(s)),
650 },
651 TypedExprKind::Match { scrutinee, arms } => TypedExprKind::Match {
652 scrutinee: Box::new(scrutinee.apply(s)),
653 arms: arms.iter().map(|(p, e)| (p.clone(), e.apply(s))).collect(),
654 },
655 };
656 TypedExpr { typ, kind }
657 }
658}
659
660#[derive(Clone, Debug, PartialEq)]
661pub enum TypedExprKind {
662 Bool(bool),
663 Uint(u64),
664 Int(i64),
665 Float(f64),
666 String(String),
667 Uuid(Uuid),
668 DateTime(DateTime<Utc>),
669 Hole,
670 Tuple(Vec<TypedExpr>),
671 List(Vec<TypedExpr>),
672 Dict(BTreeMap<Symbol, TypedExpr>),
673 RecordUpdate {
674 base: Box<TypedExpr>,
675 updates: BTreeMap<Symbol, TypedExpr>,
676 },
677 Var {
678 name: Symbol,
679 overloads: Vec<Type>,
680 },
681 App(Box<TypedExpr>, Box<TypedExpr>),
682 Project {
683 expr: Box<TypedExpr>,
684 field: Symbol,
685 },
686 Lam {
687 param: Symbol,
688 body: Box<TypedExpr>,
689 },
690 Let {
691 name: Symbol,
692 def: Box<TypedExpr>,
693 body: Box<TypedExpr>,
694 },
695 LetRec {
696 bindings: Vec<(Symbol, TypedExpr)>,
697 body: Box<TypedExpr>,
698 },
699 Ite {
700 cond: Box<TypedExpr>,
701 then_expr: Box<TypedExpr>,
702 else_expr: Box<TypedExpr>,
703 },
704 Match {
705 scrutinee: Box<TypedExpr>,
706 arms: Vec<(Pattern, TypedExpr)>,
707 },
708}
709
710#[derive(Default, Debug, Clone)]
711pub struct TypeEnv {
712 pub values: HashTrieMapSync<Symbol, Vec<Scheme>>,
713}
714
715impl TypeEnv {
716 pub fn new() -> Self {
717 Self {
718 values: HashTrieMapSync::new_sync(),
719 }
720 }
721
722 pub fn extend(&mut self, name: Symbol, scheme: Scheme) {
723 self.values = self.values.insert(name, vec![scheme]);
724 }
725
726 pub fn extend_overload(&mut self, name: Symbol, scheme: Scheme) {
727 let mut schemes = self.values.get(&name).cloned().unwrap_or_default();
728 schemes.push(scheme);
729 self.values = self.values.insert(name, schemes);
730 }
731
732 pub fn remove(&mut self, name: &Symbol) {
733 self.values = self.values.remove(name);
734 }
735
736 pub fn lookup(&self, name: &Symbol) -> Option<&[Scheme]> {
737 self.values.get(name).map(|schemes| schemes.as_slice())
738 }
739}
740
741impl Types for TypeEnv {
742 fn apply(&self, s: &Subst) -> Self {
743 let mut values = HashTrieMapSync::new_sync();
744 for (k, v) in self.values.iter() {
745 let updated = v
746 .iter()
747 .map(|scheme| {
748 if scheme.vars.is_empty() && !subst_is_empty(s) {
751 scheme.apply(s)
752 } else {
753 scheme.clone()
754 }
755 })
756 .collect();
757 values = values.insert(k.clone(), updated);
758 }
759 TypeEnv { values }
760 }
761
762 fn ftv(&self) -> BTreeSet<TypeVarId> {
763 self.values
764 .iter()
765 .flat_map(|(_, schemes)| schemes.iter().flat_map(Types::ftv))
766 .collect()
767 }
768}
769
770#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
772pub struct AdtParam {
773 pub name: Symbol,
774 pub var: TypeVar,
775}
776
777#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
779pub struct AdtVariant {
780 pub name: Symbol,
781 pub args: Vec<Type>,
782}
783
784#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
790pub struct AdtDecl {
791 pub name: Symbol,
792 pub params: Vec<AdtParam>,
793 pub variants: Vec<AdtVariant>,
794}
795
796impl AdtDecl {
797 pub fn new(name: &Symbol, param_names: &[Symbol], supply: &mut TypeVarSupply) -> Self {
798 let params = param_names
799 .iter()
800 .map(|p| AdtParam {
801 name: p.clone(),
802 var: supply.fresh(Some(p.clone())),
803 })
804 .collect();
805 Self {
806 name: name.clone(),
807 params,
808 variants: Vec::new(),
809 }
810 }
811
812 pub fn param_type(&self, name: &Symbol) -> Option<Type> {
813 self.params
814 .iter()
815 .find(|p| &p.name == name)
816 .map(|p| Type::var(p.var.clone()))
817 }
818
819 pub fn add_variant(&mut self, name: Symbol, args: Vec<Type>) {
820 self.variants.push(AdtVariant { name, args });
821 }
822
823 pub fn result_type(&self) -> Type {
824 let mut ty = Type::con(&self.name, self.params.len());
825 for param in &self.params {
826 ty = Type::app(ty, Type::var(param.var.clone()));
827 }
828 ty
829 }
830
831 pub fn constructor_schemes(&self) -> Vec<(Symbol, Scheme)> {
834 let result_ty = self.result_type();
835 let vars: Vec<TypeVar> = self.params.iter().map(|p| p.var.clone()).collect();
836 let mut out = Vec::new();
837 for variant in &self.variants {
838 let mut typ = result_ty.clone();
839 for arg in variant.args.iter().rev() {
840 typ = Type::fun(arg.clone(), typ);
841 }
842 out.push((variant.name.clone(), Scheme::new(vars.clone(), vec![], typ)));
843 }
844 out
845 }
846}
847
848#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
849pub struct Class {
850 pub supers: Vec<Symbol>,
851}
852
853impl Class {
854 pub fn new(supers: Vec<Symbol>) -> Self {
855 Self { supers }
856 }
857}
858
859#[derive(Clone, Debug, Hash, Eq, PartialEq, Ord, PartialOrd)]
860pub struct Instance {
861 pub context: Vec<Predicate>,
862 pub head: Predicate,
863}
864
865impl Instance {
866 pub fn new(context: Vec<Predicate>, head: Predicate) -> Self {
867 Self { context, head }
868 }
869}
870
871#[derive(Default, Debug, Clone)]
872pub struct ClassEnv {
873 pub classes: BTreeMap<Symbol, Class>,
874 pub instances: BTreeMap<Symbol, Vec<Instance>>,
875}
876
877impl ClassEnv {
878 pub fn new() -> Self {
879 Self {
880 classes: BTreeMap::new(),
881 instances: BTreeMap::new(),
882 }
883 }
884
885 pub fn add_class(&mut self, name: Symbol, supers: Vec<Symbol>) {
886 self.classes.insert(name, Class::new(supers));
887 }
888
889 pub fn add_instance(&mut self, class: Symbol, inst: Instance) {
890 self.instances.entry(class).or_default().push(inst);
891 }
892
893 pub fn supers_of(&self, class: &Symbol) -> Vec<Symbol> {
894 self.classes
895 .get(class)
896 .map(|c| c.supers.clone())
897 .unwrap_or_default()
898 }
899}
900
901pub fn collect_adts_in_types(types: Vec<Type>) -> Result<Vec<Type>, CollectAdtsError> {
938 let mut out = Vec::new();
939 let mut seen = BTreeSet::new();
940 let mut defs_by_name: BTreeMap<Symbol, Vec<Type>> = BTreeMap::new();
941 for typ in &types {
942 typ.for_each(|t| {
943 if let TypeKind::Con(tc) = t.as_ref() {
944 if tc.builtin_id.is_none() {
946 let adt = Type::new(TypeKind::Con(tc.clone()));
947 if seen.insert(adt.clone()) {
948 out.push(adt.clone());
949 }
950 let defs = defs_by_name.entry(tc.name.clone()).or_default();
951 if !defs.contains(&adt) {
952 defs.push(adt);
953 }
954 }
955 }
956 });
957 }
958
959 let conflicts: Vec<AdtConflict> = defs_by_name
960 .into_iter()
961 .filter_map(|(name, definitions)| {
962 (definitions.len() > 1).then_some(AdtConflict { name, definitions })
963 })
964 .collect();
965 if !conflicts.is_empty() {
966 return Err(CollectAdtsError { conflicts });
967 }
968
969 Ok(out)
970}