miden_assembly_syntax/ast/
constants.rs

1use alloc::{boxed::Box, string::String, sync::Arc};
2use core::fmt;
3
4use miden_core::FieldElement;
5use miden_debug_types::{SourceSpan, Span, Spanned};
6
7use super::DocString;
8use crate::{
9    Felt,
10    ast::Ident,
11    parser::{ParsingError, WordValue},
12};
13
14// CONSTANT
15// ================================================================================================
16
17/// Represents a constant definition in Miden Assembly syntax, i.e. `const.FOO = 1 + 1`.
18pub struct Constant {
19    /// The source span of the definition.
20    pub span: SourceSpan,
21    /// The documentation string attached to this definition.
22    pub docs: Option<DocString>,
23    /// The name of the constant.
24    pub name: Ident,
25    /// The expression associated with the constant.
26    pub value: ConstantExpr,
27}
28
29impl Constant {
30    /// Creates a new [Constant] from the given source span, name, and value.
31    pub fn new(span: SourceSpan, name: Ident, value: ConstantExpr) -> Self {
32        Self { span, docs: None, name, value }
33    }
34
35    /// Adds documentation to this constant declaration.
36    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
37        self.docs = docs.map(DocString::new);
38        self
39    }
40}
41
42impl fmt::Debug for Constant {
43    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
44        f.debug_struct("Constant")
45            .field("docs", &self.docs)
46            .field("name", &self.name)
47            .field("value", &self.value)
48            .finish()
49    }
50}
51
52impl crate::prettier::PrettyPrint for Constant {
53    fn render(&self) -> crate::prettier::Document {
54        use crate::prettier::*;
55
56        let mut doc = self
57            .docs
58            .as_ref()
59            .map(|docstring| docstring.render())
60            .unwrap_or(Document::Empty);
61
62        doc += flatten(const_text("const") + const_text(".") + display(&self.name));
63        doc += const_text("=");
64
65        doc + self.value.render()
66    }
67}
68
69impl Eq for Constant {}
70
71impl PartialEq for Constant {
72    fn eq(&self, other: &Self) -> bool {
73        self.name == other.name && self.value == other.value
74    }
75}
76
77impl Spanned for Constant {
78    fn span(&self) -> SourceSpan {
79        self.span
80    }
81}
82
83// CONSTANT EXPRESSION
84// ================================================================================================
85
86/// Represents a constant expression or value in Miden Assembly syntax.
87#[derive(Clone)]
88pub enum ConstantExpr {
89    /// A literal integer value.
90    Literal(Span<Felt>),
91    /// A reference to another constant.
92    Var(Ident),
93    /// An binary arithmetic operator.
94    BinaryOp {
95        span: SourceSpan,
96        op: ConstantOp,
97        lhs: Box<ConstantExpr>,
98        rhs: Box<ConstantExpr>,
99    },
100    String(Ident),
101    Word(Span<WordValue>),
102}
103
104impl ConstantExpr {
105    /// Unwrap a literal value from this expression or panic.
106    ///
107    /// This is used in places where we expect the expression to have been folded to a value,
108    /// otherwise a bug occurred.
109    #[track_caller]
110    pub fn expect_literal(&self) -> Felt {
111        match self {
112            Self::Literal(spanned) => spanned.into_inner(),
113            other => panic!("expected constant expression to be a literal, got {other:#?}"),
114        }
115    }
116
117    pub fn expect_string(&self) -> Arc<str> {
118        match self {
119            Self::String(spanned) => spanned.clone().into_inner(),
120            other => panic!("expected constant expression to be a string, got {other:#?}"),
121        }
122    }
123
124    /// Attempt to fold to a single value.
125    ///
126    /// This will only succeed if the expression has no references to other constants.
127    ///
128    /// # Errors
129    /// Returns an error if an invalid expression is found while folding, such as division by zero.
130    pub fn try_fold(self) -> Result<Self, ParsingError> {
131        match self {
132            Self::String(_) | Self::Literal(_) | Self::Var(_) | Self::Word(_) => Ok(self),
133            Self::BinaryOp { span, op, lhs, rhs } => {
134                if rhs.is_literal() {
135                    let rhs = Self::into_inner(rhs).try_fold()?;
136                    match rhs {
137                        Self::String(ident) => {
138                            Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
139                        },
140                        Self::Literal(rhs) => {
141                            let lhs = Self::into_inner(lhs).try_fold()?;
142                            match lhs {
143                                Self::String(ident) => {
144                                    Err(ParsingError::StringInArithmeticExpression {
145                                        span: ident.span(),
146                                    })
147                                },
148                                Self::Literal(lhs) => {
149                                    let lhs = lhs.into_inner();
150                                    let rhs = rhs.into_inner();
151                                    let is_division =
152                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
153                                    let is_division_by_zero = is_division && rhs == Felt::ZERO;
154                                    if is_division_by_zero {
155                                        return Err(ParsingError::DivisionByZero { span });
156                                    }
157                                    match op {
158                                        ConstantOp::Add => {
159                                            Ok(Self::Literal(Span::new(span, lhs + rhs)))
160                                        },
161                                        ConstantOp::Sub => {
162                                            Ok(Self::Literal(Span::new(span, lhs - rhs)))
163                                        },
164                                        ConstantOp::Mul => {
165                                            Ok(Self::Literal(Span::new(span, lhs * rhs)))
166                                        },
167                                        ConstantOp::Div => {
168                                            Ok(Self::Literal(Span::new(span, lhs / rhs)))
169                                        },
170                                        ConstantOp::IntDiv => Ok(Self::Literal(Span::new(
171                                            span,
172                                            Felt::new(lhs.as_int() / rhs.as_int()),
173                                        ))),
174                                    }
175                                },
176                                lhs => Ok(Self::BinaryOp {
177                                    span,
178                                    op,
179                                    lhs: Box::new(lhs),
180                                    rhs: Box::new(Self::Literal(rhs)),
181                                }),
182                            }
183                        },
184                        rhs => {
185                            let lhs = Self::into_inner(lhs).try_fold()?;
186                            Ok(Self::BinaryOp {
187                                span,
188                                op,
189                                lhs: Box::new(lhs),
190                                rhs: Box::new(rhs),
191                            })
192                        },
193                    }
194                } else {
195                    let lhs = Self::into_inner(lhs).try_fold()?;
196                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
197                }
198            },
199        }
200    }
201
202    fn is_literal(&self) -> bool {
203        match self {
204            Self::Literal(_) | Self::String(_) | Self::Word(_) => true,
205            Self::Var(_) => false,
206            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
207        }
208    }
209
210    #[inline(always)]
211    #[allow(clippy::boxed_local)]
212    fn into_inner(self: Box<Self>) -> Self {
213        *self
214    }
215}
216
217impl Eq for ConstantExpr {}
218
219impl PartialEq for ConstantExpr {
220    fn eq(&self, other: &Self) -> bool {
221        match (self, other) {
222            (Self::Literal(l), Self::Literal(y)) => l == y,
223            (Self::Var(l), Self::Var(y)) => l == y,
224            (
225                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
226                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
227            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
228            _ => false,
229        }
230    }
231}
232
233impl fmt::Debug for ConstantExpr {
234    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
235        match self {
236            Self::Literal(lit) => fmt::Debug::fmt(&**lit, f),
237            Self::Word(w) => fmt::Debug::fmt(&**w, f),
238            Self::Var(name) | Self::String(name) => fmt::Debug::fmt(&**name, f),
239            Self::BinaryOp { op, lhs, rhs, .. } => {
240                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
241            },
242        }
243    }
244}
245
246impl crate::prettier::PrettyPrint for ConstantExpr {
247    fn render(&self) -> crate::prettier::Document {
248        use crate::prettier::*;
249
250        match self {
251            Self::Literal(literal) => display(literal),
252            Self::Word(spanned) => spanned.render(),
253            Self::Var(ident) | Self::String(ident) => display(ident),
254            Self::BinaryOp { op, lhs, rhs, .. } => {
255                let single_line = lhs.render() + display(op) + rhs.render();
256                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
257                single_line | multi_line
258            },
259        }
260    }
261}
262
263impl Spanned for ConstantExpr {
264    fn span(&self) -> SourceSpan {
265        match self {
266            Self::Literal(spanned) => spanned.span(),
267            Self::Word(spanned) => spanned.span(),
268            Self::Var(spanned) | Self::String(spanned) => spanned.span(),
269            Self::BinaryOp { span, .. } => *span,
270        }
271    }
272}
273
274// CONSTANT OPERATION
275// ================================================================================================
276
277/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
278#[derive(Debug, Copy, Clone, PartialEq, Eq)]
279pub enum ConstantOp {
280    Add,
281    Sub,
282    Mul,
283    Div,
284    IntDiv,
285}
286
287impl ConstantOp {
288    const fn name(&self) -> &'static str {
289        match self {
290            Self::Add => "Add",
291            Self::Sub => "Sub",
292            Self::Mul => "Mul",
293            Self::Div => "Div",
294            Self::IntDiv => "IntDiv",
295        }
296    }
297}
298
299impl fmt::Display for ConstantOp {
300    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
301        match self {
302            Self::Add => f.write_str("+"),
303            Self::Sub => f.write_str("-"),
304            Self::Mul => f.write_str("*"),
305            Self::Div => f.write_str("/"),
306            Self::IntDiv => f.write_str("//"),
307        }
308    }
309}