miden_assembly/ast/
constants.rs

1use alloc::{boxed::Box, string::String, sync::Arc};
2use core::fmt;
3
4use vm_core::FieldElement;
5
6use super::DocString;
7use crate::{Felt, SourceSpan, Span, Spanned, ast::Ident, parser::ParsingError};
8
9// CONSTANT
10// ================================================================================================
11
12/// Represents a constant definition in Miden Assembly syntax, i.e. `const.FOO = 1 + 1`.
13pub struct Constant {
14    /// The source span of the definition.
15    pub span: SourceSpan,
16    /// The documentation string attached to this definition.
17    pub docs: Option<DocString>,
18    /// The name of the constant.
19    pub name: Ident,
20    /// The expression associated with the constant.
21    pub value: ConstantExpr,
22}
23
24impl Constant {
25    /// Creates a new [Constant] from the given source span, name, and value.
26    pub fn new(span: SourceSpan, name: Ident, value: ConstantExpr) -> Self {
27        Self { span, docs: None, name, value }
28    }
29
30    /// Adds documentation to this constant declaration.
31    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
32        self.docs = docs.map(DocString::new);
33        self
34    }
35}
36
37impl fmt::Debug for Constant {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        f.debug_struct("Constant")
40            .field("docs", &self.docs)
41            .field("name", &self.name)
42            .field("value", &self.value)
43            .finish()
44    }
45}
46
47impl crate::prettier::PrettyPrint for Constant {
48    fn render(&self) -> crate::prettier::Document {
49        use crate::prettier::*;
50
51        let mut doc = self
52            .docs
53            .as_ref()
54            .map(|docstring| docstring.render())
55            .unwrap_or(Document::Empty);
56
57        doc += flatten(const_text("const") + const_text(".") + display(&self.name));
58        doc += const_text("=");
59
60        doc + self.value.render()
61    }
62}
63
64impl Eq for Constant {}
65
66impl PartialEq for Constant {
67    fn eq(&self, other: &Self) -> bool {
68        self.name == other.name && self.value == other.value
69    }
70}
71
72impl Spanned for Constant {
73    fn span(&self) -> SourceSpan {
74        self.span
75    }
76}
77
78// CONSTANT EXPRESSION
79// ================================================================================================
80
81/// Represents a constant expression or value in Miden Assembly syntax.
82#[derive(Clone)]
83pub enum ConstantExpr {
84    /// A literal integer value.
85    Literal(Span<Felt>),
86    /// A reference to another constant.
87    Var(Ident),
88    /// An binary arithmetic operator.
89    BinaryOp {
90        span: SourceSpan,
91        op: ConstantOp,
92        lhs: Box<ConstantExpr>,
93        rhs: Box<ConstantExpr>,
94    },
95    String(Ident),
96}
97
98impl ConstantExpr {
99    /// Unwrap a literal value from this expression or panic.
100    ///
101    /// This is used in places where we expect the expression to have been folded to a value,
102    /// otherwise a bug occurred.
103    #[track_caller]
104    pub fn expect_literal(&self) -> Felt {
105        match self {
106            Self::Literal(spanned) => spanned.into_inner(),
107            other => panic!("expected constant expression to be a literal, got {other:#?}"),
108        }
109    }
110
111    pub fn expect_string(&self) -> Arc<str> {
112        match self {
113            Self::String(spanned) => spanned.clone().into_inner(),
114            other => panic!("expected constant expression to be a string, got {other:#?}"),
115        }
116    }
117
118    /// Attempt to fold to a single value.
119    ///
120    /// This will only succeed if the expression has no references to other constants.
121    ///
122    /// # Errors
123    /// Returns an error if an invalid expression is found while folding, such as division by zero.
124    pub fn try_fold(self) -> Result<Self, ParsingError> {
125        match self {
126            Self::String(_) | Self::Literal(_) | Self::Var(_) => Ok(self),
127            Self::BinaryOp { span, op, lhs, rhs } => {
128                if rhs.is_literal() {
129                    let rhs = Self::into_inner(rhs).try_fold()?;
130                    match rhs {
131                        Self::String(ident) => {
132                            Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
133                        },
134                        Self::Literal(rhs) => {
135                            let lhs = Self::into_inner(lhs).try_fold()?;
136                            match lhs {
137                                Self::String(ident) => {
138                                    Err(ParsingError::StringInArithmeticExpression {
139                                        span: ident.span(),
140                                    })
141                                },
142                                Self::Literal(lhs) => {
143                                    let lhs = lhs.into_inner();
144                                    let rhs = rhs.into_inner();
145                                    let is_division =
146                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
147                                    let is_division_by_zero = is_division && rhs == Felt::ZERO;
148                                    if is_division_by_zero {
149                                        return Err(ParsingError::DivisionByZero { span });
150                                    }
151                                    match op {
152                                        ConstantOp::Add => {
153                                            Ok(Self::Literal(Span::new(span, lhs + rhs)))
154                                        },
155                                        ConstantOp::Sub => {
156                                            Ok(Self::Literal(Span::new(span, lhs - rhs)))
157                                        },
158                                        ConstantOp::Mul => {
159                                            Ok(Self::Literal(Span::new(span, lhs * rhs)))
160                                        },
161                                        ConstantOp::Div => {
162                                            Ok(Self::Literal(Span::new(span, lhs / rhs)))
163                                        },
164                                        ConstantOp::IntDiv => Ok(Self::Literal(Span::new(
165                                            span,
166                                            Felt::new(lhs.as_int() / rhs.as_int()),
167                                        ))),
168                                    }
169                                },
170                                lhs => Ok(Self::BinaryOp {
171                                    span,
172                                    op,
173                                    lhs: Box::new(lhs),
174                                    rhs: Box::new(Self::Literal(rhs)),
175                                }),
176                            }
177                        },
178                        rhs => {
179                            let lhs = Self::into_inner(lhs).try_fold()?;
180                            Ok(Self::BinaryOp {
181                                span,
182                                op,
183                                lhs: Box::new(lhs),
184                                rhs: Box::new(rhs),
185                            })
186                        },
187                    }
188                } else {
189                    let lhs = Self::into_inner(lhs).try_fold()?;
190                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
191                }
192            },
193        }
194    }
195
196    fn is_literal(&self) -> bool {
197        match self {
198            Self::Literal(_) | Self::String(_) => true,
199            Self::Var(_) => false,
200            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
201        }
202    }
203
204    #[inline(always)]
205    #[allow(clippy::boxed_local)]
206    fn into_inner(self: Box<Self>) -> Self {
207        *self
208    }
209}
210
211impl Eq for ConstantExpr {}
212
213impl PartialEq for ConstantExpr {
214    fn eq(&self, other: &Self) -> bool {
215        match (self, other) {
216            (Self::Literal(l), Self::Literal(y)) => l == y,
217            (Self::Var(l), Self::Var(y)) => l == y,
218            (
219                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
220                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
221            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
222            _ => false,
223        }
224    }
225}
226
227impl fmt::Debug for ConstantExpr {
228    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
229        match self {
230            Self::Literal(lit) => fmt::Debug::fmt(&**lit, f),
231            Self::Var(name) | Self::String(name) => fmt::Debug::fmt(&**name, f),
232            Self::BinaryOp { op, lhs, rhs, .. } => {
233                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
234            },
235        }
236    }
237}
238
239impl crate::prettier::PrettyPrint for ConstantExpr {
240    fn render(&self) -> crate::prettier::Document {
241        use crate::prettier::*;
242
243        match self {
244            Self::Literal(literal) => display(literal),
245            Self::Var(ident) | Self::String(ident) => display(ident),
246            Self::BinaryOp { op, lhs, rhs, .. } => {
247                let single_line = lhs.render() + display(op) + rhs.render();
248                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
249                single_line | multi_line
250            },
251        }
252    }
253}
254
255impl Spanned for ConstantExpr {
256    fn span(&self) -> SourceSpan {
257        match self {
258            Self::Literal(spanned) => spanned.span(),
259            Self::Var(spanned) | Self::String(spanned) => spanned.span(),
260            Self::BinaryOp { span, .. } => *span,
261        }
262    }
263}
264
265// CONSTANT OPERATION
266// ================================================================================================
267
268/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
269#[derive(Debug, Copy, Clone, PartialEq, Eq)]
270pub enum ConstantOp {
271    Add,
272    Sub,
273    Mul,
274    Div,
275    IntDiv,
276}
277
278impl ConstantOp {
279    const fn name(&self) -> &'static str {
280        match self {
281            Self::Add => "Add",
282            Self::Sub => "Sub",
283            Self::Mul => "Mul",
284            Self::Div => "Div",
285            Self::IntDiv => "IntDiv",
286        }
287    }
288}
289
290impl fmt::Display for ConstantOp {
291    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
292        match self {
293            Self::Add => f.write_str("+"),
294            Self::Sub => f.write_str("-"),
295            Self::Mul => f.write_str("*"),
296            Self::Div => f.write_str("/"),
297            Self::IntDiv => f.write_str("//"),
298        }
299    }
300}