miden_assembly_syntax/ast/constants/
expr.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::fmt;
3
4use miden_core::{
5    FieldElement,
6    utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
7};
8use miden_debug_types::{SourceSpan, Span, Spanned};
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12use crate::{
13    Felt, Path,
14    ast::{ConstantValue, Ident},
15    parser::{IntValue, ParsingError, WordValue},
16};
17
18// CONSTANT EXPRESSION
19// ================================================================================================
20
21/// Represents a constant expression or value in Miden Assembly syntax.
22#[derive(Clone)]
23#[repr(u8)]
24pub enum ConstantExpr {
25    /// A literal [`Felt`] value.
26    Int(Span<IntValue>),
27    /// A reference to another constant.
28    Var(Span<Arc<Path>>),
29    /// An binary arithmetic operator.
30    BinaryOp {
31        span: SourceSpan,
32        op: ConstantOp,
33        lhs: Box<ConstantExpr>,
34        rhs: Box<ConstantExpr>,
35    },
36    /// A plain spanned string.
37    String(Ident),
38    /// A literal ['WordValue'].
39    Word(Span<WordValue>),
40    /// A spanned string with a [`HashKind`] showing to which type of value the given string should
41    /// be hashed.
42    Hash(HashKind, Ident),
43}
44
45impl ConstantExpr {
46    /// Returns true if this expression is already evaluated to a concrete value
47    pub fn is_value(&self) -> bool {
48        matches!(self, Self::Int(_) | Self::Word(_) | Self::Hash(_, _) | Self::String(_))
49    }
50
51    /// Unwrap an [`IntValue`] from this expression or panic.
52    ///
53    /// This is used in places where we expect the expression to have been folded to an integer,
54    /// otherwise a bug occurred.
55    #[track_caller]
56    pub fn expect_int(&self) -> IntValue {
57        match self {
58            Self::Int(spanned) => spanned.into_inner(),
59            other => panic!("expected constant expression to be a literal, got {other:#?}"),
60        }
61    }
62
63    /// Unwrap a [`Felt`] value from this expression or panic.
64    ///
65    /// This is used in places where we expect the expression to have been folded to a felt value,
66    /// otherwise a bug occurred.
67    #[track_caller]
68    pub fn expect_felt(&self) -> Felt {
69        match self {
70            Self::Int(spanned) => Felt::new(spanned.inner().as_int()),
71            other => panic!("expected constant expression to be a literal, got {other:#?}"),
72        }
73    }
74
75    /// Unwrap a [`Arc<str>`] value from this expression or panic.
76    ///
77    /// This is used in places where we expect the expression to have been folded to a string value,
78    /// otherwise a bug occurred.
79    #[track_caller]
80    pub fn expect_string(&self) -> Arc<str> {
81        match self {
82            Self::String(spanned) => spanned.clone().into_inner(),
83            other => panic!("expected constant expression to be a string, got {other:#?}"),
84        }
85    }
86
87    /// Unwrap a [ConstantValue] from this expression or panic.
88    ///
89    /// This is used in places where we expect the expression to have been folded to a concrete
90    /// value, otherwise a bug occurred.
91    #[track_caller]
92    pub fn expect_value(&self) -> ConstantValue {
93        self.as_value().unwrap_or_else(|| {
94            panic!("expected constant expression to be a value, got {:#?}", self)
95        })
96    }
97
98    /// Try to convert this expression into a [ConstantValue], if the expression is a value.
99    ///
100    /// Returns `Err` if the expression cannot be represented as a [ConstantValue].
101    pub fn into_value(self) -> Result<ConstantValue, Self> {
102        match self {
103            Self::Int(value) => Ok(ConstantValue::Int(value)),
104            Self::String(value) => Ok(ConstantValue::String(value)),
105            Self::Word(value) => Ok(ConstantValue::Word(value)),
106            Self::Hash(kind, value) => Ok(ConstantValue::Hash(kind, value)),
107            expr @ (Self::BinaryOp { .. } | Self::Var(_)) => Err(expr),
108        }
109    }
110
111    /// Get the [ConstantValue] representation of this expression, if it is a value.
112    ///
113    /// Returns `None` if the expression cannot be represented as a [ConstantValue].
114    pub fn as_value(&self) -> Option<ConstantValue> {
115        match self {
116            Self::Int(value) => Some(ConstantValue::Int(*value)),
117            Self::String(value) => Some(ConstantValue::String(value.clone())),
118            Self::Word(value) => Some(ConstantValue::Word(*value)),
119            Self::Hash(kind, value) => Some(ConstantValue::Hash(*kind, value.clone())),
120            Self::BinaryOp { .. } | Self::Var(_) => None,
121        }
122    }
123
124    /// Attempt to fold to a single value.
125    ///
126    /// This will only succeed if the expression has no references to other constants.
127    ///
128    /// # Errors
129    /// Returns an error if an invalid expression is found while folding, such as division by zero.
130    pub fn try_fold(self) -> Result<Self, ParsingError> {
131        match self {
132            Self::String(_) | Self::Word(_) | Self::Int(_) | Self::Var(_) | Self::Hash(..) => {
133                Ok(self)
134            },
135            Self::BinaryOp { span, op, lhs, rhs } => {
136                if rhs.is_literal() {
137                    let rhs = Self::into_inner(rhs).try_fold()?;
138                    match rhs {
139                        Self::String(ident) => {
140                            Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
141                        },
142                        Self::Int(rhs) => {
143                            let lhs = Self::into_inner(lhs).try_fold()?;
144                            match lhs {
145                                Self::String(ident) => {
146                                    Err(ParsingError::StringInArithmeticExpression {
147                                        span: ident.span(),
148                                    })
149                                },
150                                Self::Int(lhs) => {
151                                    let lhs = lhs.into_inner();
152                                    let rhs = rhs.into_inner();
153                                    let is_division =
154                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
155                                    let is_division_by_zero = is_division && rhs == Felt::ZERO;
156                                    if is_division_by_zero {
157                                        return Err(ParsingError::DivisionByZero { span });
158                                    }
159                                    match op {
160                                        ConstantOp::Add => {
161                                            Ok(Self::Int(Span::new(span, lhs + rhs)))
162                                        },
163                                        ConstantOp::Sub => {
164                                            Ok(Self::Int(Span::new(span, lhs - rhs)))
165                                        },
166                                        ConstantOp::Mul => {
167                                            Ok(Self::Int(Span::new(span, lhs * rhs)))
168                                        },
169                                        ConstantOp::Div => {
170                                            Ok(Self::Int(Span::new(span, lhs / rhs)))
171                                        },
172                                        ConstantOp::IntDiv => {
173                                            Ok(Self::Int(Span::new(span, lhs / rhs)))
174                                        },
175                                    }
176                                },
177                                lhs => Ok(Self::BinaryOp {
178                                    span,
179                                    op,
180                                    lhs: Box::new(lhs),
181                                    rhs: Box::new(Self::Int(rhs)),
182                                }),
183                            }
184                        },
185                        rhs => {
186                            let lhs = Self::into_inner(lhs).try_fold()?;
187                            Ok(Self::BinaryOp {
188                                span,
189                                op,
190                                lhs: Box::new(lhs),
191                                rhs: Box::new(rhs),
192                            })
193                        },
194                    }
195                } else {
196                    let lhs = Self::into_inner(lhs).try_fold()?;
197                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
198                }
199            },
200        }
201    }
202
203    /// Get any references to other symbols present in this expression
204    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
205        use alloc::collections::BTreeSet;
206
207        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
208        let mut references = BTreeSet::new();
209
210        while let Some(ty) = worklist.pop() {
211            match ty {
212                Self::Int(_) | Self::Word(_) | Self::String(_) | Self::Hash(..) => continue,
213                Self::Var(path) => {
214                    references.insert(path.clone());
215                },
216                Self::BinaryOp { lhs, rhs, .. } => {
217                    worklist.push(lhs);
218                    worklist.push(rhs);
219                },
220            }
221        }
222
223        references.into_iter().collect()
224    }
225
226    fn is_literal(&self) -> bool {
227        match self {
228            Self::Int(_) | Self::String(_) | Self::Word(_) | Self::Hash(..) => true,
229            Self::Var(_) => false,
230            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
231        }
232    }
233
234    #[inline(always)]
235    #[expect(clippy::boxed_local)]
236    fn into_inner(self: Box<Self>) -> Self {
237        *self
238    }
239}
240
241impl Eq for ConstantExpr {}
242
243impl PartialEq for ConstantExpr {
244    fn eq(&self, other: &Self) -> bool {
245        match (self, other) {
246            (Self::Int(x), Self::Int(y)) => x == y,
247            (Self::Int(_), _) => false,
248            (Self::Word(x), Self::Word(y)) => x == y,
249            (Self::Word(_), _) => false,
250            (Self::Var(x), Self::Var(y)) => x == y,
251            (Self::Var(_), _) => false,
252            (Self::String(x), Self::String(y)) => x == y,
253            (Self::String(_), _) => false,
254            (Self::Hash(x_hk, x_i), Self::Hash(y_hk, y_i)) => x_i == y_i && x_hk == y_hk,
255            (Self::Hash(..), _) => false,
256            (
257                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
258                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
259            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
260            (Self::BinaryOp { .. }, _) => false,
261        }
262    }
263}
264
265impl core::hash::Hash for ConstantExpr {
266    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
267        core::mem::discriminant(self).hash(state);
268        match self {
269            Self::Int(value) => value.hash(state),
270            Self::Word(value) => value.hash(state),
271            Self::String(value) => value.hash(state),
272            Self::Var(value) => value.hash(state),
273            Self::Hash(hash_kind, string) => {
274                hash_kind.hash(state);
275                string.hash(state);
276            },
277            Self::BinaryOp { op, lhs, rhs, .. } => {
278                op.hash(state);
279                lhs.hash(state);
280                rhs.hash(state);
281            },
282        }
283    }
284}
285
286impl fmt::Debug for ConstantExpr {
287    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
288        match self {
289            Self::Int(lit) => fmt::Debug::fmt(&**lit, f),
290            Self::Word(lit) => fmt::Debug::fmt(&**lit, f),
291            Self::Var(path) => fmt::Debug::fmt(path, f),
292            Self::String(name) => fmt::Debug::fmt(&**name, f),
293            Self::Hash(hash_kind, str) => {
294                f.debug_tuple("Hash").field(hash_kind).field(str).finish()
295            },
296            Self::BinaryOp { op, lhs, rhs, .. } => {
297                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
298            },
299        }
300    }
301}
302
303impl crate::prettier::PrettyPrint for ConstantExpr {
304    fn render(&self) -> crate::prettier::Document {
305        use crate::prettier::*;
306
307        match self {
308            Self::Int(literal) => literal.render(),
309            Self::Word(literal) => literal.render(),
310            Self::Var(path) => display(path),
311            Self::String(ident) => text(format!("\"{}\"", ident.as_str().escape_debug())),
312            Self::Hash(hash_kind, str) => flatten(
313                display(hash_kind)
314                    + const_text("(")
315                    + text(format!("\"{}\"", str.as_str().escape_debug()))
316                    + const_text(")"),
317            ),
318            Self::BinaryOp { op, lhs, rhs, .. } => {
319                let single_line = lhs.render() + display(op) + rhs.render();
320                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
321                single_line | multi_line
322            },
323        }
324    }
325}
326
327impl Spanned for ConstantExpr {
328    fn span(&self) -> SourceSpan {
329        match self {
330            Self::Int(spanned) => spanned.span(),
331            Self::Word(spanned) => spanned.span(),
332            Self::Hash(_, spanned) => spanned.span(),
333            Self::Var(spanned) => spanned.span(),
334            Self::String(spanned) => spanned.span(),
335            Self::BinaryOp { span, .. } => *span,
336        }
337    }
338}
339
340#[cfg(feature = "arbitrary")]
341impl proptest::arbitrary::Arbitrary for ConstantExpr {
342    type Parameters = ();
343
344    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
345        use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
346
347        prop_oneof![
348            any::<IntValue>().prop_map(|n| Self::Int(Span::unknown(n))),
349            crate::arbitrary::path::constant_path_random_length(0)
350                .prop_map(|p| Self::Var(Span::unknown(p))),
351            any::<(ConstantOp, IntValue, IntValue)>().prop_map(|(op, lhs, rhs)| Self::BinaryOp {
352                span: SourceSpan::UNKNOWN,
353                op,
354                lhs: Box::new(ConstantExpr::Int(Span::unknown(lhs))),
355                rhs: Box::new(ConstantExpr::Int(Span::unknown(rhs))),
356            }),
357            any::<Ident>().prop_map(Self::String),
358            any::<WordValue>().prop_map(|word| Self::Word(Span::unknown(word))),
359            any::<(HashKind, Ident)>().prop_map(|(kind, s)| Self::Hash(kind, s)),
360        ]
361        .boxed()
362    }
363
364    type Strategy = proptest::prelude::BoxedStrategy<Self>;
365}
366
367// CONSTANT OPERATION
368// ================================================================================================
369
370/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
371#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
372#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
373pub enum ConstantOp {
374    Add,
375    Sub,
376    Mul,
377    Div,
378    IntDiv,
379}
380
381impl ConstantOp {
382    const fn name(&self) -> &'static str {
383        match self {
384            Self::Add => "Add",
385            Self::Sub => "Sub",
386            Self::Mul => "Mul",
387            Self::Div => "Div",
388            Self::IntDiv => "IntDiv",
389        }
390    }
391}
392
393impl fmt::Display for ConstantOp {
394    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
395        match self {
396            Self::Add => f.write_str("+"),
397            Self::Sub => f.write_str("-"),
398            Self::Mul => f.write_str("*"),
399            Self::Div => f.write_str("/"),
400            Self::IntDiv => f.write_str("//"),
401        }
402    }
403}
404
405impl ConstantOp {
406    const fn tag(&self) -> u8 {
407        // SAFETY: This is safe because we have given this enum a
408        // primitive representation with #[repr(u8)], with the first
409        // field of the underlying union-of-structs the discriminant
410        //
411        // See the section on "accessing the numeric value of the discriminant"
412        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
413        unsafe { *(self as *const Self).cast::<u8>() }
414    }
415}
416
417impl Serializable for ConstantOp {
418    fn write_into<W: ByteWriter>(&self, target: &mut W) {
419        target.write_u8(self.tag());
420    }
421}
422
423impl Deserializable for ConstantOp {
424    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
425        const ADD: u8 = ConstantOp::Add.tag();
426        const SUB: u8 = ConstantOp::Sub.tag();
427        const MUL: u8 = ConstantOp::Mul.tag();
428        const DIV: u8 = ConstantOp::Div.tag();
429        const INT_DIV: u8 = ConstantOp::IntDiv.tag();
430
431        match source.read_u8()? {
432            ADD => Ok(Self::Add),
433            SUB => Ok(Self::Sub),
434            MUL => Ok(Self::Mul),
435            DIV => Ok(Self::Div),
436            INT_DIV => Ok(Self::IntDiv),
437            invalid => Err(DeserializationError::InvalidValue(format!(
438                "unexpected ConstantOp tag: '{invalid}'"
439            ))),
440        }
441    }
442}
443
444#[cfg(feature = "arbitrary")]
445impl proptest::arbitrary::Arbitrary for ConstantOp {
446    type Parameters = ();
447
448    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
449        use proptest::{
450            prop_oneof,
451            strategy::{Just, Strategy},
452        };
453
454        prop_oneof![
455            Just(Self::Add),
456            Just(Self::Sub),
457            Just(Self::Mul),
458            Just(Self::Div),
459            Just(Self::IntDiv),
460        ]
461        .boxed()
462    }
463
464    type Strategy = proptest::prelude::BoxedStrategy<Self>;
465}
466
467// HASH KIND
468// ================================================================================================
469
470/// Represents the type of the final value to which some string value should be converted.
471#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
472#[repr(u8)]
473#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
474pub enum HashKind {
475    /// Reduce a string to a word using Blake3 hash function
476    Word,
477    /// Reduce a string to a felt using Blake3 hash function (via 64-bit reduction)
478    Event,
479}
480
481impl HashKind {
482    const fn tag(&self) -> u8 {
483        // SAFETY: This is safe because we have given this enum a
484        // primitive representation with #[repr(u8)], with the first
485        // field of the underlying union-of-structs the discriminant
486        //
487        // See the section on "accessing the numeric value of the discriminant"
488        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
489        unsafe { *(self as *const Self).cast::<u8>() }
490    }
491}
492
493impl fmt::Display for HashKind {
494    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
495        match self {
496            Self::Word => f.write_str("word"),
497            Self::Event => f.write_str("event"),
498        }
499    }
500}
501
502#[cfg(feature = "arbitrary")]
503impl proptest::arbitrary::Arbitrary for HashKind {
504    type Parameters = ();
505
506    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
507        use proptest::{
508            prop_oneof,
509            strategy::{Just, Strategy},
510        };
511
512        prop_oneof![Just(Self::Word), Just(Self::Event),].boxed()
513    }
514
515    type Strategy = proptest::prelude::BoxedStrategy<Self>;
516}
517
518impl Serializable for HashKind {
519    fn write_into<W: ByteWriter>(&self, target: &mut W) {
520        target.write_u8(self.tag());
521    }
522}
523
524impl Deserializable for HashKind {
525    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
526        const WORD: u8 = HashKind::Word.tag();
527        const EVENT: u8 = HashKind::Event.tag();
528
529        match source.read_u8()? {
530            WORD => Ok(Self::Word),
531            EVENT => Ok(Self::Event),
532            invalid => Err(DeserializationError::InvalidValue(format!(
533                "unexpected HashKind tag: '{invalid}'"
534            ))),
535        }
536    }
537}