Skip to main content

rexlang_ast/
expr.rs

1use std::{
2    borrow::Borrow,
3    collections::{BTreeMap, HashMap},
4    fmt::{self, Display, Formatter},
5    ops::Deref,
6    sync::{Arc, Mutex, OnceLock},
7};
8
9use rexlang_lexer::span::{Position, Span};
10use rpds::HashTrieMapSync;
11
12use chrono::{DateTime, Utc};
13use uuid::Uuid;
14
15#[derive(
16    Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Deserialize, serde::Serialize,
17)]
18#[serde(transparent)]
19pub struct Symbol(Arc<str>);
20
21impl Symbol {
22    pub fn as_str(&self) -> &str {
23        self.0.as_ref()
24    }
25}
26
27impl Deref for Symbol {
28    type Target = str;
29
30    fn deref(&self) -> &Self::Target {
31        self.0.as_ref()
32    }
33}
34
35impl AsRef<str> for Symbol {
36    fn as_ref(&self) -> &str {
37        self.0.as_ref()
38    }
39}
40
41impl Borrow<str> for Symbol {
42    fn borrow(&self) -> &str {
43        self.0.as_ref()
44    }
45}
46
47impl Display for Symbol {
48    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
49        self.0.fmt(f)
50    }
51}
52
53impl From<&str> for Symbol {
54    fn from(value: &str) -> Self {
55        Symbol(Arc::from(value))
56    }
57}
58
59impl From<String> for Symbol {
60    fn from(value: String) -> Self {
61        Symbol(Arc::from(value))
62    }
63}
64
65impl From<Arc<str>> for Symbol {
66    fn from(value: Arc<str>) -> Self {
67        Symbol(value)
68    }
69}
70
71pub type Scope = HashTrieMapSync<Symbol, Arc<Expr>>;
72
73// Global symbol interner.
74//
75// Design constraints:
76// - Symbols wrap `Arc<str>` so cloning them is cheap and comparisons are fast.
77// - The table is process-global and monotonically grows; that's fine for a
78//   typical “compile a program, then exit” workflow.
79// - Locking makes the cost model explicit (and obvious in profiles). If this
80//   ever shows up hot, the first step is usually “reduce calls to `intern`”,
81//   not “invent a clever interner”.
82static INTERNER: OnceLock<Mutex<HashMap<String, Symbol>>> = OnceLock::new();
83
84pub fn intern(name: &str) -> Symbol {
85    let mutex = INTERNER.get_or_init(|| Mutex::new(HashMap::new()));
86    let mut table = match mutex.lock() {
87        Ok(guard) => guard,
88        Err(poisoned) => poisoned.into_inner(),
89    };
90    if let Some(existing) = table.get(name) {
91        return existing.clone();
92    }
93    let sym = Symbol(Arc::from(name));
94    table.insert(name.to_string(), sym.clone());
95    sym
96}
97
98pub fn sym(name: &str) -> Symbol {
99    intern(name)
100}
101
102pub fn sym_eq(name: &Symbol, expected: &str) -> bool {
103    name.as_ref() == expected
104}
105
106#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
107#[serde(rename_all = "lowercase")]
108pub struct Var {
109    pub span: Span,
110    pub name: Symbol,
111}
112
113impl Var {
114    pub fn new(name: impl ToString) -> Self {
115        Self {
116            span: Span::default(),
117            name: intern(&name.to_string()),
118        }
119    }
120
121    pub fn with_span(span: Span, name: impl ToString) -> Self {
122        Self {
123            span,
124            name: intern(&name.to_string()),
125        }
126    }
127
128    pub fn reset_spans(&self) -> Var {
129        Var {
130            span: Span::default(),
131            name: self.name.clone(),
132        }
133    }
134}
135
136impl Display for Var {
137    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
138        match self.name.as_ref() {
139            "+" | "-" | "*" | "/" | "==" | ">=" | ">" | "<=" | "<" | "++" | "." => {
140                '('.fmt(f)?;
141                self.name.fmt(f)?;
142                ')'.fmt(f)
143            }
144            _ => self.name.fmt(f),
145        }
146    }
147}
148
149#[derive(Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize)]
150#[serde(rename_all = "lowercase")]
151pub enum NameRef {
152    Unqualified(Symbol),
153    Qualified(Symbol, Vec<Symbol>),
154}
155
156impl NameRef {
157    pub fn from_segments(segments: Vec<Symbol>) -> Self {
158        if segments.len() == 1 {
159            NameRef::Unqualified(segments[0].clone())
160        } else {
161            let dotted = intern(
162                &segments
163                    .iter()
164                    .map(|s| s.as_ref())
165                    .collect::<Vec<_>>()
166                    .join("."),
167            );
168            NameRef::Qualified(dotted, segments)
169        }
170    }
171
172    pub fn from_dotted(name: &str) -> Self {
173        let segments: Vec<Symbol> = name.split('.').map(intern).collect();
174        NameRef::from_segments(segments)
175    }
176
177    pub fn to_dotted_symbol(&self) -> Symbol {
178        match self {
179            NameRef::Unqualified(sym) => sym.clone(),
180            NameRef::Qualified(dotted, _) => dotted.clone(),
181        }
182    }
183
184    pub fn as_segments(&self) -> Vec<Symbol> {
185        match self {
186            NameRef::Unqualified(sym) => vec![sym.clone()],
187            NameRef::Qualified(_, segments) => segments.clone(),
188        }
189    }
190}
191
192impl From<&str> for NameRef {
193    fn from(value: &str) -> Self {
194        NameRef::from_dotted(value)
195    }
196}
197
198impl From<String> for NameRef {
199    fn from(value: String) -> Self {
200        NameRef::from_dotted(&value)
201    }
202}
203
204impl AsRef<str> for NameRef {
205    fn as_ref(&self) -> &str {
206        match self {
207            NameRef::Unqualified(sym) => sym.as_ref(),
208            NameRef::Qualified(dotted, _) => dotted.as_ref(),
209        }
210    }
211}
212
213impl Display for NameRef {
214    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
215        match self {
216            NameRef::Unqualified(sym) => sym.fmt(f),
217            NameRef::Qualified(_, segments) => {
218                for (i, segment) in segments.iter().enumerate() {
219                    if i > 0 {
220                        '.'.fmt(f)?;
221                    }
222                    segment.fmt(f)?;
223                }
224                Ok(())
225            }
226        }
227    }
228}
229
230#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
231#[serde(rename_all = "lowercase")]
232pub enum Pattern {
233    Wildcard(Span),                         // _
234    Var(Var),                               // x
235    Named(Span, NameRef, Vec<Pattern>),     // Ok x y z
236    Tuple(Span, Vec<Pattern>),              // (x, y, z)
237    List(Span, Vec<Pattern>),               // [x, y, z]
238    Cons(Span, Box<Pattern>, Box<Pattern>), // x::xs
239    Dict(Span, Vec<(Symbol, Pattern)>),     // {a, b, c} or {a: x, b: y}
240}
241
242impl Pattern {
243    pub fn span(&self) -> &Span {
244        match self {
245            Pattern::Wildcard(span, ..)
246            | Pattern::Var(Var { span, .. })
247            | Pattern::Named(span, ..)
248            | Pattern::Tuple(span, ..)
249            | Pattern::List(span, ..)
250            | Pattern::Cons(span, ..)
251            | Pattern::Dict(span, ..) => span,
252        }
253    }
254
255    pub fn with_span(&self, span: Span) -> Pattern {
256        match self {
257            Pattern::Wildcard(..) => Pattern::Wildcard(span),
258            Pattern::Var(var) => Pattern::Var(Var {
259                span,
260                name: var.name.clone(),
261            }),
262            Pattern::Named(_, name, ps) => Pattern::Named(span, name.clone(), ps.clone()),
263            Pattern::Tuple(_, ps) => Pattern::Tuple(span, ps.clone()),
264            Pattern::List(_, ps) => Pattern::List(span, ps.clone()),
265            Pattern::Cons(_, head, tail) => Pattern::Cons(span, head.clone(), tail.clone()),
266            Pattern::Dict(_, fields) => Pattern::Dict(span, fields.clone()),
267        }
268    }
269
270    pub fn reset_spans(&self) -> Pattern {
271        match self {
272            Pattern::Wildcard(..) => Pattern::Wildcard(Span::default()),
273            Pattern::Var(var) => Pattern::Var(var.reset_spans()),
274            Pattern::Named(_, name, ps) => Pattern::Named(
275                Span::default(),
276                name.clone(),
277                ps.iter().map(|p| p.reset_spans()).collect(),
278            ),
279            Pattern::Tuple(_, ps) => Pattern::Tuple(
280                Span::default(),
281                ps.iter().map(|p| p.reset_spans()).collect(),
282            ),
283            Pattern::List(_, ps) => Pattern::List(
284                Span::default(),
285                ps.iter().map(|p| p.reset_spans()).collect(),
286            ),
287            Pattern::Cons(_, head, tail) => Pattern::Cons(
288                Span::default(),
289                Box::new(head.reset_spans()),
290                Box::new(tail.reset_spans()),
291            ),
292            Pattern::Dict(_, fields) => Pattern::Dict(
293                Span::default(),
294                fields
295                    .iter()
296                    .map(|(k, p)| (k.clone(), p.reset_spans()))
297                    .collect(),
298            ),
299        }
300    }
301}
302
303impl Display for Pattern {
304    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
305        match self {
306            Pattern::Wildcard(..) => write!(f, "_"),
307            Pattern::Var(var) => var.fmt(f),
308            Pattern::Named(_, name, ps) => {
309                write!(f, "{}", name)?;
310                for p in ps {
311                    write!(f, " {}", p)?;
312                }
313                Ok(())
314            }
315            Pattern::Tuple(_, ps) => {
316                '('.fmt(f)?;
317                for (i, p) in ps.iter().enumerate() {
318                    p.fmt(f)?;
319                    if i + 1 < ps.len() {
320                        ", ".fmt(f)?;
321                    }
322                }
323                ')'.fmt(f)
324            }
325            Pattern::List(_, ps) => {
326                '['.fmt(f)?;
327                for (i, p) in ps.iter().enumerate() {
328                    p.fmt(f)?;
329                    if i + 1 < ps.len() {
330                        ", ".fmt(f)?;
331                    }
332                }
333                ']'.fmt(f)
334            }
335            Pattern::Cons(_, head, tail) => write!(f, "{}::{}", head, tail),
336            Pattern::Dict(_, fields) => {
337                '{'.fmt(f)?;
338                for (i, (key, pat)) in fields.iter().enumerate() {
339                    // Use shorthand when possible to keep output stable with old syntax.
340                    match pat {
341                        Pattern::Var(var) if var.name == *key => {
342                            key.fmt(f)?;
343                        }
344                        _ => {
345                            key.fmt(f)?;
346                            ": ".fmt(f)?;
347                            pat.fmt(f)?;
348                        }
349                    }
350                    if i + 1 < fields.len() {
351                        ", ".fmt(f)?;
352                    }
353                }
354                '}'.fmt(f)
355            }
356        }
357    }
358}
359
360#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
361#[serde(rename_all = "lowercase")]
362pub enum TypeExpr {
363    Name(Span, NameRef),
364    App(Span, Box<TypeExpr>, Box<TypeExpr>),
365    Fun(Span, Box<TypeExpr>, Box<TypeExpr>),
366    Tuple(Span, Vec<TypeExpr>),
367    Record(Span, Vec<(Symbol, TypeExpr)>),
368}
369
370impl TypeExpr {
371    pub fn span(&self) -> &Span {
372        match self {
373            TypeExpr::Name(span, ..)
374            | TypeExpr::App(span, ..)
375            | TypeExpr::Fun(span, ..)
376            | TypeExpr::Tuple(span, ..)
377            | TypeExpr::Record(span, ..) => span,
378        }
379    }
380
381    pub fn reset_spans(&self) -> TypeExpr {
382        match self {
383            TypeExpr::Name(_, name) => TypeExpr::Name(Span::default(), name.clone()),
384            TypeExpr::App(_, fun, arg) => TypeExpr::App(
385                Span::default(),
386                Box::new(fun.reset_spans()),
387                Box::new(arg.reset_spans()),
388            ),
389            TypeExpr::Fun(_, arg, ret) => TypeExpr::Fun(
390                Span::default(),
391                Box::new(arg.reset_spans()),
392                Box::new(ret.reset_spans()),
393            ),
394            TypeExpr::Tuple(_, elems) => TypeExpr::Tuple(
395                Span::default(),
396                elems.iter().map(|e| e.reset_spans()).collect(),
397            ),
398            TypeExpr::Record(_, fields) => TypeExpr::Record(
399                Span::default(),
400                fields
401                    .iter()
402                    .map(|(name, ty)| (name.clone(), ty.reset_spans()))
403                    .collect(),
404            ),
405        }
406    }
407}
408
409impl Display for TypeExpr {
410    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
411        match self {
412            TypeExpr::Name(_, name) => name.fmt(f),
413            TypeExpr::App(_, fun, arg) => {
414                match fun.as_ref() {
415                    TypeExpr::Name(..) | TypeExpr::App(..) => fun.fmt(f)?,
416                    _ => {
417                        '('.fmt(f)?;
418                        fun.fmt(f)?;
419                        ')'.fmt(f)?;
420                    }
421                }
422                ' '.fmt(f)?;
423                match arg.as_ref() {
424                    TypeExpr::Name(..)
425                    | TypeExpr::App(..)
426                    | TypeExpr::Tuple(..)
427                    | TypeExpr::Record(..) => arg.fmt(f),
428                    _ => {
429                        '('.fmt(f)?;
430                        arg.fmt(f)?;
431                        ')'.fmt(f)
432                    }
433                }
434            }
435            TypeExpr::Fun(_, arg, ret) => {
436                match arg.as_ref() {
437                    TypeExpr::Fun(..) => {
438                        '('.fmt(f)?;
439                        arg.fmt(f)?;
440                        ')'.fmt(f)?;
441                    }
442                    _ => arg.fmt(f)?,
443                }
444                " -> ".fmt(f)?;
445                ret.fmt(f)
446            }
447            TypeExpr::Tuple(_, elems) => {
448                '('.fmt(f)?;
449                for (i, elem) in elems.iter().enumerate() {
450                    elem.fmt(f)?;
451                    if i + 1 < elems.len() {
452                        ", ".fmt(f)?;
453                    }
454                }
455                ')'.fmt(f)
456            }
457            TypeExpr::Record(_, fields) => {
458                '{'.fmt(f)?;
459                for (i, (name, ty)) in fields.iter().enumerate() {
460                    name.fmt(f)?;
461                    ": ".fmt(f)?;
462                    ty.fmt(f)?;
463                    if i + 1 < fields.len() {
464                        ", ".fmt(f)?;
465                    }
466                }
467                '}'.fmt(f)
468            }
469        }
470    }
471}
472
473#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
474pub struct TypeConstraint {
475    pub class: NameRef,
476    pub typ: TypeExpr,
477}
478
479impl TypeConstraint {
480    pub fn new(class: NameRef, typ: TypeExpr) -> Self {
481        Self { class, typ }
482    }
483}
484
485#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
486pub struct TypeVariant {
487    pub name: Symbol,
488    pub args: Vec<TypeExpr>,
489}
490
491#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
492pub struct TypeDecl {
493    pub span: Span,
494    pub is_pub: bool,
495    pub name: Symbol,
496    pub params: Vec<Symbol>,
497    pub variants: Vec<TypeVariant>,
498}
499
500#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
501pub struct FnDecl {
502    pub span: Span,
503    pub is_pub: bool,
504    pub name: Var,
505    pub params: Vec<(Var, TypeExpr)>,
506    pub ret: TypeExpr,
507    pub constraints: Vec<TypeConstraint>,
508    pub body: Arc<Expr>,
509}
510
511#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
512pub struct DeclareFnDecl {
513    pub span: Span,
514    pub is_pub: bool,
515    pub name: Var,
516    pub params: Vec<(Var, TypeExpr)>,
517    pub ret: TypeExpr,
518    pub constraints: Vec<TypeConstraint>,
519}
520
521#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
522pub struct ClassMethodSig {
523    pub name: Symbol,
524    pub typ: TypeExpr,
525}
526
527#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
528pub struct ClassDecl {
529    pub span: Span,
530    pub is_pub: bool,
531    pub name: Symbol,
532    pub params: Vec<Symbol>,
533    pub supers: Vec<TypeConstraint>,
534    pub methods: Vec<ClassMethodSig>,
535}
536
537#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
538pub struct InstanceMethodImpl {
539    pub name: Symbol,
540    pub body: Arc<Expr>,
541}
542
543#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
544pub struct InstanceDecl {
545    pub span: Span,
546    pub is_pub: bool,
547    pub class: Symbol,
548    pub head: TypeExpr,
549    pub context: Vec<TypeConstraint>,
550    pub methods: Vec<InstanceMethodImpl>,
551}
552
553#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
554#[serde(rename_all = "lowercase")]
555pub enum ImportPath {
556    Local {
557        segments: Vec<Symbol>,
558        sha: Option<String>,
559    },
560    Remote {
561        url: String,
562        sha: Option<String>,
563    },
564}
565
566#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
567pub struct ImportItem {
568    pub name: Symbol,
569    pub alias: Option<Symbol>,
570}
571
572#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
573#[serde(rename_all = "lowercase")]
574pub enum ImportClause {
575    All,
576    Items(Vec<ImportItem>),
577}
578
579#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
580pub struct ImportDecl {
581    pub span: Span,
582    pub is_pub: bool,
583    pub path: ImportPath,
584    pub alias: Symbol,
585    pub clause: Option<ImportClause>,
586}
587
588#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
589pub enum Decl {
590    Type(TypeDecl),
591    Fn(FnDecl),
592    DeclareFn(DeclareFnDecl),
593    Import(ImportDecl),
594    Class(ClassDecl),
595    Instance(InstanceDecl),
596}
597
598impl std::fmt::Display for Decl {
599    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
600        match self {
601            Decl::Type(d) => write!(f, "{}", d.name),
602            Decl::Fn(d) => write!(f, "{}", d.name),
603            Decl::DeclareFn(d) => write!(f, "{}", d.name),
604            Decl::Import(d) => write!(f, "{}", d.alias),
605            Decl::Class(d) => write!(f, "{}", d.name),
606            Decl::Instance(d) => write!(f, "{} {}", d.class, d.head),
607        }
608    }
609}
610
611#[derive(Clone, Debug, PartialEq, serde::Deserialize, serde::Serialize)]
612pub struct Program {
613    pub decls: Vec<Decl>,
614    pub expr: Arc<Expr>,
615}
616
617impl Program {
618    /// Lower top-level `fn` declarations into nested `let` bindings around `expr`.
619    ///
620    /// This keeps the surface syntax (`Decl::Fn`) intact for tools, while giving
621    /// the type checker and evaluator a plain expression to work with.
622    pub fn expr_with_fns(&self) -> Arc<Expr> {
623        let mut out = self.expr.clone();
624        for decl in self.decls.iter().rev() {
625            let Decl::Fn(fd) = decl else {
626                continue;
627            };
628
629            let mut lam_body = fd.body.clone();
630            let mut lam_end = lam_body.span().end;
631            for (idx, (param, ann)) in fd.params.iter().enumerate().rev() {
632                let lam_constraints = if idx == 0 {
633                    fd.constraints.clone()
634                } else {
635                    Vec::new()
636                };
637                let span = Span::from_begin_end(param.span.begin, lam_end);
638                lam_body = Arc::new(Expr::Lam(
639                    span,
640                    Scope::new_sync(),
641                    param.clone(),
642                    Some(ann.clone()),
643                    lam_constraints,
644                    lam_body,
645                ));
646                lam_end = lam_body.span().end;
647            }
648
649            let mut sig = fd.ret.clone();
650            for (_, ann) in fd.params.iter().rev() {
651                let span = Span::from_begin_end(ann.span().begin, sig.span().end);
652                sig = TypeExpr::Fun(span, Box::new(ann.clone()), Box::new(sig));
653            }
654
655            let span = Span::from_begin_end(fd.span.begin, out.span().end);
656            out = Arc::new(Expr::Let(span, fd.name.clone(), Some(sig), lam_body, out));
657        }
658        out
659    }
660}
661
662#[derive(Debug, PartialEq, serde::Deserialize, serde::Serialize)]
663#[serde(rename_all = "lowercase")]
664pub enum Expr {
665    Bool(Span, bool),              // true
666    Uint(Span, u64),               // 69
667    Int(Span, i64),                // -420
668    Float(Span, f64),              // 3.14
669    String(Span, String),          // "hello"
670    Uuid(Span, Uuid),              // a550c18e-36e1-4f6d-8c8e-2d2b1e5f3c3a
671    DateTime(Span, DateTime<Utc>), // 2023-01-01T12:00:00Z
672    Hole(Span),                    // ?
673
674    Tuple(Span, Vec<Arc<Expr>>),             // (e1, e2, e3)
675    List(Span, Vec<Arc<Expr>>),              // [e1, e2, e3]
676    Dict(Span, BTreeMap<Symbol, Arc<Expr>>), // {k1 = v1, k2 = v2}
677    RecordUpdate(Span, Arc<Expr>, BTreeMap<Symbol, Arc<Expr>>), // {base with {k1 = v1, ...}}
678
679    Var(Var),                         // x
680    App(Span, Arc<Expr>, Arc<Expr>),  // f x
681    Project(Span, Arc<Expr>, Symbol), // x.field
682    Lam(
683        Span,
684        Scope,
685        Var,
686        Option<TypeExpr>,
687        Vec<TypeConstraint>,
688        Arc<Expr>,
689    ), // λx → e
690    Let(Span, Var, Option<TypeExpr>, Arc<Expr>, Arc<Expr>), // let x = e1 in e2
691    LetRec(Span, Vec<(Var, Option<TypeExpr>, Arc<Expr>)>, Arc<Expr>), // let rec f = e1 and g = e2 in e3
692    Ite(Span, Arc<Expr>, Arc<Expr>, Arc<Expr>),                       // if e1 then e2 else e3
693    Match(Span, Arc<Expr>, Vec<(Pattern, Arc<Expr>)>),                // match e1 with patterns
694    Ann(Span, Arc<Expr>, TypeExpr),                                   // e is t
695}
696
697impl Expr {
698    pub fn span(&self) -> &Span {
699        match self {
700            Self::Bool(span, ..)
701            | Self::Uint(span, ..)
702            | Self::Int(span, ..)
703            | Self::Float(span, ..)
704            | Self::String(span, ..)
705            | Self::Uuid(span, ..)
706            | Self::DateTime(span, ..)
707            | Self::Hole(span, ..)
708            | Self::Tuple(span, ..)
709            | Self::List(span, ..)
710            | Self::Dict(span, ..)
711            | Self::RecordUpdate(span, ..)
712            | Self::Var(Var { span, .. })
713            | Self::App(span, ..)
714            | Self::Project(span, ..)
715            | Self::Lam(span, ..)
716            | Self::Let(span, ..)
717            | Self::LetRec(span, ..)
718            | Self::Ite(span, ..)
719            | Self::Match(span, ..)
720            | Self::Ann(span, ..) => span,
721        }
722    }
723
724    pub fn span_mut(&mut self) -> &mut Span {
725        match self {
726            Self::Bool(span, ..)
727            | Self::Uint(span, ..)
728            | Self::Int(span, ..)
729            | Self::Float(span, ..)
730            | Self::String(span, ..)
731            | Self::Uuid(span, ..)
732            | Self::DateTime(span, ..)
733            | Self::Hole(span, ..)
734            | Self::Tuple(span, ..)
735            | Self::List(span, ..)
736            | Self::Dict(span, ..)
737            | Self::RecordUpdate(span, ..)
738            | Self::Var(Var { span, .. })
739            | Self::App(span, ..)
740            | Self::Project(span, ..)
741            | Self::Lam(span, ..)
742            | Self::Let(span, ..)
743            | Self::LetRec(span, ..)
744            | Self::Ite(span, ..)
745            | Self::Match(span, ..)
746            | Self::Ann(span, ..) => span,
747        }
748    }
749
750    pub fn with_span_begin_end(&self, begin: Position, end: Position) -> Expr {
751        self.with_span(Span::from_begin_end(begin, end))
752    }
753
754    pub fn with_span_begin(&self, begin: Position) -> Expr {
755        let end = self.span().end;
756        self.with_span(Span::from_begin_end(begin, end))
757    }
758
759    pub fn with_span_end(&self, end: Position) -> Expr {
760        let begin = self.span().begin;
761        self.with_span(Span::from_begin_end(begin, end))
762    }
763
764    pub fn with_span(&self, span: Span) -> Expr {
765        match self {
766            Expr::Bool(_, x) => Expr::Bool(span, *x),
767            Expr::Uint(_, x) => Expr::Uint(span, *x),
768            Expr::Int(_, x) => Expr::Int(span, *x),
769            Expr::Float(_, x) => Expr::Float(span, *x),
770            Expr::String(_, x) => Expr::String(span, x.clone()),
771            Expr::Uuid(_, x) => Expr::Uuid(span, *x),
772            Expr::DateTime(_, x) => Expr::DateTime(span, *x),
773            Expr::Hole(_) => Expr::Hole(span),
774            Expr::Tuple(_, elems) => Expr::Tuple(span, elems.clone()),
775            Expr::List(_, elems) => Expr::List(span, elems.clone()),
776            Expr::Dict(_, kvs) => Expr::Dict(
777                span,
778                BTreeMap::from_iter(kvs.iter().map(|(k, v)| (k.clone(), v.clone()))),
779            ),
780            Expr::RecordUpdate(_, base, updates) => Expr::RecordUpdate(
781                span,
782                base.clone(),
783                BTreeMap::from_iter(updates.iter().map(|(k, v)| (k.clone(), v.clone()))),
784            ),
785            Expr::Var(var) => Expr::Var(Var {
786                span,
787                name: var.name.clone(),
788            }),
789            Expr::App(_, f, x) => Expr::App(span, f.clone(), x.clone()),
790            Expr::Project(_, base, field) => Expr::Project(span, base.clone(), field.clone()),
791            Expr::Lam(_, scope, param, ann, constraints, body) => Expr::Lam(
792                span,
793                scope.clone(),
794                param.clone(),
795                ann.clone(),
796                constraints.clone(),
797                body.clone(),
798            ),
799            Expr::Let(_, var, ann, def, body) => {
800                Expr::Let(span, var.clone(), ann.clone(), def.clone(), body.clone())
801            }
802            Expr::LetRec(_, bindings, body) => Expr::LetRec(
803                span,
804                bindings
805                    .iter()
806                    .map(|(var, ann, def)| (var.clone(), ann.clone(), def.clone()))
807                    .collect(),
808                body.clone(),
809            ),
810            Expr::Ite(_, cond, then, r#else) => {
811                Expr::Ite(span, cond.clone(), then.clone(), r#else.clone())
812            }
813            Expr::Match(_, scrutinee, arms) => Expr::Match(
814                span,
815                scrutinee.clone(),
816                arms.iter()
817                    .map(|(pat, expr)| (pat.clone(), expr.clone()))
818                    .collect(),
819            ),
820            Expr::Ann(_, expr, ann) => Expr::Ann(span, expr.clone(), ann.clone()),
821        }
822    }
823
824    pub fn reset_spans(&self) -> Expr {
825        match self {
826            Expr::Bool(_, x) => Expr::Bool(Span::default(), *x),
827            Expr::Uint(_, x) => Expr::Uint(Span::default(), *x),
828            Expr::Int(_, x) => Expr::Int(Span::default(), *x),
829            Expr::Float(_, x) => Expr::Float(Span::default(), *x),
830            Expr::String(_, x) => Expr::String(Span::default(), x.clone()),
831            Expr::Uuid(_, x) => Expr::Uuid(Span::default(), *x),
832            Expr::DateTime(_, x) => Expr::DateTime(Span::default(), *x),
833            Expr::Hole(_) => Expr::Hole(Span::default()),
834            Expr::Tuple(_, elems) => Expr::Tuple(
835                Span::default(),
836                elems.iter().map(|x| Arc::new(x.reset_spans())).collect(),
837            ),
838            Expr::List(_, elems) => Expr::List(
839                Span::default(),
840                elems.iter().map(|x| Arc::new(x.reset_spans())).collect(),
841            ),
842            Expr::Dict(_, kvs) => Expr::Dict(
843                Span::default(),
844                BTreeMap::from_iter(
845                    kvs.iter()
846                        .map(|(k, v)| (k.clone(), Arc::new(v.reset_spans()))),
847                ),
848            ),
849            Expr::RecordUpdate(_, base, updates) => Expr::RecordUpdate(
850                Span::default(),
851                Arc::new(base.reset_spans()),
852                BTreeMap::from_iter(
853                    updates
854                        .iter()
855                        .map(|(k, v)| (k.clone(), Arc::new(v.reset_spans()))),
856                ),
857            ),
858            Expr::Var(var) => Expr::Var(var.reset_spans()),
859            Expr::App(_, f, x) => Expr::App(
860                Span::default(),
861                Arc::new(f.reset_spans()),
862                Arc::new(x.reset_spans()),
863            ),
864            Expr::Project(_, base, field) => {
865                Expr::Project(Span::default(), Arc::new(base.reset_spans()), field.clone())
866            }
867            Expr::Lam(_, scope, param, ann, constraints, body) => Expr::Lam(
868                Span::default(),
869                scope.clone(),
870                param.reset_spans(),
871                ann.as_ref().map(TypeExpr::reset_spans),
872                constraints
873                    .iter()
874                    .map(|constraint| TypeConstraint {
875                        class: constraint.class.clone(),
876                        typ: constraint.typ.reset_spans(),
877                    })
878                    .collect(),
879                Arc::new(body.reset_spans()),
880            ),
881            Expr::Let(_, var, ann, def, body) => Expr::Let(
882                Span::default(),
883                var.reset_spans(),
884                ann.as_ref().map(|t| t.reset_spans()),
885                Arc::new(def.reset_spans()),
886                Arc::new(body.reset_spans()),
887            ),
888            Expr::LetRec(_, bindings, body) => Expr::LetRec(
889                Span::default(),
890                bindings
891                    .iter()
892                    .map(|(var, ann, def)| {
893                        (
894                            var.reset_spans(),
895                            ann.as_ref().map(TypeExpr::reset_spans),
896                            Arc::new(def.reset_spans()),
897                        )
898                    })
899                    .collect(),
900                Arc::new(body.reset_spans()),
901            ),
902            Expr::Ite(_, cond, then, r#else) => Expr::Ite(
903                Span::default(),
904                Arc::new(cond.reset_spans()),
905                Arc::new(then.reset_spans()),
906                Arc::new(r#else.reset_spans()),
907            ),
908            Expr::Match(_, scrutinee, arms) => Expr::Match(
909                Span::default(),
910                Arc::new(scrutinee.reset_spans()),
911                arms.iter()
912                    .map(|(pat, expr)| (pat.reset_spans(), Arc::new(expr.reset_spans())))
913                    .collect(),
914            ),
915            Expr::Ann(_, expr, ann) => Expr::Ann(
916                Span::default(),
917                Arc::new(expr.reset_spans()),
918                ann.reset_spans(),
919            ),
920        }
921    }
922}
923
924impl Display for Expr {
925    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
926        match self {
927            Self::Bool(_span, x) => x.fmt(f),
928            Self::Uint(_span, x) => x.fmt(f),
929            Self::Int(_span, x) => x.fmt(f),
930            Self::Float(_span, x) => x.fmt(f),
931            Self::String(_span, x) => write!(f, "{:?}", x),
932            Self::Uuid(_span, x) => x.fmt(f),
933            Self::DateTime(_span, x) => x.fmt(f),
934            Self::Hole(_span) => '?'.fmt(f),
935            Self::List(_span, xs) => {
936                '['.fmt(f)?;
937                for (i, x) in xs.iter().enumerate() {
938                    x.fmt(f)?;
939                    if i + 1 < xs.len() {
940                        ", ".fmt(f)?;
941                    }
942                }
943                ']'.fmt(f)
944            }
945            Self::Tuple(_span, xs) => {
946                '('.fmt(f)?;
947                for (i, x) in xs.iter().enumerate() {
948                    x.fmt(f)?;
949                    if i + 1 < xs.len() {
950                        ", ".fmt(f)?;
951                    }
952                }
953                ')'.fmt(f)
954            }
955            Self::Dict(_span, kvs) => {
956                '{'.fmt(f)?;
957                for (i, (k, v)) in kvs.iter().enumerate() {
958                    k.fmt(f)?;
959                    " = ".fmt(f)?;
960                    v.fmt(f)?;
961                    if i + 1 < kvs.len() {
962                        ", ".fmt(f)?;
963                    }
964                }
965                '}'.fmt(f)
966            }
967            Self::RecordUpdate(_span, base, kvs) => {
968                '{'.fmt(f)?;
969                base.fmt(f)?;
970                " with ".fmt(f)?;
971                '{'.fmt(f)?;
972                for (i, (k, v)) in kvs.iter().enumerate() {
973                    k.fmt(f)?;
974                    " = ".fmt(f)?;
975                    v.fmt(f)?;
976                    if i + 1 < kvs.len() {
977                        ", ".fmt(f)?;
978                    }
979                }
980                '}'.fmt(f)?;
981                '}'.fmt(f)
982            }
983            Self::Var(var) => var.fmt(f),
984            Self::App(_span, g, x) => {
985                g.fmt(f)?;
986                ' '.fmt(f)?;
987                match x.as_ref() {
988                    Self::Bool(..)
989                    | Self::Uint(..)
990                    | Self::Int(..)
991                    | Self::Float(..)
992                    | Self::String(..)
993                    | Self::List(..)
994                    | Self::Tuple(..)
995                    | Self::Dict(..)
996                    | Self::RecordUpdate(..)
997                    | Self::Project(..)
998                    | Self::Var(..) => x.fmt(f),
999                    _ => {
1000                        '('.fmt(f)?;
1001                        x.fmt(f)?;
1002                        ')'.fmt(f)
1003                    }
1004                }
1005            }
1006            Self::Lam(_span, _scope, param, ann, constraints, body) => {
1007                'λ'.fmt(f)?;
1008                if let Some(ann) = ann {
1009                    '('.fmt(f)?;
1010                    param.fmt(f)?;
1011                    " : ".fmt(f)?;
1012                    ann.fmt(f)?;
1013                    ')'.fmt(f)?;
1014                } else {
1015                    param.fmt(f)?;
1016                }
1017                if !constraints.is_empty() {
1018                    " where ".fmt(f)?;
1019                    for (i, constraint) in constraints.iter().enumerate() {
1020                        constraint.class.fmt(f)?;
1021                        ' '.fmt(f)?;
1022                        constraint.typ.fmt(f)?;
1023                        if i + 1 < constraints.len() {
1024                            ", ".fmt(f)?;
1025                        }
1026                    }
1027                }
1028                " → ".fmt(f)?;
1029                body.fmt(f)
1030            }
1031            Self::Let(_span, var, ann, def, body) => {
1032                "let ".fmt(f)?;
1033                var.fmt(f)?;
1034                if let Some(ann) = ann {
1035                    ": ".fmt(f)?;
1036                    ann.fmt(f)?;
1037                }
1038                " = ".fmt(f)?;
1039                def.fmt(f)?;
1040                " in ".fmt(f)?;
1041                body.fmt(f)
1042            }
1043            Self::LetRec(_span, bindings, body) => {
1044                "let rec ".fmt(f)?;
1045                for (idx, (var, ann, def)) in bindings.iter().enumerate() {
1046                    if idx > 0 {
1047                        " and ".fmt(f)?;
1048                    }
1049                    var.fmt(f)?;
1050                    if let Some(ann) = ann {
1051                        ": ".fmt(f)?;
1052                        ann.fmt(f)?;
1053                    }
1054                    " = ".fmt(f)?;
1055                    def.fmt(f)?;
1056                }
1057                " in ".fmt(f)?;
1058                body.fmt(f)
1059            }
1060            Self::Ite(_span, cond, then, r#else) => {
1061                "if ".fmt(f)?;
1062                cond.fmt(f)?;
1063                " then ".fmt(f)?;
1064                then.fmt(f)?;
1065                " else ".fmt(f)?;
1066                r#else.fmt(f)
1067            }
1068            Self::Match(_span, scrutinee, arms) => {
1069                "match ".fmt(f)?;
1070                scrutinee.fmt(f)?;
1071                ' '.fmt(f)?;
1072                for (i, (pat, expr)) in arms.iter().enumerate() {
1073                    "when ".fmt(f)?;
1074                    pat.fmt(f)?;
1075                    " -> ".fmt(f)?;
1076                    expr.fmt(f)?;
1077                    if i + 1 < arms.len() {
1078                        ' '.fmt(f)?;
1079                    }
1080                }
1081                Ok(())
1082            }
1083            Self::Project(_span, base, field) => {
1084                base.fmt(f)?;
1085                ".".fmt(f)?;
1086                field.fmt(f)
1087            }
1088            Self::Ann(_span, expr, ann) => {
1089                expr.fmt(f)?;
1090                " is ".fmt(f)?;
1091                ann.fmt(f)
1092            }
1093        }
1094    }
1095}