Skip to main content

rexlang_ast/
expr.rs

1use std::{
2    borrow::Borrow,
3    collections::{BTreeMap, HashMap},
4    fmt::{self, Display, Formatter},
5    ops::Deref,
6    sync::{Arc, Mutex, OnceLock},
7};
8
9use rexlang_lexer::span::{Position, Span};
10use rpds::HashTrieMapSync;
11
12use chrono::{DateTime, Utc};
13use uuid::Uuid;
14
15#[derive(
16    Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Deserialize, serde::Serialize,
17)]
18#[serde(transparent)]
19pub struct Symbol(Arc<str>);
20
21impl Symbol {
22    pub fn as_str(&self) -> &str {
23        self.0.as_ref()
24    }
25}
26
27impl Deref for Symbol {
28    type Target = str;
29
30    fn deref(&self) -> &Self::Target {
31        self.0.as_ref()
32    }
33}
34
35impl AsRef<str> for Symbol {
36    fn as_ref(&self) -> &str {
37        self.0.as_ref()
38    }
39}
40
41impl Borrow<str> for Symbol {
42    fn borrow(&self) -> &str {
43        self.0.as_ref()
44    }
45}
46
47impl Display for Symbol {
48    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
49        self.0.fmt(f)
50    }
51}
52
53impl From<&str> for Symbol {
54    fn from(value: &str) -> Self {
55        Symbol(Arc::from(value))
56    }
57}
58
59impl From<String> for Symbol {
60    fn from(value: String) -> Self {
61        Symbol(Arc::from(value))
62    }
63}
64
65impl From<Arc<str>> for Symbol {
66    fn from(value: Arc<str>) -> Self {
67        Symbol(value)
68    }
69}
70
71pub type Scope = HashTrieMapSync<Symbol, Arc<Expr>>;
72
73// Global symbol interner.
74//
75// Design constraints:
76// - Symbols wrap `Arc<str>` so cloning them is cheap and comparisons are fast.
77// - The table is process-global and monotonically grows; that's fine for a
78//   typical “compile a program, then exit” workflow.
79// - Locking makes the cost model explicit (and obvious in profiles). If this
80//   ever shows up hot, the first step is usually “reduce calls to `intern`”,
81//   not “invent a clever interner”.
82static INTERNER: OnceLock<Mutex<HashMap<String, Symbol>>> = OnceLock::new();
83
84pub fn intern(name: &str) -> Symbol {
85    let mutex = INTERNER.get_or_init(|| Mutex::new(HashMap::new()));
86    let mut table = match mutex.lock() {
87        Ok(guard) => guard,
88        Err(poisoned) => poisoned.into_inner(),
89    };
90    if let Some(existing) = table.get(name) {
91        return existing.clone();
92    }
93    let sym = Symbol(Arc::from(name));
94    table.insert(name.to_string(), sym.clone());
95    sym
96}
97
98pub fn sym(name: &str) -> Symbol {
99    intern(name)
100}
101
102pub fn sym_eq(name: &Symbol, expected: &str) -> bool {
103    name.as_ref() == expected
104}
105
106#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
107#[serde(rename_all = "lowercase")]
108pub struct Var {
109    pub span: Span,
110    pub name: Symbol,
111}
112
113impl Var {
114    pub fn new(name: impl ToString) -> Self {
115        Self {
116            span: Span::default(),
117            name: intern(&name.to_string()),
118        }
119    }
120
121    pub fn with_span(span: Span, name: impl ToString) -> Self {
122        Self {
123            span,
124            name: intern(&name.to_string()),
125        }
126    }
127
128    pub fn reset_spans(&self) -> Var {
129        Var {
130            span: Span::default(),
131            name: self.name.clone(),
132        }
133    }
134}
135
136impl Display for Var {
137    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
138        match self.name.as_ref() {
139            "+" | "-" | "*" | "/" | "==" | ">=" | ">" | "<=" | "<" | "++" | "." => {
140                '('.fmt(f)?;
141                self.name.fmt(f)?;
142                ')'.fmt(f)
143            }
144            _ => self.name.fmt(f),
145        }
146    }
147}
148
149#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
150#[serde(rename_all = "lowercase")]
151pub enum NameRef {
152    Unqualified(Symbol),
153    Qualified(Symbol, Vec<Symbol>),
154}
155
156impl NameRef {
157    pub fn from_segments(segments: Vec<Symbol>) -> Self {
158        if segments.len() == 1 {
159            NameRef::Unqualified(segments[0].clone())
160        } else {
161            let dotted = intern(
162                &segments
163                    .iter()
164                    .map(|s| s.as_ref())
165                    .collect::<Vec<_>>()
166                    .join("."),
167            );
168            NameRef::Qualified(dotted, segments)
169        }
170    }
171
172    pub fn from_dotted(name: &str) -> Self {
173        let segments: Vec<Symbol> = name.split('.').map(intern).collect();
174        NameRef::from_segments(segments)
175    }
176
177    pub fn to_dotted_symbol(&self) -> Symbol {
178        match self {
179            NameRef::Unqualified(sym) => sym.clone(),
180            NameRef::Qualified(dotted, _) => dotted.clone(),
181        }
182    }
183
184    pub fn as_segments(&self) -> Vec<Symbol> {
185        match self {
186            NameRef::Unqualified(sym) => vec![sym.clone()],
187            NameRef::Qualified(_, segments) => segments.clone(),
188        }
189    }
190}
191
192impl From<&str> for NameRef {
193    fn from(value: &str) -> Self {
194        NameRef::from_dotted(value)
195    }
196}
197
198impl From<String> for NameRef {
199    fn from(value: String) -> Self {
200        NameRef::from_dotted(&value)
201    }
202}
203
204impl AsRef<str> for NameRef {
205    fn as_ref(&self) -> &str {
206        match self {
207            NameRef::Unqualified(sym) => sym.as_ref(),
208            NameRef::Qualified(dotted, _) => dotted.as_ref(),
209        }
210    }
211}
212
213impl Display for NameRef {
214    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
215        match self {
216            NameRef::Unqualified(sym) => sym.fmt(f),
217            NameRef::Qualified(_, segments) => {
218                for (i, segment) in segments.iter().enumerate() {
219                    if i > 0 {
220                        '.'.fmt(f)?;
221                    }
222                    segment.fmt(f)?;
223                }
224                Ok(())
225            }
226        }
227    }
228}
229
230#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
231#[serde(rename_all = "lowercase")]
232pub enum Pattern {
233    Wildcard(Span),                         // _
234    Var(Var),                               // x
235    Named(Span, NameRef, Vec<Pattern>),     // Ok x y z
236    Tuple(Span, Vec<Pattern>),              // (x, y, z)
237    List(Span, Vec<Pattern>),               // [x, y, z]
238    Cons(Span, Box<Pattern>, Box<Pattern>), // x::xs
239    Dict(Span, Vec<(Symbol, Pattern)>),     // {a, b, c} or {a: x, b: y}
240}
241
242impl Pattern {
243    pub fn span(&self) -> &Span {
244        match self {
245            Pattern::Wildcard(span, ..)
246            | Pattern::Var(Var { span, .. })
247            | Pattern::Named(span, ..)
248            | Pattern::Tuple(span, ..)
249            | Pattern::List(span, ..)
250            | Pattern::Cons(span, ..)
251            | Pattern::Dict(span, ..) => span,
252        }
253    }
254
255    pub fn with_span(&self, span: Span) -> Pattern {
256        match self {
257            Pattern::Wildcard(..) => Pattern::Wildcard(span),
258            Pattern::Var(var) => Pattern::Var(Var {
259                span,
260                name: var.name.clone(),
261            }),
262            Pattern::Named(_, name, ps) => Pattern::Named(span, name.clone(), ps.clone()),
263            Pattern::Tuple(_, ps) => Pattern::Tuple(span, ps.clone()),
264            Pattern::List(_, ps) => Pattern::List(span, ps.clone()),
265            Pattern::Cons(_, head, tail) => Pattern::Cons(span, head.clone(), tail.clone()),
266            Pattern::Dict(_, fields) => Pattern::Dict(span, fields.clone()),
267        }
268    }
269
270    pub fn reset_spans(&self) -> Pattern {
271        match self {
272            Pattern::Wildcard(..) => Pattern::Wildcard(Span::default()),
273            Pattern::Var(var) => Pattern::Var(var.reset_spans()),
274            Pattern::Named(_, name, ps) => Pattern::Named(
275                Span::default(),
276                name.clone(),
277                ps.iter().map(|p| p.reset_spans()).collect(),
278            ),
279            Pattern::Tuple(_, ps) => Pattern::Tuple(
280                Span::default(),
281                ps.iter().map(|p| p.reset_spans()).collect(),
282            ),
283            Pattern::List(_, ps) => Pattern::List(
284                Span::default(),
285                ps.iter().map(|p| p.reset_spans()).collect(),
286            ),
287            Pattern::Cons(_, head, tail) => Pattern::Cons(
288                Span::default(),
289                Box::new(head.reset_spans()),
290                Box::new(tail.reset_spans()),
291            ),
292            Pattern::Dict(_, fields) => Pattern::Dict(
293                Span::default(),
294                fields
295                    .iter()
296                    .map(|(k, p)| (k.clone(), p.reset_spans()))
297                    .collect(),
298            ),
299        }
300    }
301}
302
303impl Display for Pattern {
304    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
305        match self {
306            Pattern::Wildcard(..) => write!(f, "_"),
307            Pattern::Var(var) => var.fmt(f),
308            Pattern::Named(_, name, ps) => {
309                write!(f, "{}", name)?;
310                for p in ps {
311                    write!(f, " {}", p)?;
312                }
313                Ok(())
314            }
315            Pattern::Tuple(_, ps) => {
316                '('.fmt(f)?;
317                for (i, p) in ps.iter().enumerate() {
318                    p.fmt(f)?;
319                    if i + 1 < ps.len() {
320                        ", ".fmt(f)?;
321                    }
322                }
323                ')'.fmt(f)
324            }
325            Pattern::List(_, ps) => {
326                '['.fmt(f)?;
327                for (i, p) in ps.iter().enumerate() {
328                    p.fmt(f)?;
329                    if i + 1 < ps.len() {
330                        ", ".fmt(f)?;
331                    }
332                }
333                ']'.fmt(f)
334            }
335            Pattern::Cons(_, head, tail) => write!(f, "{}::{}", head, tail),
336            Pattern::Dict(_, fields) => {
337                '{'.fmt(f)?;
338                for (i, (key, pat)) in fields.iter().enumerate() {
339                    // Use shorthand when possible to keep output stable with old syntax.
340                    match pat {
341                        Pattern::Var(var) if var.name == *key => {
342                            key.fmt(f)?;
343                        }
344                        _ => {
345                            key.fmt(f)?;
346                            ": ".fmt(f)?;
347                            pat.fmt(f)?;
348                        }
349                    }
350                    if i + 1 < fields.len() {
351                        ", ".fmt(f)?;
352                    }
353                }
354                '}'.fmt(f)
355            }
356        }
357    }
358}
359
360#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
361#[serde(rename_all = "lowercase")]
362pub enum TypeExpr {
363    Name(Span, NameRef),
364    App(Span, Box<TypeExpr>, Box<TypeExpr>),
365    Fun(Span, Box<TypeExpr>, Box<TypeExpr>),
366    Tuple(Span, Vec<TypeExpr>),
367    Record(Span, Vec<(Symbol, TypeExpr)>),
368}
369
370impl TypeExpr {
371    pub fn span(&self) -> &Span {
372        match self {
373            TypeExpr::Name(span, ..)
374            | TypeExpr::App(span, ..)
375            | TypeExpr::Fun(span, ..)
376            | TypeExpr::Tuple(span, ..)
377            | TypeExpr::Record(span, ..) => span,
378        }
379    }
380
381    pub fn reset_spans(&self) -> TypeExpr {
382        match self {
383            TypeExpr::Name(_, name) => TypeExpr::Name(Span::default(), name.clone()),
384            TypeExpr::App(_, fun, arg) => TypeExpr::App(
385                Span::default(),
386                Box::new(fun.reset_spans()),
387                Box::new(arg.reset_spans()),
388            ),
389            TypeExpr::Fun(_, arg, ret) => TypeExpr::Fun(
390                Span::default(),
391                Box::new(arg.reset_spans()),
392                Box::new(ret.reset_spans()),
393            ),
394            TypeExpr::Tuple(_, elems) => TypeExpr::Tuple(
395                Span::default(),
396                elems.iter().map(|e| e.reset_spans()).collect(),
397            ),
398            TypeExpr::Record(_, fields) => TypeExpr::Record(
399                Span::default(),
400                fields
401                    .iter()
402                    .map(|(name, ty)| (name.clone(), ty.reset_spans()))
403                    .collect(),
404            ),
405        }
406    }
407}
408
409impl Display for TypeExpr {
410    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
411        match self {
412            TypeExpr::Name(_, name) => name.fmt(f),
413            TypeExpr::App(_, fun, arg) => {
414                match fun.as_ref() {
415                    TypeExpr::Name(..) | TypeExpr::App(..) => fun.fmt(f)?,
416                    _ => {
417                        '('.fmt(f)?;
418                        fun.fmt(f)?;
419                        ')'.fmt(f)?;
420                    }
421                }
422                ' '.fmt(f)?;
423                match arg.as_ref() {
424                    TypeExpr::Name(..)
425                    | TypeExpr::App(..)
426                    | TypeExpr::Tuple(..)
427                    | TypeExpr::Record(..) => arg.fmt(f),
428                    _ => {
429                        '('.fmt(f)?;
430                        arg.fmt(f)?;
431                        ')'.fmt(f)
432                    }
433                }
434            }
435            TypeExpr::Fun(_, arg, ret) => {
436                match arg.as_ref() {
437                    TypeExpr::Fun(..) => {
438                        '('.fmt(f)?;
439                        arg.fmt(f)?;
440                        ')'.fmt(f)?;
441                    }
442                    _ => arg.fmt(f)?,
443                }
444                " -> ".fmt(f)?;
445                ret.fmt(f)
446            }
447            TypeExpr::Tuple(_, elems) => {
448                '('.fmt(f)?;
449                for (i, elem) in elems.iter().enumerate() {
450                    elem.fmt(f)?;
451                    if i + 1 < elems.len() {
452                        ", ".fmt(f)?;
453                    }
454                }
455                ')'.fmt(f)
456            }
457            TypeExpr::Record(_, fields) => {
458                '{'.fmt(f)?;
459                for (i, (name, ty)) in fields.iter().enumerate() {
460                    name.fmt(f)?;
461                    ": ".fmt(f)?;
462                    ty.fmt(f)?;
463                    if i + 1 < fields.len() {
464                        ", ".fmt(f)?;
465                    }
466                }
467                '}'.fmt(f)
468            }
469        }
470    }
471}
472
473#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
474pub struct TypeConstraint {
475    pub class: NameRef,
476    pub typ: TypeExpr,
477}
478
479impl TypeConstraint {
480    pub fn new(class: NameRef, typ: TypeExpr) -> Self {
481        Self { class, typ }
482    }
483}
484
485#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
486pub struct TypeVariant {
487    pub name: Symbol,
488    pub args: Vec<TypeExpr>,
489}
490
491#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
492pub struct TypeDecl {
493    pub span: Span,
494    pub is_pub: bool,
495    pub name: Symbol,
496    pub params: Vec<Symbol>,
497    pub variants: Vec<TypeVariant>,
498}
499
500#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
501pub struct FnDecl {
502    pub span: Span,
503    pub is_pub: bool,
504    pub name: Var,
505    pub params: Vec<(Var, TypeExpr)>,
506    pub ret: TypeExpr,
507    pub constraints: Vec<TypeConstraint>,
508    pub body: Arc<Expr>,
509}
510
511#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
512pub struct DeclareFnDecl {
513    pub span: Span,
514    pub is_pub: bool,
515    pub name: Var,
516    pub params: Vec<(Var, TypeExpr)>,
517    pub ret: TypeExpr,
518    pub constraints: Vec<TypeConstraint>,
519}
520
521#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
522pub struct ClassMethodSig {
523    pub name: Symbol,
524    pub typ: TypeExpr,
525}
526
527#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
528pub struct ClassDecl {
529    pub span: Span,
530    pub is_pub: bool,
531    pub name: Symbol,
532    pub params: Vec<Symbol>,
533    pub supers: Vec<TypeConstraint>,
534    pub methods: Vec<ClassMethodSig>,
535}
536
537#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
538pub struct InstanceMethodImpl {
539    pub name: Symbol,
540    pub body: Arc<Expr>,
541}
542
543#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
544pub struct InstanceDecl {
545    pub span: Span,
546    pub is_pub: bool,
547    pub class: Symbol,
548    pub head: TypeExpr,
549    pub context: Vec<TypeConstraint>,
550    pub methods: Vec<InstanceMethodImpl>,
551}
552
553#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
554#[serde(rename_all = "lowercase")]
555pub enum ImportPath {
556    Local {
557        segments: Vec<Symbol>,
558        sha: Option<String>,
559    },
560    Remote {
561        url: String,
562        sha: Option<String>,
563    },
564}
565
566#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
567pub struct ImportItem {
568    pub name: Symbol,
569    pub alias: Option<Symbol>,
570}
571
572#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
573#[serde(rename_all = "lowercase")]
574pub enum ImportClause {
575    All,
576    Items(Vec<ImportItem>),
577}
578
579#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
580pub struct ImportDecl {
581    pub span: Span,
582    pub is_pub: bool,
583    pub path: ImportPath,
584    pub alias: Symbol,
585    pub clause: Option<ImportClause>,
586}
587
588#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
589pub enum Decl {
590    Type(TypeDecl),
591    Fn(FnDecl),
592    DeclareFn(DeclareFnDecl),
593    Import(ImportDecl),
594    Class(ClassDecl),
595    Instance(InstanceDecl),
596}
597
598#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
599pub struct Program {
600    pub decls: Vec<Decl>,
601    pub expr: Arc<Expr>,
602}
603
604impl Program {
605    /// Lower top-level `fn` declarations into nested `let` bindings around `expr`.
606    ///
607    /// This keeps the surface syntax (`Decl::Fn`) intact for tools, while giving
608    /// the type checker and evaluator a plain expression to work with.
609    pub fn expr_with_fns(&self) -> Arc<Expr> {
610        let mut out = self.expr.clone();
611        for decl in self.decls.iter().rev() {
612            let Decl::Fn(fd) = decl else {
613                continue;
614            };
615
616            let mut lam_body = fd.body.clone();
617            let mut lam_end = lam_body.span().end;
618            for (idx, (param, ann)) in fd.params.iter().enumerate().rev() {
619                let lam_constraints = if idx == 0 {
620                    fd.constraints.clone()
621                } else {
622                    Vec::new()
623                };
624                let span = Span::from_begin_end(param.span.begin, lam_end);
625                lam_body = Arc::new(Expr::Lam(
626                    span,
627                    Scope::new_sync(),
628                    param.clone(),
629                    Some(ann.clone()),
630                    lam_constraints,
631                    lam_body,
632                ));
633                lam_end = lam_body.span().end;
634            }
635
636            let mut sig = fd.ret.clone();
637            for (_, ann) in fd.params.iter().rev() {
638                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
639                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
640            }
641
642            let span = Span::from_begin_end(fd.span.begin, out.span().end);
643            out = Arc::new(Expr::Let(span, fd.name.clone(), Some(sig), lam_body, out));
644        }
645        out
646    }
647}
648
649#[derive(Debug, PartialEq, serde::Deserialize, serde::Serialize)]
650#[serde(rename_all = "lowercase")]
651pub enum Expr {
652    Bool(Span, bool),              // true
653    Uint(Span, u64),               // 69
654    Int(Span, i64),                // -420
655    Float(Span, f64),              // 3.14
656    String(Span, String),          // "hello"
657    Uuid(Span, Uuid),              // a550c18e-36e1-4f6d-8c8e-2d2b1e5f3c3a
658    DateTime(Span, DateTime<Utc>), // 2023-01-01T12:00:00Z
659    Hole(Span),                    // ?
660
661    Tuple(Span, Vec<Arc<Expr>>),             // (e1, e2, e3)
662    List(Span, Vec<Arc<Expr>>),              // [e1, e2, e3]
663    Dict(Span, BTreeMap<Symbol, Arc<Expr>>), // {k1 = v1, k2 = v2}
664    RecordUpdate(Span, Arc<Expr>, BTreeMap<Symbol, Arc<Expr>>), // {base with {k1 = v1, ...}}
665
666    Var(Var),                         // x
667    App(Span, Arc<Expr>, Arc<Expr>),  // f x
668    Project(Span, Arc<Expr>, Symbol), // x.field
669    Lam(
670        Span,
671        Scope,
672        Var,
673        Option<TypeExpr>,
674        Vec<TypeConstraint>,
675        Arc<Expr>,
676    ), // λx → e
677    Let(Span, Var, Option<TypeExpr>, Arc<Expr>, Arc<Expr>), // let x = e1 in e2
678    LetRec(Span, Vec<(Var, Option<TypeExpr>, Arc<Expr>)>, Arc<Expr>), // let rec f = e1 and g = e2 in e3
679    Ite(Span, Arc<Expr>, Arc<Expr>, Arc<Expr>),                       // if e1 then e2 else e3
680    Match(Span, Arc<Expr>, Vec<(Pattern, Arc<Expr>)>),                // match e1 with patterns
681    Ann(Span, Arc<Expr>, TypeExpr),                                   // e is t
682}
683
684impl Expr {
685    pub fn span(&self) -> &Span {
686        match self {
687            Self::Bool(span, ..)
688            | Self::Uint(span, ..)
689            | Self::Int(span, ..)
690            | Self::Float(span, ..)
691            | Self::String(span, ..)
692            | Self::Uuid(span, ..)
693            | Self::DateTime(span, ..)
694            | Self::Hole(span, ..)
695            | Self::Tuple(span, ..)
696            | Self::List(span, ..)
697            | Self::Dict(span, ..)
698            | Self::RecordUpdate(span, ..)
699            | Self::Var(Var { span, .. })
700            | Self::App(span, ..)
701            | Self::Project(span, ..)
702            | Self::Lam(span, ..)
703            | Self::Let(span, ..)
704            | Self::LetRec(span, ..)
705            | Self::Ite(span, ..)
706            | Self::Match(span, ..)
707            | Self::Ann(span, ..) => span,
708        }
709    }
710
711    pub fn span_mut(&mut self) -> &mut Span {
712        match self {
713            Self::Bool(span, ..)
714            | Self::Uint(span, ..)
715            | Self::Int(span, ..)
716            | Self::Float(span, ..)
717            | Self::String(span, ..)
718            | Self::Uuid(span, ..)
719            | Self::DateTime(span, ..)
720            | Self::Hole(span, ..)
721            | Self::Tuple(span, ..)
722            | Self::List(span, ..)
723            | Self::Dict(span, ..)
724            | Self::RecordUpdate(span, ..)
725            | Self::Var(Var { span, .. })
726            | Self::App(span, ..)
727            | Self::Project(span, ..)
728            | Self::Lam(span, ..)
729            | Self::Let(span, ..)
730            | Self::LetRec(span, ..)
731            | Self::Ite(span, ..)
732            | Self::Match(span, ..)
733            | Self::Ann(span, ..) => span,
734        }
735    }
736
737    pub fn with_span_begin_end(&self, begin: Position, end: Position) -> Expr {
738        self.with_span(Span::from_begin_end(begin, end))
739    }
740
741    pub fn with_span_begin(&self, begin: Position) -> Expr {
742        let end = self.span().end;
743        self.with_span(Span::from_begin_end(begin, end))
744    }
745
746    pub fn with_span_end(&self, end: Position) -> Expr {
747        let begin = self.span().begin;
748        self.with_span(Span::from_begin_end(begin, end))
749    }
750
751    pub fn with_span(&self, span: Span) -> Expr {
752        match self {
753            Expr::Bool(_, x) => Expr::Bool(span, *x),
754            Expr::Uint(_, x) => Expr::Uint(span, *x),
755            Expr::Int(_, x) => Expr::Int(span, *x),
756            Expr::Float(_, x) => Expr::Float(span, *x),
757            Expr::String(_, x) => Expr::String(span, x.clone()),
758            Expr::Uuid(_, x) => Expr::Uuid(span, *x),
759            Expr::DateTime(_, x) => Expr::DateTime(span, *x),
760            Expr::Hole(_) => Expr::Hole(span),
761            Expr::Tuple(_, elems) => Expr::Tuple(span, elems.clone()),
762            Expr::List(_, elems) => Expr::List(span, elems.clone()),
763            Expr::Dict(_, kvs) => Expr::Dict(
764                span,
765                BTreeMap::from_iter(kvs.iter().map(|(k, v)| (k.clone(), v.clone()))),
766            ),
767            Expr::RecordUpdate(_, base, updates) => Expr::RecordUpdate(
768                span,
769                base.clone(),
770                BTreeMap::from_iter(updates.iter().map(|(k, v)| (k.clone(), v.clone()))),
771            ),
772            Expr::Var(var) => Expr::Var(Var {
773                span,
774                name: var.name.clone(),
775            }),
776            Expr::App(_, f, x) => Expr::App(span, f.clone(), x.clone()),
777            Expr::Project(_, base, field) => Expr::Project(span, base.clone(), field.clone()),
778            Expr::Lam(_, scope, param, ann, constraints, body) => Expr::Lam(
779                span,
780                scope.clone(),
781                param.clone(),
782                ann.clone(),
783                constraints.clone(),
784                body.clone(),
785            ),
786            Expr::Let(_, var, ann, def, body) => {
787                Expr::Let(span, var.clone(), ann.clone(), def.clone(), body.clone())
788            }
789            Expr::LetRec(_, bindings, body) => Expr::LetRec(
790                span,
791                bindings
792                    .iter()
793                    .map(|(var, ann, def)| (var.clone(), ann.clone(), def.clone()))
794                    .collect(),
795                body.clone(),
796            ),
797            Expr::Ite(_, cond, then, r#else) => {
798                Expr::Ite(span, cond.clone(), then.clone(), r#else.clone())
799            }
800            Expr::Match(_, scrutinee, arms) => Expr::Match(
801                span,
802                scrutinee.clone(),
803                arms.iter()
804                    .map(|(pat, expr)| (pat.clone(), expr.clone()))
805                    .collect(),
806            ),
807            Expr::Ann(_, expr, ann) => Expr::Ann(span, expr.clone(), ann.clone()),
808        }
809    }
810
811    pub fn reset_spans(&self) -> Expr {
812        match self {
813            Expr::Bool(_, x) => Expr::Bool(Span::default(), *x),
814            Expr::Uint(_, x) => Expr::Uint(Span::default(), *x),
815            Expr::Int(_, x) => Expr::Int(Span::default(), *x),
816            Expr::Float(_, x) => Expr::Float(Span::default(), *x),
817            Expr::String(_, x) => Expr::String(Span::default(), x.clone()),
818            Expr::Uuid(_, x) => Expr::Uuid(Span::default(), *x),
819            Expr::DateTime(_, x) => Expr::DateTime(Span::default(), *x),
820            Expr::Hole(_) => Expr::Hole(Span::default()),
821            Expr::Tuple(_, elems) => Expr::Tuple(
822                Span::default(),
823                elems.iter().map(|x| Arc::new(x.reset_spans())).collect(),
824            ),
825            Expr::List(_, elems) => Expr::List(
826                Span::default(),
827                elems.iter().map(|x| Arc::new(x.reset_spans())).collect(),
828            ),
829            Expr::Dict(_, kvs) => Expr::Dict(
830                Span::default(),
831                BTreeMap::from_iter(
832                    kvs.iter()
833                        .map(|(k, v)| (k.clone(), Arc::new(v.reset_spans()))),
834                ),
835            ),
836            Expr::RecordUpdate(_, base, updates) => Expr::RecordUpdate(
837                Span::default(),
838                Arc::new(base.reset_spans()),
839                BTreeMap::from_iter(
840                    updates
841                        .iter()
842                        .map(|(k, v)| (k.clone(), Arc::new(v.reset_spans()))),
843                ),
844            ),
845            Expr::Var(var) => Expr::Var(var.reset_spans()),
846            Expr::App(_, f, x) => Expr::App(
847                Span::default(),
848                Arc::new(f.reset_spans()),
849                Arc::new(x.reset_spans()),
850            ),
851            Expr::Project(_, base, field) => {
852                Expr::Project(Span::default(), Arc::new(base.reset_spans()), field.clone())
853            }
854            Expr::Lam(_, scope, param, ann, constraints, body) => Expr::Lam(
855                Span::default(),
856                scope.clone(),
857                param.reset_spans(),
858                ann.as_ref().map(TypeExpr::reset_spans),
859                constraints
860                    .iter()
861                    .map(|constraint| TypeConstraint {
862                        class: constraint.class.clone(),
863                        typ: constraint.typ.reset_spans(),
864                    })
865                    .collect(),
866                Arc::new(body.reset_spans()),
867            ),
868            Expr::Let(_, var, ann, def, body) => Expr::Let(
869                Span::default(),
870                var.reset_spans(),
871                ann.as_ref().map(|t| t.reset_spans()),
872                Arc::new(def.reset_spans()),
873                Arc::new(body.reset_spans()),
874            ),
875            Expr::LetRec(_, bindings, body) => Expr::LetRec(
876                Span::default(),
877                bindings
878                    .iter()
879                    .map(|(var, ann, def)| {
880                        (
881                            var.reset_spans(),
882                            ann.as_ref().map(TypeExpr::reset_spans),
883                            Arc::new(def.reset_spans()),
884                        )
885                    })
886                    .collect(),
887                Arc::new(body.reset_spans()),
888            ),
889            Expr::Ite(_, cond, then, r#else) => Expr::Ite(
890                Span::default(),
891                Arc::new(cond.reset_spans()),
892                Arc::new(then.reset_spans()),
893                Arc::new(r#else.reset_spans()),
894            ),
895            Expr::Match(_, scrutinee, arms) => Expr::Match(
896                Span::default(),
897                Arc::new(scrutinee.reset_spans()),
898                arms.iter()
899                    .map(|(pat, expr)| (pat.reset_spans(), Arc::new(expr.reset_spans())))
900                    .collect(),
901            ),
902            Expr::Ann(_, expr, ann) => Expr::Ann(
903                Span::default(),
904                Arc::new(expr.reset_spans()),
905                ann.reset_spans(),
906            ),
907        }
908    }
909}
910
911impl Display for Expr {
912    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
913        match self {
914            Self::Bool(_span, x) => x.fmt(f),
915            Self::Uint(_span, x) => x.fmt(f),
916            Self::Int(_span, x) => x.fmt(f),
917            Self::Float(_span, x) => x.fmt(f),
918            Self::String(_span, x) => write!(f, "{:?}", x),
919            Self::Uuid(_span, x) => x.fmt(f),
920            Self::DateTime(_span, x) => x.fmt(f),
921            Self::Hole(_span) => '?'.fmt(f),
922            Self::List(_span, xs) => {
923                '['.fmt(f)?;
924                for (i, x) in xs.iter().enumerate() {
925                    x.fmt(f)?;
926                    if i + 1 < xs.len() {
927                        ", ".fmt(f)?;
928                    }
929                }
930                ']'.fmt(f)
931            }
932            Self::Tuple(_span, xs) => {
933                '('.fmt(f)?;
934                for (i, x) in xs.iter().enumerate() {
935                    x.fmt(f)?;
936                    if i + 1 < xs.len() {
937                        ", ".fmt(f)?;
938                    }
939                }
940                ')'.fmt(f)
941            }
942            Self::Dict(_span, kvs) => {
943                '{'.fmt(f)?;
944                for (i, (k, v)) in kvs.iter().enumerate() {
945                    k.fmt(f)?;
946                    " = ".fmt(f)?;
947                    v.fmt(f)?;
948                    if i + 1 < kvs.len() {
949                        ", ".fmt(f)?;
950                    }
951                }
952                '}'.fmt(f)
953            }
954            Self::RecordUpdate(_span, base, kvs) => {
955                '{'.fmt(f)?;
956                base.fmt(f)?;
957                " with ".fmt(f)?;
958                '{'.fmt(f)?;
959                for (i, (k, v)) in kvs.iter().enumerate() {
960                    k.fmt(f)?;
961                    " = ".fmt(f)?;
962                    v.fmt(f)?;
963                    if i + 1 < kvs.len() {
964                        ", ".fmt(f)?;
965                    }
966                }
967                '}'.fmt(f)?;
968                '}'.fmt(f)
969            }
970            Self::Var(var) => var.fmt(f),
971            Self::App(_span, g, x) => {
972                g.fmt(f)?;
973                ' '.fmt(f)?;
974                match x.as_ref() {
975                    Self::Bool(..)
976                    | Self::Uint(..)
977                    | Self::Int(..)
978                    | Self::Float(..)
979                    | Self::String(..)
980                    | Self::List(..)
981                    | Self::Tuple(..)
982                    | Self::Dict(..)
983                    | Self::RecordUpdate(..)
984                    | Self::Project(..)
985                    | Self::Var(..) => x.fmt(f),
986                    _ => {
987                        '('.fmt(f)?;
988                        x.fmt(f)?;
989                        ')'.fmt(f)
990                    }
991                }
992            }
993            Self::Lam(_span, _scope, param, ann, constraints, body) => {
994                'λ'.fmt(f)?;
995                if let Some(ann) = ann {
996                    '('.fmt(f)?;
997                    param.fmt(f)?;
998                    " : ".fmt(f)?;
999                    ann.fmt(f)?;
1000                    ')'.fmt(f)?;
1001                } else {
1002                    param.fmt(f)?;
1003                }
1004                if !constraints.is_empty() {
1005                    " where ".fmt(f)?;
1006                    for (i, constraint) in constraints.iter().enumerate() {
1007                        constraint.class.fmt(f)?;
1008                        ' '.fmt(f)?;
1009                        constraint.typ.fmt(f)?;
1010                        if i + 1 < constraints.len() {
1011                            ", ".fmt(f)?;
1012                        }
1013                    }
1014                }
1015                " → ".fmt(f)?;
1016                body.fmt(f)
1017            }
1018            Self::Let(_span, var, ann, def, body) => {
1019                "let ".fmt(f)?;
1020                var.fmt(f)?;
1021                if let Some(ann) = ann {
1022                    ": ".fmt(f)?;
1023                    ann.fmt(f)?;
1024                }
1025                " = ".fmt(f)?;
1026                def.fmt(f)?;
1027                " in ".fmt(f)?;
1028                body.fmt(f)
1029            }
1030            Self::LetRec(_span, bindings, body) => {
1031                "let rec ".fmt(f)?;
1032                for (idx, (var, ann, def)) in bindings.iter().enumerate() {
1033                    if idx > 0 {
1034                        " and ".fmt(f)?;
1035                    }
1036                    var.fmt(f)?;
1037                    if let Some(ann) = ann {
1038                        ": ".fmt(f)?;
1039                        ann.fmt(f)?;
1040                    }
1041                    " = ".fmt(f)?;
1042                    def.fmt(f)?;
1043                }
1044                " in ".fmt(f)?;
1045                body.fmt(f)
1046            }
1047            Self::Ite(_span, cond, then, r#else) => {
1048                "if ".fmt(f)?;
1049                cond.fmt(f)?;
1050                " then ".fmt(f)?;
1051                then.fmt(f)?;
1052                " else ".fmt(f)?;
1053                r#else.fmt(f)
1054            }
1055            Self::Match(_span, scrutinee, arms) => {
1056                "match ".fmt(f)?;
1057                scrutinee.fmt(f)?;
1058                ' '.fmt(f)?;
1059                for (i, (pat, expr)) in arms.iter().enumerate() {
1060                    "when ".fmt(f)?;
1061                    pat.fmt(f)?;
1062                    " -> ".fmt(f)?;
1063                    expr.fmt(f)?;
1064                    if i + 1 < arms.len() {
1065                        ' '.fmt(f)?;
1066                    }
1067                }
1068                Ok(())
1069            }
1070            Self::Project(_span, base, field) => {
1071                base.fmt(f)?;
1072                ".".fmt(f)?;
1073                field.fmt(f)
1074            }
1075            Self::Ann(_span, expr, ann) => {
1076                expr.fmt(f)?;
1077                " is ".fmt(f)?;
1078                ann.fmt(f)
1079            }
1080        }
1081    }
1082}