Skip to main content

miden_assembly_syntax/ast/constants/
expr.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::fmt;
3
4use miden_core::serde::{
5    ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
6};
7use miden_debug_types::{SourceSpan, Span, Spanned};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    Felt, Path,
13    ast::{ConstantValue, Ident},
14    parser::{IntValue, LiteralErrorKind, ParsingError, WordValue},
15};
16
17/// Maximum allowed nesting depth for constant-expression folding.
18///
19/// This limit is intended to prevent stack overflows from maliciously deep constant expressions
20/// while remaining far above typical constant-expression complexity.
21const MAX_CONST_EXPR_FOLD_DEPTH: usize = 256;
22
23// CONSTANT EXPRESSION
24// ================================================================================================
25
26/// Represents a constant expression or value in Miden Assembly syntax.
27#[derive(Clone)]
28#[repr(u8)]
29pub enum ConstantExpr {
30    /// A literal [`Felt`] value.
31    Int(Span<IntValue>),
32    /// A reference to another constant.
33    Var(Span<Arc<Path>>),
34    /// An binary arithmetic operator.
35    BinaryOp {
36        span: SourceSpan,
37        op: ConstantOp,
38        lhs: Box<ConstantExpr>,
39        rhs: Box<ConstantExpr>,
40    },
41    /// A plain spanned string.
42    String(Ident),
43    /// A literal ['WordValue'].
44    Word(Span<WordValue>),
45    /// A spanned string with a [`HashKind`] showing to which type of value the given string should
46    /// be hashed.
47    Hash(HashKind, Ident),
48}
49
50impl ConstantExpr {
51    /// Returns true if this expression is already evaluated to a concrete value
52    pub fn is_value(&self) -> bool {
53        matches!(self, Self::Int(_) | Self::Word(_) | Self::Hash(_, _) | Self::String(_))
54    }
55
56    /// Unwrap an [`IntValue`] from this expression or panic.
57    ///
58    /// This is used in places where we expect the expression to have been folded to an integer,
59    /// otherwise a bug occurred.
60    #[track_caller]
61    pub fn expect_int(&self) -> IntValue {
62        match self {
63            Self::Int(spanned) => spanned.into_inner(),
64            other => panic!("expected constant expression to be a literal, got {other:#?}"),
65        }
66    }
67
68    /// Unwrap a [`Felt`] value from this expression or panic.
69    ///
70    /// This is used in places where we expect the expression to have been folded to a felt value,
71    /// otherwise a bug occurred.
72    #[track_caller]
73    pub fn expect_felt(&self) -> Felt {
74        match self {
75            Self::Int(spanned) => Felt::new(spanned.inner().as_int()),
76            other => panic!("expected constant expression to be a literal, got {other:#?}"),
77        }
78    }
79
80    /// Unwrap a [`Arc<str>`] value from this expression or panic.
81    ///
82    /// This is used in places where we expect the expression to have been folded to a string value,
83    /// otherwise a bug occurred.
84    #[track_caller]
85    pub fn expect_string(&self) -> Arc<str> {
86        match self {
87            Self::String(spanned) => spanned.clone().into_inner(),
88            other => panic!("expected constant expression to be a string, got {other:#?}"),
89        }
90    }
91
92    /// Unwrap a [ConstantValue] from this expression or panic.
93    ///
94    /// This is used in places where we expect the expression to have been folded to a concrete
95    /// value, otherwise a bug occurred.
96    #[track_caller]
97    pub fn expect_value(&self) -> ConstantValue {
98        self.as_value().unwrap_or_else(|| {
99            panic!("expected constant expression to be a value, got {:#?}", self)
100        })
101    }
102
103    /// Try to convert this expression into a [ConstantValue], if the expression is a value.
104    ///
105    /// Returns `Err` if the expression cannot be represented as a [ConstantValue].
106    pub fn into_value(self) -> Result<ConstantValue, Self> {
107        match self {
108            Self::Int(value) => Ok(ConstantValue::Int(value)),
109            Self::String(value) => Ok(ConstantValue::String(value)),
110            Self::Word(value) => Ok(ConstantValue::Word(value)),
111            Self::Hash(kind, value) => Ok(ConstantValue::Hash(kind, value)),
112            expr @ (Self::BinaryOp { .. } | Self::Var(_)) => Err(expr),
113        }
114    }
115
116    /// Get the [ConstantValue] representation of this expression, if it is a value.
117    ///
118    /// Returns `None` if the expression cannot be represented as a [ConstantValue].
119    pub fn as_value(&self) -> Option<ConstantValue> {
120        match self {
121            Self::Int(value) => Some(ConstantValue::Int(*value)),
122            Self::String(value) => Some(ConstantValue::String(value.clone())),
123            Self::Word(value) => Some(ConstantValue::Word(*value)),
124            Self::Hash(kind, value) => Some(ConstantValue::Hash(*kind, value.clone())),
125            Self::BinaryOp { .. } | Self::Var(_) => None,
126        }
127    }
128
129    /// Attempt to fold to a single value.
130    ///
131    /// This will only succeed if the expression has no references to other constants.
132    ///
133    /// # Errors
134    /// Returns an error if an invalid expression is found while folding, such as division by zero.
135    pub fn try_fold(self) -> Result<Self, ParsingError> {
136        self.try_fold_with_depth(0)
137    }
138
139    fn try_fold_with_depth(self, depth: usize) -> Result<Self, ParsingError> {
140        if depth > MAX_CONST_EXPR_FOLD_DEPTH {
141            return Err(ParsingError::ConstExprDepthExceeded {
142                span: self.span(),
143                max_depth: MAX_CONST_EXPR_FOLD_DEPTH,
144            });
145        }
146
147        match self {
148            Self::String(_) | Self::Word(_) | Self::Int(_) | Self::Var(_) | Self::Hash(..) => {
149                Ok(self)
150            },
151            Self::BinaryOp { span, op, lhs, rhs } => {
152                if rhs.is_literal() {
153                    let rhs = Self::into_inner(rhs).try_fold_with_depth(depth + 1)?;
154                    match rhs {
155                        Self::String(ident) => {
156                            Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
157                        },
158                        Self::Int(rhs) => {
159                            let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
160                            match lhs {
161                                Self::String(ident) => {
162                                    Err(ParsingError::StringInArithmeticExpression {
163                                        span: ident.span(),
164                                    })
165                                },
166                                Self::Int(lhs) => {
167                                    let lhs = lhs.into_inner();
168                                    let rhs = rhs.into_inner();
169                                    let is_division =
170                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
171                                    let is_division_by_zero = is_division && rhs.as_int() == 0;
172                                    if is_division_by_zero {
173                                        return Err(ParsingError::DivisionByZero { span });
174                                    }
175                                    match op {
176                                        ConstantOp::Add => {
177                                            let value = lhs
178                                                .checked_add(rhs)
179                                                .ok_or(ParsingError::ConstantOverflow { span })?;
180                                            Some(Self::Int(Span::new(span, value)))
181                                        },
182                                        ConstantOp::Sub => {
183                                            let value = lhs
184                                                .checked_sub(rhs)
185                                                .ok_or(ParsingError::ConstantOverflow { span })?;
186                                            Some(Self::Int(Span::new(span, value)))
187                                        },
188                                        ConstantOp::Mul => {
189                                            let value = lhs
190                                                .checked_mul(rhs)
191                                                .ok_or(ParsingError::ConstantOverflow { span })?;
192                                            Some(Self::Int(Span::new(span, value)))
193                                        },
194                                        ConstantOp::IntDiv => {
195                                            let value = lhs
196                                                .checked_div(rhs)
197                                                .ok_or(ParsingError::ConstantOverflow { span })?;
198                                            Some(Self::Int(Span::new(span, value)))
199                                        },
200                                        ConstantOp::Div => {
201                                            let lhs = Felt::new(lhs.as_int());
202                                            let rhs = Felt::new(rhs.as_int());
203                                            let value = IntValue::from(lhs / rhs);
204                                            Some(Self::Int(Span::new(span, value)))
205                                        },
206                                    }
207                                    .ok_or(
208                                        ParsingError::InvalidLiteral {
209                                            span,
210                                            kind: LiteralErrorKind::FeltOverflow,
211                                        },
212                                    )
213                                },
214                                lhs => Ok(Self::BinaryOp {
215                                    span,
216                                    op,
217                                    lhs: Box::new(lhs),
218                                    rhs: Box::new(Self::Int(rhs)),
219                                }),
220                            }
221                        },
222                        rhs => {
223                            let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
224                            Ok(Self::BinaryOp {
225                                span,
226                                op,
227                                lhs: Box::new(lhs),
228                                rhs: Box::new(rhs),
229                            })
230                        },
231                    }
232                } else {
233                    let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
234                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
235                }
236            },
237        }
238    }
239
240    /// Get any references to other symbols present in this expression
241    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
242        use alloc::collections::BTreeSet;
243
244        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
245        let mut references = BTreeSet::new();
246
247        while let Some(ty) = worklist.pop() {
248            match ty {
249                Self::Int(_) | Self::Word(_) | Self::String(_) | Self::Hash(..) => {},
250                Self::Var(path) => {
251                    references.insert(path.clone());
252                },
253                Self::BinaryOp { lhs, rhs, .. } => {
254                    worklist.push(lhs);
255                    worklist.push(rhs);
256                },
257            }
258        }
259
260        references.into_iter().collect()
261    }
262
263    fn is_literal(&self) -> bool {
264        match self {
265            Self::Int(_) | Self::String(_) | Self::Word(_) | Self::Hash(..) => true,
266            Self::Var(_) => false,
267            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
268        }
269    }
270
271    #[inline(always)]
272    #[expect(clippy::boxed_local)]
273    fn into_inner(self: Box<Self>) -> Self {
274        *self
275    }
276}
277
278impl Eq for ConstantExpr {}
279
280impl PartialEq for ConstantExpr {
281    fn eq(&self, other: &Self) -> bool {
282        match (self, other) {
283            (Self::Int(x), Self::Int(y)) => x == y,
284            (Self::Int(_), _) => false,
285            (Self::Word(x), Self::Word(y)) => x == y,
286            (Self::Word(_), _) => false,
287            (Self::Var(x), Self::Var(y)) => x == y,
288            (Self::Var(_), _) => false,
289            (Self::String(x), Self::String(y)) => x == y,
290            (Self::String(_), _) => false,
291            (Self::Hash(x_hk, x_i), Self::Hash(y_hk, y_i)) => x_i == y_i && x_hk == y_hk,
292            (Self::Hash(..), _) => false,
293            (
294                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
295                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
296            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
297            (Self::BinaryOp { .. }, _) => false,
298        }
299    }
300}
301
302impl core::hash::Hash for ConstantExpr {
303    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
304        core::mem::discriminant(self).hash(state);
305        match self {
306            Self::Int(value) => value.hash(state),
307            Self::Word(value) => value.hash(state),
308            Self::String(value) => value.hash(state),
309            Self::Var(value) => value.hash(state),
310            Self::Hash(hash_kind, string) => {
311                hash_kind.hash(state);
312                string.hash(state);
313            },
314            Self::BinaryOp { op, lhs, rhs, .. } => {
315                op.hash(state);
316                lhs.hash(state);
317                rhs.hash(state);
318            },
319        }
320    }
321}
322
323impl fmt::Debug for ConstantExpr {
324    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
325        match self {
326            Self::Int(lit) => fmt::Debug::fmt(&**lit, f),
327            Self::Word(lit) => fmt::Debug::fmt(&**lit, f),
328            Self::Var(path) => fmt::Debug::fmt(path, f),
329            Self::String(name) => fmt::Debug::fmt(&**name, f),
330            Self::Hash(hash_kind, str) => {
331                f.debug_tuple("Hash").field(hash_kind).field(str).finish()
332            },
333            Self::BinaryOp { op, lhs, rhs, .. } => {
334                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
335            },
336        }
337    }
338}
339
340impl crate::prettier::PrettyPrint for ConstantExpr {
341    fn render(&self) -> crate::prettier::Document {
342        use crate::prettier::*;
343
344        match self {
345            Self::Int(literal) => literal.render(),
346            Self::Word(literal) => literal.render(),
347            Self::Var(path) => display(path),
348            Self::String(ident) => text(format!("\"{}\"", ident.as_str().escape_debug())),
349            Self::Hash(hash_kind, str) => flatten(
350                display(hash_kind)
351                    + const_text("(")
352                    + text(format!("\"{}\"", str.as_str().escape_debug()))
353                    + const_text(")"),
354            ),
355            Self::BinaryOp { op, lhs, rhs, .. } => {
356                let single_line = lhs.render() + display(op) + rhs.render();
357                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
358                single_line | multi_line
359            },
360        }
361    }
362}
363
364impl Spanned for ConstantExpr {
365    fn span(&self) -> SourceSpan {
366        match self {
367            Self::Int(spanned) => spanned.span(),
368            Self::Word(spanned) => spanned.span(),
369            Self::Hash(_, spanned) => spanned.span(),
370            Self::Var(spanned) => spanned.span(),
371            Self::String(spanned) => spanned.span(),
372            Self::BinaryOp { span, .. } => *span,
373        }
374    }
375}
376
377#[cfg(feature = "arbitrary")]
378impl proptest::arbitrary::Arbitrary for ConstantExpr {
379    type Parameters = ();
380
381    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
382        use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
383
384        prop_oneof![
385            any::<IntValue>().prop_map(|n| Self::Int(Span::unknown(n))),
386            crate::arbitrary::path::constant_path_random_length(0)
387                .prop_map(|p| Self::Var(Span::unknown(p))),
388            any::<(ConstantOp, IntValue, IntValue)>().prop_map(|(op, lhs, rhs)| Self::BinaryOp {
389                span: SourceSpan::UNKNOWN,
390                op,
391                lhs: Box::new(ConstantExpr::Int(Span::unknown(lhs))),
392                rhs: Box::new(ConstantExpr::Int(Span::unknown(rhs))),
393            }),
394            any::<Ident>().prop_map(Self::String),
395            any::<WordValue>().prop_map(|word| Self::Word(Span::unknown(word))),
396            any::<(HashKind, Ident)>().prop_map(|(kind, s)| Self::Hash(kind, s)),
397        ]
398        .boxed()
399    }
400
401    type Strategy = proptest::prelude::BoxedStrategy<Self>;
402}
403
404// CONSTANT OPERATION
405// ================================================================================================
406
407/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
408#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
409#[repr(u8)]
410#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
411pub enum ConstantOp {
412    Add,
413    Sub,
414    Mul,
415    Div,
416    IntDiv,
417}
418
419impl ConstantOp {
420    const fn name(&self) -> &'static str {
421        match self {
422            Self::Add => "Add",
423            Self::Sub => "Sub",
424            Self::Mul => "Mul",
425            Self::Div => "Div",
426            Self::IntDiv => "IntDiv",
427        }
428    }
429}
430
431impl fmt::Display for ConstantOp {
432    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
433        match self {
434            Self::Add => f.write_str("+"),
435            Self::Sub => f.write_str("-"),
436            Self::Mul => f.write_str("*"),
437            Self::Div => f.write_str("/"),
438            Self::IntDiv => f.write_str("//"),
439        }
440    }
441}
442
443impl ConstantOp {
444    const fn tag(&self) -> u8 {
445        // SAFETY: This is safe because we have given this enum a
446        // primitive representation with #[repr(u8)], with the first
447        // field of the underlying union-of-structs the discriminant
448        //
449        // See the section on "accessing the numeric value of the discriminant"
450        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
451        unsafe { *(self as *const Self).cast::<u8>() }
452    }
453}
454
455impl Serializable for ConstantOp {
456    fn write_into<W: ByteWriter>(&self, target: &mut W) {
457        target.write_u8(self.tag());
458    }
459}
460
461impl Deserializable for ConstantOp {
462    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
463        const ADD: u8 = ConstantOp::Add.tag();
464        const SUB: u8 = ConstantOp::Sub.tag();
465        const MUL: u8 = ConstantOp::Mul.tag();
466        const DIV: u8 = ConstantOp::Div.tag();
467        const INT_DIV: u8 = ConstantOp::IntDiv.tag();
468
469        match source.read_u8()? {
470            ADD => Ok(Self::Add),
471            SUB => Ok(Self::Sub),
472            MUL => Ok(Self::Mul),
473            DIV => Ok(Self::Div),
474            INT_DIV => Ok(Self::IntDiv),
475            invalid => Err(DeserializationError::InvalidValue(format!(
476                "unexpected ConstantOp tag: '{invalid}'"
477            ))),
478        }
479    }
480}
481
482#[cfg(feature = "arbitrary")]
483impl proptest::arbitrary::Arbitrary for ConstantOp {
484    type Parameters = ();
485
486    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
487        use proptest::{
488            prop_oneof,
489            strategy::{Just, Strategy},
490        };
491
492        prop_oneof![
493            Just(Self::Add),
494            Just(Self::Sub),
495            Just(Self::Mul),
496            Just(Self::Div),
497            Just(Self::IntDiv),
498        ]
499        .boxed()
500    }
501
502    type Strategy = proptest::prelude::BoxedStrategy<Self>;
503}
504
505// HASH KIND
506// ================================================================================================
507
508/// Represents the type of the final value to which some string value should be converted.
509#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
510#[repr(u8)]
511#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
512pub enum HashKind {
513    /// Reduce a string to a word using Blake3 hash function
514    Word,
515    /// Reduce a string to a felt using Blake3 hash function (via 64-bit reduction)
516    Event,
517}
518
519impl HashKind {
520    const fn tag(&self) -> u8 {
521        // SAFETY: This is safe because we have given this enum a
522        // primitive representation with #[repr(u8)], with the first
523        // field of the underlying union-of-structs the discriminant
524        //
525        // See the section on "accessing the numeric value of the discriminant"
526        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
527        unsafe { *(self as *const Self).cast::<u8>() }
528    }
529}
530
531impl fmt::Display for HashKind {
532    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
533        match self {
534            Self::Word => f.write_str("word"),
535            Self::Event => f.write_str("event"),
536        }
537    }
538}
539
540#[cfg(feature = "arbitrary")]
541impl proptest::arbitrary::Arbitrary for HashKind {
542    type Parameters = ();
543
544    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
545        use proptest::{
546            prop_oneof,
547            strategy::{Just, Strategy},
548        };
549
550        prop_oneof![Just(Self::Word), Just(Self::Event),].boxed()
551    }
552
553    type Strategy = proptest::prelude::BoxedStrategy<Self>;
554}
555
556impl Serializable for HashKind {
557    fn write_into<W: ByteWriter>(&self, target: &mut W) {
558        target.write_u8(self.tag());
559    }
560}
561
562impl Deserializable for HashKind {
563    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
564        const WORD: u8 = HashKind::Word.tag();
565        const EVENT: u8 = HashKind::Event.tag();
566
567        match source.read_u8()? {
568            WORD => Ok(Self::Word),
569            EVENT => Ok(Self::Event),
570            invalid => Err(DeserializationError::InvalidValue(format!(
571                "unexpected HashKind tag: '{invalid}'"
572            ))),
573        }
574    }
575}
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    fn nested_add_expr(depth: usize) -> ConstantExpr {
582        let mut expr = ConstantExpr::Int(Span::unknown(IntValue::from(1u8)));
583        for _ in 0..depth {
584            expr = ConstantExpr::BinaryOp {
585                span: SourceSpan::UNKNOWN,
586                op: ConstantOp::Add,
587                lhs: Box::new(expr),
588                rhs: Box::new(ConstantExpr::Int(Span::unknown(IntValue::from(1u8)))),
589            };
590        }
591        expr
592    }
593
594    #[test]
595    fn const_expr_fold_depth_boundary() {
596        let ok_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH);
597        assert!(ok_expr.try_fold().is_ok());
598
599        let err_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH + 1);
600        let err = err_expr.try_fold().expect_err("expected depth-exceeded error");
601        assert!(matches!(err, ParsingError::ConstExprDepthExceeded { max_depth, .. }
602                if max_depth == MAX_CONST_EXPR_FOLD_DEPTH));
603    }
604}