rumoca/ir/
ast.rs

1//! This module defines the Abstract Syntax Tree (AST) and Intermediate Representation (IR)
2//! structures for a custom language or model representation. It provides a comprehensive set
3//! of data structures to represent various components, expressions, equations, and statements
4//! in the language. The module also includes serialization and deserialization support via
5//! `serde` and custom implementations of `Debug` and `Display` traits for better debugging
6//! and formatting.
7//!
8//! # Key Structures
9//!
10//! - **Location**: Represents the location of a token or element in the source file, including
11//!   line and column numbers.
12//! - **Token**: Represents a lexical token with its text, location, type, and number.
13//! - **Name**: Represents a hierarchical name composed of multiple tokens.
14//! - **StoredDefinition**: Represents a collection of class definitions and an optional
15//!   "within" clause.
16//! - **Component**: Represents a component with its name, type, variability, causality,
17//!   connection, description, and initial value.
18//! - **ClassDefinition**: Represents a class definition with its name, components, equations,
19//!   and algorithms.
20//! - **ComponentReference**: Represents a reference to a component, including its parts and
21//!   optional subscripts.
22//! - **Equation**: Represents various types of equations, such as simple equations, connect
23//!   equations, and conditional equations.
24//! - **Expression**: Represents various types of expressions, including binary, unary,
25//!   terminal, and function call expressions.
26//! - **Statement**: Represents various types of statements, such as assignments, loops, and
27//!   function calls.
28//!
29//! # Enums
30//!
31//! - **OpBinary**: Represents binary operators like addition, subtraction, multiplication, etc.
32//! - **OpUnary**: Represents unary operators like negation and logical NOT.
33//! - **TerminalType**: Represents the type of a terminal expression, such as real, integer,
34//!   string, or boolean.
35//! - **Variability**: Represents the variability of a component (e.g., constant, discrete,
36//!   parameter).
37//! - **Connection**: Represents the connection type of a component (e.g., flow, stream).
38//! - **Causality**: Represents the causality of a component (e.g., input, output).
39//!
40//! This module is designed to be extensible and serves as the foundation for parsing,
41//! analyzing, and generating code for the custom language or model representation.
42use indexmap::IndexMap;
43use serde::{Deserialize, Serialize};
44use std::{fmt::Debug, fmt::Display};
45
46#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
47pub struct Location {
48    pub start_line: u32,
49    pub start_column: u32,
50    pub end_line: u32,
51    pub end_column: u32,
52    pub start: u32,
53    pub end: u32,
54    pub file_name: String,
55}
56
57#[derive(Default, Clone, PartialEq, Serialize, Deserialize)]
58
59pub struct Token {
60    pub text: String,
61    pub location: Location,
62    pub token_number: u32,
63    pub token_type: u16,
64}
65
66impl Debug for Token {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(f, "{:?}", self.text)
69    }
70}
71
72#[derive(Default, Clone, PartialEq, Serialize, Deserialize)]
73
74pub struct Name {
75    pub name: Vec<Token>,
76}
77
78impl Display for Name {
79    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
80        let mut s = Vec::new();
81        for n in &self.name {
82            s.push(n.text.clone());
83        }
84        write!(f, "{}", s.join("."))
85    }
86}
87
88impl Debug for Name {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        let mut s = Vec::new();
91        for n in &self.name {
92            s.push(n.text.clone());
93        }
94        write!(f, "{:?}", s.join("."))
95    }
96}
97
98#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
99
100pub struct StoredDefinition {
101    pub class_list: IndexMap<String, ClassDefinition>,
102    pub within: Option<Name>,
103}
104
105#[derive(Default, Clone, PartialEq, Serialize, Deserialize)]
106
107pub struct Component {
108    pub name: String,
109    /// The token for the component name with exact source location
110    pub name_token: Token,
111    pub type_name: Name,
112    pub variability: Variability,
113    pub causality: Causality,
114    pub connection: Connection,
115    pub description: Vec<Token>,
116    pub start: Expression,
117    /// True if start value is from a modification (start=x), false if from binding (= x)
118    pub start_is_modification: bool,
119    /// True if the start modifier has `each` prefix (for array components)
120    pub start_has_each: bool,
121    /// Array dimensions - empty for scalars, e.g., [2, 3] for a 2x3 matrix
122    /// Only populated when dimensions are known literal integers
123    pub shape: Vec<usize>,
124    /// Raw array dimension subscripts (e.g., `n` in `Real x[n]` or `:` in `Real a[:]`)
125    /// Used for parameter-dependent dimensions that need runtime evaluation
126    /// Can be Subscript::Expression for explicit dimensions or Subscript::Range for `:`
127    pub shape_expr: Vec<Subscript>,
128    /// True if shape is from a modification (shape=x), false if from subscript \[x\]
129    pub shape_is_modification: bool,
130    /// Annotation arguments (e.g., from `annotation(Icon(...), Dialog(...))`)
131    pub annotation: Vec<Expression>,
132    /// Component modifications (e.g., R=10 in `Resistor R1(R=10)`)
133    /// Maps parameter name to its modified value expression
134    pub modifications: IndexMap<String, Expression>,
135    /// Full source location for the component declaration
136    pub location: Location,
137    /// Conditional component expression (e.g., `if use_reset` in `BooleanInput reset if use_reset`)
138    /// None means the component is unconditional (always present)
139    pub condition: Option<Expression>,
140    /// True if declared with 'inner' prefix (provides instance to outer references)
141    pub inner: bool,
142    /// True if declared with 'outer' prefix (references an inner instance from enclosing scope)
143    pub outer: bool,
144    /// Set of attribute names that are marked as final (e.g., "start" if `final start = 1.0`)
145    /// When a derived class tries to override these attributes, an error should be raised
146    pub final_attributes: std::collections::HashSet<String>,
147    /// True if this component is declared in a protected section
148    pub is_protected: bool,
149}
150
151impl Debug for Component {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        let mut builder = f.debug_struct("Component");
154        builder
155            .field("name", &self.name)
156            .field("type_name", &self.type_name);
157        if self.variability != Variability::Empty {
158            builder.field("variability", &self.variability);
159        }
160        if self.causality != Causality::Empty {
161            builder.field("causality", &self.causality);
162        }
163        if self.connection != Connection::Empty {
164            builder.field("connection", &self.connection);
165        }
166        if !self.description.is_empty() {
167            builder.field("description", &self.description);
168        }
169        if !self.shape.is_empty() {
170            builder.field("shape", &self.shape);
171        }
172        if !self.shape_expr.is_empty() {
173            builder.field("shape_expr", &self.shape_expr);
174        }
175        if !self.annotation.is_empty() {
176            builder.field("annotation", &self.annotation);
177        }
178        if !self.modifications.is_empty() {
179            builder.field("modifications", &self.modifications);
180        }
181        if self.condition.is_some() {
182            builder.field("condition", &self.condition);
183        }
184        if self.inner {
185            builder.field("inner", &self.inner);
186        }
187        if self.outer {
188            builder.field("outer", &self.outer);
189        }
190        if !self.final_attributes.is_empty() {
191            builder.field("final_attributes", &self.final_attributes);
192        }
193        if self.is_protected {
194            builder.field("is_protected", &self.is_protected);
195        }
196        builder.finish()
197    }
198}
199
200/// Enumeration literal with optional description
201#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
202pub struct EnumLiteral {
203    /// The literal identifier (e.g., 'U' or Red)
204    pub ident: Token,
205    /// Optional description strings (e.g., "Uninitialized")
206    pub description: Vec<Token>,
207}
208
209/// Type of class (model, function, connector, etc.)
210#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
211pub enum ClassType {
212    #[default]
213    Model,
214    Class,
215    Block,
216    Connector,
217    Record,
218    Type,
219    Package,
220    Function,
221    Operator,
222}
223
224#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
225
226pub struct ClassDefinition {
227    pub name: Token,
228    pub class_type: ClassType,
229    /// Token for the class type keyword (model, class, function, etc.)
230    pub class_type_token: Token,
231    pub encapsulated: bool,
232    /// True if the class is declared with the `partial` keyword
233    pub partial: bool,
234    /// Causality from type alias definition (e.g., `connector RealInput = input Real`)
235    /// Components of this type inherit this causality
236    pub causality: Causality,
237    /// Description string for this class (e.g., "A test model")
238    pub description: Vec<Token>,
239    /// Full source location spanning from class keyword to end statement
240    pub location: Location,
241    pub extends: Vec<Extend>,
242    pub imports: Vec<Import>,
243    /// Nested class definitions (functions, models, packages, etc.)
244    pub classes: IndexMap<String, ClassDefinition>,
245    pub components: IndexMap<String, Component>,
246    pub equations: Vec<Equation>,
247    pub initial_equations: Vec<Equation>,
248    pub algorithms: Vec<Vec<Statement>>,
249    pub initial_algorithms: Vec<Vec<Statement>>,
250    /// Token for "equation" keyword (if present)
251    pub equation_keyword: Option<Token>,
252    /// Token for "initial equation" keyword (if present)
253    pub initial_equation_keyword: Option<Token>,
254    /// Token for "algorithm" keyword (if present)
255    pub algorithm_keyword: Option<Token>,
256    /// Token for "initial algorithm" keyword (if present)
257    pub initial_algorithm_keyword: Option<Token>,
258    /// Token for the class name in "end ClassName;" (for rename support)
259    pub end_name_token: Option<Token>,
260    /// Enumeration literals for enum types (e.g., `type MyEnum = enumeration(A "desc", B, C)`)
261    pub enum_literals: Vec<EnumLiteral>,
262    /// Annotation clause for this class (e.g., Documentation, Icon, Diagram)
263    pub annotation: Vec<Expression>,
264}
265
266impl ClassDefinition {
267    /// Iterate over all component declarations with their names.
268    ///
269    /// This provides a convenient way to iterate over components without
270    /// directly accessing the `components` field.
271    pub fn iter_components(&self) -> impl Iterator<Item = (&str, &Component)> {
272        self.components
273            .iter()
274            .map(|(name, comp)| (name.as_str(), comp))
275    }
276
277    /// Iterate over all nested class definitions with their names.
278    ///
279    /// This includes functions, models, packages, types, etc. defined within this class.
280    pub fn iter_classes(&self) -> impl Iterator<Item = (&str, &ClassDefinition)> {
281        self.classes
282            .iter()
283            .map(|(name, class)| (name.as_str(), class))
284    }
285
286    /// Iterate over all equations (regular + initial).
287    ///
288    /// This chains `equations` and `initial_equations` into a single iterator.
289    pub fn iter_all_equations(&self) -> impl Iterator<Item = &Equation> {
290        self.equations.iter().chain(self.initial_equations.iter())
291    }
292
293    /// Iterate over all statements from all algorithm sections (regular + initial).
294    ///
295    /// This flattens all algorithm blocks and chains regular and initial algorithms.
296    pub fn iter_all_statements(&self) -> impl Iterator<Item = &Statement> {
297        self.algorithms
298            .iter()
299            .flatten()
300            .chain(self.initial_algorithms.iter().flatten())
301    }
302}
303
304#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
305
306pub struct Extend {
307    pub comp: Name,
308    /// Source location of the extends clause
309    pub location: Location,
310    /// Modifications applied to the extends clause (e.g., extends Foo(bar=1))
311    pub modifications: Vec<Expression>,
312}
313
314/// Import clause for bringing names into scope
315/// Modelica supports several import styles:
316/// - `import A.B.C;` - qualified import (use as C)
317/// - `import D = A.B.C;` - renamed import (use as D)
318/// - `import A.B.*;` - unqualified import (all names from A.B)
319/// - `import A.B.{C, D, E};` - selective import (specific names)
320#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
321pub enum Import {
322    /// Qualified import: `import A.B.C;` - imports C, accessed as C
323    Qualified { path: Name, location: Location },
324    /// Renamed import: `import D = A.B.C;` - imports C, accessed as D
325    Renamed {
326        alias: Token,
327        path: Name,
328        location: Location,
329    },
330    /// Unqualified import: `import A.B.*;` - imports all from A.B
331    Unqualified { path: Name, location: Location },
332    /// Selective import: `import A.B.{C, D};` - imports specific names
333    Selective {
334        path: Name,
335        names: Vec<Token>,
336        location: Location,
337    },
338}
339
340impl Import {
341    /// Get the base path for this import
342    pub fn base_path(&self) -> &Name {
343        match self {
344            Import::Qualified { path, .. } => path,
345            Import::Renamed { path, .. } => path,
346            Import::Unqualified { path, .. } => path,
347            Import::Selective { path, .. } => path,
348        }
349    }
350
351    /// Get the source location of this import
352    pub fn location(&self) -> &Location {
353        match self {
354            Import::Qualified { location, .. } => location,
355            Import::Renamed { location, .. } => location,
356            Import::Unqualified { location, .. } => location,
357            Import::Selective { location, .. } => location,
358        }
359    }
360}
361
362#[derive(Default, Clone, PartialEq, Serialize, Deserialize)]
363
364pub struct ComponentRefPart {
365    pub ident: Token,
366    pub subs: Option<Vec<Subscript>>,
367}
368
369impl Debug for ComponentRefPart {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        // Use Display for debug to keep formatting consistent
372        write!(f, "{}", self)
373    }
374}
375
376impl Display for ComponentRefPart {
377    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378        write!(f, "{}", self.ident.text)?;
379        if let Some(subs) = &self.subs {
380            write!(f, "[")?;
381            for (i, sub) in subs.iter().enumerate() {
382                if i > 0 {
383                    write!(f, ", ")?;
384                }
385                write!(f, "{}", sub)?;
386            }
387            write!(f, "]")?;
388        }
389        Ok(())
390    }
391}
392
393#[derive(Default, Clone, PartialEq, Serialize, Deserialize)]
394
395pub struct ComponentReference {
396    pub local: bool,
397    pub parts: Vec<ComponentRefPart>,
398}
399
400impl Display for ComponentReference {
401    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
402        for (i, part) in self.parts.iter().enumerate() {
403            if i > 0 {
404                write!(f, ".")?;
405            }
406            write!(f, "{}", part)?;
407        }
408        Ok(())
409    }
410}
411
412impl Debug for ComponentReference {
413    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
414        // Use Display for debug to keep formatting consistent
415        write!(f, "{}", self)
416    }
417}
418
419impl ComponentReference {
420    /// Get the source location of the first token in this component reference.
421    pub fn get_location(&self) -> Option<&Location> {
422        self.parts.first().map(|part| &part.ident.location)
423    }
424}
425
426#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
427
428pub struct EquationBlock {
429    pub cond: Expression,
430    pub eqs: Vec<Equation>,
431}
432
433#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
434
435pub struct StatementBlock {
436    pub cond: Expression,
437    pub stmts: Vec<Statement>,
438}
439
440#[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)]
441
442pub struct ForIndex {
443    pub ident: Token,
444    pub range: Expression,
445}
446
447#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
448
449pub enum Equation {
450    #[default]
451    Empty,
452    Simple {
453        lhs: Expression,
454        rhs: Expression,
455    },
456    Connect {
457        lhs: ComponentReference,
458        rhs: ComponentReference,
459    },
460    For {
461        indices: Vec<ForIndex>,
462        equations: Vec<Equation>,
463    },
464    When(Vec<EquationBlock>),
465    If {
466        cond_blocks: Vec<EquationBlock>,
467        else_block: Option<Vec<Equation>>,
468    },
469    FunctionCall {
470        comp: ComponentReference,
471        args: Vec<Expression>,
472    },
473}
474
475impl Equation {
476    /// Get the source location of the first token in this equation.
477    /// Returns None for Empty equations.
478    pub fn get_location(&self) -> Option<&Location> {
479        match self {
480            Equation::Empty => None,
481            Equation::Simple { lhs, .. } => lhs.get_location(),
482            Equation::Connect { lhs, .. } => lhs.get_location(),
483            Equation::For { indices, .. } => indices.first().map(|i| &i.ident.location),
484            Equation::When(blocks) => blocks.first().and_then(|b| b.cond.get_location()),
485            Equation::If { cond_blocks, .. } => {
486                cond_blocks.first().and_then(|b| b.cond.get_location())
487            }
488            Equation::FunctionCall { comp, .. } => comp.get_location(),
489        }
490    }
491}
492
493#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
494pub enum OpBinary {
495    #[default]
496    Empty,
497    Add(Token),
498    Sub(Token),
499    Mul(Token),
500    Div(Token),
501    Eq(Token),
502    Neq(Token),
503    Lt(Token),
504    Le(Token),
505    Gt(Token),
506    Ge(Token),
507    And(Token),
508    Or(Token),
509    Exp(Token),
510    AddElem(Token),
511    SubElem(Token),
512    MulElem(Token),
513    DivElem(Token),
514    /// Assignment/modification operator `=` (not equality `==`)
515    /// Used in named arguments and modifications like `annotation(HideResult=true)`
516    Assign(Token),
517}
518
519impl Display for OpBinary {
520    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521        match self {
522            OpBinary::Empty => write!(f, ""),
523            OpBinary::Add(_) => write!(f, "+"),
524            OpBinary::Sub(_) => write!(f, "-"),
525            OpBinary::Mul(_) => write!(f, "*"),
526            OpBinary::Div(_) => write!(f, "/"),
527            OpBinary::Eq(_) => write!(f, "=="),
528            OpBinary::Neq(_) => write!(f, "<>"),
529            OpBinary::Lt(_) => write!(f, "<"),
530            OpBinary::Le(_) => write!(f, "<="),
531            OpBinary::Gt(_) => write!(f, ">"),
532            OpBinary::Ge(_) => write!(f, ">="),
533            OpBinary::And(_) => write!(f, "and"),
534            OpBinary::Or(_) => write!(f, "or"),
535            OpBinary::Exp(_) => write!(f, "^"),
536            OpBinary::AddElem(_) => write!(f, ".+"),
537            OpBinary::SubElem(_) => write!(f, ".-"),
538            OpBinary::MulElem(_) => write!(f, ".*"),
539            OpBinary::DivElem(_) => write!(f, "./"),
540            OpBinary::Assign(_) => write!(f, "="),
541        }
542    }
543}
544
545#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
546pub enum OpUnary {
547    #[default]
548    Empty,
549    Minus(Token),
550    Plus(Token),
551    DotMinus(Token),
552    DotPlus(Token),
553    Not(Token),
554}
555
556impl Display for OpUnary {
557    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
558        match self {
559            OpUnary::Empty => write!(f, ""),
560            OpUnary::Minus(_) => write!(f, "-"),
561            OpUnary::Plus(_) => write!(f, "+"),
562            OpUnary::DotMinus(_) => write!(f, ".-"),
563            OpUnary::DotPlus(_) => write!(f, ".+"),
564            OpUnary::Not(_) => write!(f, "not "),
565        }
566    }
567}
568
569#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
570pub enum TerminalType {
571    #[default]
572    Empty,
573    UnsignedReal,
574    UnsignedInteger,
575    String,
576    Bool,
577    End,
578}
579
580#[derive(Default, Clone, PartialEq, Serialize, Deserialize)]
581
582pub enum Expression {
583    #[default]
584    Empty,
585    Range {
586        start: Box<Expression>,
587        step: Option<Box<Expression>>,
588        end: Box<Expression>,
589    },
590    Unary {
591        op: OpUnary,
592        rhs: Box<Expression>,
593    },
594    Binary {
595        op: OpBinary,
596        lhs: Box<Expression>,
597        rhs: Box<Expression>,
598    },
599    Terminal {
600        terminal_type: TerminalType,
601        token: Token,
602    },
603    ComponentReference(ComponentReference),
604    FunctionCall {
605        comp: ComponentReference,
606        args: Vec<Expression>,
607    },
608    Array {
609        elements: Vec<Expression>,
610        /// True if original syntax was `[a;b]` matrix notation, false for `{a,b}` array notation
611        is_matrix: bool,
612    },
613    /// Tuple expression for multi-output function calls: (a, b) = func()
614    Tuple {
615        elements: Vec<Expression>,
616    },
617    /// If expression: if cond then expr elseif cond2 then expr2 else expr3
618    If {
619        /// List of (condition, expression) pairs for if and elseif branches
620        branches: Vec<(Expression, Expression)>,
621        /// The else branch expression
622        else_branch: Box<Expression>,
623    },
624    /// Parenthesized expression to preserve explicit parentheses from source
625    Parenthesized {
626        inner: Box<Expression>,
627    },
628    /// Array comprehension: {expr for i in range}
629    ArrayComprehension {
630        expr: Box<Expression>,
631        indices: Vec<ForIndex>,
632    },
633}
634
635impl Debug for Expression {
636    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
637        match self {
638            Expression::Empty => write!(f, "Empty"),
639            Expression::Range { start, step, end } => f
640                .debug_struct("Range")
641                .field("start", start)
642                .field("step", step)
643                .field("end", end)
644                .finish(),
645            Expression::ComponentReference(comp) => write!(f, "{:?}", comp),
646            Expression::FunctionCall { comp, args } => f
647                .debug_struct("FunctionCall")
648                .field("comp", comp)
649                .field("args", args)
650                .finish(),
651            Expression::Binary { op, lhs, rhs } => f
652                .debug_struct(&format!("{:?}", op))
653                .field("lhs", lhs)
654                .field("rhs", rhs)
655                .finish(),
656            Expression::Unary { op, rhs } => f
657                .debug_struct(&format!("{:?}", op))
658                .field("rhs", rhs)
659                .finish(),
660            Expression::Terminal {
661                terminal_type,
662                token,
663            } => write!(f, "{:?}({:?})", terminal_type, token),
664            Expression::Array { elements, .. } => f.debug_list().entries(elements.iter()).finish(),
665            Expression::Tuple { elements } => {
666                write!(f, "(")?;
667                for (i, e) in elements.iter().enumerate() {
668                    if i > 0 {
669                        write!(f, ", ")?;
670                    }
671                    write!(f, "{:?}", e)?;
672                }
673                write!(f, ")")
674            }
675            Expression::If {
676                branches,
677                else_branch,
678            } => {
679                write!(f, "if ")?;
680                for (i, (cond, expr)) in branches.iter().enumerate() {
681                    if i > 0 {
682                        write!(f, " elseif ")?;
683                    }
684                    write!(f, "{:?} then {:?}", cond, expr)?;
685                }
686                write!(f, " else {:?}", else_branch)
687            }
688            Expression::Parenthesized { inner } => {
689                write!(f, "({:?})", inner)
690            }
691            Expression::ArrayComprehension { expr, indices } => {
692                write!(f, "{{{{ {:?} for {:?} }}}}", expr, indices)
693            }
694        }
695    }
696}
697
698impl Expression {
699    /// Get the source location of the first token in this expression.
700    /// Returns None for Empty expressions.
701    pub fn get_location(&self) -> Option<&Location> {
702        match self {
703            Expression::Empty => None,
704            Expression::Range { start, .. } => start.get_location(),
705            Expression::Unary { rhs, .. } => rhs.get_location(),
706            Expression::Binary { lhs, .. } => lhs.get_location(),
707            Expression::Terminal { token, .. } => Some(&token.location),
708            Expression::ComponentReference(comp) => {
709                comp.parts.first().map(|part| &part.ident.location)
710            }
711            Expression::FunctionCall { comp, .. } => {
712                comp.parts.first().map(|part| &part.ident.location)
713            }
714            Expression::Array { elements, .. } => elements.first().and_then(|e| e.get_location()),
715            Expression::Tuple { elements } => elements.first().and_then(|e| e.get_location()),
716            Expression::If { branches, .. } => {
717                branches.first().and_then(|(cond, _)| cond.get_location())
718            }
719            Expression::Parenthesized { inner } => inner.get_location(),
720            Expression::ArrayComprehension { expr, .. } => expr.get_location(),
721        }
722    }
723}
724
725impl std::fmt::Display for Expression {
726    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
727        match self {
728            Expression::Empty => write!(f, ""),
729            Expression::Range { start, step, end } => {
730                if let Some(s) = step {
731                    write!(f, "{}:{}:{}", start, s, end)
732                } else {
733                    write!(f, "{}:{}", start, end)
734                }
735            }
736            Expression::Unary { op, rhs } => {
737                let op_str = match op {
738                    OpUnary::Minus(_) => "-",
739                    OpUnary::Plus(_) => "+",
740                    OpUnary::DotMinus(_) => ".-",
741                    OpUnary::DotPlus(_) => ".+",
742                    OpUnary::Not(_) => "not ",
743                    OpUnary::Empty => "",
744                };
745                write!(f, "{}{}", op_str, rhs)
746            }
747            Expression::Binary { op, lhs, rhs } => {
748                let op_str = match op {
749                    OpBinary::Add(_) => "+",
750                    OpBinary::Sub(_) => "-",
751                    OpBinary::Mul(_) => "*",
752                    OpBinary::Div(_) => "/",
753                    OpBinary::Eq(_) => "==",
754                    OpBinary::Neq(_) => "<>",
755                    OpBinary::Lt(_) => "<",
756                    OpBinary::Le(_) => "<=",
757                    OpBinary::Gt(_) => ">",
758                    OpBinary::Ge(_) => ">=",
759                    OpBinary::And(_) => "and",
760                    OpBinary::Or(_) => "or",
761                    OpBinary::Exp(_) => "^",
762                    OpBinary::AddElem(_) => ".+",
763                    OpBinary::SubElem(_) => ".-",
764                    OpBinary::MulElem(_) => ".*",
765                    OpBinary::DivElem(_) => "./",
766                    OpBinary::Assign(_) => "=",
767                    OpBinary::Empty => "?",
768                };
769                write!(f, "{} {} {}", lhs, op_str, rhs)
770            }
771            Expression::Terminal {
772                terminal_type,
773                token,
774            } => match terminal_type {
775                TerminalType::String => write!(f, "\"{}\"", token.text),
776                TerminalType::Bool => write!(f, "{}", token.text),
777                _ => write!(f, "{}", token.text),
778            },
779            Expression::ComponentReference(comp) => write!(f, "{}", comp),
780            Expression::FunctionCall { comp, args } => {
781                write!(f, "{}(", comp)?;
782                for (i, arg) in args.iter().enumerate() {
783                    if i > 0 {
784                        write!(f, ", ")?;
785                    }
786                    write!(f, "{}", arg)?;
787                }
788                write!(f, ")")
789            }
790            Expression::Array { elements, .. } => {
791                write!(f, "{{")?;
792                for (i, e) in elements.iter().enumerate() {
793                    if i > 0 {
794                        write!(f, ", ")?;
795                    }
796                    write!(f, "{}", e)?;
797                }
798                write!(f, "}}")
799            }
800            Expression::Tuple { elements } => {
801                write!(f, "(")?;
802                for (i, e) in elements.iter().enumerate() {
803                    if i > 0 {
804                        write!(f, ", ")?;
805                    }
806                    write!(f, "{}", e)?;
807                }
808                write!(f, ")")
809            }
810            Expression::If {
811                branches,
812                else_branch,
813            } => {
814                write!(f, "if ")?;
815                for (i, (cond, expr)) in branches.iter().enumerate() {
816                    if i > 0 {
817                        write!(f, " elseif ")?;
818                    }
819                    write!(f, "{} then {}", cond, expr)?;
820                }
821                write!(f, " else {}", else_branch)
822            }
823            Expression::Parenthesized { inner } => write!(f, "({})", inner),
824            Expression::ArrayComprehension { expr, indices } => {
825                write!(f, "{{ {} for ", expr)?;
826                for (i, idx) in indices.iter().enumerate() {
827                    if i > 0 {
828                        write!(f, ", ")?;
829                    }
830                    write!(f, "{} in {}", idx.ident.text, idx.range)?;
831                }
832                write!(f, " }}")
833            }
834        }
835    }
836}
837
838impl std::fmt::Display for Equation {
839    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
840        match self {
841            Equation::Empty => write!(f, ""),
842            Equation::Simple { lhs, rhs } => write!(f, "{} = {}", lhs, rhs),
843            Equation::Connect { lhs, rhs } => write!(f, "connect({}, {})", lhs, rhs),
844            Equation::For { indices, equations } => {
845                write!(f, "for ")?;
846                for (i, idx) in indices.iter().enumerate() {
847                    if i > 0 {
848                        write!(f, ", ")?;
849                    }
850                    write!(f, "{} in {}", idx.ident.text, idx.range)?;
851                }
852                writeln!(f, " loop")?;
853                for eq in equations {
854                    writeln!(f, "  {};", eq)?;
855                }
856                write!(f, "end for")
857            }
858            Equation::When(blocks) => {
859                for (i, block) in blocks.iter().enumerate() {
860                    if i == 0 {
861                        write!(f, "when {} then", block.cond)?;
862                    } else {
863                        write!(f, " elsewhen {} then", block.cond)?;
864                    }
865                    for eq in &block.eqs {
866                        write!(f, " {};", eq)?;
867                    }
868                }
869                write!(f, " end when")
870            }
871            Equation::If {
872                cond_blocks,
873                else_block,
874            } => {
875                for (i, block) in cond_blocks.iter().enumerate() {
876                    if i == 0 {
877                        write!(f, "if {} then", block.cond)?;
878                    } else {
879                        write!(f, " elseif {} then", block.cond)?;
880                    }
881                    for eq in &block.eqs {
882                        write!(f, " {};", eq)?;
883                    }
884                }
885                if let Some(else_eqs) = else_block {
886                    write!(f, " else")?;
887                    for eq in else_eqs {
888                        write!(f, " {};", eq)?;
889                    }
890                }
891                write!(f, " end if")
892            }
893            Equation::FunctionCall { comp, args } => {
894                write!(f, "{}(", comp)?;
895                for (i, arg) in args.iter().enumerate() {
896                    if i > 0 {
897                        write!(f, ", ")?;
898                    }
899                    write!(f, "{}", arg)?;
900                }
901                write!(f, ")")
902            }
903        }
904    }
905}
906
907#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
908
909pub enum Statement {
910    #[default]
911    Empty,
912    Assignment {
913        comp: ComponentReference,
914        value: Expression,
915    },
916    Return {
917        token: Token,
918    },
919    Break {
920        token: Token,
921    },
922    For {
923        indices: Vec<ForIndex>,
924        equations: Vec<Statement>,
925    },
926    While(StatementBlock),
927    /// If statement: if cond then stmts elseif cond2 then stmts2 else stmts3
928    If {
929        cond_blocks: Vec<StatementBlock>,
930        else_block: Option<Vec<Statement>>,
931    },
932    /// When statement: when cond then stmts elsewhen cond2 then stmts2
933    When(Vec<StatementBlock>),
934    /// Function call statement, optionally with output assignments
935    /// For `(a, b) := func(x)`, outputs contains [a, b]
936    FunctionCall {
937        comp: ComponentReference,
938        args: Vec<Expression>,
939        /// Output variables being assigned (for `(a, b) := func(x)` style calls)
940        outputs: Vec<Expression>,
941    },
942}
943
944#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
945
946pub enum Subscript {
947    #[default]
948    Empty,
949    Expression(Expression),
950    Range {
951        token: Token,
952    },
953}
954
955impl Display for Subscript {
956    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
957        match self {
958            Subscript::Empty => write!(f, ""),
959            Subscript::Expression(expr) => write!(f, "{}", expr),
960            Subscript::Range { .. } => write!(f, ":"),
961        }
962    }
963}
964
965#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
966
967pub enum Variability {
968    #[default]
969    Empty,
970    Constant(Token),
971    Discrete(Token),
972    Parameter(Token),
973}
974
975#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
976
977pub enum Connection {
978    #[default]
979    Empty,
980    Flow(Token),
981    Stream(Token),
982}
983
984#[derive(Debug, Default, Clone, PartialEq, Serialize, Deserialize)]
985
986pub enum Causality {
987    #[default]
988    Empty,
989    Input(Token),
990    Output(Token),
991}