Skip to main content

oak_lua/ast/
mod.rs

1#![doc = include_str!("readme.md")]
2use core::range::Range;
3use oak_core::source::{SourceBuffer, ToSource};
4#[cfg(feature = "oak-pretty-print")]
5use oak_pretty_print::{AsDocument, Document};
6#[cfg(feature = "serde")]
7use serde::{Deserialize, Serialize};
8
9/// Lua root node
10#[derive(Debug, Clone)]
11#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
12pub struct LuaRoot {
13    /// Statements in the root.
14    pub statements: Vec<LuaStatement>,
15    /// Source span of the root.
16    #[cfg_attr(feature = "serde", serde(with = "oak_core::serde_range"))]
17    pub span: Range<usize>,
18}
19
20impl ToSource for LuaRoot {
21    fn to_source(&self, buffer: &mut SourceBuffer) {
22        for stmt in &self.statements {
23            stmt.to_source(buffer);
24            buffer.push("\n")
25        }
26    }
27}
28
29#[cfg(feature = "oak-pretty-print")]
30impl AsDocument for LuaRoot {
31    fn as_document(&self) -> Document<'_> {
32        Document::join(self.statements.iter().map(|s| s.as_document()), Document::Line)
33    }
34}
35
36/// Lua statement
37#[derive(Debug, Clone)]
38#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
39pub enum LuaStatement {
40    /// A local statement.
41    Local(LuaLocalStatement),
42    /// An assignment statement.
43    Assignment(LuaAssignmentStatement),
44    /// An expression statement.
45    Expression(LuaExpression),
46    /// A return statement.
47    Return(LuaReturnStatement),
48    /// An if statement.
49    If(LuaIfStatement),
50    /// A while statement.
51    While(LuaWhileStatement),
52    /// A for statement.
53    For(LuaForStatement),
54    /// A repeat statement.
55    Repeat(LuaRepeatStatement),
56    /// A function statement.
57    Function(LuaFunctionStatement),
58    /// A break statement.
59    Break,
60    /// A do block.
61    Do(Vec<LuaStatement>),
62    /// A goto statement.
63    Goto(String),
64    /// A label statement.
65    Label(String),
66}
67
68impl ToSource for LuaStatement {
69    fn to_source(&self, buffer: &mut SourceBuffer) {
70        match self {
71            LuaStatement::Local(s) => s.to_source(buffer),
72            LuaStatement::Assignment(s) => s.to_source(buffer),
73            LuaStatement::Expression(e) => e.to_source(buffer),
74            LuaStatement::Return(s) => s.to_source(buffer),
75            LuaStatement::If(s) => s.to_source(buffer),
76            LuaStatement::While(s) => s.to_source(buffer),
77            LuaStatement::For(s) => s.to_source(buffer),
78            LuaStatement::Repeat(s) => s.to_source(buffer),
79            LuaStatement::Function(s) => s.to_source(buffer),
80            LuaStatement::Break => buffer.push("break"),
81            LuaStatement::Do(stmts) => {
82                buffer.push("do\n");
83                for stmt in stmts {
84                    stmt.to_source(buffer);
85                    buffer.push("\n")
86                }
87                buffer.push("end")
88            }
89            LuaStatement::Goto(label) => {
90                buffer.push("goto ");
91                buffer.push(label)
92            }
93            LuaStatement::Label(name) => {
94                buffer.push("::");
95                buffer.push(name);
96                buffer.push("::")
97            }
98        }
99    }
100}
101
102#[cfg(feature = "oak-pretty-print")]
103impl AsDocument for LuaStatement {
104    fn as_document(&self) -> Document<'_> {
105        let mut buffer = SourceBuffer::new();
106        self.to_source(&mut buffer);
107        Document::Text(buffer.finish().into())
108    }
109}
110
111/// Local variable declaration
112#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
113#[derive(Debug, Clone)]
114pub struct LuaLocalStatement {
115    /// Names of the local variables.
116    pub names: Vec<String>,
117    /// Values assigned to the local variables.
118    pub values: Vec<LuaExpression>,
119}
120
121impl ToSource for LuaLocalStatement {
122    fn to_source(&self, buffer: &mut SourceBuffer) {
123        buffer.push("local ");
124        for (i, name) in self.names.iter().enumerate() {
125            if i > 0 {
126                buffer.push(", ")
127            }
128            buffer.push(name)
129        }
130        if !self.values.is_empty() {
131            buffer.push(" = ");
132            for (i, val) in self.values.iter().enumerate() {
133                if i > 0 {
134                    buffer.push(", ")
135                }
136                val.to_source(buffer)
137            }
138        }
139    }
140}
141
142/// Assignment statement
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
144#[derive(Debug, Clone)]
145pub struct LuaAssignmentStatement {
146    /// Targets of the assignment.
147    pub targets: Vec<LuaExpression>,
148    /// Values assigned to the targets.
149    pub values: Vec<LuaExpression>,
150}
151
152impl ToSource for LuaAssignmentStatement {
153    fn to_source(&self, buffer: &mut SourceBuffer) {
154        for (i, target) in self.targets.iter().enumerate() {
155            if i > 0 {
156                buffer.push(", ")
157            }
158            target.to_source(buffer)
159        }
160        buffer.push(" = ");
161        for (i, val) in self.values.iter().enumerate() {
162            if i > 0 {
163                buffer.push(", ")
164            }
165            val.to_source(buffer)
166        }
167    }
168}
169
170/// Return statement
171#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
172#[derive(Debug, Clone)]
173pub struct LuaReturnStatement {
174    /// Values returned.
175    pub values: Vec<LuaExpression>,
176}
177
178impl ToSource for LuaReturnStatement {
179    fn to_source(&self, buffer: &mut SourceBuffer) {
180        buffer.push("return ");
181        for (i, val) in self.values.iter().enumerate() {
182            if i > 0 {
183                buffer.push(", ")
184            }
185            val.to_source(buffer)
186        }
187    }
188}
189
190/// If statement
191#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
192#[derive(Debug, Clone)]
193pub struct LuaIfStatement {
194    /// The condition of the `if` block.
195    pub condition: LuaExpression,
196    /// The block of the `if` part.
197    pub then_block: Vec<LuaStatement>,
198    /// Else-if blocks.
199    pub else_ifs: Vec<(LuaExpression, Vec<LuaStatement>)>,
200    /// The block of the `else` part.
201    pub else_block: Option<Vec<LuaStatement>>,
202}
203
204impl ToSource for LuaIfStatement {
205    fn to_source(&self, buffer: &mut SourceBuffer) {
206        buffer.push("if ");
207        self.condition.to_source(buffer);
208        buffer.push(" then\n");
209        for stmt in &self.then_block {
210            stmt.to_source(buffer);
211            buffer.push("\n")
212        }
213        for (cond, block) in &self.else_ifs {
214            buffer.push("elseif ");
215            cond.to_source(buffer);
216            buffer.push(" then\n");
217            for stmt in block {
218                stmt.to_source(buffer);
219                buffer.push("\n")
220            }
221        }
222        if let Some(block) = &self.else_block {
223            buffer.push("else\n");
224            for stmt in block {
225                stmt.to_source(buffer);
226                buffer.push("\n")
227            }
228        }
229        buffer.push("end")
230    }
231}
232
233/// While statement
234#[derive(Debug, Clone)]
235#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
236pub struct LuaWhileStatement {
237    /// The condition of the `while` loop.
238    pub condition: LuaExpression,
239    /// The block of the `while` loop.
240    pub block: Vec<LuaStatement>,
241}
242
243impl ToSource for LuaWhileStatement {
244    fn to_source(&self, buffer: &mut SourceBuffer) {
245        buffer.push("while ");
246        self.condition.to_source(buffer);
247        buffer.push(" do\n");
248        for stmt in &self.block {
249            stmt.to_source(buffer);
250            buffer.push("\n")
251        }
252        buffer.push("end")
253    }
254}
255
256/// For statement
257#[derive(Debug, Clone)]
258#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
259pub enum LuaForStatement {
260    /// A numeric for loop: `for var = start, end, step do block end`.
261    Numeric {
262        /// The loop variable.
263        variable: String,
264        /// The start value.
265        start: LuaExpression,
266        /// The end value.
267        end: LuaExpression,
268        /// The step value.
269        step: Option<LuaExpression>,
270        /// The loop block.
271        block: Vec<LuaStatement>,
272    },
273    /// A generic for loop: `for vars in iters do block end`.
274    Generic {
275        /// The loop variables.
276        variables: Vec<String>,
277        /// The iterators.
278        iterators: Vec<LuaExpression>,
279        /// The loop block.
280        block: Vec<LuaStatement>,
281    },
282}
283
284impl ToSource for LuaForStatement {
285    fn to_source(&self, buffer: &mut SourceBuffer) {
286        match self {
287            LuaForStatement::Numeric { variable, start, end, step, block } => {
288                buffer.push("for ");
289                buffer.push(variable);
290                buffer.push(" = ");
291                start.to_source(buffer);
292                buffer.push(", ");
293                end.to_source(buffer);
294                if let Some(s) = step {
295                    buffer.push(", ");
296                    s.to_source(buffer)
297                }
298                buffer.push(" do\n");
299                for stmt in block {
300                    stmt.to_source(buffer);
301                    buffer.push("\n")
302                }
303                buffer.push("end")
304            }
305            LuaForStatement::Generic { variables, iterators, block } => {
306                buffer.push("for ");
307                for (i, var) in variables.iter().enumerate() {
308                    if i > 0 {
309                        buffer.push(", ")
310                    }
311                    buffer.push(var)
312                }
313                buffer.push(" in ");
314                for (i, it) in iterators.iter().enumerate() {
315                    if i > 0 {
316                        buffer.push(", ")
317                    }
318                    it.to_source(buffer)
319                }
320                buffer.push(" do\n");
321                for stmt in block {
322                    stmt.to_source(buffer);
323                    buffer.push("\n")
324                }
325                buffer.push("end")
326            }
327        }
328    }
329}
330
331/// Repeat statement
332#[derive(Debug, Clone)]
333#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
334pub struct LuaRepeatStatement {
335    /// The block of the `repeat` loop.
336    pub block: Vec<LuaStatement>,
337    /// The condition of the `repeat` loop.
338    pub condition: LuaExpression,
339}
340
341impl ToSource for LuaRepeatStatement {
342    fn to_source(&self, buffer: &mut SourceBuffer) {
343        buffer.push("repeat\n");
344        for stmt in &self.block {
345            stmt.to_source(buffer);
346            buffer.push("\n")
347        }
348        buffer.push("until ");
349        self.condition.to_source(buffer)
350    }
351}
352
353/// Function definition statement
354#[derive(Debug, Clone)]
355#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
356pub struct LuaFunctionStatement {
357    /// The name parts of the function.
358    pub name: Vec<String>,
359    /// The receiver part (after `:`) if any.
360    pub receiver: Option<String>,
361    /// The parameters of the function.
362    pub parameters: Vec<String>,
363    /// Whether the function has a vararg parameter.
364    pub is_vararg: bool,
365    /// The function body.
366    pub block: Vec<LuaStatement>,
367}
368
369impl ToSource for LuaFunctionStatement {
370    fn to_source(&self, buffer: &mut SourceBuffer) {
371        buffer.push("function ");
372        for (i, part) in self.name.iter().enumerate() {
373            if i > 0 {
374                buffer.push(".")
375            }
376            buffer.push(part)
377        }
378        if let Some(recv) = &self.receiver {
379            buffer.push(":");
380            buffer.push(recv)
381        }
382        buffer.push("(");
383        for (i, param) in self.parameters.iter().enumerate() {
384            if i > 0 {
385                buffer.push(", ")
386            }
387            buffer.push(param)
388        }
389        if self.is_vararg {
390            if !self.parameters.is_empty() {
391                buffer.push(", ")
392            }
393            buffer.push("...")
394        }
395        buffer.push(")\n");
396        for stmt in &self.block {
397            stmt.to_source(buffer);
398            buffer.push("\n")
399        }
400        buffer.push("end")
401    }
402}
403
404/// Lua expression
405#[derive(Debug, Clone)]
406#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
407pub enum LuaExpression {
408    /// An identifier.
409    Identifier(String),
410    /// A number literal.
411    Number(f64),
412    /// A string literal.
413    String(String),
414    /// A boolean literal.
415    Boolean(bool),
416    /// A nil literal.
417    Nil,
418    /// A binary expression.
419    Binary(Box<LuaBinaryExpression>),
420    /// A unary expression.
421    Unary(Box<LuaUnaryExpression>),
422    /// A call expression.
423    Call(Box<LuaCallExpression>),
424    /// A table constructor.
425    Table(LuaTableConstructor),
426    /// A function expression.
427    Function(LuaFunctionExpression),
428    /// An index expression.
429    Index(Box<LuaIndexExpression>),
430    /// A member expression.
431    Member(Box<LuaMemberExpression>),
432    /// A vararg expression.
433    Vararg,
434}
435
436impl ToSource for LuaExpression {
437    fn to_source(&self, buffer: &mut SourceBuffer) {
438        match self {
439            LuaExpression::Identifier(id) => buffer.push(id),
440            LuaExpression::Number(n) => buffer.push(&n.to_string()),
441            LuaExpression::String(s) => {
442                buffer.push("\"");
443                buffer.push(s);
444                buffer.push("\"")
445            }
446            LuaExpression::Boolean(b) => buffer.push(if *b { "true" } else { "false" }),
447            LuaExpression::Nil => buffer.push("nil"),
448            LuaExpression::Binary(bin) => bin.to_source(buffer),
449            LuaExpression::Unary(un) => un.to_source(buffer),
450            LuaExpression::Call(call) => call.to_source(buffer),
451            LuaExpression::Table(table) => table.to_source(buffer),
452            LuaExpression::Function(func) => func.to_source(buffer),
453            LuaExpression::Index(idx) => idx.to_source(buffer),
454            LuaExpression::Member(mem) => mem.to_source(buffer),
455            LuaExpression::Vararg => buffer.push("..."),
456        }
457    }
458}
459
460#[cfg(feature = "oak-pretty-print")]
461impl AsDocument for LuaExpression {
462    fn as_document(&self) -> Document<'_> {
463        let mut buffer = SourceBuffer::new();
464        self.to_source(&mut buffer);
465        Document::Text(buffer.finish().into())
466    }
467}
468
469/// Unary expression
470#[derive(Debug, Clone)]
471#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
472pub struct LuaUnaryExpression {
473    /// The operator.
474    pub op: String,
475    /// The operand.
476    pub operand: LuaExpression,
477}
478
479impl ToSource for LuaUnaryExpression {
480    fn to_source(&self, buffer: &mut SourceBuffer) {
481        buffer.push(&self.op);
482        self.operand.to_source(buffer)
483    }
484}
485
486/// Binary expression
487#[derive(Debug, Clone)]
488#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
489pub struct LuaBinaryExpression {
490    /// The left-hand side.
491    pub left: LuaExpression,
492    /// The operator.
493    pub op: String,
494    /// The right-hand side.
495    pub right: LuaExpression,
496}
497
498impl ToSource for LuaBinaryExpression {
499    fn to_source(&self, buffer: &mut SourceBuffer) {
500        self.left.to_source(buffer);
501        buffer.push(" ");
502        buffer.push(&self.op);
503        buffer.push(" ");
504        self.right.to_source(buffer)
505    }
506}
507
508/// Function call expression
509#[derive(Debug, Clone)]
510#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
511pub struct LuaCallExpression {
512    /// The function being called.
513    pub function: LuaExpression,
514    /// The arguments passed to the function.
515    pub arguments: Vec<LuaExpression>,
516}
517
518impl ToSource for LuaCallExpression {
519    fn to_source(&self, buffer: &mut SourceBuffer) {
520        self.function.to_source(buffer);
521        buffer.push("(");
522        for (i, arg) in self.arguments.iter().enumerate() {
523            if i > 0 {
524                buffer.push(", ")
525            }
526            arg.to_source(buffer)
527        }
528        buffer.push(")")
529    }
530}
531
532/// Table constructor
533#[derive(Debug, Clone)]
534#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
535pub struct LuaTableConstructor {
536    /// The fields in the table constructor.
537    pub fields: Vec<LuaTableField>,
538}
539
540impl ToSource for LuaTableConstructor {
541    fn to_source(&self, buffer: &mut SourceBuffer) {
542        buffer.push("{");
543        for (i, field) in self.fields.iter().enumerate() {
544            if i > 0 {
545                buffer.push(", ")
546            }
547            field.to_source(buffer)
548        }
549        buffer.push("}")
550    }
551}
552
553/// A field in a Lua table constructor.
554#[derive(Debug, Clone)]
555#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
556pub enum LuaTableField {
557    /// A keyed field: `[key] = value`.
558    Keyed {
559        /// The key expression.
560        key: LuaExpression,
561        /// The value expression.
562        value: LuaExpression,
563    },
564    /// A named field: `name = value`.
565    Named {
566        /// The name.
567        name: String,
568        /// The value expression.
569        value: LuaExpression,
570    },
571    /// A list field: `value`.
572    List {
573        /// The value expression.
574        value: LuaExpression,
575    },
576}
577
578impl ToSource for LuaTableField {
579    fn to_source(&self, buffer: &mut SourceBuffer) {
580        match self {
581            LuaTableField::Keyed { key, value } => {
582                buffer.push("[");
583                key.to_source(buffer);
584                buffer.push("] = ");
585                value.to_source(buffer)
586            }
587            LuaTableField::Named { name, value } => {
588                buffer.push(name);
589                buffer.push(" = ");
590                value.to_source(buffer)
591            }
592            LuaTableField::List { value } => value.to_source(buffer),
593        }
594    }
595}
596
597/// Anonymous function expression
598#[derive(Debug, Clone)]
599#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
600pub struct LuaFunctionExpression {
601    /// The parameters of the function.
602    pub parameters: Vec<String>,
603    /// Whether the function has a vararg parameter.
604    pub is_vararg: bool,
605    /// The function body.
606    pub block: Vec<LuaStatement>,
607}
608
609impl ToSource for LuaFunctionExpression {
610    fn to_source(&self, buffer: &mut SourceBuffer) {
611        buffer.push("function(");
612        for (i, param) in self.parameters.iter().enumerate() {
613            if i > 0 {
614                buffer.push(", ")
615            }
616            buffer.push(param)
617        }
618        if self.is_vararg {
619            if !self.parameters.is_empty() {
620                buffer.push(", ")
621            }
622            buffer.push("...")
623        }
624        buffer.push(")\n");
625        for stmt in &self.block {
626            stmt.to_source(buffer);
627            buffer.push("\n")
628        }
629        buffer.push("end")
630    }
631}
632
633/// Index access expression
634#[derive(Debug, Clone)]
635#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
636pub struct LuaIndexExpression {
637    /// The table being indexed.
638    pub table: LuaExpression,
639    /// The index expression.
640    pub index: LuaExpression,
641}
642
643impl ToSource for LuaIndexExpression {
644    fn to_source(&self, buffer: &mut SourceBuffer) {
645        self.table.to_source(buffer);
646        buffer.push("[");
647        self.index.to_source(buffer);
648        buffer.push("]")
649    }
650}
651
652/// Member access expression
653#[derive(Debug, Clone)]
654#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
655pub struct LuaMemberExpression {
656    /// The table whose member is being accessed.
657    pub table: LuaExpression,
658    /// The member name.
659    pub member: String,
660    /// Whether this is a method call (using `:`).
661    pub is_method: bool,
662}
663
664impl ToSource for LuaMemberExpression {
665    fn to_source(&self, buffer: &mut SourceBuffer) {
666        self.table.to_source(buffer);
667        if self.is_method {
668            buffer.push(":")
669        }
670        else {
671            buffer.push(".")
672        }
673        buffer.push(&self.member)
674    }
675}