Skip to main content

mq_lang/ast/
node.rs

1use super::{Program, TokenId};
2#[cfg(feature = "ast-json")]
3use crate::arena::ArenaId;
4use crate::{Ident, Shared, Token, arena::Arena, number::Number, range::Range, selector::Selector};
5#[cfg(feature = "ast-json")]
6use serde::{Deserialize, Serialize};
7use smallvec::SmallVec;
8use smol_str::SmolStr;
9use std::{
10    fmt::{self, Display, Formatter},
11    hash::{Hash, Hasher},
12};
13
14/// Represents a function parameter with an optional default value
15#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
16#[derive(PartialEq, PartialOrd, Debug, Clone)]
17pub struct Param {
18    pub ident: IdentWithToken,
19    pub default: Option<Shared<Node>>,
20    pub is_variadic: bool,
21}
22
23impl Display for Param {
24    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
25        if self.is_variadic {
26            write!(f, "*{}", self.ident)
27        } else {
28            write!(f, "{}", self.ident)
29        }
30    }
31}
32
33impl Param {
34    pub fn new(name: IdentWithToken) -> Self {
35        Self::with_default(name, None)
36    }
37
38    pub fn with_default(name: IdentWithToken, default_value: Option<Shared<Node>>) -> Self {
39        Self {
40            ident: name,
41            default: default_value,
42            is_variadic: false,
43        }
44    }
45
46    /// Creates a variadic parameter (e.g., `*args`)
47    pub fn variadic(name: IdentWithToken) -> Self {
48        Self {
49            ident: name,
50            default: None,
51            is_variadic: true,
52        }
53    }
54}
55
56pub type Params = SmallVec<[Param; 4]>;
57pub type Args = SmallVec<[Shared<Node>; 4]>;
58pub type Cond = (Option<Shared<Node>>, Shared<Node>);
59pub type Branches = SmallVec<[Cond; 4]>;
60pub type MatchArms = SmallVec<[MatchArm; 4]>;
61
62#[derive(PartialEq, PartialOrd, Debug, Clone)]
63#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
64pub struct Node {
65    #[cfg_attr(
66        feature = "ast-json",
67        serde(skip_serializing, skip_deserializing, default = "default_token_id")
68    )]
69    pub token_id: TokenId,
70    pub expr: Shared<Expr>,
71}
72
73#[cfg(feature = "ast-json")]
74fn default_token_id() -> TokenId {
75    ArenaId::new(0)
76}
77
78impl Node {
79    #[cfg(feature = "ast-json")]
80    pub fn to_json(&self) -> Result<String, serde_json::Error> {
81        serde_json::to_string_pretty(self)
82    }
83
84    #[cfg(feature = "ast-json")]
85    pub fn from_json(json_str: &str) -> Result<Self, serde_json::Error> {
86        serde_json::from_str(json_str)
87    }
88
89    pub fn range(&self, arena: Shared<Arena<Shared<Token>>>) -> Range {
90        match &*self.expr {
91            Expr::Block(program)
92            | Expr::Def(_, _, program)
93            | Expr::Fn(_, program)
94            | Expr::While(_, program)
95            | Expr::Loop(program)
96            | Expr::Module(_, program)
97            | Expr::Foreach(_, _, program) => {
98                let start = program
99                    .first()
100                    .map(|node| node.range(Shared::clone(&arena)).start)
101                    .unwrap_or_default();
102                let end = program
103                    .last()
104                    .map(|node| node.range(Shared::clone(&arena)).end)
105                    .unwrap_or_default();
106                Range { start, end }
107            }
108            Expr::Call(_, args) => {
109                let start = args
110                    .first()
111                    .map(|node| node.range(Shared::clone(&arena)).start)
112                    .unwrap_or_default();
113                let end = args
114                    .last()
115                    .map(|node| node.range(Shared::clone(&arena)).end)
116                    .unwrap_or_default();
117                Range { start, end }
118            }
119            Expr::CallDynamic(callable, args) => {
120                let start = callable.range(Shared::clone(&arena)).start;
121                let end = args
122                    .last()
123                    .map(|node| node.range(Shared::clone(&arena)).end)
124                    .unwrap_or_else(|| callable.range(Shared::clone(&arena)).end);
125                Range { start, end }
126            }
127            Expr::Macro(_, params, block) => {
128                let start = params
129                    .first()
130                    .and_then(|param| param.ident.token.as_ref().map(|t| t.range))
131                    .unwrap_or(block.range(Shared::clone(&arena)))
132                    .start;
133                let end = block.range(arena).end;
134                Range { start, end }
135            }
136            Expr::Let(_, node)
137            | Expr::Var(_, node)
138            | Expr::Assign(_, node)
139            | Expr::Quote(node)
140            | Expr::Unquote(node) => node.range(Shared::clone(&arena)),
141            Expr::If(nodes) => {
142                if let (Some(first), Some(last)) = (nodes.first(), nodes.last()) {
143                    let start = first.1.range(Shared::clone(&arena));
144                    let end = last.1.range(Shared::clone(&arena));
145                    Range {
146                        start: start.start,
147                        end: end.end,
148                    }
149                } else {
150                    // Fallback to token range if no branches exist
151                    arena[self.token_id].range
152                }
153            }
154            Expr::Match(value, arms) => {
155                let start = value.range(Shared::clone(&arena)).start;
156                let end = arms
157                    .last()
158                    .map(|arm| arm.body.range(Shared::clone(&arena)).end)
159                    .unwrap_or_else(|| arena[self.token_id].range.end);
160                Range { start, end }
161            }
162            Expr::Paren(node) => node.range(Shared::clone(&arena)),
163            Expr::Try(try_expr, catch_expr) => {
164                let start = try_expr.range(Shared::clone(&arena)).start;
165                let end = catch_expr.range(Shared::clone(&arena)).end;
166                Range { start, end }
167            }
168            Expr::And(expr1, expr2) | Expr::Or(expr1, expr2) => {
169                let start = expr1.range(Shared::clone(&arena)).start;
170                let end = expr2.range(Shared::clone(&arena)).end;
171                Range { start, end }
172            }
173            Expr::Break(Some(value_node)) => {
174                let start = arena[self.token_id].range.start;
175                let end = value_node.range(Shared::clone(&arena)).end;
176                Range { start, end }
177            }
178            Expr::Literal(_)
179            | Expr::Ident(_)
180            | Expr::Selector(_)
181            | Expr::Include(_)
182            | Expr::Import(_)
183            | Expr::InterpolatedString(_)
184            | Expr::QualifiedAccess(_, _)
185            | Expr::Nodes
186            | Expr::Self_
187            | Expr::Break(None)
188            | Expr::Continue => arena[self.token_id].range,
189        }
190    }
191
192    pub fn is_nodes(&self) -> bool {
193        matches!(*self.expr, Expr::Nodes)
194    }
195}
196
197#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
198#[derive(PartialEq, Debug, Eq, Clone)]
199pub struct IdentWithToken {
200    pub name: Ident,
201    #[cfg_attr(feature = "ast-json", serde(skip_serializing_if = "Option::is_none", default))]
202    pub token: Option<Shared<Token>>,
203}
204
205impl Hash for IdentWithToken {
206    fn hash<H: Hasher>(&self, state: &mut H) {
207        self.name.hash(state);
208    }
209}
210
211impl Ord for IdentWithToken {
212    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
213        self.name.cmp(&other.name)
214    }
215}
216
217impl PartialOrd for IdentWithToken {
218    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
219        Some(self.cmp(other))
220    }
221}
222
223impl IdentWithToken {
224    pub fn new(name: &str) -> Self {
225        Self::new_with_token(name, None)
226    }
227
228    pub fn new_with_token(name: &str, token: Option<Shared<Token>>) -> Self {
229        Self {
230            name: name.into(),
231            token,
232        }
233    }
234}
235
236impl Display for IdentWithToken {
237    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
238        write!(f, "{}", self.name)
239    }
240}
241
242#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
243#[derive(Debug, Clone, PartialOrd, PartialEq)]
244pub enum StringSegment {
245    Text(String),
246    Expr(Shared<Node>),
247    Env(SmolStr),
248    Self_,
249}
250
251#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
252#[derive(PartialEq, PartialOrd, Debug, Clone)]
253pub enum Pattern {
254    Literal(Literal),
255    Ident(IdentWithToken),
256    Wildcard,
257    Array(Vec<Pattern>),
258    ArrayRest(Vec<Pattern>, IdentWithToken), // patterns before .., rest binding
259    Dict(Vec<(IdentWithToken, Pattern)>),
260    Type(Ident), // :string, :number, etc.
261}
262
263#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
264#[derive(PartialEq, PartialOrd, Debug, Clone)]
265pub struct MatchArm {
266    pub pattern: Pattern,
267    pub guard: Option<Shared<Node>>,
268    pub body: Shared<Node>,
269}
270
271#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
272#[derive(PartialEq, PartialOrd, Debug, Clone)]
273pub enum Literal {
274    String(String),
275    Number(Number),
276    Symbol(Ident),
277    Bool(bool),
278    None,
279}
280
281impl From<&str> for Literal {
282    fn from(s: &str) -> Self {
283        Literal::String(s.to_owned())
284    }
285}
286
287#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
288#[derive(PartialEq, PartialOrd, Debug, Clone)]
289pub enum AccessTarget {
290    Call(IdentWithToken, Args),
291    Ident(IdentWithToken),
292}
293
294impl Display for Literal {
295    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), fmt::Error> {
296        match self {
297            Literal::String(s) => write!(f, "{}", s),
298            Literal::Number(n) => write!(f, "{}", n),
299            Literal::Symbol(i) => write!(f, "{}", i),
300            Literal::Bool(b) => write!(f, "{}", b),
301            Literal::None => write!(f, "none"),
302        }
303    }
304}
305
306#[cfg_attr(feature = "ast-json", derive(Serialize, Deserialize))]
307#[derive(PartialEq, PartialOrd, Debug, Clone)]
308pub enum Expr {
309    Block(Program),
310    Call(IdentWithToken, Args),
311    CallDynamic(Shared<Node>, Args),
312    Def(IdentWithToken, Params, Program),
313    Macro(IdentWithToken, Params, Shared<Node>),
314    Fn(Params, Program),
315    Let(Pattern, Shared<Node>),
316    Loop(Program),
317    Var(Pattern, Shared<Node>),
318    Assign(IdentWithToken, Shared<Node>),
319    And(Shared<Node>, Shared<Node>),
320    Or(Shared<Node>, Shared<Node>),
321    Literal(Literal),
322    Ident(IdentWithToken),
323    InterpolatedString(Vec<StringSegment>),
324    Selector(Selector),
325    While(Shared<Node>, Program),
326    Foreach(IdentWithToken, Shared<Node>, Program),
327    If(Branches),
328    Match(Shared<Node>, MatchArms),
329    Include(Literal),
330    Import(Literal),
331    Module(IdentWithToken, Program),
332    QualifiedAccess(Vec<IdentWithToken>, AccessTarget),
333    Self_,
334    Nodes,
335    Paren(Shared<Node>),
336    Quote(Shared<Node>),
337    Unquote(Shared<Node>),
338    Try(Shared<Node>, Shared<Node>),
339    Break(Option<Shared<Node>>),
340    Continue,
341}
342
343#[cfg(feature = "debugger")]
344impl Display for Expr {
345    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
346        match self {
347            Expr::Call(ident, args) => {
348                write!(f, "{}(", ident)?;
349                for (i, arg) in args.iter().enumerate() {
350                    if i > 0 {
351                        write!(f, ", ")?;
352                    }
353                    write!(f, "{}", arg.expr)?;
354                }
355                write!(f, ")")
356            }
357            Expr::CallDynamic(callable, args) => {
358                write!(f, "{}(", callable.expr)?;
359                for (i, arg) in args.iter().enumerate() {
360                    if i > 0 {
361                        write!(f, ", ")?;
362                    }
363                    write!(f, "{}", arg.expr)?;
364                }
365                write!(f, ")")
366            }
367            _ => write!(f, ""),
368        }
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use crate::{Position, TokenKind, arena::ArenaId};
376    use rstest::rstest;
377    use smallvec::smallvec;
378
379    fn create_token(range: Range) -> Shared<Token> {
380        Shared::new(Token {
381            range,
382            kind: TokenKind::Eof,
383            module_id: ArenaId::new(0),
384        })
385    }
386
387    #[rstest]
388    #[case(
389        Expr::CallDynamic(
390            Shared::new(Node {
391                token_id: ArenaId::new(1),
392                expr: Shared::new(Expr::Literal(Literal::String("callee".to_string()))),
393            }),
394            smallvec![
395                Shared::new(Node {
396                    token_id: ArenaId::new(0),
397                    expr: Shared::new(Expr::Literal(Literal::String("arg1".to_string()))),
398                }),
399                Shared::new(Node {
400                    token_id: ArenaId::new(1),
401                    expr: Shared::new(Expr::Literal(Literal::String("arg2".to_string()))),
402                }),
403            ]
404        ),
405        vec![
406            (0, Range { start: Position::new(1, 1), end: Position::new(1, 5) }),
407            (1, Range { start: Position::new(2, 1), end: Position::new(2, 5) }),
408        ],
409        Range { start: Position::new(2, 1), end: Position::new(2, 5) }
410    )]
411    #[case(
412        Expr::Match(
413            Shared::new(Node {
414                token_id: ArenaId::new(0),
415                expr: Shared::new(Expr::Literal(Literal::String("val".to_string()))),
416            }),
417            smallvec![
418                MatchArm {
419                    pattern: Pattern::Literal(Literal::String("a".to_string())),
420                    guard: None,
421                    body: Shared::new(Node {
422                        token_id: ArenaId::new(1),
423                        expr: Shared::new(Expr::Literal(Literal::String("body1".to_string()))),
424                    }),
425                },
426                MatchArm {
427                    pattern: Pattern::Literal(Literal::String("b".to_string())),
428                    guard: None,
429                    body: Shared::new(Node {
430                        token_id: ArenaId::new(2),
431                        expr: Shared::new(Expr::Literal(Literal::String("body2".to_string()))),
432                    }),
433                },
434            ]
435        ),
436        vec![
437            (0, Range { start: Position::new(10, 1), end: Position::new(10, 5) }),
438            (1, Range { start: Position::new(11, 1), end: Position::new(11, 5) }),
439            (2, Range { start: Position::new(12, 1), end: Position::new(12, 5) }),
440        ],
441        Range { start: Position::new(10, 1), end: Position::new(12, 5) }
442    )]
443    #[case(
444        Expr::Try(
445            Shared::new(Node {
446                token_id: ArenaId::new(0),
447                expr: Shared::new(Expr::Literal(Literal::String("try".to_string()))),
448            }),
449            Shared::new(Node {
450                token_id: ArenaId::new(1),
451                expr: Shared::new(Expr::Literal(Literal::String("catch".to_string()))),
452            })
453        ),
454        vec![
455            (0, Range { start: Position::new(20, 1), end: Position::new(20, 5) }),
456            (1, Range { start: Position::new(21, 1), end: Position::new(21, 5) }),
457        ],
458        Range { start: Position::new(20, 1), end: Position::new(21, 5) }
459    )]
460    #[case(
461        Expr::Let(
462            Pattern::Ident(IdentWithToken::new("x")),
463            Shared::new(Node {
464                token_id: ArenaId::new(0),
465                expr: Shared::new(Expr::Literal(Literal::String("letval".to_string()))),
466            })
467        ),
468        vec![
469            (0, Range { start: Position::new(30, 1), end: Position::new(30, 5) }),
470        ],
471        Range { start: Position::new(30, 1), end: Position::new(30, 5) }
472    )]
473    #[case(
474        Expr::Paren(
475            Shared::new(Node {
476                token_id: ArenaId::new(0),
477                expr: Shared::new(Expr::Literal(Literal::String("paren".to_string()))),
478            })
479        ),
480        vec![
481            (0, Range { start: Position::new(40, 1), end: Position::new(40, 5) }),
482        ],
483        Range { start: Position::new(40, 1), end: Position::new(40, 5) }
484    )]
485    #[case(
486        Expr::Block(vec![
487            Shared::new(Node {
488                token_id: ArenaId::new(0),
489                expr: Shared::new(Expr::Literal(Literal::String("block1".to_string()))),
490            }),
491            Shared::new(Node {
492                token_id: ArenaId::new(1),
493                expr: Shared::new(Expr::Literal(Literal::String("block2".to_string()))),
494            }),
495        ]),
496        vec![
497            (0, Range { start: Position::new(50, 1), end: Position::new(50, 5) }),
498            (1, Range { start: Position::new(51, 1), end: Position::new(51, 5) }),
499        ],
500        Range { start: Position::new(50, 1), end: Position::new(51, 5) }
501    )]
502    #[case(
503        Expr::Def(
504            IdentWithToken::new("f"),
505            smallvec![],
506            vec![
507                Shared::new(Node {
508                    token_id: ArenaId::new(0),
509                    expr: Shared::new(Expr::Literal(Literal::String("def1".to_string()))),
510                }),
511                Shared::new(Node {
512                    token_id: ArenaId::new(1),
513                    expr: Shared::new(Expr::Literal(Literal::String("def2".to_string()))),
514                }),
515            ]
516        ),
517        vec![
518            (0, Range { start: Position::new(60, 1), end: Position::new(60, 5) }),
519            (1, Range { start: Position::new(61, 1), end: Position::new(61, 5) }),
520        ],
521        Range { start: Position::new(60, 1), end: Position::new(61, 5) }
522    )]
523    #[case(
524        Expr::Fn(
525            smallvec![],
526            vec![
527                Shared::new(Node {
528                    token_id: ArenaId::new(0),
529                    expr: Shared::new(Expr::Literal(Literal::String("fn1".to_string()))),
530                }),
531                Shared::new(Node {
532                    token_id: ArenaId::new(1),
533                    expr: Shared::new(Expr::Literal(Literal::String("fn2".to_string()))),
534                }),
535            ]
536        ),
537        vec![
538            (0, Range { start: Position::new(70, 1), end: Position::new(70, 5) }),
539            (1, Range { start: Position::new(71, 1), end: Position::new(71, 5) }),
540        ],
541        Range { start: Position::new(70, 1), end: Position::new(71, 5) }
542    )]
543    #[case(
544        Expr::While(
545            Shared::new(Node {
546                token_id: ArenaId::new(0),
547                expr: Shared::new(Expr::Literal(Literal::String("cond".to_string()))),
548            }),
549            vec![
550                Shared::new(Node {
551                    token_id: ArenaId::new(1),
552                    expr: Shared::new(Expr::Literal(Literal::String("while1".to_string()))),
553                }),
554                Shared::new(Node {
555                    token_id: ArenaId::new(2),
556                    expr: Shared::new(Expr::Literal(Literal::String("while2".to_string()))),
557                }),
558            ]
559        ),
560        vec![
561            (0, Range { start: Position::new(81, 1), end: Position::new(81, 5) }),
562            (1, Range { start: Position::new(82, 1), end: Position::new(82, 5) }),
563            (2, Range { start: Position::new(82, 1), end: Position::new(82, 5) }),
564        ],
565        Range { start: Position::new(82, 1), end: Position::new(82, 5) }
566    )]
567    #[case(
568        Expr::Foreach(
569            IdentWithToken::new("item"),
570            Shared::new(Node {
571                token_id: ArenaId::new(0),
572                expr: Shared::new(Expr::Literal(Literal::String("iter".to_string()))),
573            }),
574            vec![
575                Shared::new(Node {
576                    token_id: ArenaId::new(1),
577                    expr: Shared::new(Expr::Literal(Literal::String("foreach1".to_string()))),
578                }),
579                Shared::new(Node {
580                    token_id: ArenaId::new(2),
581                    expr: Shared::new(Expr::Literal(Literal::String("foreach2".to_string()))),
582                }),
583            ]
584        ),
585        vec![
586            (0, Range { start: Position::new(101, 1), end: Position::new(101, 5) }),
587            (1, Range { start: Position::new(102, 1), end: Position::new(102, 5) }),
588            (2, Range { start: Position::new(102, 1), end: Position::new(102, 5) }),
589        ],
590        Range { start: Position::new(102, 1), end: Position::new(102, 5) }
591    )]
592    #[case(
593        Expr::If(smallvec![
594            (
595                Some(Shared::new(Node {
596                    token_id: ArenaId::new(0),
597                    expr: Shared::new(Expr::Literal(Literal::String("cond1".to_string()))),
598                })),
599                Shared::new(Node {
600                    token_id: ArenaId::new(1),
601                    expr: Shared::new(Expr::Literal(Literal::String("if1".to_string()))),
602                })
603            ),
604            (
605                Some(Shared::new(Node {
606                    token_id: ArenaId::new(2),
607                    expr: Shared::new(Expr::Literal(Literal::String("cond2".to_string()))),
608                })),
609                Shared::new(Node {
610                    token_id: ArenaId::new(3),
611                    expr: Shared::new(Expr::Literal(Literal::String("if2".to_string()))),
612                })
613            ),
614        ]),
615        vec![
616            (0, Range { start: Position::new(111, 1), end: Position::new(111, 5) }),
617            (1, Range { start: Position::new(113, 1), end: Position::new(113, 5) }),
618            (2, Range { start: Position::new(114, 1), end: Position::new(115, 5) }),
619            (3, Range { start: Position::new(116, 1), end: Position::new(117, 5) }),
620        ],
621        Range { start: Position::new(113, 1), end: Position::new(117, 5) }
622    )]
623    #[case(
624        Expr::Call(
625            IdentWithToken::new("func"),
626            smallvec![
627                Shared::new(Node {
628                    token_id: ArenaId::new(0),
629                    expr: Shared::new(Expr::Literal(Literal::String("arg1".to_string()))),
630                }),
631                Shared::new(Node {
632                    token_id: ArenaId::new(1),
633                    expr: Shared::new(Expr::Literal(Literal::String("arg2".to_string()))),
634                }),
635            ]
636        ),
637        vec![
638            (0, Range { start: Position::new(120, 1), end: Position::new(120, 5) }),
639            (1, Range { start: Position::new(121, 1), end: Position::new(121, 5) }),
640        ],
641        Range { start: Position::new(120, 1), end: Position::new(121, 5) }
642    )]
643    fn test_node_range_various_exprs(
644        #[case] expr: Expr,
645        #[case] token_ranges: Vec<(usize, Range)>,
646        #[case] expected: Range,
647    ) {
648        let mut arena = Arena::new(150);
649        for (_, range) in &token_ranges {
650            let token = create_token(*range);
651            let _ = arena.alloc(token);
652        }
653        let node = Node {
654            token_id: ArenaId::new(0),
655            expr: Shared::new(expr),
656        };
657        assert_eq!(node.range(Shared::new(arena)), expected);
658    }
659}