miden_assembly/ast/
constants.rs

1use alloc::{boxed::Box, string::String};
2use core::fmt;
3
4use vm_core::FieldElement;
5
6use super::DocString;
7use crate::{Felt, SourceSpan, Span, Spanned, ast::Ident, parser::ParsingError};
8
9// CONSTANT
10// ================================================================================================
11
12/// Represents a constant definition in Miden Assembly syntax, i.e. `const.FOO = 1 + 1`.
13pub struct Constant {
14    /// The source span of the definition.
15    pub span: SourceSpan,
16    /// The documentation string attached to this definition.
17    pub docs: Option<DocString>,
18    /// The name of the constant.
19    pub name: Ident,
20    /// The expression associated with the constant.
21    pub value: ConstantExpr,
22}
23
24impl Constant {
25    /// Creates a new [Constant] from the given source span, name, and value.
26    pub fn new(span: SourceSpan, name: Ident, value: ConstantExpr) -> Self {
27        Self { span, docs: None, name, value }
28    }
29
30    /// Adds documentation to this constant declaration.
31    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
32        self.docs = docs.map(DocString::new);
33        self
34    }
35}
36
37impl fmt::Debug for Constant {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        f.debug_struct("Constant")
40            .field("docs", &self.docs)
41            .field("name", &self.name)
42            .field("value", &self.value)
43            .finish()
44    }
45}
46
47impl crate::prettier::PrettyPrint for Constant {
48    fn render(&self) -> crate::prettier::Document {
49        use crate::prettier::*;
50
51        let mut doc = self
52            .docs
53            .as_ref()
54            .map(|docstring| docstring.render())
55            .unwrap_or(Document::Empty);
56
57        doc += flatten(const_text("const") + const_text(".") + display(&self.name));
58        doc += const_text("=");
59
60        doc + self.value.render()
61    }
62}
63
64impl Eq for Constant {}
65
66impl PartialEq for Constant {
67    fn eq(&self, other: &Self) -> bool {
68        self.name == other.name && self.value == other.value
69    }
70}
71
72impl Spanned for Constant {
73    fn span(&self) -> SourceSpan {
74        self.span
75    }
76}
77
78// CONSTANT EXPRESSION
79// ================================================================================================
80
81/// Represents a constant expression or value in Miden Assembly syntax.
82pub enum ConstantExpr {
83    /// A literal integer value.
84    Literal(Span<Felt>),
85    /// A reference to another constant.
86    Var(Ident),
87    /// An binary arithmetic operator.
88    BinaryOp {
89        span: SourceSpan,
90        op: ConstantOp,
91        lhs: Box<ConstantExpr>,
92        rhs: Box<ConstantExpr>,
93    },
94}
95
96impl ConstantExpr {
97    /// Unwrap a literal value from this expression or panic.
98    ///
99    /// This is used in places where we expect the expression to have been folded to a value,
100    /// otherwise a bug occurred.
101    #[track_caller]
102    pub fn expect_literal(&self) -> Felt {
103        match self {
104            Self::Literal(spanned) => spanned.into_inner(),
105            other => panic!("expected constant expression to be a literal, got {other:#?}"),
106        }
107    }
108
109    /// Attempt to fold to a single value.
110    ///
111    /// This will only succeed if the expression has no references to other constants.
112    ///
113    /// # Errors
114    /// Returns an error if an invalid expression is found while folding, such as division by zero.
115    pub fn try_fold(self) -> Result<Self, ParsingError> {
116        match self {
117            Self::Literal(_) | Self::Var(_) => Ok(self),
118            Self::BinaryOp { span, op, lhs, rhs } => {
119                if rhs.is_literal() {
120                    let rhs = Self::into_inner(rhs).try_fold()?;
121                    match rhs {
122                        Self::Literal(rhs) => {
123                            let lhs = Self::into_inner(lhs).try_fold()?;
124                            match lhs {
125                                Self::Literal(lhs) => {
126                                    let lhs = lhs.into_inner();
127                                    let rhs = rhs.into_inner();
128                                    let is_division =
129                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
130                                    let is_division_by_zero = is_division && rhs == Felt::ZERO;
131                                    if is_division_by_zero {
132                                        return Err(ParsingError::DivisionByZero { span });
133                                    }
134                                    match op {
135                                        ConstantOp::Add => {
136                                            Ok(Self::Literal(Span::new(span, lhs + rhs)))
137                                        },
138                                        ConstantOp::Sub => {
139                                            Ok(Self::Literal(Span::new(span, lhs - rhs)))
140                                        },
141                                        ConstantOp::Mul => {
142                                            Ok(Self::Literal(Span::new(span, lhs * rhs)))
143                                        },
144                                        ConstantOp::Div => {
145                                            Ok(Self::Literal(Span::new(span, lhs / rhs)))
146                                        },
147                                        ConstantOp::IntDiv => Ok(Self::Literal(Span::new(
148                                            span,
149                                            Felt::new(lhs.as_int() / rhs.as_int()),
150                                        ))),
151                                    }
152                                },
153                                lhs => Ok(Self::BinaryOp {
154                                    span,
155                                    op,
156                                    lhs: Box::new(lhs),
157                                    rhs: Box::new(Self::Literal(rhs)),
158                                }),
159                            }
160                        },
161                        rhs => {
162                            let lhs = Self::into_inner(lhs).try_fold()?;
163                            Ok(Self::BinaryOp {
164                                span,
165                                op,
166                                lhs: Box::new(lhs),
167                                rhs: Box::new(rhs),
168                            })
169                        },
170                    }
171                } else {
172                    let lhs = Self::into_inner(lhs).try_fold()?;
173                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
174                }
175            },
176        }
177    }
178
179    fn is_literal(&self) -> bool {
180        match self {
181            Self::Literal(_) => true,
182            Self::Var(_) => false,
183            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
184        }
185    }
186
187    #[inline(always)]
188    #[allow(clippy::boxed_local)]
189    fn into_inner(self: Box<Self>) -> Self {
190        *self
191    }
192}
193
194impl Eq for ConstantExpr {}
195
196impl PartialEq for ConstantExpr {
197    fn eq(&self, other: &Self) -> bool {
198        match (self, other) {
199            (Self::Literal(l), Self::Literal(y)) => l == y,
200            (Self::Var(l), Self::Var(y)) => l == y,
201            (
202                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
203                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
204            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
205            _ => false,
206        }
207    }
208}
209
210impl fmt::Debug for ConstantExpr {
211    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
212        match self {
213            Self::Literal(lit) => fmt::Debug::fmt(&**lit, f),
214            Self::Var(name) => fmt::Debug::fmt(&**name, f),
215            Self::BinaryOp { op, lhs, rhs, .. } => {
216                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
217            },
218        }
219    }
220}
221
222impl crate::prettier::PrettyPrint for ConstantExpr {
223    fn render(&self) -> crate::prettier::Document {
224        use crate::prettier::*;
225
226        match self {
227            Self::Literal(literal) => display(literal),
228            Self::Var(ident) => display(ident),
229            Self::BinaryOp { op, lhs, rhs, .. } => {
230                let single_line = lhs.render() + display(op) + rhs.render();
231                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
232                single_line | multi_line
233            },
234        }
235    }
236}
237
238impl Spanned for ConstantExpr {
239    fn span(&self) -> SourceSpan {
240        match self {
241            Self::Literal(spanned) => spanned.span(),
242            Self::Var(spanned) => spanned.span(),
243            Self::BinaryOp { span, .. } => *span,
244        }
245    }
246}
247
248// CONSTANT OPERATION
249// ================================================================================================
250
251/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
252#[derive(Debug, Copy, Clone, PartialEq, Eq)]
253pub enum ConstantOp {
254    Add,
255    Sub,
256    Mul,
257    Div,
258    IntDiv,
259}
260
261impl ConstantOp {
262    const fn name(&self) -> &'static str {
263        match self {
264            Self::Add => "Add",
265            Self::Sub => "Sub",
266            Self::Mul => "Mul",
267            Self::Div => "Div",
268            Self::IntDiv => "IntDiv",
269        }
270    }
271}
272
273impl fmt::Display for ConstantOp {
274    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
275        match self {
276            Self::Add => f.write_str("+"),
277            Self::Sub => f.write_str("-"),
278            Self::Mul => f.write_str("*"),
279            Self::Div => f.write_str("/"),
280            Self::IntDiv => f.write_str("//"),
281        }
282    }
283}