aplang_lib/parser/
ast.rs

1use crate::lexer::token::Token;
2use miette::SourceSpan;
3use std::hash::{Hash, Hasher};
4use std::ops::Deref;
5use std::sync::Arc;
6// To facilitate better error handling down the line,
7// we're going to store the tokens that the thing came from
8// so we can report back to them later
9
10#[derive(Debug, Clone)]
11pub struct Ast {
12    pub source: Arc<str>,
13    pub program: Vec<Stmt>,
14}
15
16type Ident = String;
17
18#[derive(Debug, Clone)]
19pub enum Stmt {
20    Expr(Arc<Expr>),
21
22    If(Arc<If>),
23
24    RepeatTimes(Arc<RepeatTimes>),
25
26    RepeatUntil(Arc<RepeatUntil>),
27
28    ForEach(Arc<ForEach>),
29
30    ProcDeclaration(Arc<ProcDeclaration>),
31
32    Block(Arc<Block>),
33
34    Return(Arc<Return>),
35
36    Continue(Arc<Continue>),
37
38    Break(Arc<Break>),
39
40    Import(Arc<Import>),
41}
42#[derive(Debug, Clone)]
43pub struct If {
44    pub condition: Expr,
45    pub then_branch: Stmt,
46    pub else_branch: Option<Stmt>,
47
48    pub if_token: Token,
49    pub else_token: Option<Token>,
50}
51#[derive(Debug, Clone)]
52pub struct RepeatTimes {
53    pub count: Expr,
54    pub body: Stmt,
55
56    pub repeat_token: Token,
57    pub times_token: Token,
58    pub count_token: Token,
59}
60#[derive(Debug, Clone)]
61pub struct RepeatUntil {
62    pub condition: Expr,
63    pub body: Stmt,
64
65    pub repeat_token: Token,
66    pub until_token: Token,
67}
68#[derive(Debug, Clone)]
69pub struct ForEach {
70    pub item: Variable,
71    pub list: Expr,
72    pub body: Stmt,
73
74    pub item_token: Token,
75    pub for_token: Token,
76    pub each_token: Token,
77    pub in_token: Token,
78    pub list_token: Token,
79}
80#[derive(Debug, Clone)]
81pub struct ProcDeclaration {
82    pub name: Ident,
83    pub params: Vec<Variable>,
84    pub body: Stmt,
85    pub exported: bool,
86
87    pub proc_token: Token,
88    pub name_token: Token,
89}
90#[derive(Debug, Clone)]
91pub struct Block {
92    pub lb_token: Token,
93    pub statements: Vec<Stmt>,
94    pub rb_token: Token,
95}
96#[derive(Debug, Clone)]
97pub struct Return {
98    pub token: Token,
99    pub data: Option<Expr>,
100}
101
102#[derive(Debug, Clone)]
103pub struct Continue {
104    pub token: Token,
105}
106
107#[derive(Debug, Clone)]
108pub struct Break {
109    pub token: Token,
110}
111
112#[derive(Debug, Clone)]
113pub struct Import {
114    pub import_token: Token,
115    pub mod_token: Token,
116    pub maybe_from_token: Option<Token>,
117
118    pub only_functions: Option<Vec<Token>>,
119    pub module_name: Token,
120}
121
122#[derive(Debug, Clone)]
123pub enum Expr {
124    Literal(Arc<ExprLiteral>),
125    Binary(Arc<Binary>),
126    Logical(Arc<Logical>),
127
128    Unary(Arc<Unary>),
129
130    Grouping(Arc<Grouping>),
131
132    ProcCall(Arc<ProcCall>),
133
134    Access(Arc<Access>),
135
136    List(Arc<List>),
137
138    Variable(Arc<Variable>),
139
140    Assign(Arc<Assignment>),
141
142    Set(Arc<Set>),
143}
144#[derive(Debug, Clone)]
145pub struct ExprLiteral {
146    pub value: Literal,
147    pub token: Token,
148}
149#[derive(Debug, Clone)]
150pub struct Binary {
151    pub left: Expr,
152    pub operator: BinaryOp,
153    pub right: Expr,
154    pub token: Token,
155}
156#[derive(Debug, Clone)]
157pub struct Logical {
158    pub left: Expr,
159    pub operator: LogicalOp,
160    pub right: Expr,
161    pub token: Token,
162}
163#[derive(Debug, Clone)]
164pub struct Unary {
165    pub operator: UnaryOp,
166    pub right: Expr,
167    pub token: Token,
168}
169#[derive(Debug, Clone)]
170pub struct Grouping {
171    pub expr: Expr,
172    pub parens: (Token, Token),
173}
174#[derive(Debug, Clone)]
175pub struct ProcCall {
176    pub ident: String,
177    pub arguments: Vec<Expr>,
178    pub arguments_spans: Vec<SourceSpan>,
179
180    pub token: Token,
181    pub parens: (Token, Token),
182}
183#[derive(Debug, Clone)]
184pub struct Access {
185    pub list: Expr,
186    pub list_token: Token,
187    pub key: Expr,
188    pub brackets: (Token, Token),
189}
190#[derive(Debug, Clone)]
191pub struct List {
192    pub items: Vec<Expr>,
193    pub brackets: (Token, Token),
194}
195#[derive(Debug, Clone)]
196pub struct Variable {
197    pub ident: String,
198    pub token: Token,
199}
200impl Hash for Variable {
201    fn hash<H: Hasher>(&self, state: &mut H) {
202        self.ident.hash(state);
203    }
204}
205
206impl PartialEq for Variable {
207    fn eq(&self, other: &Self) -> bool {
208        self.ident.eq(&other.ident)
209    }
210}
211
212impl Eq for Variable {}
213
214#[derive(Debug, Clone)]
215pub struct Assignment {
216    pub target: Arc<Variable>,
217    pub value: Expr,
218
219    pub ident_token: Token,
220    pub arrow_token: Token,
221}
222#[derive(Debug, Clone)]
223pub struct Set {
224    pub target: Expr,
225    pub value: Expr,
226
227    pub list: Expr,
228    pub idx: Expr,
229
230    pub list_token: Token,
231    pub brackets: (Token, Token),
232    pub arrow_token: Token,
233}
234
235#[derive(Debug, Clone)]
236pub enum Literal {
237    Number(f64),
238    String(String),
239    True,
240    False,
241    Null,
242}
243
244#[derive(Debug, Clone)]
245pub enum BinaryOp {
246    EqualEqual,
247    NotEqual,
248    Less,
249    LessEqual,
250    Greater,
251    GreaterEqual,
252    Plus,
253    Minus,
254    Star,
255    Slash,
256    Modulo,
257}
258
259#[derive(Debug, Clone)]
260pub enum UnaryOp {
261    Minus,
262    Not,
263}
264
265#[derive(Debug, Clone, PartialEq, Eq)]
266pub enum LogicalOp {
267    Or,
268    And,
269}
270
271pub mod pretty {
272    use super::*;
273    use std::fmt;
274    use std::fmt::{Display, Formatter};
275
276    pub trait TreePrinter {
277        fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_>;
278
279        fn node(&self) -> Box<dyn Display>;
280
281        fn print_tree_base(&self, prefix: &str, last: bool) -> String {
282            let mut result = format!(
283                "{}{}{}\n",
284                prefix,
285                if last { "└── " } else { "├── " },
286                self.node()
287            );
288            let prefix_child = if last { "    " } else { "│   " };
289            let children: Vec<_> = self.node_children().collect();
290            for (i, child) in children.iter().enumerate() {
291                let last_child = i == children.len() - 1;
292                result += &child.print_tree_base(&(prefix.to_owned() + prefix_child), last_child);
293            }
294            result
295        }
296
297        fn header(&self) -> Box<dyn Display> {
298            Box::<String>::default()
299        }
300
301        fn print_tree(&self) -> String {
302            let len = self.node_children().count();
303            let tree = self
304                .node_children()
305                .enumerate()
306                .map(|(i, child)| {
307                    let last = len - 1 == i;
308                    child.print_tree_base("", last)
309                })
310                .collect::<String>();
311
312            format!("{}{}\n{}", String::default(), self.node(), tree)
313        }
314    }
315
316    impl TreePrinter for Ast {
317        fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_> {
318            Box::new(
319                self.program
320                    .iter()
321                    .map(|stmt| Box::new(stmt.clone()) as Box<dyn TreePrinter>),
322            )
323        }
324
325        fn node(&self) -> Box<dyn Display> {
326            Box::new(format!("Ast (Source: {:?})", self.source))
327        }
328    }
329
330    impl TreePrinter for Stmt {
331        fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_> {
332            match self {
333                Stmt::Expr(expr) => Box::new(std::iter::once(
334                    Box::new(expr.deref().clone()) as Box<dyn TreePrinter>
335                )),
336                Stmt::If(if_stmt) => Box::new(
337                    std::iter::once(Box::new(if_stmt.condition.clone()) as Box<dyn TreePrinter>)
338                        .chain(std::iter::once(
339                            Box::new(if_stmt.then_branch.clone()) as Box<dyn TreePrinter>
340                        ))
341                        .chain(if_stmt.else_branch.as_ref().map(|else_branch| {
342                            Box::new(else_branch.clone()) as Box<dyn TreePrinter>
343                        })),
344                ),
345                Stmt::RepeatTimes(repeat_times) => Box::new(
346                    std::iter::once(Box::new(repeat_times.count.clone()) as Box<dyn TreePrinter>)
347                        .chain(std::iter::once(
348                            Box::new(repeat_times.body.clone()) as Box<dyn TreePrinter>
349                        )),
350                ),
351                Stmt::RepeatUntil(repeat_until) => Box::new(
352                    std::iter::once(
353                        Box::new(repeat_until.condition.clone()) as Box<dyn TreePrinter>
354                    )
355                    .chain(std::iter::once(
356                        Box::new(repeat_until.body.clone()) as Box<dyn TreePrinter>,
357                    )),
358                ),
359                Stmt::ForEach(for_each) => Box::new(
360                    std::iter::once(Box::new(for_each.list.clone()) as Box<dyn TreePrinter>).chain(
361                        std::iter::once(Box::new(for_each.body.clone()) as Box<dyn TreePrinter>),
362                    ),
363                ),
364                Stmt::ProcDeclaration(proc_decl) => Box::new(std::iter::once(Box::new(
365                    proc_decl.body.clone(),
366                )
367                    as Box<dyn TreePrinter>)),
368                Stmt::Block(block) => Box::new(
369                    block
370                        .statements
371                        .iter()
372                        .map(|stmt| Box::new(stmt.clone()) as Box<dyn TreePrinter>)
373                        .collect::<Vec<_>>()
374                        .into_iter(),
375                ),
376                Stmt::Return(return_stmt) => Box::new(
377                    return_stmt
378                        .data
379                        .as_ref()
380                        .map(|expr| Box::new(expr.clone()) as Box<dyn TreePrinter>)
381                        .into_iter(),
382                ),
383                Stmt::Import(_import_stmt) => Box::new(std::iter::empty()),
384                Stmt::Continue(_import_stmt) => Box::new(std::iter::empty()),
385                Stmt::Break(_import_stmt) => Box::new(std::iter::empty()),
386            }
387        }
388
389        fn node(&self) -> Box<dyn Display> {
390            Box::new(format!("{}", self)) // Implement Display for Stmt or adjust this to a custom string representation
391        }
392    }
393
394    impl TreePrinter for Expr {
395        fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_> {
396            match self {
397                Expr::Binary(binary) => Box::new(
398                    std::iter::once(Box::new(binary.left.clone()) as Box<dyn TreePrinter>).chain(
399                        std::iter::once(Box::new(binary.right.clone()) as Box<dyn TreePrinter>),
400                    ),
401                ),
402                Expr::Logical(logical) => Box::new(
403                    std::iter::once(Box::new(logical.left.clone()) as Box<dyn TreePrinter>).chain(
404                        std::iter::once(Box::new(logical.right.clone()) as Box<dyn TreePrinter>),
405                    ),
406                ),
407                Expr::Unary(unary) => Box::new(std::iter::once(
408                    Box::new(unary.right.clone()) as Box<dyn TreePrinter>
409                )),
410                Expr::Grouping(grouping) => Box::new(std::iter::once(
411                    Box::new(grouping.expr.clone()) as Box<dyn TreePrinter>,
412                )),
413                Expr::ProcCall(proc_call) => Box::new(
414                    proc_call
415                        .arguments
416                        .iter()
417                        .map(|arg| Box::new(arg.clone()) as Box<dyn TreePrinter>)
418                        .collect::<Vec<_>>()
419                        .into_iter(),
420                ),
421                Expr::Access(access) => Box::new(
422                    std::iter::once(Box::new(access.list.clone()) as Box<dyn TreePrinter>).chain(
423                        std::iter::once(Box::new(access.key.clone()) as Box<dyn TreePrinter>),
424                    ),
425                ),
426                Expr::List(list) => Box::new(
427                    list.items
428                        .iter()
429                        .map(|item| Box::new(item.clone()) as Box<dyn TreePrinter>)
430                        .collect::<Vec<_>>()
431                        .into_iter(),
432                ),
433                Expr::Variable(_) | Expr::Literal(_) => Box::new(std::iter::empty()),
434                Expr::Assign(assignment) => Box::new(std::iter::once(Box::new(
435                    assignment.value.clone(),
436                )
437                    as Box<dyn TreePrinter>)),
438                Expr::Set(set) => Box::new(
439                    std::iter::once(Box::new(set.target.clone()) as Box<dyn TreePrinter>).chain(
440                        std::iter::once(Box::new(set.value.clone()) as Box<dyn TreePrinter>),
441                    ),
442                ),
443            }
444        }
445
446        fn node(&self) -> Box<dyn Display> {
447            Box::new(format!("{}", self)) // Implement Display for Expr or adjust this to a custom string representation
448        }
449    }
450
451    impl Display for Expr {
452        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
453            match self {
454                Expr::Literal(literal) => write!(f, "{}", literal.value),
455                Expr::Binary(binary) => {
456                    write!(f, "({} {} {})", binary.left, binary.operator, binary.right)
457                }
458                Expr::Logical(logical) => write!(
459                    f,
460                    "({} {} {})",
461                    logical.left, logical.operator, logical.right
462                ),
463                Expr::Unary(unary) => write!(f, "({}{})", unary.operator, unary.right),
464                Expr::Grouping(grouping) => write!(f, "(group {})", grouping.expr),
465                Expr::ProcCall(proc_call) => {
466                    let args = proc_call
467                        .arguments
468                        .iter()
469                        .map(|arg| format!("{}", arg))
470                        .collect::<Vec<_>>()
471                        .join(", ");
472                    write!(f, "{}({})", proc_call.ident, args)
473                }
474                Expr::Access(access) => write!(f, "{}[{}]", access.list, access.key),
475                Expr::List(list) => {
476                    let items = list
477                        .items
478                        .iter()
479                        .map(|item| format!("{}", item))
480                        .collect::<Vec<_>>()
481                        .join(", ");
482                    write!(f, "[{}]", items)
483                }
484                Expr::Variable(variable) => write!(f, "{}", variable.ident),
485                Expr::Assign(assignment) => {
486                    write!(f, "{} <- {}", assignment.target, assignment.value)
487                }
488                Expr::Set(set) => write!(f, "{}[{}] = {}", set.target, set.arrow_token, set.value),
489            }
490        }
491    }
492
493    impl Display for Stmt {
494        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
495            match self {
496                Stmt::Expr(expr) => write!(f, "{}", expr),
497                Stmt::If(if_stmt) => {
498                    let else_part = if let Some(else_branch) = &if_stmt.else_branch {
499                        format!(" else {}", else_branch)
500                    } else {
501                        String::new()
502                    };
503                    write!(
504                        f,
505                        "if {} then {}{}",
506                        if_stmt.condition, if_stmt.then_branch, else_part
507                    )
508                }
509                Stmt::RepeatTimes(repeat_times) => write!(
510                    f,
511                    "repeat {} times {}",
512                    repeat_times.count, repeat_times.body
513                ),
514                Stmt::RepeatUntil(repeat_until) => write!(
515                    f,
516                    "repeat until {} {}",
517                    repeat_until.condition, repeat_until.body
518                ),
519                Stmt::ForEach(for_each) => write!(
520                    f,
521                    "for {} in {} do {}",
522                    for_each.item, for_each.list, for_each.body
523                ),
524                Stmt::ProcDeclaration(proc_decl) => {
525                    // let params = proc_decl.join(", ");
526                    let params = proc_decl
527                        .params
528                        .iter()
529                        .map(|var| var.ident.clone())
530                        .collect::<Vec<_>>()
531                        .join(", ");
532
533                    write!(
534                        f,
535                        "procedure {}({}) {}",
536                        proc_decl.name, params, proc_decl.body
537                    )
538                }
539                Stmt::Block(block) => {
540                    let statements = block
541                        .statements
542                        .iter()
543                        .map(|stmt| format!("{}", stmt))
544                        .collect::<Vec<_>>()
545                        .join("; ");
546                    write!(f, "{{ {} }}", statements)
547                }
548                Stmt::Return(return_stmt) => match &return_stmt.data {
549                    Some(data) => write!(f, "return {}", data),
550                    None => write!(f, "return"),
551                },
552                Stmt::Import(import_stmt) => {
553                    write!(f, "import module {}", import_stmt.module_name)
554                }
555                Stmt::Break(_) => {
556                    write!(f, "loop break")
557                }
558                Stmt::Continue(_) => {
559                    write!(f, "loop continue")
560                }
561            }
562        }
563    }
564
565    impl Display for Variable {
566        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567            write!(f, "{}", self.ident)
568        }
569    }
570
571    impl Display for Literal {
572        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
573            match self {
574                Literal::Number(num) => write!(f, "{}", num),
575                Literal::String(s) => write!(f, "\"{}\"", s), // Enclose strings in quotes
576                Literal::True => write!(f, "TRUE"),
577                Literal::False => write!(f, "FALSE"),
578                Literal::Null => write!(f, "NULL"),
579            }
580        }
581    }
582
583    impl Display for BinaryOp {
584        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
585            let op = match self {
586                BinaryOp::EqualEqual => "==",
587                BinaryOp::NotEqual => "!=",
588                BinaryOp::Less => "<",
589                BinaryOp::LessEqual => "<=",
590                BinaryOp::Greater => ">",
591                BinaryOp::GreaterEqual => ">=",
592                BinaryOp::Plus => "+",
593                BinaryOp::Minus => "-",
594                BinaryOp::Star => "*",
595                BinaryOp::Slash => "/",
596                BinaryOp::Modulo => "%",
597            };
598            write!(f, "{}", op)
599        }
600    }
601
602    impl Display for UnaryOp {
603        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
604            let op = match self {
605                UnaryOp::Minus => "-",
606                UnaryOp::Not => "!",
607            };
608            write!(f, "{}", op)
609        }
610    }
611
612    impl Display for LogicalOp {
613        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
614            let op = match self {
615                LogicalOp::And => "and",
616                LogicalOp::Or => "or",
617            };
618            write!(f, "{}", op)
619        }
620    }
621}
622
623#[macro_export]
624macro_rules! BinaryOp [
625    [==] => [$crate::ast::BinaryOp::EqualEqual];
626    [!=] => [$crate::ast::BinaryOp::NotEqual];
627    [<] => [$crate::ast::BinaryOp::Less];
628    [<=] => [$crate::ast::BinaryOp::LessEqual];
629    [>] => [$crate::ast::BinaryOp::Greater];
630    [>=] => [$crate::ast::BinaryOp::GreaterEqual];
631    [+] => [$crate::ast::BinaryOp::Plus];
632    [-] => [$crate::ast::BinaryOp::Minus];
633    [*] => [$crate::ast::BinaryOp::Star];
634    [/] => [$crate::ast::BinaryOp::Slash];
635];