Skip to main content

mimium_lang/
ast.rs

1pub mod builder;
2pub mod operators;
3pub mod program;
4mod resolve_include;
5pub mod statement;
6use serde::{Deserialize, Serialize};
7
8use crate::ast::operators::Op;
9use crate::ast::program::QualifiedPath;
10use crate::interner::{ExprNodeId, Symbol, TypeNodeId, with_session_globals};
11use crate::pattern::{TypedId, TypedPattern};
12use crate::utils::metadata::{Location, Span};
13use crate::utils::miniprint::MiniPrint;
14use std::fmt::{self};
15pub type Time = i64;
16
17#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
18pub enum StageKind {
19    Persistent = -1,
20    Macro = 0,
21    Main,
22}
23impl std::fmt::Display for StageKind {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            StageKind::Persistent => write!(f, "persistent"),
27            StageKind::Macro => write!(f, "macro"),
28            StageKind::Main => write!(f, "main"),
29        }
30    }
31}
32#[derive(Clone, Debug, PartialEq, Hash, Serialize, Deserialize)]
33pub enum Literal {
34    String(Symbol),
35    Int(i64),
36    Float(Symbol),
37    SelfLit,
38    Now,
39    SampleRate,
40    PlaceHolder,
41}
42
43impl Expr {
44    fn into_id_inner(self, loc: Option<Location>) -> ExprNodeId {
45        let loc = loc.unwrap_or_default();
46        with_session_globals(|session_globals| session_globals.store_expr_with_location(self, loc))
47    }
48
49    pub fn into_id(self, loc: Location) -> ExprNodeId {
50        self.into_id_inner(Some(loc))
51    }
52
53    // For testing purposes
54    pub fn into_id_without_span(self) -> ExprNodeId {
55        self.into_id_inner(None)
56    }
57}
58
59#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
60pub struct RecordField {
61    pub name: Symbol,
62    pub expr: ExprNodeId,
63}
64
65/// Pattern for match expressions
66
67#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
68pub enum MatchPattern {
69    /// Literal pattern: matches a specific value (e.g., 0, 1, 2)
70    Literal(Literal),
71    /// Wildcard pattern: matches anything (_)
72    Wildcard,
73    /// Variable binding pattern: binds a value to a name
74    Variable(Symbol),
75    /// Constructor pattern for union types: TypeName(inner_pattern)
76    /// e.g., Float(x), String(s), Two((x, y))
77    /// The Symbol is the type/constructor name, the Option<Box<MatchPattern>> is the optional inner pattern
78    Constructor(Symbol, Option<Box<MatchPattern>>),
79    /// Tuple pattern: matches a tuple and binds its elements
80    /// e.g., (x, y), (a, b, c)
81    Tuple(Vec<MatchPattern>),
82}
83
84/// A single arm of a match expression
85#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
86pub struct MatchArm {
87    pub pattern: MatchPattern,
88    pub body: ExprNodeId,
89}
90
91#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
92pub enum Expr {
93    Literal(Literal), // literal, or special symbols (self, now, _)
94    Var(Symbol),
95    QualifiedVar(QualifiedPath), // qualified name like modA::funcB
96    Block(Option<ExprNodeId>),
97    Tuple(Vec<ExprNodeId>),
98    Proj(ExprNodeId, i64),
99    ArrayAccess(ExprNodeId, ExprNodeId),
100    ArrayLiteral(Vec<ExprNodeId>),   // Array literal [e1, e2, ..., en]
101    RecordLiteral(Vec<RecordField>), // Complete record literal {field1 = expr1, field2 = expr2, ...}
102    ImcompleteRecord(Vec<RecordField>), // Incomplete record literal with default values {field1 = expr1, ..}
103    RecordUpdate(ExprNodeId, Vec<RecordField>), // Record update syntax: { record <- field1 = expr1, field2 = expr2, ... }
104    FieldAccess(ExprNodeId, Symbol),            // Record field access: record.field
105    Apply(ExprNodeId, Vec<ExprNodeId>),
106
107    MacroExpand(ExprNodeId, Vec<ExprNodeId>), // syntax sugar: hoge!(a,b) => ${hoge(a,b)}
108    BinOp(ExprNodeId, (Op, Span), ExprNodeId), // syntax sugar: LHS op RHS =>  OP(LHS, RHS) except for pipe operator : RHS(LHS)
109    UniOp((Op, Span), ExprNodeId), // syntax sugar: LHS op RHS =>  OP(LHS, RHS) except for pipe operator : RHS(LHS)
110    Paren(ExprNodeId),             // syntax sugar to preserve context for pretty printing
111
112    Lambda(Vec<TypedId>, Option<TypeNodeId>, ExprNodeId), //lambda, maybe information for internal state is needed
113    Assign(ExprNodeId, ExprNodeId),
114    Then(ExprNodeId, Option<ExprNodeId>),
115    Feed(Symbol, ExprNodeId), //feedback connection primitive operation. This will be shown only after self-removal stage
116    Let(TypedPattern, ExprNodeId, Option<ExprNodeId>),
117    LetRec(TypedId, ExprNodeId, Option<ExprNodeId>),
118    If(ExprNodeId, ExprNodeId, Option<ExprNodeId>),
119    Match(ExprNodeId, Vec<MatchArm>), // match expression: match scrutinee { pattern => expr, ... }
120    //exprimental macro system using multi-stage computation
121    Bracket(ExprNodeId),
122    Escape(ExprNodeId),
123
124    Error,
125}
126
127impl ExprNodeId {
128    /// Check whether the AST contains any staging constructs (Bracket, Escape,
129    /// or MacroExpand).  Programs without these are plain stage-1 code and do
130    /// not need to be wrapped / processed through the staging pipeline.
131    pub fn has_staging_constructs(self) -> bool {
132        self.has_staging_rec()
133    }
134
135    fn has_staging_rec(self) -> bool {
136        let conv = |e: &Self| e.has_staging_rec();
137        let conv_opt = |e: &Option<Self>| e.as_ref().is_some_and(|e| e.has_staging_rec());
138        let convvec = |es: &[Self]| es.iter().any(|e| e.has_staging_rec());
139        let convfields = |fs: &[RecordField]| fs.iter().any(|f| f.expr.has_staging_rec());
140        match self.to_expr() {
141            Expr::Bracket(_) | Expr::Escape(_) | Expr::MacroExpand(..) => true,
142            Expr::Proj(e, _)
143            | Expr::FieldAccess(e, _)
144            | Expr::UniOp(_, e)
145            | Expr::Paren(e)
146            | Expr::Lambda(_, _, e)
147            | Expr::Feed(_, e) => conv(&e),
148            Expr::ArrayAccess(e1, e2) | Expr::BinOp(e1, _, e2) | Expr::Assign(e1, e2) => {
149                conv(&e1) || conv(&e2)
150            }
151            Expr::Block(e) => conv_opt(&e),
152            Expr::Tuple(es) | Expr::ArrayLiteral(es) => convvec(&es),
153            Expr::RecordLiteral(fields) | Expr::ImcompleteRecord(fields) => convfields(&fields),
154            Expr::RecordUpdate(e, fields) => conv(&e) || convfields(&fields),
155            Expr::Apply(e, args) => conv(&e) || convvec(&args),
156            Expr::Then(e1, e2) | Expr::Let(_, e1, e2) | Expr::LetRec(_, e1, e2) => {
157                conv(&e1) || conv_opt(&e2)
158            }
159            Expr::If(cond, then, orelse) => conv(&cond) || conv(&then) || conv_opt(&orelse),
160            Expr::Match(scrutinee, arms) => {
161                conv(&scrutinee) || arms.iter().any(|arm| arm.body.has_staging_rec())
162            }
163            _ => false,
164        }
165    }
166
167    pub fn wrap_to_staged_expr(self) -> Self {
168        // TODO: what if more escape is used than minimum level??
169
170        // let min_level = self.get_min_stage_rec(0);
171        // let res = if min_level < 0 {
172        //     std::iter::repeat_n((), -min_level as usize).fold(self, |wrapped, _level| {
173        //         Expr::Bracket(wrapped).into_id_without_span()
174        //     })
175        // } else {
176        //     self
177        // };
178        //we have to wrap one more time because if there are no macro-related expression, that means stage-1(runtime) code.
179        Expr::Bracket(self).into_id_without_span()
180    }
181    fn get_min_stage_rec(self, current_level: i32) -> i32 {
182        let conv = move |e: &Self| e.get_min_stage_rec(current_level);
183        let conv2 = move |e1: &Self, e2: &Self| {
184            e1.get_min_stage_rec(current_level)
185                .min(e2.get_min_stage_rec(current_level))
186        };
187        let conv_opt = move |e: &Option<Self>| {
188            e.as_ref()
189                .map_or(current_level, |e| e.get_min_stage_rec(current_level))
190        };
191        let convvec = move |es: &[Self]| es.iter().map(conv).min().unwrap_or(current_level);
192        let convfields = move |fs: &[RecordField]| {
193            fs.iter()
194                .map(|f| f.expr.get_min_stage_rec(current_level))
195                .min()
196                .unwrap_or(current_level)
197        };
198        match self.to_expr() {
199            Expr::Bracket(e) => e.get_min_stage_rec(current_level + 1),
200            Expr::Escape(e) => e.get_min_stage_rec(current_level - 1),
201            Expr::MacroExpand(e, args) => conv(&e).min(convvec(&args)) - 1,
202            Expr::Proj(e, _)
203            | Expr::FieldAccess(e, _)
204            | Expr::UniOp(_, e)
205            | Expr::Paren(e)
206            | Expr::Lambda(_, _, e)
207            | Expr::Feed(_, e) => conv(&e),
208            Expr::ArrayAccess(e1, e2) | Expr::BinOp(e1, _, e2) | Expr::Assign(e1, e2) => {
209                conv2(&e1, &e2)
210            }
211            Expr::Block(e) => conv_opt(&e),
212            Expr::Tuple(es) | Expr::ArrayLiteral(es) => convvec(&es),
213
214            Expr::RecordLiteral(fields) | Expr::ImcompleteRecord(fields) => convfields(&fields),
215            Expr::RecordUpdate(e1, fields) => conv(&e1).min(convfields(&fields)),
216            Expr::Apply(e, args) => conv(&e).min(convvec(&args)),
217            Expr::Then(e1, e2) | Expr::Let(_, e1, e2) | Expr::LetRec(_, e1, e2) => {
218                conv(&e1).min(conv_opt(&e2))
219            }
220            Expr::If(cond, then, orelse) => conv(&cond).min(conv(&then)).min(conv_opt(&orelse)),
221            Expr::Match(scrutinee, arms) => {
222                let arm_min = arms
223                    .iter()
224                    .map(|arm| arm.body.get_min_stage_rec(current_level))
225                    .min()
226                    .unwrap_or(current_level);
227                conv(&scrutinee).min(arm_min)
228            }
229
230            _ => current_level,
231        }
232    }
233}
234
235impl fmt::Display for Literal {
236    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
237        match self {
238            Literal::Float(n) => write!(f, "(float {n})"),
239            Literal::Int(n) => write!(f, "(int {n})"),
240            Literal::String(s) => write!(f, "\"{s}\""),
241            Literal::Now => write!(f, "now"),
242            Literal::SampleRate => write!(f, "samplerate"),
243            Literal::SelfLit => write!(f, "self"),
244            Literal::PlaceHolder => write!(f, "_"),
245        }
246    }
247}
248
249impl MiniPrint for Literal {
250    fn simple_print(&self) -> String {
251        self.to_string()
252    }
253}
254
255fn concat_vec<T: MiniPrint>(vec: &[T]) -> String {
256    vec.iter()
257        .map(|t| t.simple_print())
258        .collect::<Vec<_>>()
259        .join(" ")
260}
261
262impl MiniPrint for ExprNodeId {
263    fn simple_print(&self) -> String {
264        let span = self.to_span();
265        format!(
266            "{}:{}..{}",
267            self.to_expr().simple_print(),
268            span.start,
269            span.end
270        )
271    }
272}
273
274impl MiniPrint for Option<ExprNodeId> {
275    fn simple_print(&self) -> String {
276        match self {
277            Some(e) => e.simple_print(),
278            None => "()".to_string(),
279        }
280    }
281}
282
283impl MiniPrint for RecordField {
284    fn simple_print(&self) -> String {
285        format!("{}: {}", self.name, self.expr.simple_print())
286    }
287}
288
289impl MiniPrint for Expr {
290    fn simple_print(&self) -> String {
291        match self {
292            Expr::Literal(l) => l.simple_print(),
293            Expr::Var(v) => format!("{v}"),
294            Expr::QualifiedVar(path) => path
295                .segments
296                .iter()
297                .map(|s| s.to_string())
298                .collect::<Vec<_>>()
299                .join("::"),
300            Expr::Block(e) => e.map_or("".to_string(), |eid| {
301                format!("(block {})", eid.simple_print())
302            }),
303            Expr::Tuple(e) => {
304                let e1 = e.iter().map(|e| e.to_expr().clone()).collect::<Vec<Expr>>();
305                format!("(tuple ({}))", concat_vec(&e1))
306            }
307            Expr::Proj(e, idx) => format!("(proj {} {})", e.simple_print(), idx),
308            Expr::Apply(e1, e2) => {
309                format!("(app {} ({}))", e1.simple_print(), concat_vec(e2))
310            }
311            Expr::MacroExpand(e1, e2s) => {
312                format!("(macro {} ({}))", e1.simple_print(), concat_vec(e2s))
313            }
314            Expr::ArrayAccess(e, i) => {
315                format!("(arrayaccess {} ({}))", e.simple_print(), i.simple_print())
316            }
317            Expr::ArrayLiteral(items) => {
318                let items_str = items
319                    .iter()
320                    .map(|e| e.simple_print())
321                    .collect::<Vec<String>>()
322                    .join(", ");
323                format!("(array [{items_str}])")
324            }
325            Expr::RecordLiteral(fields) => {
326                let fields_str = fields
327                    .iter()
328                    .map(|f| f.simple_print())
329                    .collect::<Vec<String>>()
330                    .join(", ");
331                format!("(record {{{fields_str}}})")
332            }
333            Expr::ImcompleteRecord(fields) => {
334                let fields_str = fields
335                    .iter()
336                    .map(|f| f.simple_print())
337                    .collect::<Vec<String>>()
338                    .join(", ");
339                format!("(incomplete-record {{{fields_str}, ..}})")
340            }
341            Expr::RecordUpdate(record, fields) => {
342                let fields_str = fields
343                    .iter()
344                    .map(|f| f.simple_print())
345                    .collect::<Vec<String>>()
346                    .join(", ");
347                format!(
348                    "(record-update {} {{{}}})",
349                    record.simple_print(),
350                    fields_str
351                )
352            }
353            Expr::FieldAccess(record, field) => {
354                format!("(field-access {} {})", record.simple_print(), field)
355            }
356            Expr::UniOp(op, expr) => {
357                format!("(unary {} {})", op.0, expr.simple_print())
358            }
359            Expr::BinOp(lhs, op, rhs) => {
360                format!(
361                    "(binop {} {} {})",
362                    op.0,
363                    lhs.simple_print(),
364                    rhs.simple_print()
365                )
366            }
367            Expr::Lambda(params, _, body) => {
368                format!("(lambda ({}) {})", concat_vec(params), body.simple_print())
369            }
370            Expr::Feed(id, body) => format!("(feed {} {})", id, body.simple_print()),
371            Expr::Let(id, body, then) => format!(
372                "(let {} {} {})",
373                id.simple_print(),
374                body.simple_print(),
375                then.simple_print()
376            ),
377            Expr::LetRec(id, body, then) => format!(
378                "(letrec {} {} {})",
379                &id.simple_print(),
380                body.simple_print(),
381                then.simple_print()
382            ),
383            Expr::Assign(lid, rhs) => {
384                format!("(assign {} {})", lid.simple_print(), rhs.simple_print())
385            }
386            Expr::Then(first, second) => {
387                format!("(then {} {})", first.simple_print(), second.simple_print())
388            }
389            Expr::If(cond, then, optelse) => format!(
390                "(if {} {} {})",
391                cond.simple_print(),
392                then.simple_print(),
393                optelse.simple_print()
394            ),
395            Expr::Match(scrutinee, arms) => {
396                let arms_str = arms
397                    .iter()
398                    .map(|arm| format!("{:?} => {}", arm.pattern, arm.body.simple_print()))
399                    .collect::<Vec<_>>()
400                    .join(", ");
401                format!("(match {} [{}])", scrutinee.simple_print(), arms_str)
402            }
403            Expr::Bracket(e) => format!("(bracket {})", e.simple_print()),
404            Expr::Escape(e) => format!("(escape {})", e.simple_print()),
405            Expr::Error => "(error)".to_string(),
406            Expr::Paren(expr_node_id) => format!("(paren {})", expr_node_id.simple_print()),
407        }
408    }
409}