Skip to main content

rex_ast/
ast.rs

1use std::{
2    collections::BTreeMap,
3    fmt::{self, Display, Formatter},
4    mem,
5    sync::Arc,
6};
7
8use rpds::HashTrieMapSync;
9
10use chrono::{DateTime, Utc};
11use uuid::Uuid;
12
13use crate::{
14    span::{Position, Span},
15    symbol::Symbol,
16};
17
18pub type Scope = HashTrieMapSync<Symbol, Arc<Expr>>;
19
20#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
21#[serde(rename_all = "lowercase")]
22pub struct Var {
23    pub span: Span,
24    pub name: Symbol,
25}
26
27impl Var {
28    pub fn new(name: impl ToString) -> Self {
29        Self {
30            span: Span::default(),
31            name: Symbol::intern(&name.to_string()),
32        }
33    }
34
35    pub fn with_span(span: Span, name: impl ToString) -> Self {
36        Self {
37            span,
38            name: Symbol::intern(&name.to_string()),
39        }
40    }
41
42    pub fn reset_spans(&self) -> Var {
43        Var {
44            span: Span::default(),
45            name: self.name.clone(),
46        }
47    }
48}
49
50impl Display for Var {
51    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
52        match self.name.as_ref() {
53            "+" | "-" | "*" | "/" | "==" | ">=" | ">" | "<=" | "<" | "++" | "." => {
54                '('.fmt(f)?;
55                self.name.fmt(f)?;
56                ')'.fmt(f)
57            }
58            _ => self.name.fmt(f),
59        }
60    }
61}
62
63#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
64#[serde(rename_all = "lowercase")]
65pub enum NameRef {
66    Unqualified(Symbol),
67    Qualified(Symbol, Vec<Symbol>),
68}
69
70impl NameRef {
71    pub fn from_segments(segments: Vec<Symbol>) -> Self {
72        if segments.len() == 1 {
73            NameRef::Unqualified(segments[0].clone())
74        } else {
75            let dotted = Symbol::intern(
76                &segments
77                    .iter()
78                    .map(|s| s.as_ref())
79                    .collect::<Vec<_>>()
80                    .join("."),
81            );
82            NameRef::Qualified(dotted, segments)
83        }
84    }
85
86    pub fn from_dotted(name: &str) -> Self {
87        let segments: Vec<Symbol> = name.split('.').map(Symbol::intern).collect();
88        NameRef::from_segments(segments)
89    }
90
91    pub fn to_dotted_symbol(&self) -> Symbol {
92        match self {
93            NameRef::Unqualified(sym) => sym.clone(),
94            NameRef::Qualified(dotted, _) => dotted.clone(),
95        }
96    }
97
98    pub fn as_segments(&self) -> Vec<Symbol> {
99        match self {
100            NameRef::Unqualified(sym) => vec![sym.clone()],
101            NameRef::Qualified(_, segments) => segments.clone(),
102        }
103    }
104}
105
106impl From<&str> for NameRef {
107    fn from(value: &str) -> Self {
108        NameRef::from_dotted(value)
109    }
110}
111
112impl From<String> for NameRef {
113    fn from(value: String) -> Self {
114        NameRef::from_dotted(&value)
115    }
116}
117
118impl AsRef<str> for NameRef {
119    fn as_ref(&self) -> &str {
120        match self {
121            NameRef::Unqualified(sym) => sym.as_ref(),
122            NameRef::Qualified(dotted, _) => dotted.as_ref(),
123        }
124    }
125}
126
127impl Display for NameRef {
128    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
129        match self {
130            NameRef::Unqualified(sym) => sym.fmt(f),
131            NameRef::Qualified(_, segments) => {
132                for (i, segment) in segments.iter().enumerate() {
133                    if i > 0 {
134                        '.'.fmt(f)?;
135                    }
136                    segment.fmt(f)?;
137                }
138                Ok(())
139            }
140        }
141    }
142}
143
144#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
145#[serde(rename_all = "lowercase")]
146pub enum Pattern {
147    Wildcard(Span),                         // _
148    Var(Var),                               // x
149    Named(Span, NameRef, Vec<Pattern>),     // Ok x y z
150    Tuple(Span, Vec<Pattern>),              // (x, y, z)
151    List(Span, Vec<Pattern>),               // [x, y, z]
152    Cons(Span, Box<Pattern>, Box<Pattern>), // x::xs
153    Dict(Span, Vec<(Symbol, Pattern)>),     // {a, b, c} or {a: x, b: y}
154}
155
156impl Pattern {
157    pub fn span(&self) -> &Span {
158        match self {
159            Pattern::Wildcard(span, ..)
160            | Pattern::Var(Var { span, .. })
161            | Pattern::Named(span, ..)
162            | Pattern::Tuple(span, ..)
163            | Pattern::List(span, ..)
164            | Pattern::Cons(span, ..)
165            | Pattern::Dict(span, ..) => span,
166        }
167    }
168
169    pub fn with_span(&self, span: Span) -> Pattern {
170        match self {
171            Pattern::Wildcard(..) => Pattern::Wildcard(span),
172            Pattern::Var(var) => Pattern::Var(Var {
173                span,
174                name: var.name.clone(),
175            }),
176            Pattern::Named(_, name, ps) => Pattern::Named(span, name.clone(), ps.clone()),
177            Pattern::Tuple(_, ps) => Pattern::Tuple(span, ps.clone()),
178            Pattern::List(_, ps) => Pattern::List(span, ps.clone()),
179            Pattern::Cons(_, head, tail) => Pattern::Cons(span, head.clone(), tail.clone()),
180            Pattern::Dict(_, fields) => Pattern::Dict(span, fields.clone()),
181        }
182    }
183
184    pub fn reset_spans(&self) -> Pattern {
185        match self {
186            Pattern::Wildcard(..) => Pattern::Wildcard(Span::default()),
187            Pattern::Var(var) => Pattern::Var(var.reset_spans()),
188            Pattern::Named(_, name, ps) => Pattern::Named(
189                Span::default(),
190                name.clone(),
191                ps.iter().map(|p| p.reset_spans()).collect(),
192            ),
193            Pattern::Tuple(_, ps) => Pattern::Tuple(
194                Span::default(),
195                ps.iter().map(|p| p.reset_spans()).collect(),
196            ),
197            Pattern::List(_, ps) => Pattern::List(
198                Span::default(),
199                ps.iter().map(|p| p.reset_spans()).collect(),
200            ),
201            Pattern::Cons(_, head, tail) => Pattern::Cons(
202                Span::default(),
203                Box::new(head.reset_spans()),
204                Box::new(tail.reset_spans()),
205            ),
206            Pattern::Dict(_, fields) => Pattern::Dict(
207                Span::default(),
208                fields
209                    .iter()
210                    .map(|(k, p)| (k.clone(), p.reset_spans()))
211                    .collect(),
212            ),
213        }
214    }
215}
216
217impl Display for Pattern {
218    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
219        match self {
220            Pattern::Wildcard(..) => write!(f, "_"),
221            Pattern::Var(var) => var.fmt(f),
222            Pattern::Named(_, name, ps) => {
223                write!(f, "{}", name)?;
224                for p in ps {
225                    write!(f, " {}", p)?;
226                }
227                Ok(())
228            }
229            Pattern::Tuple(_, ps) => {
230                '('.fmt(f)?;
231                for (i, p) in ps.iter().enumerate() {
232                    p.fmt(f)?;
233                    if i + 1 < ps.len() {
234                        ", ".fmt(f)?;
235                    }
236                }
237                ')'.fmt(f)
238            }
239            Pattern::List(_, ps) => {
240                '['.fmt(f)?;
241                for (i, p) in ps.iter().enumerate() {
242                    p.fmt(f)?;
243                    if i + 1 < ps.len() {
244                        ", ".fmt(f)?;
245                    }
246                }
247                ']'.fmt(f)
248            }
249            Pattern::Cons(_, head, tail) => write!(f, "{}::{}", head, tail),
250            Pattern::Dict(_, fields) => {
251                '{'.fmt(f)?;
252                for (i, (key, pat)) in fields.iter().enumerate() {
253                    // Use shorthand when possible to keep output stable with old syntax.
254                    match pat {
255                        Pattern::Var(var) if var.name == *key => {
256                            key.fmt(f)?;
257                        }
258                        _ => {
259                            key.fmt(f)?;
260                            ": ".fmt(f)?;
261                            pat.fmt(f)?;
262                        }
263                    }
264                    if i + 1 < fields.len() {
265                        ", ".fmt(f)?;
266                    }
267                }
268                '}'.fmt(f)
269            }
270        }
271    }
272}
273
274#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
275#[serde(rename_all = "lowercase")]
276pub enum TypeExpr {
277    Name(Span, NameRef),
278    App(Span, Box<TypeExpr>, Box<TypeExpr>),
279    Fun(Span, Box<TypeExpr>, Box<TypeExpr>),
280    Tuple(Span, Vec<TypeExpr>),
281    Record(Span, Vec<(Symbol, TypeExpr)>),
282}
283
284impl TypeExpr {
285    pub fn span(&self) -> &Span {
286        match self {
287            TypeExpr::Name(span, ..)
288            | TypeExpr::App(span, ..)
289            | TypeExpr::Fun(span, ..)
290            | TypeExpr::Tuple(span, ..)
291            | TypeExpr::Record(span, ..) => span,
292        }
293    }
294
295    pub fn reset_spans(&self) -> TypeExpr {
296        match self {
297            TypeExpr::Name(_, name) => TypeExpr::Name(Span::default(), name.clone()),
298            TypeExpr::App(_, fun, arg) => TypeExpr::App(
299                Span::default(),
300                Box::new(fun.reset_spans()),
301                Box::new(arg.reset_spans()),
302            ),
303            TypeExpr::Fun(_, arg, ret) => TypeExpr::Fun(
304                Span::default(),
305                Box::new(arg.reset_spans()),
306                Box::new(ret.reset_spans()),
307            ),
308            TypeExpr::Tuple(_, elems) => TypeExpr::Tuple(
309                Span::default(),
310                elems.iter().map(|e| e.reset_spans()).collect(),
311            ),
312            TypeExpr::Record(_, fields) => TypeExpr::Record(
313                Span::default(),
314                fields
315                    .iter()
316                    .map(|(name, ty)| (name.clone(), ty.reset_spans()))
317                    .collect(),
318            ),
319        }
320    }
321}
322
323impl Display for TypeExpr {
324    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
325        match self {
326            TypeExpr::Name(_, name) => name.fmt(f),
327            TypeExpr::App(_, fun, arg) => {
328                match fun.as_ref() {
329                    TypeExpr::Name(..) | TypeExpr::App(..) => fun.fmt(f)?,
330                    _ => {
331                        '('.fmt(f)?;
332                        fun.fmt(f)?;
333                        ')'.fmt(f)?;
334                    }
335                }
336                ' '.fmt(f)?;
337                match arg.as_ref() {
338                    TypeExpr::Name(..)
339                    | TypeExpr::App(..)
340                    | TypeExpr::Tuple(..)
341                    | TypeExpr::Record(..) => arg.fmt(f),
342                    _ => {
343                        '('.fmt(f)?;
344                        arg.fmt(f)?;
345                        ')'.fmt(f)
346                    }
347                }
348            }
349            TypeExpr::Fun(_, arg, ret) => {
350                match arg.as_ref() {
351                    TypeExpr::Fun(..) => {
352                        '('.fmt(f)?;
353                        arg.fmt(f)?;
354                        ')'.fmt(f)?;
355                    }
356                    _ => arg.fmt(f)?,
357                }
358                " -> ".fmt(f)?;
359                ret.fmt(f)
360            }
361            TypeExpr::Tuple(_, elems) => {
362                '('.fmt(f)?;
363                for (i, elem) in elems.iter().enumerate() {
364                    elem.fmt(f)?;
365                    if i + 1 < elems.len() {
366                        ", ".fmt(f)?;
367                    }
368                }
369                ')'.fmt(f)
370            }
371            TypeExpr::Record(_, fields) => {
372                '{'.fmt(f)?;
373                for (i, (name, ty)) in fields.iter().enumerate() {
374                    name.fmt(f)?;
375                    ": ".fmt(f)?;
376                    ty.fmt(f)?;
377                    if i + 1 < fields.len() {
378                        ", ".fmt(f)?;
379                    }
380                }
381                '}'.fmt(f)
382            }
383        }
384    }
385}
386
387#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
388pub struct TypeConstraint {
389    pub class: NameRef,
390    pub typ: TypeExpr,
391}
392
393impl TypeConstraint {
394    pub fn new(class: NameRef, typ: TypeExpr) -> Self {
395        Self { class, typ }
396    }
397}
398
399#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
400pub struct TypeVariant {
401    pub name: Symbol,
402    pub args: Vec<TypeExpr>,
403}
404
405#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
406pub struct TypeDecl {
407    pub span: Span,
408    pub is_pub: bool,
409    pub name: Symbol,
410    pub params: Vec<Symbol>,
411    pub variants: Vec<TypeVariant>,
412}
413
414#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
415pub struct FnDecl {
416    pub span: Span,
417    pub is_pub: bool,
418    pub name: Var,
419    pub params: Vec<(Var, TypeExpr)>,
420    pub ret: TypeExpr,
421    pub constraints: Vec<TypeConstraint>,
422    pub body: Arc<Expr>,
423}
424
425#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
426pub struct DeclareFnDecl {
427    pub span: Span,
428    pub is_pub: bool,
429    pub name: Var,
430    pub params: Vec<(Var, TypeExpr)>,
431    pub ret: TypeExpr,
432    pub constraints: Vec<TypeConstraint>,
433}
434
435#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
436pub struct ClassMethodSig {
437    pub name: Symbol,
438    pub typ: TypeExpr,
439}
440
441#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
442pub struct ClassDecl {
443    pub span: Span,
444    pub is_pub: bool,
445    pub name: Symbol,
446    pub params: Vec<Symbol>,
447    pub supers: Vec<TypeConstraint>,
448    pub methods: Vec<ClassMethodSig>,
449}
450
451#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
452pub struct InstanceMethodImpl {
453    pub name: Symbol,
454    pub body: Arc<Expr>,
455}
456
457#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
458pub struct InstanceDecl {
459    pub span: Span,
460    pub is_pub: bool,
461    pub class: Symbol,
462    pub head: TypeExpr,
463    pub context: Vec<TypeConstraint>,
464    pub methods: Vec<InstanceMethodImpl>,
465}
466
467#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
468#[serde(rename_all = "lowercase")]
469pub enum ImportPath {
470    Local {
471        segments: Vec<Symbol>,
472        sha: Option<String>,
473    },
474    Remote {
475        url: String,
476        sha: Option<String>,
477    },
478}
479
480#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
481pub struct ImportItem {
482    pub name: Symbol,
483    pub alias: Option<Symbol>,
484}
485
486#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
487#[serde(rename_all = "lowercase")]
488pub enum ImportClause {
489    All,
490    Items(Vec<ImportItem>),
491}
492
493#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
494pub struct ImportDecl {
495    pub span: Span,
496    pub is_pub: bool,
497    pub path: ImportPath,
498    pub alias: Symbol,
499    pub clause: Option<ImportClause>,
500}
501
502#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
503pub enum Decl {
504    Type(TypeDecl),
505    Fn(FnDecl),
506    DeclareFn(DeclareFnDecl),
507    Import(ImportDecl),
508    Class(ClassDecl),
509    Instance(InstanceDecl),
510}
511
512impl std::fmt::Display for Decl {
513    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
514        match self {
515            Decl::Type(d) => write!(f, "{}", d.name),
516            Decl::Fn(d) => write!(f, "{}", d.name),
517            Decl::DeclareFn(d) => write!(f, "{}", d.name),
518            Decl::Import(d) => write!(f, "{}", d.alias),
519            Decl::Class(d) => write!(f, "{}", d.name),
520            Decl::Instance(d) => write!(f, "{} {}", d.class, d.head),
521        }
522    }
523}
524
525#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
526pub struct CompilationUnit {
527    pub decls: Vec<Decl>,
528    pub body: Option<Arc<Expr>>,
529}
530
531impl CompilationUnit {
532    /// Lower top-level `fn` declarations into nested `let` bindings around `body`.
533    ///
534    /// This keeps the surface syntax (`Decl::Fn`) intact for tools, while giving
535    /// the type checker and evaluator a plain expression to work with.
536    pub fn body_with_fns(&self) -> Option<Arc<Expr>> {
537        let mut out = self.body.clone()?;
538        for decl in self.decls.iter().rev() {
539            let Decl::Fn(fd) = decl else {
540                continue;
541            };
542
543            let mut lam_body = fd.body.clone();
544            let mut lam_end = lam_body.span().end;
545            for (idx, (param, ann)) in fd.params.iter().enumerate().rev() {
546                let lam_constraints = if idx == 0 {
547                    fd.constraints.clone()
548                } else {
549                    Vec::new()
550                };
551                let span = Span::from_begin_end(param.span.begin, lam_end);
552                lam_body = Arc::new(Expr::Lam(
553                    span,
554                    Scope::new_sync(),
555                    param.clone(),
556                    Some(ann.clone()),
557                    lam_constraints,
558                    lam_body,
559                ));
560                lam_end = lam_body.span().end;
561            }
562
563            let mut sig = fd.ret.clone();
564            for (_, ann) in fd.params.iter().rev() {
565                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
566                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
567            }
568
569            let span = Span::from_begin_end(fd.span.begin, out.span().end);
570            out = Arc::new(Expr::Let(span, fd.name.clone(), Some(sig), lam_body, out));
571        }
572        Some(out)
573    }
574}
575
576#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
577#[serde(rename_all = "lowercase")]
578pub enum Expr {
579    Bool(Span, bool),              // true
580    Uint(Span, u64),               // 69
581    Int(Span, i64),                // -420
582    Float(Span, f64),              // 3.14
583    String(Span, String),          // "hello"
584    Uuid(Span, Uuid),              // a550c18e-36e1-4f6d-8c8e-2d2b1e5f3c3a
585    DateTime(Span, DateTime<Utc>), // 2023-01-01T12:00:00Z
586    Hole(Span),                    // ?
587
588    Tuple(Span, Vec<Arc<Expr>>),             // (e1, e2, e3)
589    List(Span, Vec<Arc<Expr>>),              // [e1, e2, e3]
590    Dict(Span, BTreeMap<Symbol, Arc<Expr>>), // {k1 = v1, k2 = v2}
591    RecordUpdate(Span, Arc<Expr>, BTreeMap<Symbol, Arc<Expr>>), // {base with {k1 = v1, ...}}
592
593    Var(Var),                         // x
594    App(Span, Arc<Expr>, Arc<Expr>),  // f x
595    Project(Span, Arc<Expr>, Symbol), // x.field
596    Lam(
597        Span,
598        Scope,
599        Var,
600        Option<TypeExpr>,
601        Vec<TypeConstraint>,
602        Arc<Expr>,
603    ), // λx → e
604    Let(Span, Var, Option<TypeExpr>, Arc<Expr>, Arc<Expr>), // let x = e1 in e2
605    LetRec(Span, Vec<(Var, Option<TypeExpr>, Arc<Expr>)>, Arc<Expr>), // let rec f = e1 and g = e2 in e3
606    Ite(Span, Arc<Expr>, Arc<Expr>, Arc<Expr>),                       // if e1 then e2 else e3
607    Match(Span, Arc<Expr>, Vec<(Pattern, Arc<Expr>)>),                // match e1 with { ... }
608    Ann(Span, Arc<Expr>, TypeExpr),                                   // e is t
609}
610
611fn expr_drop_placeholder() -> Arc<Expr> {
612    Arc::new(Expr::Hole(Span::default()))
613}
614
615fn drain_expr(expr: &mut Expr, stack: &mut Vec<Arc<Expr>>) {
616    match expr {
617        Expr::Tuple(_, elems) | Expr::List(_, elems) => {
618            stack.extend(mem::take(elems));
619        }
620        Expr::Dict(_, kvs) => {
621            stack.extend(mem::take(kvs).into_values());
622        }
623        Expr::RecordUpdate(_, base, updates) => {
624            stack.push(mem::replace(base, expr_drop_placeholder()));
625            stack.extend(mem::take(updates).into_values());
626        }
627        Expr::App(_, f, x) => {
628            stack.push(mem::replace(f, expr_drop_placeholder()));
629            stack.push(mem::replace(x, expr_drop_placeholder()));
630        }
631        Expr::Project(_, base, _) | Expr::Ann(_, base, _) => {
632            stack.push(mem::replace(base, expr_drop_placeholder()));
633        }
634        Expr::Lam(_, scope, _, _, _, body) => {
635            let old_scope = mem::replace(scope, Scope::new_sync());
636            stack.extend(old_scope.iter().map(|(_, expr)| expr.clone()));
637            stack.push(mem::replace(body, expr_drop_placeholder()));
638        }
639        Expr::Let(_, _, _, def, body) => {
640            stack.push(mem::replace(def, expr_drop_placeholder()));
641            stack.push(mem::replace(body, expr_drop_placeholder()));
642        }
643        Expr::LetRec(_, bindings, body) => {
644            for (_var, _ann, def) in mem::take(bindings) {
645                stack.push(def);
646            }
647            stack.push(mem::replace(body, expr_drop_placeholder()));
648        }
649        Expr::Ite(_, cond, then_expr, else_expr) => {
650            stack.push(mem::replace(cond, expr_drop_placeholder()));
651            stack.push(mem::replace(then_expr, expr_drop_placeholder()));
652            stack.push(mem::replace(else_expr, expr_drop_placeholder()));
653        }
654        Expr::Match(_, scrutinee, arms) => {
655            stack.push(mem::replace(scrutinee, expr_drop_placeholder()));
656            for (_pat, arm) in mem::take(arms) {
657                stack.push(arm);
658            }
659        }
660        Expr::Bool(..)
661        | Expr::Uint(..)
662        | Expr::Int(..)
663        | Expr::Float(..)
664        | Expr::String(..)
665        | Expr::Uuid(..)
666        | Expr::DateTime(..)
667        | Expr::Hole(..)
668        | Expr::Var(..) => {}
669    }
670}
671
672impl Drop for Expr {
673    fn drop(&mut self) {
674        let mut stack = Vec::new();
675        drain_expr(self, &mut stack);
676        while let Some(mut expr) = stack.pop() {
677            let Some(expr) = Arc::get_mut(&mut expr) else {
678                continue;
679            };
680            drain_expr(expr, &mut stack);
681        }
682    }
683}
684
685impl Expr {
686    pub fn span(&self) -> &Span {
687        match self {
688            Self::Bool(span, ..)
689            | Self::Uint(span, ..)
690            | Self::Int(span, ..)
691            | Self::Float(span, ..)
692            | Self::String(span, ..)
693            | Self::Uuid(span, ..)
694            | Self::DateTime(span, ..)
695            | Self::Hole(span, ..)
696            | Self::Tuple(span, ..)
697            | Self::List(span, ..)
698            | Self::Dict(span, ..)
699            | Self::RecordUpdate(span, ..)
700            | Self::Var(Var { span, .. })
701            | Self::App(span, ..)
702            | Self::Project(span, ..)
703            | Self::Lam(span, ..)
704            | Self::Let(span, ..)
705            | Self::LetRec(span, ..)
706            | Self::Ite(span, ..)
707            | Self::Match(span, ..)
708            | Self::Ann(span, ..) => span,
709        }
710    }
711
712    pub fn span_mut(&mut self) -> &mut Span {
713        match self {
714            Self::Bool(span, ..)
715            | Self::Uint(span, ..)
716            | Self::Int(span, ..)
717            | Self::Float(span, ..)
718            | Self::String(span, ..)
719            | Self::Uuid(span, ..)
720            | Self::DateTime(span, ..)
721            | Self::Hole(span, ..)
722            | Self::Tuple(span, ..)
723            | Self::List(span, ..)
724            | Self::Dict(span, ..)
725            | Self::RecordUpdate(span, ..)
726            | Self::Var(Var { span, .. })
727            | Self::App(span, ..)
728            | Self::Project(span, ..)
729            | Self::Lam(span, ..)
730            | Self::Let(span, ..)
731            | Self::LetRec(span, ..)
732            | Self::Ite(span, ..)
733            | Self::Match(span, ..)
734            | Self::Ann(span, ..) => span,
735        }
736    }
737
738    pub fn with_span_begin_end(&self, begin: Position, end: Position) -> Expr {
739        self.with_span(Span::from_begin_end(begin, end))
740    }
741
742    pub fn with_span_begin(&self, begin: Position) -> Expr {
743        let end = self.span().end;
744        self.with_span(Span::from_begin_end(begin, end))
745    }
746
747    pub fn with_span_end(&self, end: Position) -> Expr {
748        let begin = self.span().begin;
749        self.with_span(Span::from_begin_end(begin, end))
750    }
751
752    pub fn with_span(&self, span: Span) -> Expr {
753        match self {
754            Expr::Bool(_, x) => Expr::Bool(span, *x),
755            Expr::Uint(_, x) => Expr::Uint(span, *x),
756            Expr::Int(_, x) => Expr::Int(span, *x),
757            Expr::Float(_, x) => Expr::Float(span, *x),
758            Expr::String(_, x) => Expr::String(span, x.clone()),
759            Expr::Uuid(_, x) => Expr::Uuid(span, *x),
760            Expr::DateTime(_, x) => Expr::DateTime(span, *x),
761            Expr::Hole(_) => Expr::Hole(span),
762            Expr::Tuple(_, elems) => Expr::Tuple(span, elems.clone()),
763            Expr::List(_, elems) => Expr::List(span, elems.clone()),
764            Expr::Dict(_, kvs) => Expr::Dict(
765                span,
766                BTreeMap::from_iter(kvs.iter().map(|(k, v)| (k.clone(), v.clone()))),
767            ),
768            Expr::RecordUpdate(_, base, updates) => Expr::RecordUpdate(
769                span,
770                base.clone(),
771                BTreeMap::from_iter(updates.iter().map(|(k, v)| (k.clone(), v.clone()))),
772            ),
773            Expr::Var(var) => Expr::Var(Var {
774                span,
775                name: var.name.clone(),
776            }),
777            Expr::App(_, f, x) => Expr::App(span, f.clone(), x.clone()),
778            Expr::Project(_, base, field) => Expr::Project(span, base.clone(), field.clone()),
779            Expr::Lam(_, scope, param, ann, constraints, body) => Expr::Lam(
780                span,
781                scope.clone(),
782                param.clone(),
783                ann.clone(),
784                constraints.clone(),
785                body.clone(),
786            ),
787            Expr::Let(_, var, ann, def, body) => {
788                Expr::Let(span, var.clone(), ann.clone(), def.clone(), body.clone())
789            }
790            Expr::LetRec(_, bindings, body) => Expr::LetRec(
791                span,
792                bindings
793                    .iter()
794                    .map(|(var, ann, def)| (var.clone(), ann.clone(), def.clone()))
795                    .collect(),
796                body.clone(),
797            ),
798            Expr::Ite(_, cond, then, r#else) => {
799                Expr::Ite(span, cond.clone(), then.clone(), r#else.clone())
800            }
801            Expr::Match(_, scrutinee, arms) => Expr::Match(
802                span,
803                scrutinee.clone(),
804                arms.iter()
805                    .map(|(pat, expr)| (pat.clone(), expr.clone()))
806                    .collect(),
807            ),
808            Expr::Ann(_, expr, ann) => Expr::Ann(span, expr.clone(), ann.clone()),
809        }
810    }
811
812    pub fn reset_spans(&self) -> Expr {
813        match self {
814            Expr::Bool(_, x) => Expr::Bool(Span::default(), *x),
815            Expr::Uint(_, x) => Expr::Uint(Span::default(), *x),
816            Expr::Int(_, x) => Expr::Int(Span::default(), *x),
817            Expr::Float(_, x) => Expr::Float(Span::default(), *x),
818            Expr::String(_, x) => Expr::String(Span::default(), x.clone()),
819            Expr::Uuid(_, x) => Expr::Uuid(Span::default(), *x),
820            Expr::DateTime(_, x) => Expr::DateTime(Span::default(), *x),
821            Expr::Hole(_) => Expr::Hole(Span::default()),
822            Expr::Tuple(_, elems) => Expr::Tuple(
823                Span::default(),
824                elems.iter().map(|x| Arc::new(x.reset_spans())).collect(),
825            ),
826            Expr::List(_, elems) => Expr::List(
827                Span::default(),
828                elems.iter().map(|x| Arc::new(x.reset_spans())).collect(),
829            ),
830            Expr::Dict(_, kvs) => Expr::Dict(
831                Span::default(),
832                BTreeMap::from_iter(
833                    kvs.iter()
834                        .map(|(k, v)| (k.clone(), Arc::new(v.reset_spans()))),
835                ),
836            ),
837            Expr::RecordUpdate(_, base, updates) => Expr::RecordUpdate(
838                Span::default(),
839                Arc::new(base.reset_spans()),
840                BTreeMap::from_iter(
841                    updates
842                        .iter()
843                        .map(|(k, v)| (k.clone(), Arc::new(v.reset_spans()))),
844                ),
845            ),
846            Expr::Var(var) => Expr::Var(var.reset_spans()),
847            Expr::App(_, f, x) => Expr::App(
848                Span::default(),
849                Arc::new(f.reset_spans()),
850                Arc::new(x.reset_spans()),
851            ),
852            Expr::Project(_, base, field) => {
853                Expr::Project(Span::default(), Arc::new(base.reset_spans()), field.clone())
854            }
855            Expr::Lam(_, scope, param, ann, constraints, body) => Expr::Lam(
856                Span::default(),
857                scope.clone(),
858                param.reset_spans(),
859                ann.as_ref().map(TypeExpr::reset_spans),
860                constraints
861                    .iter()
862                    .map(|constraint| TypeConstraint {
863                        class: constraint.class.clone(),
864                        typ: constraint.typ.reset_spans(),
865                    })
866                    .collect(),
867                Arc::new(body.reset_spans()),
868            ),
869            Expr::Let(_, var, ann, def, body) => Expr::Let(
870                Span::default(),
871                var.reset_spans(),
872                ann.as_ref().map(|t| t.reset_spans()),
873                Arc::new(def.reset_spans()),
874                Arc::new(body.reset_spans()),
875            ),
876            Expr::LetRec(_, bindings, body) => Expr::LetRec(
877                Span::default(),
878                bindings
879                    .iter()
880                    .map(|(var, ann, def)| {
881                        (
882                            var.reset_spans(),
883                            ann.as_ref().map(TypeExpr::reset_spans),
884                            Arc::new(def.reset_spans()),
885                        )
886                    })
887                    .collect(),
888                Arc::new(body.reset_spans()),
889            ),
890            Expr::Ite(_, cond, then, r#else) => Expr::Ite(
891                Span::default(),
892                Arc::new(cond.reset_spans()),
893                Arc::new(then.reset_spans()),
894                Arc::new(r#else.reset_spans()),
895            ),
896            Expr::Match(_, scrutinee, arms) => Expr::Match(
897                Span::default(),
898                Arc::new(scrutinee.reset_spans()),
899                arms.iter()
900                    .map(|(pat, expr)| (pat.reset_spans(), Arc::new(expr.reset_spans())))
901                    .collect(),
902            ),
903            Expr::Ann(_, expr, ann) => Expr::Ann(
904                Span::default(),
905                Arc::new(expr.reset_spans()),
906                ann.reset_spans(),
907            ),
908        }
909    }
910}
911
912impl Display for Expr {
913    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
914        match self {
915            Self::Bool(_span, x) => x.fmt(f),
916            Self::Uint(_span, x) => x.fmt(f),
917            Self::Int(_span, x) => x.fmt(f),
918            Self::Float(_span, x) => x.fmt(f),
919            Self::String(_span, x) => write!(f, "{:?}", x),
920            Self::Uuid(_span, x) => x.fmt(f),
921            Self::DateTime(_span, x) => x.fmt(f),
922            Self::Hole(_span) => '?'.fmt(f),
923            Self::List(_span, xs) => {
924                '['.fmt(f)?;
925                for (i, x) in xs.iter().enumerate() {
926                    x.fmt(f)?;
927                    if i + 1 < xs.len() {
928                        ", ".fmt(f)?;
929                    }
930                }
931                ']'.fmt(f)
932            }
933            Self::Tuple(_span, xs) => {
934                '('.fmt(f)?;
935                for (i, x) in xs.iter().enumerate() {
936                    x.fmt(f)?;
937                    if i + 1 < xs.len() {
938                        ", ".fmt(f)?;
939                    }
940                }
941                ')'.fmt(f)
942            }
943            Self::Dict(_span, kvs) => {
944                '{'.fmt(f)?;
945                for (i, (k, v)) in kvs.iter().enumerate() {
946                    k.fmt(f)?;
947                    " = ".fmt(f)?;
948                    v.fmt(f)?;
949                    if i + 1 < kvs.len() {
950                        ", ".fmt(f)?;
951                    }
952                }
953                '}'.fmt(f)
954            }
955            Self::RecordUpdate(_span, base, kvs) => {
956                '{'.fmt(f)?;
957                base.fmt(f)?;
958                " with ".fmt(f)?;
959                '{'.fmt(f)?;
960                for (i, (k, v)) in kvs.iter().enumerate() {
961                    k.fmt(f)?;
962                    " = ".fmt(f)?;
963                    v.fmt(f)?;
964                    if i + 1 < kvs.len() {
965                        ", ".fmt(f)?;
966                    }
967                }
968                '}'.fmt(f)?;
969                '}'.fmt(f)
970            }
971            Self::Var(var) => var.fmt(f),
972            Self::App(_span, g, x) => {
973                g.fmt(f)?;
974                ' '.fmt(f)?;
975                match x.as_ref() {
976                    Self::Bool(..)
977                    | Self::Uint(..)
978                    | Self::Int(..)
979                    | Self::Float(..)
980                    | Self::String(..)
981                    | Self::List(..)
982                    | Self::Tuple(..)
983                    | Self::Dict(..)
984                    | Self::RecordUpdate(..)
985                    | Self::Project(..)
986                    | Self::Var(..) => x.fmt(f),
987                    _ => {
988                        '('.fmt(f)?;
989                        x.fmt(f)?;
990                        ')'.fmt(f)
991                    }
992                }
993            }
994            Self::Lam(_span, _scope, param, ann, constraints, body) => {
995                'λ'.fmt(f)?;
996                if let Some(ann) = ann {
997                    '('.fmt(f)?;
998                    param.fmt(f)?;
999                    " : ".fmt(f)?;
1000                    ann.fmt(f)?;
1001                    ')'.fmt(f)?;
1002                } else {
1003                    param.fmt(f)?;
1004                }
1005                if !constraints.is_empty() {
1006                    " where ".fmt(f)?;
1007                    for (i, constraint) in constraints.iter().enumerate() {
1008                        constraint.class.fmt(f)?;
1009                        ' '.fmt(f)?;
1010                        constraint.typ.fmt(f)?;
1011                        if i + 1 < constraints.len() {
1012                            ", ".fmt(f)?;
1013                        }
1014                    }
1015                }
1016                " → ".fmt(f)?;
1017                body.fmt(f)
1018            }
1019            Self::Let(_span, var, ann, def, body) => {
1020                "let ".fmt(f)?;
1021                var.fmt(f)?;
1022                if let Some(ann) = ann {
1023                    ": ".fmt(f)?;
1024                    ann.fmt(f)?;
1025                }
1026                " = ".fmt(f)?;
1027                def.fmt(f)?;
1028                " in ".fmt(f)?;
1029                body.fmt(f)
1030            }
1031            Self::LetRec(_span, bindings, body) => {
1032                "let rec ".fmt(f)?;
1033                for (idx, (var, ann, def)) in bindings.iter().enumerate() {
1034                    if idx > 0 {
1035                        " and ".fmt(f)?;
1036                    }
1037                    var.fmt(f)?;
1038                    if let Some(ann) = ann {
1039                        ": ".fmt(f)?;
1040                        ann.fmt(f)?;
1041                    }
1042                    " = ".fmt(f)?;
1043                    def.fmt(f)?;
1044                }
1045                " in ".fmt(f)?;
1046                body.fmt(f)
1047            }
1048            Self::Ite(_span, cond, then, r#else) => {
1049                "if ".fmt(f)?;
1050                cond.fmt(f)?;
1051                " then ".fmt(f)?;
1052                then.fmt(f)?;
1053                " else ".fmt(f)?;
1054                r#else.fmt(f)
1055            }
1056            Self::Match(_span, scrutinee, arms) => {
1057                "match ".fmt(f)?;
1058                scrutinee.fmt(f)?;
1059                " with { ".fmt(f)?;
1060                for (pat, expr) in arms {
1061                    "case ".fmt(f)?;
1062                    pat.fmt(f)?;
1063                    " -> ".fmt(f)?;
1064                    expr.fmt(f)?;
1065                    "; ".fmt(f)?;
1066                }
1067                "}".fmt(f)
1068            }
1069            Self::Project(_span, base, field) => {
1070                base.fmt(f)?;
1071                ".".fmt(f)?;
1072                field.fmt(f)
1073            }
1074            Self::Ann(_span, expr, ann) => {
1075                expr.fmt(f)?;
1076                " is ".fmt(f)?;
1077                ann.fmt(f)
1078            }
1079        }
1080    }
1081}