1use crate::lexer::token::Token;
2use miette::SourceSpan;
3use std::hash::{Hash, Hasher};
4use std::ops::Deref;
5use std::sync::Arc;
6#[derive(Debug, Clone)]
11pub struct Ast {
12 pub source: Arc<str>,
13 pub program: Vec<Stmt>,
14}
15
16type Ident = String;
17
18#[derive(Debug, Clone)]
19pub enum Stmt {
20 Expr(Arc<Expr>),
21
22 If(Arc<If>),
23
24 RepeatTimes(Arc<RepeatTimes>),
25
26 RepeatUntil(Arc<RepeatUntil>),
27
28 ForEach(Arc<ForEach>),
29
30 ProcDeclaration(Arc<ProcDeclaration>),
31
32 Block(Arc<Block>),
33
34 Return(Arc<Return>),
35
36 Continue(Arc<Continue>),
37
38 Break(Arc<Break>),
39
40 Import(Arc<Import>),
41}
42#[derive(Debug, Clone)]
43pub struct If {
44 pub condition: Expr,
45 pub then_branch: Stmt,
46 pub else_branch: Option<Stmt>,
47
48 pub if_token: Token,
49 pub else_token: Option<Token>,
50}
51#[derive(Debug, Clone)]
52pub struct RepeatTimes {
53 pub count: Expr,
54 pub body: Stmt,
55
56 pub repeat_token: Token,
57 pub times_token: Token,
58 pub count_token: Token,
59}
60#[derive(Debug, Clone)]
61pub struct RepeatUntil {
62 pub condition: Expr,
63 pub body: Stmt,
64
65 pub repeat_token: Token,
66 pub until_token: Token,
67}
68#[derive(Debug, Clone)]
69pub struct ForEach {
70 pub item: Variable,
71 pub list: Expr,
72 pub body: Stmt,
73
74 pub item_token: Token,
75 pub for_token: Token,
76 pub each_token: Token,
77 pub in_token: Token,
78 pub list_token: Token,
79}
80#[derive(Debug, Clone)]
81pub struct ProcDeclaration {
82 pub name: Ident,
83 pub params: Vec<Variable>,
84 pub body: Stmt,
85 pub exported: bool,
86
87 pub proc_token: Token,
88 pub name_token: Token,
89}
90#[derive(Debug, Clone)]
91pub struct Block {
92 pub lb_token: Token,
93 pub statements: Vec<Stmt>,
94 pub rb_token: Token,
95}
96#[derive(Debug, Clone)]
97pub struct Return {
98 pub token: Token,
99 pub data: Option<Expr>,
100}
101
102#[derive(Debug, Clone)]
103pub struct Continue {
104 pub token: Token,
105}
106
107#[derive(Debug, Clone)]
108pub struct Break {
109 pub token: Token,
110}
111
112#[derive(Debug, Clone)]
113pub struct Import {
114 pub import_token: Token,
115 pub mod_token: Token,
116 pub maybe_from_token: Option<Token>,
117
118 pub only_functions: Option<Vec<Token>>,
119 pub module_name: Token,
120}
121
122#[derive(Debug, Clone)]
123pub enum Expr {
124 Literal(Arc<ExprLiteral>),
125 Binary(Arc<Binary>),
126 Logical(Arc<Logical>),
127
128 Unary(Arc<Unary>),
129
130 Grouping(Arc<Grouping>),
131
132 ProcCall(Arc<ProcCall>),
133
134 Access(Arc<Access>),
135
136 List(Arc<List>),
137
138 Variable(Arc<Variable>),
139
140 Assign(Arc<Assignment>),
141
142 Set(Arc<Set>),
143}
144#[derive(Debug, Clone)]
145pub struct ExprLiteral {
146 pub value: Literal,
147 pub token: Token,
148}
149#[derive(Debug, Clone)]
150pub struct Binary {
151 pub left: Expr,
152 pub operator: BinaryOp,
153 pub right: Expr,
154 pub token: Token,
155}
156#[derive(Debug, Clone)]
157pub struct Logical {
158 pub left: Expr,
159 pub operator: LogicalOp,
160 pub right: Expr,
161 pub token: Token,
162}
163#[derive(Debug, Clone)]
164pub struct Unary {
165 pub operator: UnaryOp,
166 pub right: Expr,
167 pub token: Token,
168}
169#[derive(Debug, Clone)]
170pub struct Grouping {
171 pub expr: Expr,
172 pub parens: (Token, Token),
173}
174#[derive(Debug, Clone)]
175pub struct ProcCall {
176 pub ident: String,
177 pub arguments: Vec<Expr>,
178 pub arguments_spans: Vec<SourceSpan>,
179
180 pub token: Token,
181 pub parens: (Token, Token),
182}
183#[derive(Debug, Clone)]
184pub struct Access {
185 pub list: Expr,
186 pub list_token: Token,
187 pub key: Expr,
188 pub brackets: (Token, Token),
189}
190#[derive(Debug, Clone)]
191pub struct List {
192 pub items: Vec<Expr>,
193 pub brackets: (Token, Token),
194}
195#[derive(Debug, Clone)]
196pub struct Variable {
197 pub ident: String,
198 pub token: Token,
199}
200impl Hash for Variable {
201 fn hash<H: Hasher>(&self, state: &mut H) {
202 self.ident.hash(state);
203 }
204}
205
206impl PartialEq for Variable {
207 fn eq(&self, other: &Self) -> bool {
208 self.ident.eq(&other.ident)
209 }
210}
211
212impl Eq for Variable {}
213
214#[derive(Debug, Clone)]
215pub struct Assignment {
216 pub target: Arc<Variable>,
217 pub value: Expr,
218
219 pub ident_token: Token,
220 pub arrow_token: Token,
221}
222#[derive(Debug, Clone)]
223pub struct Set {
224 pub target: Expr,
225 pub value: Expr,
226
227 pub list: Expr,
228 pub idx: Expr,
229
230 pub list_token: Token,
231 pub brackets: (Token, Token),
232 pub arrow_token: Token,
233}
234
235#[derive(Debug, Clone)]
236pub enum Literal {
237 Number(f64),
238 String(String),
239 True,
240 False,
241 Null,
242}
243
244#[derive(Debug, Clone)]
245pub enum BinaryOp {
246 EqualEqual,
247 NotEqual,
248 Less,
249 LessEqual,
250 Greater,
251 GreaterEqual,
252 Plus,
253 Minus,
254 Star,
255 Slash,
256 Modulo,
257}
258
259#[derive(Debug, Clone)]
260pub enum UnaryOp {
261 Minus,
262 Not,
263}
264
265#[derive(Debug, Clone, PartialEq, Eq)]
266pub enum LogicalOp {
267 Or,
268 And,
269}
270
271pub mod pretty {
272 use super::*;
273 use std::fmt;
274 use std::fmt::{Display, Formatter};
275
276 pub trait TreePrinter {
277 fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_>;
278
279 fn node(&self) -> Box<dyn Display>;
280
281 fn print_tree_base(&self, prefix: &str, last: bool) -> String {
282 let mut result = format!(
283 "{}{}{}\n",
284 prefix,
285 if last { "└── " } else { "├── " },
286 self.node()
287 );
288 let prefix_child = if last { " " } else { "│ " };
289 let children: Vec<_> = self.node_children().collect();
290 for (i, child) in children.iter().enumerate() {
291 let last_child = i == children.len() - 1;
292 result += &child.print_tree_base(&(prefix.to_owned() + prefix_child), last_child);
293 }
294 result
295 }
296
297 fn header(&self) -> Box<dyn Display> {
298 Box::<String>::default()
299 }
300
301 fn print_tree(&self) -> String {
302 let len = self.node_children().count();
303 let tree = self
304 .node_children()
305 .enumerate()
306 .map(|(i, child)| {
307 let last = len - 1 == i;
308 child.print_tree_base("", last)
309 })
310 .collect::<String>();
311
312 format!("{}{}\n{}", String::default(), self.node(), tree)
313 }
314 }
315
316 impl TreePrinter for Ast {
317 fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_> {
318 Box::new(
319 self.program
320 .iter()
321 .map(|stmt| Box::new(stmt.clone()) as Box<dyn TreePrinter>),
322 )
323 }
324
325 fn node(&self) -> Box<dyn Display> {
326 Box::new(format!("Ast (Source: {:?})", self.source))
327 }
328 }
329
330 impl TreePrinter for Stmt {
331 fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_> {
332 match self {
333 Stmt::Expr(expr) => Box::new(std::iter::once(
334 Box::new(expr.deref().clone()) as Box<dyn TreePrinter>
335 )),
336 Stmt::If(if_stmt) => Box::new(
337 std::iter::once(Box::new(if_stmt.condition.clone()) as Box<dyn TreePrinter>)
338 .chain(std::iter::once(
339 Box::new(if_stmt.then_branch.clone()) as Box<dyn TreePrinter>
340 ))
341 .chain(if_stmt.else_branch.as_ref().map(|else_branch| {
342 Box::new(else_branch.clone()) as Box<dyn TreePrinter>
343 })),
344 ),
345 Stmt::RepeatTimes(repeat_times) => Box::new(
346 std::iter::once(Box::new(repeat_times.count.clone()) as Box<dyn TreePrinter>)
347 .chain(std::iter::once(
348 Box::new(repeat_times.body.clone()) as Box<dyn TreePrinter>
349 )),
350 ),
351 Stmt::RepeatUntil(repeat_until) => Box::new(
352 std::iter::once(
353 Box::new(repeat_until.condition.clone()) as Box<dyn TreePrinter>
354 )
355 .chain(std::iter::once(
356 Box::new(repeat_until.body.clone()) as Box<dyn TreePrinter>,
357 )),
358 ),
359 Stmt::ForEach(for_each) => Box::new(
360 std::iter::once(Box::new(for_each.list.clone()) as Box<dyn TreePrinter>).chain(
361 std::iter::once(Box::new(for_each.body.clone()) as Box<dyn TreePrinter>),
362 ),
363 ),
364 Stmt::ProcDeclaration(proc_decl) => Box::new(std::iter::once(Box::new(
365 proc_decl.body.clone(),
366 )
367 as Box<dyn TreePrinter>)),
368 Stmt::Block(block) => Box::new(
369 block
370 .statements
371 .iter()
372 .map(|stmt| Box::new(stmt.clone()) as Box<dyn TreePrinter>)
373 .collect::<Vec<_>>()
374 .into_iter(),
375 ),
376 Stmt::Return(return_stmt) => Box::new(
377 return_stmt
378 .data
379 .as_ref()
380 .map(|expr| Box::new(expr.clone()) as Box<dyn TreePrinter>)
381 .into_iter(),
382 ),
383 Stmt::Import(_import_stmt) => Box::new(std::iter::empty()),
384 Stmt::Continue(_import_stmt) => Box::new(std::iter::empty()),
385 Stmt::Break(_import_stmt) => Box::new(std::iter::empty()),
386 }
387 }
388
389 fn node(&self) -> Box<dyn Display> {
390 Box::new(format!("{}", self)) }
392 }
393
394 impl TreePrinter for Expr {
395 fn node_children(&self) -> Box<dyn Iterator<Item = Box<dyn TreePrinter>> + '_> {
396 match self {
397 Expr::Binary(binary) => Box::new(
398 std::iter::once(Box::new(binary.left.clone()) as Box<dyn TreePrinter>).chain(
399 std::iter::once(Box::new(binary.right.clone()) as Box<dyn TreePrinter>),
400 ),
401 ),
402 Expr::Logical(logical) => Box::new(
403 std::iter::once(Box::new(logical.left.clone()) as Box<dyn TreePrinter>).chain(
404 std::iter::once(Box::new(logical.right.clone()) as Box<dyn TreePrinter>),
405 ),
406 ),
407 Expr::Unary(unary) => Box::new(std::iter::once(
408 Box::new(unary.right.clone()) as Box<dyn TreePrinter>
409 )),
410 Expr::Grouping(grouping) => Box::new(std::iter::once(
411 Box::new(grouping.expr.clone()) as Box<dyn TreePrinter>,
412 )),
413 Expr::ProcCall(proc_call) => Box::new(
414 proc_call
415 .arguments
416 .iter()
417 .map(|arg| Box::new(arg.clone()) as Box<dyn TreePrinter>)
418 .collect::<Vec<_>>()
419 .into_iter(),
420 ),
421 Expr::Access(access) => Box::new(
422 std::iter::once(Box::new(access.list.clone()) as Box<dyn TreePrinter>).chain(
423 std::iter::once(Box::new(access.key.clone()) as Box<dyn TreePrinter>),
424 ),
425 ),
426 Expr::List(list) => Box::new(
427 list.items
428 .iter()
429 .map(|item| Box::new(item.clone()) as Box<dyn TreePrinter>)
430 .collect::<Vec<_>>()
431 .into_iter(),
432 ),
433 Expr::Variable(_) | Expr::Literal(_) => Box::new(std::iter::empty()),
434 Expr::Assign(assignment) => Box::new(std::iter::once(Box::new(
435 assignment.value.clone(),
436 )
437 as Box<dyn TreePrinter>)),
438 Expr::Set(set) => Box::new(
439 std::iter::once(Box::new(set.target.clone()) as Box<dyn TreePrinter>).chain(
440 std::iter::once(Box::new(set.value.clone()) as Box<dyn TreePrinter>),
441 ),
442 ),
443 }
444 }
445
446 fn node(&self) -> Box<dyn Display> {
447 Box::new(format!("{}", self)) }
449 }
450
451 impl Display for Expr {
452 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
453 match self {
454 Expr::Literal(literal) => write!(f, "{}", literal.value),
455 Expr::Binary(binary) => {
456 write!(f, "({} {} {})", binary.left, binary.operator, binary.right)
457 }
458 Expr::Logical(logical) => write!(
459 f,
460 "({} {} {})",
461 logical.left, logical.operator, logical.right
462 ),
463 Expr::Unary(unary) => write!(f, "({}{})", unary.operator, unary.right),
464 Expr::Grouping(grouping) => write!(f, "(group {})", grouping.expr),
465 Expr::ProcCall(proc_call) => {
466 let args = proc_call
467 .arguments
468 .iter()
469 .map(|arg| format!("{}", arg))
470 .collect::<Vec<_>>()
471 .join(", ");
472 write!(f, "{}({})", proc_call.ident, args)
473 }
474 Expr::Access(access) => write!(f, "{}[{}]", access.list, access.key),
475 Expr::List(list) => {
476 let items = list
477 .items
478 .iter()
479 .map(|item| format!("{}", item))
480 .collect::<Vec<_>>()
481 .join(", ");
482 write!(f, "[{}]", items)
483 }
484 Expr::Variable(variable) => write!(f, "{}", variable.ident),
485 Expr::Assign(assignment) => {
486 write!(f, "{} <- {}", assignment.target, assignment.value)
487 }
488 Expr::Set(set) => write!(f, "{}[{}] = {}", set.target, set.arrow_token, set.value),
489 }
490 }
491 }
492
493 impl Display for Stmt {
494 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
495 match self {
496 Stmt::Expr(expr) => write!(f, "{}", expr),
497 Stmt::If(if_stmt) => {
498 let else_part = if let Some(else_branch) = &if_stmt.else_branch {
499 format!(" else {}", else_branch)
500 } else {
501 String::new()
502 };
503 write!(
504 f,
505 "if {} then {}{}",
506 if_stmt.condition, if_stmt.then_branch, else_part
507 )
508 }
509 Stmt::RepeatTimes(repeat_times) => write!(
510 f,
511 "repeat {} times {}",
512 repeat_times.count, repeat_times.body
513 ),
514 Stmt::RepeatUntil(repeat_until) => write!(
515 f,
516 "repeat until {} {}",
517 repeat_until.condition, repeat_until.body
518 ),
519 Stmt::ForEach(for_each) => write!(
520 f,
521 "for {} in {} do {}",
522 for_each.item, for_each.list, for_each.body
523 ),
524 Stmt::ProcDeclaration(proc_decl) => {
525 let params = proc_decl
527 .params
528 .iter()
529 .map(|var| var.ident.clone())
530 .collect::<Vec<_>>()
531 .join(", ");
532
533 write!(
534 f,
535 "procedure {}({}) {}",
536 proc_decl.name, params, proc_decl.body
537 )
538 }
539 Stmt::Block(block) => {
540 let statements = block
541 .statements
542 .iter()
543 .map(|stmt| format!("{}", stmt))
544 .collect::<Vec<_>>()
545 .join("; ");
546 write!(f, "{{ {} }}", statements)
547 }
548 Stmt::Return(return_stmt) => match &return_stmt.data {
549 Some(data) => write!(f, "return {}", data),
550 None => write!(f, "return"),
551 },
552 Stmt::Import(import_stmt) => {
553 write!(f, "import module {}", import_stmt.module_name)
554 }
555 Stmt::Break(_) => {
556 write!(f, "loop break")
557 }
558 Stmt::Continue(_) => {
559 write!(f, "loop continue")
560 }
561 }
562 }
563 }
564
565 impl Display for Variable {
566 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
567 write!(f, "{}", self.ident)
568 }
569 }
570
571 impl Display for Literal {
572 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
573 match self {
574 Literal::Number(num) => write!(f, "{}", num),
575 Literal::String(s) => write!(f, "\"{}\"", s), Literal::True => write!(f, "TRUE"),
577 Literal::False => write!(f, "FALSE"),
578 Literal::Null => write!(f, "NULL"),
579 }
580 }
581 }
582
583 impl Display for BinaryOp {
584 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
585 let op = match self {
586 BinaryOp::EqualEqual => "==",
587 BinaryOp::NotEqual => "!=",
588 BinaryOp::Less => "<",
589 BinaryOp::LessEqual => "<=",
590 BinaryOp::Greater => ">",
591 BinaryOp::GreaterEqual => ">=",
592 BinaryOp::Plus => "+",
593 BinaryOp::Minus => "-",
594 BinaryOp::Star => "*",
595 BinaryOp::Slash => "/",
596 BinaryOp::Modulo => "%",
597 };
598 write!(f, "{}", op)
599 }
600 }
601
602 impl Display for UnaryOp {
603 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
604 let op = match self {
605 UnaryOp::Minus => "-",
606 UnaryOp::Not => "!",
607 };
608 write!(f, "{}", op)
609 }
610 }
611
612 impl Display for LogicalOp {
613 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
614 let op = match self {
615 LogicalOp::And => "and",
616 LogicalOp::Or => "or",
617 };
618 write!(f, "{}", op)
619 }
620 }
621}
622
623#[macro_export]
624macro_rules! BinaryOp [
625 [==] => [$crate::ast::BinaryOp::EqualEqual];
626 [!=] => [$crate::ast::BinaryOp::NotEqual];
627 [<] => [$crate::ast::BinaryOp::Less];
628 [<=] => [$crate::ast::BinaryOp::LessEqual];
629 [>] => [$crate::ast::BinaryOp::Greater];
630 [>=] => [$crate::ast::BinaryOp::GreaterEqual];
631 [+] => [$crate::ast::BinaryOp::Plus];
632 [-] => [$crate::ast::BinaryOp::Minus];
633 [*] => [$crate::ast::BinaryOp::Star];
634 [/] => [$crate::ast::BinaryOp::Slash];
635];