Skip to main content

miden_assembly_syntax/ast/constants/
expr.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::fmt;
3
4use miden_core::serde::{
5    ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
6};
7use miden_debug_types::{SourceSpan, Span, Spanned};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    Felt, Path,
13    ast::{ConstantValue, Ident},
14    parser::{IntValue, LiteralErrorKind, ParsingError, WordValue},
15};
16
17/// Maximum allowed nesting depth for constant-expression folding.
18///
19/// This limit is intended to prevent stack overflows from maliciously deep constant expressions
20/// while remaining far above typical constant-expression complexity.
21const MAX_CONST_EXPR_FOLD_DEPTH: usize = 256;
22
23// CONSTANT EXPRESSION
24// ================================================================================================
25
26/// Represents a constant expression or value in Miden Assembly syntax.
27#[derive(Clone)]
28#[repr(u8)]
29pub enum ConstantExpr {
30    /// A literal [`Felt`] value.
31    Int(Span<IntValue>),
32    /// A reference to another constant.
33    Var(Span<Arc<Path>>),
34    /// An binary arithmetic operator.
35    BinaryOp {
36        span: SourceSpan,
37        op: ConstantOp,
38        lhs: Box<ConstantExpr>,
39        rhs: Box<ConstantExpr>,
40    },
41    /// A plain spanned string.
42    String(Ident),
43    /// A literal ['WordValue'].
44    Word(Span<WordValue>),
45    /// A spanned string with a [`HashKind`] showing to which type of value the given string should
46    /// be hashed.
47    Hash(HashKind, Ident),
48}
49
50impl ConstantExpr {
51    /// Returns true if this expression is already evaluated to a concrete value
52    pub fn is_value(&self) -> bool {
53        matches!(self, Self::Int(_) | Self::Word(_) | Self::Hash(_, _) | Self::String(_))
54    }
55
56    /// Unwrap an [`IntValue`] from this expression or panic.
57    ///
58    /// This is used in places where we expect the expression to have been folded to an integer,
59    /// otherwise a bug occurred.
60    #[track_caller]
61    pub fn expect_int(&self) -> IntValue {
62        match self {
63            Self::Int(spanned) => spanned.into_inner(),
64            other => panic!("expected constant expression to be a literal, got {other:#?}"),
65        }
66    }
67
68    /// Unwrap a [`Felt`] value from this expression or panic.
69    ///
70    /// This is used in places where we expect the expression to have been folded to a felt value,
71    /// otherwise a bug occurred.
72    #[track_caller]
73    pub fn expect_felt(&self) -> Felt {
74        match self {
75            Self::Int(spanned) => Felt::new_unchecked(spanned.inner().as_int()),
76            other => panic!("expected constant expression to be a literal, got {other:#?}"),
77        }
78    }
79
80    /// Unwrap a [`Arc<str>`] value from this expression or panic.
81    ///
82    /// This is used in places where we expect the expression to have been folded to a string value,
83    /// otherwise a bug occurred.
84    #[track_caller]
85    pub fn expect_string(&self) -> Arc<str> {
86        match self {
87            Self::String(spanned) => spanned.clone().into_inner(),
88            other => panic!("expected constant expression to be a string, got {other:#?}"),
89        }
90    }
91
92    /// Unwrap a [ConstantValue] from this expression or panic.
93    ///
94    /// This is used in places where we expect the expression to have been folded to a concrete
95    /// value, otherwise a bug occurred.
96    #[track_caller]
97    pub fn expect_value(&self) -> ConstantValue {
98        self.as_value()
99            .unwrap_or_else(|| panic!("expected constant expression to be a value, got {self:#?}"))
100    }
101
102    /// Try to convert this expression into a [ConstantValue], if the expression is a value.
103    ///
104    /// Returns `Err` if the expression cannot be represented as a [ConstantValue].
105    pub fn into_value(self) -> Result<ConstantValue, Self> {
106        match self {
107            Self::Int(value) => Ok(ConstantValue::Int(value)),
108            Self::String(value) => Ok(ConstantValue::String(value)),
109            Self::Word(value) => Ok(ConstantValue::Word(value)),
110            Self::Hash(kind, value) => Ok(ConstantValue::Hash(kind, value)),
111            expr @ (Self::BinaryOp { .. } | Self::Var(_)) => Err(expr),
112        }
113    }
114
115    /// Get the [ConstantValue] representation of this expression, if it is a value.
116    ///
117    /// Returns `None` if the expression cannot be represented as a [ConstantValue].
118    pub fn as_value(&self) -> Option<ConstantValue> {
119        match self {
120            Self::Int(value) => Some(ConstantValue::Int(*value)),
121            Self::String(value) => Some(ConstantValue::String(value.clone())),
122            Self::Word(value) => Some(ConstantValue::Word(*value)),
123            Self::Hash(kind, value) => Some(ConstantValue::Hash(*kind, value.clone())),
124            Self::BinaryOp { .. } | Self::Var(_) => None,
125        }
126    }
127
128    /// Attempt to fold to a single value.
129    ///
130    /// This will only succeed if the expression has no references to other constants.
131    ///
132    /// # Errors
133    /// Returns an error if an invalid expression is found while folding, such as division by zero.
134    pub fn try_fold(self) -> Result<Self, ParsingError> {
135        self.try_fold_with_depth(0)
136    }
137
138    fn try_fold_with_depth(self, depth: usize) -> Result<Self, ParsingError> {
139        if depth > MAX_CONST_EXPR_FOLD_DEPTH {
140            return Err(ParsingError::ConstExprDepthExceeded {
141                span: self.span(),
142                max_depth: MAX_CONST_EXPR_FOLD_DEPTH,
143            });
144        }
145
146        match self {
147            Self::String(_) | Self::Word(_) | Self::Int(_) | Self::Var(_) | Self::Hash(..) => {
148                Ok(self)
149            },
150            Self::BinaryOp { span, op, lhs, rhs } => {
151                if rhs.is_literal() {
152                    let rhs = Self::into_inner(rhs).try_fold_with_depth(depth + 1)?;
153                    match rhs {
154                        Self::String(ident) => {
155                            Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
156                        },
157                        Self::Int(rhs) => {
158                            let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
159                            match lhs {
160                                Self::String(ident) => {
161                                    Err(ParsingError::StringInArithmeticExpression {
162                                        span: ident.span(),
163                                    })
164                                },
165                                Self::Int(lhs) => {
166                                    let lhs = lhs.into_inner();
167                                    let rhs = rhs.into_inner();
168                                    let is_division =
169                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
170                                    let is_division_by_zero = is_division && rhs.as_int() == 0;
171                                    if is_division_by_zero {
172                                        return Err(ParsingError::DivisionByZero { span });
173                                    }
174                                    match op {
175                                        ConstantOp::Add => {
176                                            let value = lhs
177                                                .checked_add(rhs)
178                                                .ok_or(ParsingError::ConstantOverflow { span })?;
179                                            Some(Self::Int(Span::new(span, value)))
180                                        },
181                                        ConstantOp::Sub => {
182                                            let value = lhs
183                                                .checked_sub(rhs)
184                                                .ok_or(ParsingError::ConstantOverflow { span })?;
185                                            Some(Self::Int(Span::new(span, value)))
186                                        },
187                                        ConstantOp::Mul => {
188                                            let value = lhs
189                                                .checked_mul(rhs)
190                                                .ok_or(ParsingError::ConstantOverflow { span })?;
191                                            Some(Self::Int(Span::new(span, value)))
192                                        },
193                                        ConstantOp::IntDiv => {
194                                            let value = lhs
195                                                .checked_div(rhs)
196                                                .ok_or(ParsingError::ConstantOverflow { span })?;
197                                            Some(Self::Int(Span::new(span, value)))
198                                        },
199                                        ConstantOp::Div => {
200                                            let lhs = Felt::new_unchecked(lhs.as_int());
201                                            let rhs = Felt::new_unchecked(rhs.as_int());
202                                            let value = IntValue::from(lhs / rhs);
203                                            Some(Self::Int(Span::new(span, value)))
204                                        },
205                                    }
206                                    .ok_or(
207                                        ParsingError::InvalidLiteral {
208                                            span,
209                                            kind: LiteralErrorKind::FeltOverflow,
210                                        },
211                                    )
212                                },
213                                lhs => Ok(Self::BinaryOp {
214                                    span,
215                                    op,
216                                    lhs: Box::new(lhs),
217                                    rhs: Box::new(Self::Int(rhs)),
218                                }),
219                            }
220                        },
221                        rhs => {
222                            let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
223                            Ok(Self::BinaryOp {
224                                span,
225                                op,
226                                lhs: Box::new(lhs),
227                                rhs: Box::new(rhs),
228                            })
229                        },
230                    }
231                } else {
232                    let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
233                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
234                }
235            },
236        }
237    }
238
239    /// Get any references to other symbols present in this expression
240    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
241        use alloc::collections::BTreeSet;
242
243        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
244        let mut references = BTreeSet::new();
245
246        while let Some(ty) = worklist.pop() {
247            match ty {
248                Self::Int(_) | Self::Word(_) | Self::String(_) | Self::Hash(..) => {},
249                Self::Var(path) => {
250                    references.insert(path.clone());
251                },
252                Self::BinaryOp { lhs, rhs, .. } => {
253                    worklist.push(lhs);
254                    worklist.push(rhs);
255                },
256            }
257        }
258
259        references.into_iter().collect()
260    }
261
262    fn is_literal(&self) -> bool {
263        match self {
264            Self::Int(_) | Self::String(_) | Self::Word(_) | Self::Hash(..) => true,
265            Self::Var(_) => false,
266            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
267        }
268    }
269
270    #[inline(always)]
271    #[expect(clippy::boxed_local)]
272    fn into_inner(self: Box<Self>) -> Self {
273        *self
274    }
275}
276
277impl Eq for ConstantExpr {}
278
279impl PartialEq for ConstantExpr {
280    fn eq(&self, other: &Self) -> bool {
281        match (self, other) {
282            (Self::Int(x), Self::Int(y)) => x == y,
283            (Self::Int(_), _) => false,
284            (Self::Word(x), Self::Word(y)) => x == y,
285            (Self::Word(_), _) => false,
286            (Self::Var(x), Self::Var(y)) => x == y,
287            (Self::Var(_), _) => false,
288            (Self::String(x), Self::String(y)) => x == y,
289            (Self::String(_), _) => false,
290            (Self::Hash(x_hk, x_i), Self::Hash(y_hk, y_i)) => x_i == y_i && x_hk == y_hk,
291            (Self::Hash(..), _) => false,
292            (
293                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
294                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
295            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
296            (Self::BinaryOp { .. }, _) => false,
297        }
298    }
299}
300
301impl core::hash::Hash for ConstantExpr {
302    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
303        core::mem::discriminant(self).hash(state);
304        match self {
305            Self::Int(value) => value.hash(state),
306            Self::Word(value) => value.hash(state),
307            Self::String(value) => value.hash(state),
308            Self::Var(value) => value.hash(state),
309            Self::Hash(hash_kind, string) => {
310                hash_kind.hash(state);
311                string.hash(state);
312            },
313            Self::BinaryOp { op, lhs, rhs, .. } => {
314                op.hash(state);
315                lhs.hash(state);
316                rhs.hash(state);
317            },
318        }
319    }
320}
321
322impl fmt::Debug for ConstantExpr {
323    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
324        match self {
325            Self::Int(lit) => fmt::Debug::fmt(&**lit, f),
326            Self::Word(lit) => fmt::Debug::fmt(&**lit, f),
327            Self::Var(path) => fmt::Debug::fmt(path, f),
328            Self::String(name) => fmt::Debug::fmt(&**name, f),
329            Self::Hash(hash_kind, str) => {
330                f.debug_tuple("Hash").field(hash_kind).field(str).finish()
331            },
332            Self::BinaryOp { op, lhs, rhs, .. } => {
333                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
334            },
335        }
336    }
337}
338
339impl crate::prettier::PrettyPrint for ConstantExpr {
340    fn render(&self) -> crate::prettier::Document {
341        use crate::prettier::*;
342
343        match self {
344            Self::Int(literal) => literal.render(),
345            Self::Word(literal) => literal.render(),
346            Self::Var(path) => display(path),
347            Self::String(ident) => text(format!("\"{}\"", ident.as_str().escape_debug())),
348            Self::Hash(hash_kind, str) => flatten(
349                display(hash_kind)
350                    + const_text("(")
351                    + text(format!("\"{}\"", str.as_str().escape_debug()))
352                    + const_text(")"),
353            ),
354            Self::BinaryOp { op, lhs, rhs, .. } => {
355                let single_line = lhs.render() + display(op) + rhs.render();
356                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
357                single_line | multi_line
358            },
359        }
360    }
361}
362
363impl Spanned for ConstantExpr {
364    fn span(&self) -> SourceSpan {
365        match self {
366            Self::Int(spanned) => spanned.span(),
367            Self::Word(spanned) => spanned.span(),
368            Self::Hash(_, spanned) => spanned.span(),
369            Self::Var(spanned) => spanned.span(),
370            Self::String(spanned) => spanned.span(),
371            Self::BinaryOp { span, .. } => *span,
372        }
373    }
374}
375
376#[cfg(feature = "arbitrary")]
377impl proptest::arbitrary::Arbitrary for ConstantExpr {
378    type Parameters = ();
379
380    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
381        use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
382
383        prop_oneof![
384            any::<IntValue>().prop_map(|n| Self::Int(Span::unknown(n))),
385            crate::arbitrary::path::constant_path_random_length(0)
386                .prop_map(|p| Self::Var(Span::unknown(p))),
387            any::<(ConstantOp, IntValue, IntValue)>().prop_map(|(op, lhs, rhs)| Self::BinaryOp {
388                span: SourceSpan::UNKNOWN,
389                op,
390                lhs: Box::new(ConstantExpr::Int(Span::unknown(lhs))),
391                rhs: Box::new(ConstantExpr::Int(Span::unknown(rhs))),
392            }),
393            any::<Ident>().prop_map(Self::String),
394            any::<WordValue>().prop_map(|word| Self::Word(Span::unknown(word))),
395            any::<(HashKind, Ident)>().prop_map(|(kind, s)| Self::Hash(kind, s)),
396        ]
397        .boxed()
398    }
399
400    type Strategy = proptest::prelude::BoxedStrategy<Self>;
401}
402
403// CONSTANT OPERATION
404// ================================================================================================
405
406/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
407#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
408#[repr(u8)]
409#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
410#[cfg_attr(
411    all(feature = "arbitrary", test),
412    miden_test_serde_macros::serde_test(binary_serde(true))
413)]
414pub enum ConstantOp {
415    Add,
416    Sub,
417    Mul,
418    Div,
419    IntDiv,
420}
421
422impl ConstantOp {
423    const fn name(self) -> &'static str {
424        match self {
425            Self::Add => "Add",
426            Self::Sub => "Sub",
427            Self::Mul => "Mul",
428            Self::Div => "Div",
429            Self::IntDiv => "IntDiv",
430        }
431    }
432}
433
434impl fmt::Display for ConstantOp {
435    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
436        match self {
437            Self::Add => f.write_str("+"),
438            Self::Sub => f.write_str("-"),
439            Self::Mul => f.write_str("*"),
440            Self::Div => f.write_str("/"),
441            Self::IntDiv => f.write_str("//"),
442        }
443    }
444}
445
446impl ConstantOp {
447    const fn tag(&self) -> u8 {
448        // SAFETY: This is safe because we have given this enum a
449        // primitive representation with #[repr(u8)], with the first
450        // field of the underlying union-of-structs the discriminant
451        //
452        // See the section on "accessing the numeric value of the discriminant"
453        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
454        unsafe { *(self as *const Self).cast::<u8>() }
455    }
456}
457
458impl Serializable for ConstantOp {
459    fn write_into<W: ByteWriter>(&self, target: &mut W) {
460        target.write_u8(self.tag());
461    }
462}
463
464impl Deserializable for ConstantOp {
465    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
466        const ADD: u8 = ConstantOp::Add.tag();
467        const SUB: u8 = ConstantOp::Sub.tag();
468        const MUL: u8 = ConstantOp::Mul.tag();
469        const DIV: u8 = ConstantOp::Div.tag();
470        const INT_DIV: u8 = ConstantOp::IntDiv.tag();
471
472        match source.read_u8()? {
473            ADD => Ok(Self::Add),
474            SUB => Ok(Self::Sub),
475            MUL => Ok(Self::Mul),
476            DIV => Ok(Self::Div),
477            INT_DIV => Ok(Self::IntDiv),
478            invalid => Err(DeserializationError::InvalidValue(format!(
479                "unexpected ConstantOp tag: '{invalid}'"
480            ))),
481        }
482    }
483}
484
485#[cfg(feature = "arbitrary")]
486impl proptest::arbitrary::Arbitrary for ConstantOp {
487    type Parameters = ();
488
489    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
490        use proptest::{
491            prop_oneof,
492            strategy::{Just, Strategy},
493        };
494
495        prop_oneof![
496            Just(Self::Add),
497            Just(Self::Sub),
498            Just(Self::Mul),
499            Just(Self::Div),
500            Just(Self::IntDiv),
501        ]
502        .boxed()
503    }
504
505    type Strategy = proptest::prelude::BoxedStrategy<Self>;
506}
507
508// HASH KIND
509// ================================================================================================
510
511/// Represents the type of the final value to which some string value should be converted.
512#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
513#[repr(u8)]
514#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
515#[cfg_attr(
516    all(feature = "arbitrary", test),
517    miden_test_serde_macros::serde_test(binary_serde(true))
518)]
519pub enum HashKind {
520    /// Reduce a string to a word using Blake3 hash function
521    Word,
522    /// Reduce a string to a felt using Blake3 hash function (via 64-bit reduction)
523    Event,
524}
525
526impl HashKind {
527    const fn tag(&self) -> u8 {
528        // SAFETY: This is safe because we have given this enum a
529        // primitive representation with #[repr(u8)], with the first
530        // field of the underlying union-of-structs the discriminant
531        //
532        // See the section on "accessing the numeric value of the discriminant"
533        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
534        unsafe { *(self as *const Self).cast::<u8>() }
535    }
536}
537
538impl fmt::Display for HashKind {
539    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
540        match self {
541            Self::Word => f.write_str("word"),
542            Self::Event => f.write_str("event"),
543        }
544    }
545}
546
547#[cfg(feature = "arbitrary")]
548impl proptest::arbitrary::Arbitrary for HashKind {
549    type Parameters = ();
550
551    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
552        use proptest::{
553            prop_oneof,
554            strategy::{Just, Strategy},
555        };
556
557        prop_oneof![Just(Self::Word), Just(Self::Event),].boxed()
558    }
559
560    type Strategy = proptest::prelude::BoxedStrategy<Self>;
561}
562
563impl Serializable for HashKind {
564    fn write_into<W: ByteWriter>(&self, target: &mut W) {
565        target.write_u8(self.tag());
566    }
567}
568
569impl Deserializable for HashKind {
570    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
571        const WORD: u8 = HashKind::Word.tag();
572        const EVENT: u8 = HashKind::Event.tag();
573
574        match source.read_u8()? {
575            WORD => Ok(Self::Word),
576            EVENT => Ok(Self::Event),
577            invalid => Err(DeserializationError::InvalidValue(format!(
578                "unexpected HashKind tag: '{invalid}'"
579            ))),
580        }
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    fn nested_add_expr(depth: usize) -> ConstantExpr {
589        let mut expr = ConstantExpr::Int(Span::unknown(IntValue::from(1u8)));
590        for _ in 0..depth {
591            expr = ConstantExpr::BinaryOp {
592                span: SourceSpan::UNKNOWN,
593                op: ConstantOp::Add,
594                lhs: Box::new(expr),
595                rhs: Box::new(ConstantExpr::Int(Span::unknown(IntValue::from(1u8)))),
596            };
597        }
598        expr
599    }
600
601    #[test]
602    fn const_expr_fold_depth_boundary() {
603        let ok_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH);
604        assert!(ok_expr.try_fold().is_ok());
605
606        let err_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH + 1);
607        let err = err_expr.try_fold().expect_err("expected depth-exceeded error");
608        assert!(matches!(err, ParsingError::ConstExprDepthExceeded { max_depth, .. }
609                if max_depth == MAX_CONST_EXPR_FOLD_DEPTH));
610    }
611}