Skip to main content

mq_lang/ast/
node.rs

1use super::{Program, TokenId};
2#[cfg(feature = "ast-json")]
3use crate::arena::ArenaId;
4use crate::{Ident, Shared, Token, arena::Arena, number::Number, range::Range, selector::Selector};
5#[cfg(feature = "ast-json")]
6use serde::{Deserialize, Serialize};
7use smallvec::SmallVec;
8use smol_str::SmolStr;
9use std::{
10    fmt::{self, Display, Formatter},
11    hash::{Hash, Hasher},
12};
13
14/// Represents a function parameter with an optional default value
15#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
16#[derive(PartialEq, PartialOrd, Debug, Clone)]
17pub struct Param {
18    pub ident: IdentWithToken,
19    pub default: Option<Shared<Node>>,
20    pub is_variadic: bool,
21}
22
23impl Display for Param {
24    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
25        if self.is_variadic {
26            write!(f, "*{}", self.ident)
27        } else {
28            write!(f, "{}", self.ident)
29        }
30    }
31}
32
33impl Param {
34    pub fn new(name: IdentWithToken) -> Self {
35        Self::with_default(name, None)
36    }
37
38    pub fn with_default(name: IdentWithToken, default_value: Option<Shared<Node>>) -> Self {
39        Self {
40            ident: name,
41            default: default_value,
42            is_variadic: false,
43        }
44    }
45
46    /// Creates a variadic parameter (e.g., `*args`)
47    pub fn variadic(name: IdentWithToken) -> Self {
48        Self {
49            ident: name,
50            default: None,
51            is_variadic: true,
52        }
53    }
54}
55
56pub type Params = SmallVec<[Param; 4]>;
57pub type Args = SmallVec<[Shared<Node>; 4]>;
58pub type Cond = (Option<Shared<Node>>, Shared<Node>);
59pub type Branches = SmallVec<[Cond; 4]>;
60pub type MatchArms = SmallVec<[MatchArm; 4]>;
61
62#[derive(PartialEq, PartialOrd, Debug, Clone)]
63#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
64pub struct Node {
65    #[cfg_attr(
66        feature = "ast-json",
67        serde(skip_serializing, skip_deserializing, default = "default_token_id")
68    )]
69    pub token_id: TokenId,
70    pub expr: Shared<Expr>,
71}
72
73#[cfg(feature = "ast-json")]
74fn default_token_id() -> TokenId {
75    ArenaId::new(0)
76}
77
78impl Node {
79    #[cfg(feature = "ast-json")]
80    pub fn to_json(&self) -> Result<String, serde_json::Error> {
81        serde_json::to_string_pretty(self)
82    }
83
84    #[cfg(feature = "ast-json")]
85    pub fn from_json(json_str: &str) -> Result<Self, serde_json::Error> {
86        serde_json::from_str(json_str)
87    }
88
89    pub fn range(&self, arena: Shared<Arena<Shared<Token>>>) -> Range {
90        match &*self.expr {
91            Expr::Block(program)
92            | Expr::Def(_, _, program)
93            | Expr::Fn(_, program)
94            | Expr::While(_, program)
95            | Expr::Loop(program)
96            | Expr::Module(_, program)
97            | Expr::Foreach(_, _, program) => {
98                let start = program
99                    .first()
100                    .map(|node| node.range(Shared::clone(&arena)).start)
101                    .unwrap_or_default();
102                let end = program
103                    .last()
104                    .map(|node| node.range(Shared::clone(&arena)).end)
105                    .unwrap_or_default();
106                Range { start, end }
107            }
108            Expr::Call(_, args) => {
109                let start = args
110                    .first()
111                    .map(|node| node.range(Shared::clone(&arena)).start)
112                    .unwrap_or_default();
113                let end = args
114                    .last()
115                    .map(|node| node.range(Shared::clone(&arena)).end)
116                    .unwrap_or_default();
117                Range { start, end }
118            }
119            Expr::CallDynamic(callable, args) => {
120                let start = callable.range(Shared::clone(&arena)).start;
121                let end = args
122                    .last()
123                    .map(|node| node.range(Shared::clone(&arena)).end)
124                    .unwrap_or_else(|| callable.range(Shared::clone(&arena)).end);
125                Range { start, end }
126            }
127            Expr::Macro(_, params, block) => {
128                let start = params
129                    .first()
130                    .and_then(|param| param.ident.token.as_ref().map(|t| t.range))
131                    .unwrap_or(block.range(Shared::clone(&arena)))
132                    .start;
133                let end = block.range(arena).end;
134                Range { start, end }
135            }
136            Expr::Let(_, node)
137            | Expr::Var(_, node)
138            | Expr::Assign(_, node)
139            | Expr::Quote(node)
140            | Expr::Unquote(node) => node.range(Shared::clone(&arena)),
141            Expr::If(nodes) => {
142                if let (Some(first), Some(last)) = (nodes.first(), nodes.last()) {
143                    let start = first.1.range(Shared::clone(&arena));
144                    let end = last.1.range(Shared::clone(&arena));
145                    Range {
146                        start: start.start,
147                        end: end.end,
148                    }
149                } else {
150                    // Fallback to token range if no branches exist
151                    arena[self.token_id].range
152                }
153            }
154            Expr::Match(value, arms) => {
155                let start = value.range(Shared::clone(&arena)).start;
156                let end = arms
157                    .last()
158                    .map(|arm| arm.body.range(Shared::clone(&arena)).end)
159                    .unwrap_or_else(|| arena[self.token_id].range.end);
160                Range { start, end }
161            }
162            Expr::Paren(node) => node.range(Shared::clone(&arena)),
163            Expr::Try(try_expr, catch_expr) => {
164                let start = try_expr.range(Shared::clone(&arena)).start;
165                let end = catch_expr.range(Shared::clone(&arena)).end;
166                Range { start, end }
167            }
168            Expr::And(exprs) | Expr::Or(exprs) => {
169                if let (Some(first), Some(last)) = (exprs.first(), exprs.last()) {
170                    Range {
171                        start: first.range(Shared::clone(&arena)).start,
172                        end: last.range(Shared::clone(&arena)).end,
173                    }
174                } else {
175                    arena[self.token_id].range
176                }
177            }
178            Expr::Break(Some(value_node)) => {
179                let start = arena[self.token_id].range.start;
180                let end = value_node.range(Shared::clone(&arena)).end;
181                Range { start, end }
182            }
183            Expr::Literal(_)
184            | Expr::Ident(_)
185            | Expr::Selector(_)
186            | Expr::Include(_)
187            | Expr::Import(_)
188            | Expr::InterpolatedString(_)
189            | Expr::QualifiedAccess(_, _)
190            | Expr::Nodes
191            | Expr::Self_
192            | Expr::Break(None)
193            | Expr::Continue => arena[self.token_id].range,
194        }
195    }
196
197    pub fn is_nodes(&self) -> bool {
198        matches!(*self.expr, Expr::Nodes)
199    }
200}
201
202#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
203#[derive(PartialEq, Debug, Eq, Clone)]
204pub struct IdentWithToken {
205    pub name: Ident,
206    #[cfg_attr(feature = "ast-json", serde(skip_serializing_if = "Option::is_none", default))]
207    pub token: Option<Shared<Token>>,
208}
209
210impl Hash for IdentWithToken {
211    fn hash<H: Hasher>(&self, state: &mut H) {
212        self.name.hash(state);
213    }
214}
215
216impl Ord for IdentWithToken {
217    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
218        self.name.cmp(&other.name)
219    }
220}
221
222impl PartialOrd for IdentWithToken {
223    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
224        Some(self.cmp(other))
225    }
226}
227
228impl IdentWithToken {
229    pub fn new(name: &str) -> Self {
230        Self::new_with_token(name, None)
231    }
232
233    pub fn new_with_token(name: &str, token: Option<Shared<Token>>) -> Self {
234        Self {
235            name: name.into(),
236            token,
237        }
238    }
239}
240
241impl Display for IdentWithToken {
242    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
243        write!(f, "{}", self.name)
244    }
245}
246
247#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
248#[derive(Debug, Clone, PartialOrd, PartialEq)]
249pub enum StringSegment {
250    Text(String),
251    Expr(Shared<Node>),
252    Env(SmolStr),
253    Self_,
254}
255
256#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
257#[derive(PartialEq, PartialOrd, Debug, Clone)]
258pub enum Pattern {
259    Literal(Literal),
260    Ident(IdentWithToken),
261    Wildcard,
262    Array(Vec<Pattern>),
263    ArrayRest(Vec<Pattern>, IdentWithToken), // patterns before .., rest binding
264    Dict(Vec<(IdentWithToken, Pattern)>),
265    Type(Ident),      // :string, :number, etc.
266    Or(Vec<Pattern>), // p1 || p2 || p3
267}
268
269#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
270#[derive(PartialEq, PartialOrd, Debug, Clone)]
271pub struct MatchArm {
272    pub pattern: Pattern,
273    pub guard: Option<Shared<Node>>,
274    pub body: Shared<Node>,
275}
276
277#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
278#[derive(PartialEq, PartialOrd, Debug, Clone)]
279pub enum Literal {
280    String(String),
281    Number(Number),
282    Symbol(Ident),
283    Bool(bool),
284    None,
285}
286
287impl From<&str> for Literal {
288    fn from(s: &str) -> Self {
289        Literal::String(s.to_owned())
290    }
291}
292
293#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
294#[derive(PartialEq, PartialOrd, Debug, Clone)]
295pub enum AccessTarget {
296    Call(IdentWithToken, Args),
297    Ident(IdentWithToken),
298}
299
300impl Display for Literal {
301    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
302        match self {
303            Literal::String(s) => write!(f, "{}", s),
304            Literal::Number(n) => write!(f, "{}", n),
305            Literal::Symbol(i) => write!(f, "{}", i),
306            Literal::Bool(b) => write!(f, "{}", b),
307            Literal::None => write!(f, "none"),
308        }
309    }
310}
311
312#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
313#[derive(PartialEq, PartialOrd, Debug, Clone)]
314pub enum Expr {
315    Block(Program),
316    Call(IdentWithToken, Args),
317    CallDynamic(Shared<Node>, Args),
318    Def(IdentWithToken, Params, Program),
319    Macro(IdentWithToken, Params, Shared<Node>),
320    Fn(Params, Program),
321    Let(Pattern, Shared<Node>),
322    Loop(Program),
323    Var(Pattern, Shared<Node>),
324    Assign(IdentWithToken, Shared<Node>),
325    And(Vec<Shared<Node>>),
326    Or(Vec<Shared<Node>>),
327    Literal(Literal),
328    Ident(IdentWithToken),
329    InterpolatedString(Vec<StringSegment>),
330    Selector(Selector),
331    While(Shared<Node>, Program),
332    Foreach(IdentWithToken, Shared<Node>, Program),
333    If(Branches),
334    Match(Shared<Node>, MatchArms),
335    Include(Literal),
336    Import(Literal),
337    Module(IdentWithToken, Program),
338    QualifiedAccess(Vec<IdentWithToken>, AccessTarget),
339    Self_,
340    Nodes,
341    Paren(Shared<Node>),
342    Quote(Shared<Node>),
343    Unquote(Shared<Node>),
344    Try(Shared<Node>, Shared<Node>),
345    Break(Option<Shared<Node>>),
346    Continue,
347}
348
349#[cfg(feature = "debugger")]
350impl Display for Expr {
351    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
352        match self {
353            Expr::Call(ident, args) => {
354                write!(f, "{}(", ident)?;
355                for (i, arg) in args.iter().enumerate() {
356                    if i > 0 {
357                        write!(f, ", ")?;
358                    }
359                    write!(f, "{}", arg.expr)?;
360                }
361                write!(f, ")")
362            }
363            Expr::CallDynamic(callable, args) => {
364                write!(f, "{}(", callable.expr)?;
365                for (i, arg) in args.iter().enumerate() {
366                    if i > 0 {
367                        write!(f, ", ")?;
368                    }
369                    write!(f, "{}", arg.expr)?;
370                }
371                write!(f, ")")
372            }
373            _ => write!(f, ""),
374        }
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use crate::{Position, TokenKind, arena::ArenaId};
382    use rstest::rstest;
383    use smallvec::smallvec;
384
385    fn create_token(range: Range) -> Shared<Token> {
386        Shared::new(Token {
387            range,
388            kind: TokenKind::Eof,
389            module_id: ArenaId::new(0),
390        })
391    }
392
393    #[rstest]
394    #[case(
395        Expr::CallDynamic(
396            Shared::new(Node {
397                token_id: ArenaId::new(1),
398                expr: Shared::new(Expr::Literal(Literal::String("callee".to_string()))),
399            }),
400            smallvec![
401                Shared::new(Node {
402                    token_id: ArenaId::new(0),
403                    expr: Shared::new(Expr::Literal(Literal::String("arg1".to_string()))),
404                }),
405                Shared::new(Node {
406                    token_id: ArenaId::new(1),
407                    expr: Shared::new(Expr::Literal(Literal::String("arg2".to_string()))),
408                }),
409            ]
410        ),
411        vec![
412            (0, Range { start: Position::new(1, 1), end: Position::new(1, 5) }),
413            (1, Range { start: Position::new(2, 1), end: Position::new(2, 5) }),
414        ],
415        Range { start: Position::new(2, 1), end: Position::new(2, 5) }
416    )]
417    #[case(
418        Expr::Match(
419            Shared::new(Node {
420                token_id: ArenaId::new(0),
421                expr: Shared::new(Expr::Literal(Literal::String("val".to_string()))),
422            }),
423            smallvec![
424                MatchArm {
425                    pattern: Pattern::Literal(Literal::String("a".to_string())),
426                    guard: None,
427                    body: Shared::new(Node {
428                        token_id: ArenaId::new(1),
429                        expr: Shared::new(Expr::Literal(Literal::String("body1".to_string()))),
430                    }),
431                },
432                MatchArm {
433                    pattern: Pattern::Literal(Literal::String("b".to_string())),
434                    guard: None,
435                    body: Shared::new(Node {
436                        token_id: ArenaId::new(2),
437                        expr: Shared::new(Expr::Literal(Literal::String("body2".to_string()))),
438                    }),
439                },
440            ]
441        ),
442        vec![
443            (0, Range { start: Position::new(10, 1), end: Position::new(10, 5) }),
444            (1, Range { start: Position::new(11, 1), end: Position::new(11, 5) }),
445            (2, Range { start: Position::new(12, 1), end: Position::new(12, 5) }),
446        ],
447        Range { start: Position::new(10, 1), end: Position::new(12, 5) }
448    )]
449    #[case(
450        Expr::Try(
451            Shared::new(Node {
452                token_id: ArenaId::new(0),
453                expr: Shared::new(Expr::Literal(Literal::String("try".to_string()))),
454            }),
455            Shared::new(Node {
456                token_id: ArenaId::new(1),
457                expr: Shared::new(Expr::Literal(Literal::String("catch".to_string()))),
458            })
459        ),
460        vec![
461            (0, Range { start: Position::new(20, 1), end: Position::new(20, 5) }),
462            (1, Range { start: Position::new(21, 1), end: Position::new(21, 5) }),
463        ],
464        Range { start: Position::new(20, 1), end: Position::new(21, 5) }
465    )]
466    #[case(
467        Expr::Let(
468            Pattern::Ident(IdentWithToken::new("x")),
469            Shared::new(Node {
470                token_id: ArenaId::new(0),
471                expr: Shared::new(Expr::Literal(Literal::String("letval".to_string()))),
472            })
473        ),
474        vec![
475            (0, Range { start: Position::new(30, 1), end: Position::new(30, 5) }),
476        ],
477        Range { start: Position::new(30, 1), end: Position::new(30, 5) }
478    )]
479    #[case(
480        Expr::Paren(
481            Shared::new(Node {
482                token_id: ArenaId::new(0),
483                expr: Shared::new(Expr::Literal(Literal::String("paren".to_string()))),
484            })
485        ),
486        vec![
487            (0, Range { start: Position::new(40, 1), end: Position::new(40, 5) }),
488        ],
489        Range { start: Position::new(40, 1), end: Position::new(40, 5) }
490    )]
491    #[case(
492        Expr::Block(vec![
493            Shared::new(Node {
494                token_id: ArenaId::new(0),
495                expr: Shared::new(Expr::Literal(Literal::String("block1".to_string()))),
496            }),
497            Shared::new(Node {
498                token_id: ArenaId::new(1),
499                expr: Shared::new(Expr::Literal(Literal::String("block2".to_string()))),
500            }),
501        ]),
502        vec![
503            (0, Range { start: Position::new(50, 1), end: Position::new(50, 5) }),
504            (1, Range { start: Position::new(51, 1), end: Position::new(51, 5) }),
505        ],
506        Range { start: Position::new(50, 1), end: Position::new(51, 5) }
507    )]
508    #[case(
509        Expr::Def(
510            IdentWithToken::new("f"),
511            smallvec![],
512            vec![
513                Shared::new(Node {
514                    token_id: ArenaId::new(0),
515                    expr: Shared::new(Expr::Literal(Literal::String("def1".to_string()))),
516                }),
517                Shared::new(Node {
518                    token_id: ArenaId::new(1),
519                    expr: Shared::new(Expr::Literal(Literal::String("def2".to_string()))),
520                }),
521            ]
522        ),
523        vec![
524            (0, Range { start: Position::new(60, 1), end: Position::new(60, 5) }),
525            (1, Range { start: Position::new(61, 1), end: Position::new(61, 5) }),
526        ],
527        Range { start: Position::new(60, 1), end: Position::new(61, 5) }
528    )]
529    #[case(
530        Expr::Fn(
531            smallvec![],
532            vec![
533                Shared::new(Node {
534                    token_id: ArenaId::new(0),
535                    expr: Shared::new(Expr::Literal(Literal::String("fn1".to_string()))),
536                }),
537                Shared::new(Node {
538                    token_id: ArenaId::new(1),
539                    expr: Shared::new(Expr::Literal(Literal::String("fn2".to_string()))),
540                }),
541            ]
542        ),
543        vec![
544            (0, Range { start: Position::new(70, 1), end: Position::new(70, 5) }),
545            (1, Range { start: Position::new(71, 1), end: Position::new(71, 5) }),
546        ],
547        Range { start: Position::new(70, 1), end: Position::new(71, 5) }
548    )]
549    #[case(
550        Expr::While(
551            Shared::new(Node {
552                token_id: ArenaId::new(0),
553                expr: Shared::new(Expr::Literal(Literal::String("cond".to_string()))),
554            }),
555            vec![
556                Shared::new(Node {
557                    token_id: ArenaId::new(1),
558                    expr: Shared::new(Expr::Literal(Literal::String("while1".to_string()))),
559                }),
560                Shared::new(Node {
561                    token_id: ArenaId::new(2),
562                    expr: Shared::new(Expr::Literal(Literal::String("while2".to_string()))),
563                }),
564            ]
565        ),
566        vec![
567            (0, Range { start: Position::new(81, 1), end: Position::new(81, 5) }),
568            (1, Range { start: Position::new(82, 1), end: Position::new(82, 5) }),
569            (2, Range { start: Position::new(82, 1), end: Position::new(82, 5) }),
570        ],
571        Range { start: Position::new(82, 1), end: Position::new(82, 5) }
572    )]
573    #[case(
574        Expr::Foreach(
575            IdentWithToken::new("item"),
576            Shared::new(Node {
577                token_id: ArenaId::new(0),
578                expr: Shared::new(Expr::Literal(Literal::String("iter".to_string()))),
579            }),
580            vec![
581                Shared::new(Node {
582                    token_id: ArenaId::new(1),
583                    expr: Shared::new(Expr::Literal(Literal::String("foreach1".to_string()))),
584                }),
585                Shared::new(Node {
586                    token_id: ArenaId::new(2),
587                    expr: Shared::new(Expr::Literal(Literal::String("foreach2".to_string()))),
588                }),
589            ]
590        ),
591        vec![
592            (0, Range { start: Position::new(101, 1), end: Position::new(101, 5) }),
593            (1, Range { start: Position::new(102, 1), end: Position::new(102, 5) }),
594            (2, Range { start: Position::new(102, 1), end: Position::new(102, 5) }),
595        ],
596        Range { start: Position::new(102, 1), end: Position::new(102, 5) }
597    )]
598    #[case(
599        Expr::If(smallvec![
600            (
601                Some(Shared::new(Node {
602                    token_id: ArenaId::new(0),
603                    expr: Shared::new(Expr::Literal(Literal::String("cond1".to_string()))),
604                })),
605                Shared::new(Node {
606                    token_id: ArenaId::new(1),
607                    expr: Shared::new(Expr::Literal(Literal::String("if1".to_string()))),
608                })
609            ),
610            (
611                Some(Shared::new(Node {
612                    token_id: ArenaId::new(2),
613                    expr: Shared::new(Expr::Literal(Literal::String("cond2".to_string()))),
614                })),
615                Shared::new(Node {
616                    token_id: ArenaId::new(3),
617                    expr: Shared::new(Expr::Literal(Literal::String("if2".to_string()))),
618                })
619            ),
620        ]),
621        vec![
622            (0, Range { start: Position::new(111, 1), end: Position::new(111, 5) }),
623            (1, Range { start: Position::new(113, 1), end: Position::new(113, 5) }),
624            (2, Range { start: Position::new(114, 1), end: Position::new(115, 5) }),
625            (3, Range { start: Position::new(116, 1), end: Position::new(117, 5) }),
626        ],
627        Range { start: Position::new(113, 1), end: Position::new(117, 5) }
628    )]
629    #[case(
630        Expr::Call(
631            IdentWithToken::new("func"),
632            smallvec![
633                Shared::new(Node {
634                    token_id: ArenaId::new(0),
635                    expr: Shared::new(Expr::Literal(Literal::String("arg1".to_string()))),
636                }),
637                Shared::new(Node {
638                    token_id: ArenaId::new(1),
639                    expr: Shared::new(Expr::Literal(Literal::String("arg2".to_string()))),
640                }),
641            ]
642        ),
643        vec![
644            (0, Range { start: Position::new(120, 1), end: Position::new(120, 5) }),
645            (1, Range { start: Position::new(121, 1), end: Position::new(121, 5) }),
646        ],
647        Range { start: Position::new(120, 1), end: Position::new(121, 5) }
648    )]
649    fn test_node_range_various_exprs(
650        #[case] expr: Expr,
651        #[case] token_ranges: Vec<(usize, Range)>,
652        #[case] expected: Range,
653    ) {
654        let mut arena = Arena::new(150);
655        for (_, range) in &token_ranges {
656            let token = create_token(*range);
657            let _ = arena.alloc(token);
658        }
659        let node = Node {
660            token_id: ArenaId::new(0),
661            expr: Shared::new(expr),
662        };
663        assert_eq!(node.range(Shared::new(arena)), expected);
664    }
665}