miden_assembly_syntax/ast/
constants.rs

1use alloc::{boxed::Box, string::String, sync::Arc};
2use core::fmt;
3
4use miden_core::FieldElement;
5use miden_debug_types::{SourceSpan, Span, Spanned};
6
7use super::DocString;
8use crate::{
9    Felt,
10    ast::Ident,
11    parser::{ParsingError, WordValue},
12};
13
14// CONSTANT
15// ================================================================================================
16
17/// Represents a constant definition in Miden Assembly syntax, i.e. `const.FOO = 1 + 1`.
18#[derive(Clone)]
19pub struct Constant {
20    /// The source span of the definition.
21    pub span: SourceSpan,
22    /// The documentation string attached to this definition.
23    pub docs: Option<DocString>,
24    /// The name of the constant.
25    pub name: Ident,
26    /// The expression associated with the constant.
27    pub value: ConstantExpr,
28}
29
30impl Constant {
31    /// Creates a new [Constant] from the given source span, name, and value.
32    pub fn new(span: SourceSpan, name: Ident, value: ConstantExpr) -> Self {
33        Self { span, docs: None, name, value }
34    }
35
36    /// Adds documentation to this constant declaration.
37    pub fn with_docs(mut self, docs: Option<Span<String>>) -> Self {
38        self.docs = docs.map(DocString::new);
39        self
40    }
41}
42
43impl fmt::Debug for Constant {
44    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45        f.debug_struct("Constant")
46            .field("docs", &self.docs)
47            .field("name", &self.name)
48            .field("value", &self.value)
49            .finish()
50    }
51}
52
53impl crate::prettier::PrettyPrint for Constant {
54    fn render(&self) -> crate::prettier::Document {
55        use crate::prettier::*;
56
57        let mut doc = self
58            .docs
59            .as_ref()
60            .map(|docstring| docstring.render())
61            .unwrap_or(Document::Empty);
62
63        doc += flatten(const_text("const") + const_text(".") + display(&self.name));
64        doc += const_text("=");
65
66        doc + self.value.render()
67    }
68}
69
70impl Eq for Constant {}
71
72impl PartialEq for Constant {
73    fn eq(&self, other: &Self) -> bool {
74        self.name == other.name && self.value == other.value
75    }
76}
77
78impl Spanned for Constant {
79    fn span(&self) -> SourceSpan {
80        self.span
81    }
82}
83
84// CONSTANT EXPRESSION
85// ================================================================================================
86
87/// Represents a constant expression or value in Miden Assembly syntax.
88#[derive(Clone)]
89pub enum ConstantExpr {
90    /// A literal integer value.
91    Literal(Span<Felt>),
92    /// A reference to another constant.
93    Var(Ident),
94    /// An binary arithmetic operator.
95    BinaryOp {
96        span: SourceSpan,
97        op: ConstantOp,
98        lhs: Box<ConstantExpr>,
99        rhs: Box<ConstantExpr>,
100    },
101    String(Ident),
102    Word(Span<WordValue>),
103}
104
105impl ConstantExpr {
106    /// Unwrap a literal value from this expression or panic.
107    ///
108    /// This is used in places where we expect the expression to have been folded to a value,
109    /// otherwise a bug occurred.
110    #[track_caller]
111    pub fn expect_literal(&self) -> Felt {
112        match self {
113            Self::Literal(spanned) => spanned.into_inner(),
114            other => panic!("expected constant expression to be a literal, got {other:#?}"),
115        }
116    }
117
118    pub fn expect_string(&self) -> Arc<str> {
119        match self {
120            Self::String(spanned) => spanned.clone().into_inner(),
121            other => panic!("expected constant expression to be a string, got {other:#?}"),
122        }
123    }
124
125    /// Attempt to fold to a single value.
126    ///
127    /// This will only succeed if the expression has no references to other constants.
128    ///
129    /// # Errors
130    /// Returns an error if an invalid expression is found while folding, such as division by zero.
131    pub fn try_fold(self) -> Result<Self, ParsingError> {
132        match self {
133            Self::String(_) | Self::Literal(_) | Self::Var(_) | Self::Word(_) => Ok(self),
134            Self::BinaryOp { span, op, lhs, rhs } => {
135                if rhs.is_literal() {
136                    let rhs = Self::into_inner(rhs).try_fold()?;
137                    match rhs {
138                        Self::String(ident) => {
139                            Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
140                        },
141                        Self::Literal(rhs) => {
142                            let lhs = Self::into_inner(lhs).try_fold()?;
143                            match lhs {
144                                Self::String(ident) => {
145                                    Err(ParsingError::StringInArithmeticExpression {
146                                        span: ident.span(),
147                                    })
148                                },
149                                Self::Literal(lhs) => {
150                                    let lhs = lhs.into_inner();
151                                    let rhs = rhs.into_inner();
152                                    let is_division =
153                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
154                                    let is_division_by_zero = is_division && rhs == Felt::ZERO;
155                                    if is_division_by_zero {
156                                        return Err(ParsingError::DivisionByZero { span });
157                                    }
158                                    match op {
159                                        ConstantOp::Add => {
160                                            Ok(Self::Literal(Span::new(span, lhs + rhs)))
161                                        },
162                                        ConstantOp::Sub => {
163                                            Ok(Self::Literal(Span::new(span, lhs - rhs)))
164                                        },
165                                        ConstantOp::Mul => {
166                                            Ok(Self::Literal(Span::new(span, lhs * rhs)))
167                                        },
168                                        ConstantOp::Div => {
169                                            Ok(Self::Literal(Span::new(span, lhs / rhs)))
170                                        },
171                                        ConstantOp::IntDiv => Ok(Self::Literal(Span::new(
172                                            span,
173                                            Felt::new(lhs.as_int() / rhs.as_int()),
174                                        ))),
175                                    }
176                                },
177                                lhs => Ok(Self::BinaryOp {
178                                    span,
179                                    op,
180                                    lhs: Box::new(lhs),
181                                    rhs: Box::new(Self::Literal(rhs)),
182                                }),
183                            }
184                        },
185                        rhs => {
186                            let lhs = Self::into_inner(lhs).try_fold()?;
187                            Ok(Self::BinaryOp {
188                                span,
189                                op,
190                                lhs: Box::new(lhs),
191                                rhs: Box::new(rhs),
192                            })
193                        },
194                    }
195                } else {
196                    let lhs = Self::into_inner(lhs).try_fold()?;
197                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
198                }
199            },
200        }
201    }
202
203    fn is_literal(&self) -> bool {
204        match self {
205            Self::Literal(_) | Self::String(_) | Self::Word(_) => true,
206            Self::Var(_) => false,
207            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
208        }
209    }
210
211    #[inline(always)]
212    #[allow(clippy::boxed_local)]
213    fn into_inner(self: Box<Self>) -> Self {
214        *self
215    }
216}
217
218impl Eq for ConstantExpr {}
219
220impl PartialEq for ConstantExpr {
221    fn eq(&self, other: &Self) -> bool {
222        match (self, other) {
223            (Self::Literal(l), Self::Literal(y)) => l == y,
224            (Self::Var(l), Self::Var(y)) => l == y,
225            (
226                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
227                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
228            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
229            _ => false,
230        }
231    }
232}
233
234impl core::hash::Hash for ConstantExpr {
235    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
236        core::mem::discriminant(self).hash(state);
237        match self {
238            Self::Literal(value) => value.as_int().hash(state),
239            Self::String(value) => value.hash(state),
240            Self::Word(value) => value.hash(state),
241            Self::Var(value) => value.hash(state),
242            Self::BinaryOp { op, lhs, rhs, .. } => {
243                op.hash(state);
244                lhs.hash(state);
245                rhs.hash(state);
246            },
247        }
248    }
249}
250
251impl fmt::Debug for ConstantExpr {
252    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
253        match self {
254            Self::Literal(lit) => fmt::Debug::fmt(&**lit, f),
255            Self::Word(w) => fmt::Debug::fmt(&**w, f),
256            Self::Var(name) | Self::String(name) => fmt::Debug::fmt(&**name, f),
257            Self::BinaryOp { op, lhs, rhs, .. } => {
258                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
259            },
260        }
261    }
262}
263
264impl crate::prettier::PrettyPrint for ConstantExpr {
265    fn render(&self) -> crate::prettier::Document {
266        use crate::prettier::*;
267
268        match self {
269            Self::Literal(literal) => display(literal),
270            Self::Word(spanned) => spanned.render(),
271            Self::Var(ident) | Self::String(ident) => display(ident),
272            Self::BinaryOp { op, lhs, rhs, .. } => {
273                let single_line = lhs.render() + display(op) + rhs.render();
274                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
275                single_line | multi_line
276            },
277        }
278    }
279}
280
281impl Spanned for ConstantExpr {
282    fn span(&self) -> SourceSpan {
283        match self {
284            Self::Literal(spanned) => spanned.span(),
285            Self::Word(spanned) => spanned.span(),
286            Self::Var(spanned) | Self::String(spanned) => spanned.span(),
287            Self::BinaryOp { span, .. } => *span,
288        }
289    }
290}
291
292// CONSTANT OPERATION
293// ================================================================================================
294
295/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
296#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
297pub enum ConstantOp {
298    Add,
299    Sub,
300    Mul,
301    Div,
302    IntDiv,
303}
304
305impl ConstantOp {
306    const fn name(&self) -> &'static str {
307        match self {
308            Self::Add => "Add",
309            Self::Sub => "Sub",
310            Self::Mul => "Mul",
311            Self::Div => "Div",
312            Self::IntDiv => "IntDiv",
313        }
314    }
315}
316
317impl fmt::Display for ConstantOp {
318    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
319        match self {
320            Self::Add => f.write_str("+"),
321            Self::Sub => f.write_str("-"),
322            Self::Mul => f.write_str("*"),
323            Self::Div => f.write_str("/"),
324            Self::IntDiv => f.write_str("//"),
325        }
326    }
327}