miden_assembly/ast/
constants.rs

1use alloc::{boxed::Box, string::String};
2use core::fmt;
3
4use vm_core::FieldElement;
5
6use crate::{Felt, SourceSpan, Span, Spanned, ast::Ident, parser::ParsingError};
7
8// CONSTANT
9// ================================================================================================
10
11/// Represents a constant definition in Miden Assembly syntax, i.e. `const.FOO = 1 + 1`.
12pub struct Constant {
13    /// The source span of the definition.
14    pub span: SourceSpan,
15    /// The documentation string attached to this definition.
16    pub docs: Option<Span<String>>,
17    /// The name of the constant.
18    pub name: Ident,
19    /// The expression associated with the constant.
20    pub value: ConstantExpr,
21}
22
23impl Constant {
24    /// Creates a new [Constant] from the given source span, name, and value.
25    pub fn new(span: SourceSpan, name: Ident, value: ConstantExpr) -> Self {
26        Self { span, docs: None, name, value }
27    }
28
29    /// Adds documentation to this constant declaration.
30    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
31        self.docs = docs;
32        self
33    }
34}
35
36impl fmt::Debug for Constant {
37    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
38        f.debug_struct("Constant")
39            .field("docs", &self.docs)
40            .field("name", &self.name)
41            .field("value", &self.value)
42            .finish()
43    }
44}
45
46impl crate::prettier::PrettyPrint for Constant {
47    fn render(&self) -> crate::prettier::Document {
48        use crate::prettier::*;
49
50        let mut doc = Document::Empty;
51        if let Some(docs) = self.docs.as_ref() {
52            let fragment =
53                docs.lines().map(text).reduce(|acc, line| acc + nl() + const_text("#! ") + line);
54
55            if let Some(fragment) = fragment {
56                doc += fragment;
57            }
58        }
59
60        doc += nl();
61        doc += flatten(const_text("const") + const_text(".") + display(&self.name));
62        doc += const_text("=");
63
64        doc + self.value.render()
65    }
66}
67
68impl Eq for Constant {}
69
70impl PartialEq for Constant {
71    fn eq(&self, other: &Self) -> bool {
72        self.name == other.name && self.value == other.value
73    }
74}
75
76impl Spanned for Constant {
77    fn span(&self) -> SourceSpan {
78        self.span
79    }
80}
81
82// CONSTANT EXPRESSION
83// ================================================================================================
84
85/// Represents a constant expression or value in Miden Assembly syntax.
86pub enum ConstantExpr {
87    /// A literal integer value.
88    Literal(Span<Felt>),
89    /// A reference to another constant.
90    Var(Ident),
91    /// An binary arithmetic operator.
92    BinaryOp {
93        span: SourceSpan,
94        op: ConstantOp,
95        lhs: Box<ConstantExpr>,
96        rhs: Box<ConstantExpr>,
97    },
98}
99
100impl ConstantExpr {
101    /// Unwrap a literal value from this expression or panic.
102    ///
103    /// This is used in places where we expect the expression to have been folded to a value,
104    /// otherwise a bug occurred.
105    #[track_caller]
106    pub fn expect_literal(&self) -> Felt {
107        match self {
108            Self::Literal(spanned) => spanned.into_inner(),
109            other => panic!("expected constant expression to be a literal, got {other:#?}"),
110        }
111    }
112
113    /// Attempt to fold to a single value.
114    ///
115    /// This will only succeed if the expression has no references to other constants.
116    ///
117    /// # Errors
118    /// Returns an error if an invalid expression is found while folding, such as division by zero.
119    pub fn try_fold(self) -> Result<Self, ParsingError> {
120        match self {
121            Self::Literal(_) | Self::Var(_) => Ok(self),
122            Self::BinaryOp { span, op, lhs, rhs } => {
123                if rhs.is_literal() {
124                    let rhs = Self::into_inner(rhs).try_fold()?;
125                    match rhs {
126                        Self::Literal(rhs) => {
127                            let lhs = Self::into_inner(lhs).try_fold()?;
128                            match lhs {
129                                Self::Literal(lhs) => {
130                                    let lhs = lhs.into_inner();
131                                    let rhs = rhs.into_inner();
132                                    let is_division =
133                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
134                                    let is_division_by_zero = is_division && rhs == Felt::ZERO;
135                                    if is_division_by_zero {
136                                        return Err(ParsingError::DivisionByZero { span });
137                                    }
138                                    match op {
139                                        ConstantOp::Add => {
140                                            Ok(Self::Literal(Span::new(span, lhs + rhs)))
141                                        },
142                                        ConstantOp::Sub => {
143                                            Ok(Self::Literal(Span::new(span, lhs - rhs)))
144                                        },
145                                        ConstantOp::Mul => {
146                                            Ok(Self::Literal(Span::new(span, lhs * rhs)))
147                                        },
148                                        ConstantOp::Div => {
149                                            Ok(Self::Literal(Span::new(span, lhs / rhs)))
150                                        },
151                                        ConstantOp::IntDiv => Ok(Self::Literal(Span::new(
152                                            span,
153                                            Felt::new(lhs.as_int() / rhs.as_int()),
154                                        ))),
155                                    }
156                                },
157                                lhs => Ok(Self::BinaryOp {
158                                    span,
159                                    op,
160                                    lhs: Box::new(lhs),
161                                    rhs: Box::new(Self::Literal(rhs)),
162                                }),
163                            }
164                        },
165                        rhs => {
166                            let lhs = Self::into_inner(lhs).try_fold()?;
167                            Ok(Self::BinaryOp {
168                                span,
169                                op,
170                                lhs: Box::new(lhs),
171                                rhs: Box::new(rhs),
172                            })
173                        },
174                    }
175                } else {
176                    let lhs = Self::into_inner(lhs).try_fold()?;
177                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
178                }
179            },
180        }
181    }
182
183    fn is_literal(&self) -> bool {
184        match self {
185            Self::Literal(_) => true,
186            Self::Var(_) => false,
187            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
188        }
189    }
190
191    #[inline(always)]
192    #[allow(clippy::boxed_local)]
193    fn into_inner(self: Box<Self>) -> Self {
194        *self
195    }
196}
197
198impl Eq for ConstantExpr {}
199
200impl PartialEq for ConstantExpr {
201    fn eq(&self, other: &Self) -> bool {
202        match (self, other) {
203            (Self::Literal(l), Self::Literal(y)) => l == y,
204            (Self::Var(l), Self::Var(y)) => l == y,
205            (
206                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
207                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
208            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
209            _ => false,
210        }
211    }
212}
213
214impl fmt::Debug for ConstantExpr {
215    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
216        match self {
217            Self::Literal(lit) => fmt::Debug::fmt(&**lit, f),
218            Self::Var(name) => fmt::Debug::fmt(&**name, f),
219            Self::BinaryOp { op, lhs, rhs, .. } => {
220                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
221            },
222        }
223    }
224}
225
226impl crate::prettier::PrettyPrint for ConstantExpr {
227    fn render(&self) -> crate::prettier::Document {
228        use crate::prettier::*;
229
230        match self {
231            Self::Literal(literal) => display(literal),
232            Self::Var(ident) => display(ident),
233            Self::BinaryOp { op, lhs, rhs, .. } => {
234                let single_line = lhs.render() + display(op) + rhs.render();
235                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
236                single_line | multi_line
237            },
238        }
239    }
240}
241
242impl Spanned for ConstantExpr {
243    fn span(&self) -> SourceSpan {
244        match self {
245            Self::Literal(spanned) => spanned.span(),
246            Self::Var(spanned) => spanned.span(),
247            Self::BinaryOp { span, .. } => *span,
248        }
249    }
250}
251
252// CONSTANT OPERATION
253// ================================================================================================
254
255/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
256#[derive(Debug, Copy, Clone, PartialEq, Eq)]
257pub enum ConstantOp {
258    Add,
259    Sub,
260    Mul,
261    Div,
262    IntDiv,
263}
264
265impl ConstantOp {
266    const fn name(&self) -> &'static str {
267        match self {
268            Self::Add => "Add",
269            Self::Sub => "Sub",
270            Self::Mul => "Mul",
271            Self::Div => "Div",
272            Self::IntDiv => "IntDiv",
273        }
274    }
275}
276
277impl fmt::Display for ConstantOp {
278    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
279        match self {
280            Self::Add => f.write_str("+"),
281            Self::Sub => f.write_str("-"),
282            Self::Mul => f.write_str("*"),
283            Self::Div => f.write_str("/"),
284            Self::IntDiv => f.write_str("//"),
285        }
286    }
287}