Skip to main content

oak_lua/ast/
mod.rs

1#![doc = include_str!("readme.md")]
2use core::range::Range;
3use oak_core::source::{SourceBuffer, ToSource};
4#[cfg(feature = "oak-pretty-print")]
5use oak_pretty_print::{AsDocument, Document};
6
7/// Lua root node
8#[derive(Debug, Clone)]
9#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
10pub struct LuaRoot {
11    /// Statements in the root.
12    pub statements: Vec<LuaStatement>,
13    /// Source span of the root.
14    #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
15    pub span: Range<usize>,
16}
17
18impl ToSource for LuaRoot {
19    fn to_source(&self, buffer: &mut SourceBuffer) {
20        for stmt in &self.statements {
21            stmt.to_source(buffer);
22            buffer.push("\n")
23        }
24    }
25}
26
27#[cfg(feature = "oak-pretty-print")]
28impl AsDocument for LuaRoot {
29    fn as_document(&self, _params: &Self::Params) -> Document<'_> {
30        Document::join(self.statements.iter().map(|s| s.as_document(&())), Document::Line)
31    }
32}
33
34/// Lua statement
35#[derive(Debug, Clone)]
36#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
37pub enum LuaStatement {
38    /// A local statement.
39    Local(LuaLocalStatement),
40    /// An assignment statement.
41    Assignment(LuaAssignmentStatement),
42    /// An expression statement.
43    Expression(LuaExpression),
44    /// A return statement.
45    Return(LuaReturnStatement),
46    /// An if statement.
47    If(LuaIfStatement),
48    /// A while statement.
49    While(LuaWhileStatement),
50    /// A for statement.
51    For(LuaForStatement),
52    /// A repeat statement.
53    Repeat(LuaRepeatStatement),
54    /// A function statement.
55    Function(LuaFunctionStatement),
56    /// A break statement.
57    Break,
58    /// A do block.
59    Do(Vec<LuaStatement>),
60    /// A goto statement.
61    Goto(String),
62    /// A label statement.
63    Label(String),
64}
65
66impl ToSource for LuaStatement {
67    fn to_source(&self, buffer: &mut SourceBuffer) {
68        match self {
69            LuaStatement::Local(s) => s.to_source(buffer),
70            LuaStatement::Assignment(s) => s.to_source(buffer),
71            LuaStatement::Expression(e) => e.to_source(buffer),
72            LuaStatement::Return(s) => s.to_source(buffer),
73            LuaStatement::If(s) => s.to_source(buffer),
74            LuaStatement::While(s) => s.to_source(buffer),
75            LuaStatement::For(s) => s.to_source(buffer),
76            LuaStatement::Repeat(s) => s.to_source(buffer),
77            LuaStatement::Function(s) => s.to_source(buffer),
78            LuaStatement::Break => buffer.push("break"),
79            LuaStatement::Do(stmts) => {
80                buffer.push("do\n");
81                for stmt in stmts {
82                    stmt.to_source(buffer);
83                    buffer.push("\n")
84                }
85                buffer.push("end")
86            }
87            LuaStatement::Goto(label) => {
88                buffer.push("goto ");
89                buffer.push(label)
90            }
91            LuaStatement::Label(name) => {
92                buffer.push("::");
93                buffer.push(name);
94                buffer.push("::")
95            }
96        }
97    }
98}
99
100#[cfg(feature = "oak-pretty-print")]
101impl AsDocument for LuaStatement {
102    fn as_document(&self, _params: &Self::Params) -> Document<'_> {
103        let mut buffer = SourceBuffer::new();
104        self.to_source(&mut buffer);
105        Document::Text(buffer.finish().into())
106    }
107}
108
109/// Local variable declaration
110#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
111#[derive(Debug, Clone)]
112pub struct LuaLocalStatement {
113    /// Names of the local variables.
114    pub names: Vec<String>,
115    /// Values assigned to the local variables.
116    pub values: Vec<LuaExpression>,
117}
118
119impl ToSource for LuaLocalStatement {
120    fn to_source(&self, buffer: &mut SourceBuffer) {
121        buffer.push("local ");
122        for (i, name) in self.names.iter().enumerate() {
123            if i > 0 {
124                buffer.push(", ")
125            }
126            buffer.push(name)
127        }
128        if !self.values.is_empty() {
129            buffer.push(" = ");
130            for (i, val) in self.values.iter().enumerate() {
131                if i > 0 {
132                    buffer.push(", ")
133                }
134                val.to_source(buffer)
135            }
136        }
137    }
138}
139
140/// Assignment statement
141#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
142#[derive(Debug, Clone)]
143pub struct LuaAssignmentStatement {
144    /// Targets of the assignment.
145    pub targets: Vec<LuaExpression>,
146    /// Values assigned to the targets.
147    pub values: Vec<LuaExpression>,
148}
149
150impl ToSource for LuaAssignmentStatement {
151    fn to_source(&self, buffer: &mut SourceBuffer) {
152        for (i, target) in self.targets.iter().enumerate() {
153            if i > 0 {
154                buffer.push(", ")
155            }
156            target.to_source(buffer)
157        }
158        buffer.push(" = ");
159        for (i, val) in self.values.iter().enumerate() {
160            if i > 0 {
161                buffer.push(", ")
162            }
163            val.to_source(buffer)
164        }
165    }
166}
167
168/// Return statement
169#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
170#[derive(Debug, Clone)]
171pub struct LuaReturnStatement {
172    /// Values returned.
173    pub values: Vec<LuaExpression>,
174}
175
176impl ToSource for LuaReturnStatement {
177    fn to_source(&self, buffer: &mut SourceBuffer) {
178        buffer.push("return ");
179        for (i, val) in self.values.iter().enumerate() {
180            if i > 0 {
181                buffer.push(", ")
182            }
183            val.to_source(buffer)
184        }
185    }
186}
187
188/// If statement
189#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
190#[derive(Debug, Clone)]
191pub struct LuaIfStatement {
192    /// The condition of the `if` block.
193    pub condition: LuaExpression,
194    /// The block of the `if` part.
195    pub then_block: Vec<LuaStatement>,
196    /// Else-if blocks.
197    pub else_ifs: Vec<(LuaExpression, Vec<LuaStatement>)>,
198    /// The block of the `else` part.
199    pub else_block: Option<Vec<LuaStatement>>,
200}
201
202impl ToSource for LuaIfStatement {
203    fn to_source(&self, buffer: &mut SourceBuffer) {
204        buffer.push("if ");
205        self.condition.to_source(buffer);
206        buffer.push(" then\n");
207        for stmt in &self.then_block {
208            stmt.to_source(buffer);
209            buffer.push("\n")
210        }
211        for (cond, block) in &self.else_ifs {
212            buffer.push("elseif ");
213            cond.to_source(buffer);
214            buffer.push(" then\n");
215            for stmt in block {
216                stmt.to_source(buffer);
217                buffer.push("\n")
218            }
219        }
220        if let Some(block) = &self.else_block {
221            buffer.push("else\n");
222            for stmt in block {
223                stmt.to_source(buffer);
224                buffer.push("\n")
225            }
226        }
227        buffer.push("end")
228    }
229}
230
231/// While statement
232#[derive(Debug, Clone)]
233#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
234pub struct LuaWhileStatement {
235    /// The condition of the `while` loop.
236    pub condition: LuaExpression,
237    /// The block of the `while` loop.
238    pub block: Vec<LuaStatement>,
239}
240
241impl ToSource for LuaWhileStatement {
242    fn to_source(&self, buffer: &mut SourceBuffer) {
243        buffer.push("while ");
244        self.condition.to_source(buffer);
245        buffer.push(" do\n");
246        for stmt in &self.block {
247            stmt.to_source(buffer);
248            buffer.push("\n")
249        }
250        buffer.push("end")
251    }
252}
253
254/// For statement
255#[derive(Debug, Clone)]
256#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
257pub enum LuaForStatement {
258    /// A numeric for loop: `for var = start, end, step do block end`.
259    Numeric {
260        /// The loop variable.
261        variable: String,
262        /// The start value.
263        start: LuaExpression,
264        /// The end value.
265        end: LuaExpression,
266        /// The step value.
267        step: Option<LuaExpression>,
268        /// The loop block.
269        block: Vec<LuaStatement>,
270    },
271    /// A generic for loop: `for vars in iters do block end`.
272    Generic {
273        /// The loop variables.
274        variables: Vec<String>,
275        /// The iterators.
276        iterators: Vec<LuaExpression>,
277        /// The loop block.
278        block: Vec<LuaStatement>,
279    },
280}
281
282impl ToSource for LuaForStatement {
283    fn to_source(&self, buffer: &mut SourceBuffer) {
284        match self {
285            LuaForStatement::Numeric { variable, start, end, step, block } => {
286                buffer.push("for ");
287                buffer.push(variable);
288                buffer.push(" = ");
289                start.to_source(buffer);
290                buffer.push(", ");
291                end.to_source(buffer);
292                if let Some(s) = step {
293                    buffer.push(", ");
294                    s.to_source(buffer)
295                }
296                buffer.push(" do\n");
297                for stmt in block {
298                    stmt.to_source(buffer);
299                    buffer.push("\n")
300                }
301                buffer.push("end")
302            }
303            LuaForStatement::Generic { variables, iterators, block } => {
304                buffer.push("for ");
305                for (i, var) in variables.iter().enumerate() {
306                    if i > 0 {
307                        buffer.push(", ")
308                    }
309                    buffer.push(var)
310                }
311                buffer.push(" in ");
312                for (i, it) in iterators.iter().enumerate() {
313                    if i > 0 {
314                        buffer.push(", ")
315                    }
316                    it.to_source(buffer)
317                }
318                buffer.push(" do\n");
319                for stmt in block {
320                    stmt.to_source(buffer);
321                    buffer.push("\n")
322                }
323                buffer.push("end")
324            }
325        }
326    }
327}
328
329/// Repeat statement
330#[derive(Debug, Clone)]
331#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
332pub struct LuaRepeatStatement {
333    /// The block of the `repeat` loop.
334    pub block: Vec<LuaStatement>,
335    /// The condition of the `repeat` loop.
336    pub condition: LuaExpression,
337}
338
339impl ToSource for LuaRepeatStatement {
340    fn to_source(&self, buffer: &mut SourceBuffer) {
341        buffer.push("repeat\n");
342        for stmt in &self.block {
343            stmt.to_source(buffer);
344            buffer.push("\n")
345        }
346        buffer.push("until ");
347        self.condition.to_source(buffer)
348    }
349}
350
351/// Function definition statement
352#[derive(Debug, Clone)]
353#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
354pub struct LuaFunctionStatement {
355    /// The name parts of the function.
356    pub name: Vec<String>,
357    /// The receiver part (after `:`) if any.
358    pub receiver: Option<String>,
359    /// The parameters of the function.
360    pub parameters: Vec<String>,
361    /// Whether the function has a vararg parameter.
362    pub is_vararg: bool,
363    /// The function body.
364    pub block: Vec<LuaStatement>,
365}
366
367impl ToSource for LuaFunctionStatement {
368    fn to_source(&self, buffer: &mut SourceBuffer) {
369        buffer.push("function ");
370        for (i, part) in self.name.iter().enumerate() {
371            if i > 0 {
372                buffer.push(".")
373            }
374            buffer.push(part)
375        }
376        if let Some(recv) = &self.receiver {
377            buffer.push(":");
378            buffer.push(recv)
379        }
380        buffer.push("(");
381        for (i, param) in self.parameters.iter().enumerate() {
382            if i > 0 {
383                buffer.push(", ")
384            }
385            buffer.push(param)
386        }
387        if self.is_vararg {
388            if !self.parameters.is_empty() {
389                buffer.push(", ")
390            }
391            buffer.push("...")
392        }
393        buffer.push(")\n");
394        for stmt in &self.block {
395            stmt.to_source(buffer);
396            buffer.push("\n")
397        }
398        buffer.push("end")
399    }
400}
401
402/// Lua expression
403#[derive(Debug, Clone)]
404#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
405pub enum LuaExpression {
406    /// An identifier.
407    Identifier(String),
408    /// A number literal.
409    Number(f64),
410    /// A string literal.
411    String(String),
412    /// A boolean literal.
413    Boolean(bool),
414    /// A nil literal.
415    Nil,
416    /// A binary expression.
417    Binary(Box<LuaBinaryExpression>),
418    /// A unary expression.
419    Unary(Box<LuaUnaryExpression>),
420    /// A call expression.
421    Call(Box<LuaCallExpression>),
422    /// A table constructor.
423    Table(LuaTableConstructor),
424    /// A function expression.
425    Function(LuaFunctionExpression),
426    /// An index expression.
427    Index(Box<LuaIndexExpression>),
428    /// A member expression.
429    Member(Box<LuaMemberExpression>),
430    /// A vararg expression.
431    Vararg,
432}
433
434impl ToSource for LuaExpression {
435    fn to_source(&self, buffer: &mut SourceBuffer) {
436        match self {
437            LuaExpression::Identifier(id) => buffer.push(id),
438            LuaExpression::Number(n) => buffer.push(&n.to_string()),
439            LuaExpression::String(s) => {
440                buffer.push("\"");
441                buffer.push(s);
442                buffer.push("\"")
443            }
444            LuaExpression::Boolean(b) => buffer.push(if *b { "true" } else { "false" }),
445            LuaExpression::Nil => buffer.push("nil"),
446            LuaExpression::Binary(bin) => bin.to_source(buffer),
447            LuaExpression::Unary(un) => un.to_source(buffer),
448            LuaExpression::Call(call) => call.to_source(buffer),
449            LuaExpression::Table(table) => table.to_source(buffer),
450            LuaExpression::Function(func) => func.to_source(buffer),
451            LuaExpression::Index(idx) => idx.to_source(buffer),
452            LuaExpression::Member(mem) => mem.to_source(buffer),
453            LuaExpression::Vararg => buffer.push("..."),
454        }
455    }
456}
457
458#[cfg(feature = "oak-pretty-print")]
459impl AsDocument for LuaExpression {
460    fn as_document(&self, _params: &Self::Params) -> Document<'_> {
461        let mut buffer = SourceBuffer::new();
462        self.to_source(&mut buffer);
463        Document::Text(buffer.finish().into())
464    }
465}
466
467/// Unary expression
468#[derive(Debug, Clone)]
469#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
470pub struct LuaUnaryExpression {
471    /// The operator.
472    pub op: String,
473    /// The operand.
474    pub operand: LuaExpression,
475}
476
477impl ToSource for LuaUnaryExpression {
478    fn to_source(&self, buffer: &mut SourceBuffer) {
479        buffer.push(&self.op);
480        self.operand.to_source(buffer)
481    }
482}
483
484/// Binary expression
485#[derive(Debug, Clone)]
486#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
487pub struct LuaBinaryExpression {
488    /// The left-hand side.
489    pub left: LuaExpression,
490    /// The operator.
491    pub op: String,
492    /// The right-hand side.
493    pub right: LuaExpression,
494}
495
496impl ToSource for LuaBinaryExpression {
497    fn to_source(&self, buffer: &mut SourceBuffer) {
498        self.left.to_source(buffer);
499        buffer.push(" ");
500        buffer.push(&self.op);
501        buffer.push(" ");
502        self.right.to_source(buffer)
503    }
504}
505
506/// Function call expression
507#[derive(Debug, Clone)]
508#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
509pub struct LuaCallExpression {
510    /// The function being called.
511    pub function: LuaExpression,
512    /// The arguments passed to the function.
513    pub arguments: Vec<LuaExpression>,
514}
515
516impl ToSource for LuaCallExpression {
517    fn to_source(&self, buffer: &mut SourceBuffer) {
518        self.function.to_source(buffer);
519        buffer.push("(");
520        for (i, arg) in self.arguments.iter().enumerate() {
521            if i > 0 {
522                buffer.push(", ")
523            }
524            arg.to_source(buffer)
525        }
526        buffer.push(")")
527    }
528}
529
530/// Table constructor
531#[derive(Debug, Clone)]
532#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
533pub struct LuaTableConstructor {
534    /// The fields in the table constructor.
535    pub fields: Vec<LuaTableField>,
536}
537
538impl ToSource for LuaTableConstructor {
539    fn to_source(&self, buffer: &mut SourceBuffer) {
540        buffer.push("{");
541        for (i, field) in self.fields.iter().enumerate() {
542            if i > 0 {
543                buffer.push(", ")
544            }
545            field.to_source(buffer)
546        }
547        buffer.push("}")
548    }
549}
550
551/// A field in a Lua table constructor.
552#[derive(Debug, Clone)]
553#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
554pub enum LuaTableField {
555    /// A keyed field: `[key] = value`.
556    Keyed {
557        /// The key expression.
558        key: LuaExpression,
559        /// The value expression.
560        value: LuaExpression,
561    },
562    /// A named field: `name = value`.
563    Named {
564        /// The name.
565        name: String,
566        /// The value expression.
567        value: LuaExpression,
568    },
569    /// A list field: `value`.
570    List {
571        /// The value expression.
572        value: LuaExpression,
573    },
574}
575
576impl ToSource for LuaTableField {
577    fn to_source(&self, buffer: &mut SourceBuffer) {
578        match self {
579            LuaTableField::Keyed { key, value } => {
580                buffer.push("[");
581                key.to_source(buffer);
582                buffer.push("] = ");
583                value.to_source(buffer)
584            }
585            LuaTableField::Named { name, value } => {
586                buffer.push(name);
587                buffer.push(" = ");
588                value.to_source(buffer)
589            }
590            LuaTableField::List { value } => value.to_source(buffer),
591        }
592    }
593}
594
595/// Anonymous function expression
596#[derive(Debug, Clone)]
597#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
598pub struct LuaFunctionExpression {
599    /// The parameters of the function.
600    pub parameters: Vec<String>,
601    /// Whether the function has a vararg parameter.
602    pub is_vararg: bool,
603    /// The function body.
604    pub block: Vec<LuaStatement>,
605}
606
607impl ToSource for LuaFunctionExpression {
608    fn to_source(&self, buffer: &mut SourceBuffer) {
609        buffer.push("function(");
610        for (i, param) in self.parameters.iter().enumerate() {
611            if i > 0 {
612                buffer.push(", ")
613            }
614            buffer.push(param)
615        }
616        if self.is_vararg {
617            if !self.parameters.is_empty() {
618                buffer.push(", ")
619            }
620            buffer.push("...")
621        }
622        buffer.push(")\n");
623        for stmt in &self.block {
624            stmt.to_source(buffer);
625            buffer.push("\n")
626        }
627        buffer.push("end")
628    }
629}
630
631/// Index access expression
632#[derive(Debug, Clone)]
633#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
634pub struct LuaIndexExpression {
635    /// The table being indexed.
636    pub table: LuaExpression,
637    /// The index expression.
638    pub index: LuaExpression,
639}
640
641impl ToSource for LuaIndexExpression {
642    fn to_source(&self, buffer: &mut SourceBuffer) {
643        self.table.to_source(buffer);
644        buffer.push("[");
645        self.index.to_source(buffer);
646        buffer.push("]")
647    }
648}
649
650/// Member access expression
651#[derive(Debug, Clone)]
652#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
653pub struct LuaMemberExpression {
654    /// The table whose member is being accessed.
655    pub table: LuaExpression,
656    /// The member name.
657    pub member: String,
658    /// Whether this is a method call (using `:`).
659    pub is_method: bool,
660}
661
662impl ToSource for LuaMemberExpression {
663    fn to_source(&self, buffer: &mut SourceBuffer) {
664        self.table.to_source(buffer);
665        if self.is_method {
666            buffer.push(":")
667        }
668        else {
669            buffer.push(".")
670        }
671        buffer.push(&self.member)
672    }
673}