Skip to main content

miden_assembly_syntax/ast/constants/
expr.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::fmt;
3
4use miden_core::serde::{
5    ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
6};
7use miden_debug_types::{SourceSpan, Span, Spanned};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    Felt, Path,
13    ast::{ConstantValue, Ident},
14    parser::{IntValue, ParsingError, WordValue},
15};
16
17// CONSTANT EXPRESSION
18// ================================================================================================
19
20/// Represents a constant expression or value in Miden Assembly syntax.
21#[derive(Clone)]
22#[repr(u8)]
23pub enum ConstantExpr {
24    /// A literal [`Felt`] value.
25    Int(Span<IntValue>),
26    /// A reference to another constant.
27    Var(Span<Arc<Path>>),
28    /// An binary arithmetic operator.
29    BinaryOp {
30        span: SourceSpan,
31        op: ConstantOp,
32        lhs: Box<ConstantExpr>,
33        rhs: Box<ConstantExpr>,
34    },
35    /// A plain spanned string.
36    String(Ident),
37    /// A literal ['WordValue'].
38    Word(Span<WordValue>),
39    /// A spanned string with a [`HashKind`] showing to which type of value the given string should
40    /// be hashed.
41    Hash(HashKind, Ident),
42}
43
44impl ConstantExpr {
45    /// Returns true if this expression is already evaluated to a concrete value
46    pub fn is_value(&self) -> bool {
47        matches!(self, Self::Int(_) | Self::Word(_) | Self::Hash(_, _) | Self::String(_))
48    }
49
50    /// Unwrap an [`IntValue`] from this expression or panic.
51    ///
52    /// This is used in places where we expect the expression to have been folded to an integer,
53    /// otherwise a bug occurred.
54    #[track_caller]
55    pub fn expect_int(&self) -> IntValue {
56        match self {
57            Self::Int(spanned) => spanned.into_inner(),
58            other => panic!("expected constant expression to be a literal, got {other:#?}"),
59        }
60    }
61
62    /// Unwrap a [`Felt`] value from this expression or panic.
63    ///
64    /// This is used in places where we expect the expression to have been folded to a felt value,
65    /// otherwise a bug occurred.
66    #[track_caller]
67    pub fn expect_felt(&self) -> Felt {
68        match self {
69            Self::Int(spanned) => Felt::new(spanned.inner().as_int()),
70            other => panic!("expected constant expression to be a literal, got {other:#?}"),
71        }
72    }
73
74    /// Unwrap a [`Arc<str>`] value from this expression or panic.
75    ///
76    /// This is used in places where we expect the expression to have been folded to a string value,
77    /// otherwise a bug occurred.
78    #[track_caller]
79    pub fn expect_string(&self) -> Arc<str> {
80        match self {
81            Self::String(spanned) => spanned.clone().into_inner(),
82            other => panic!("expected constant expression to be a string, got {other:#?}"),
83        }
84    }
85
86    /// Unwrap a [ConstantValue] from this expression or panic.
87    ///
88    /// This is used in places where we expect the expression to have been folded to a concrete
89    /// value, otherwise a bug occurred.
90    #[track_caller]
91    pub fn expect_value(&self) -> ConstantValue {
92        self.as_value().unwrap_or_else(|| {
93            panic!("expected constant expression to be a value, got {:#?}", self)
94        })
95    }
96
97    /// Try to convert this expression into a [ConstantValue], if the expression is a value.
98    ///
99    /// Returns `Err` if the expression cannot be represented as a [ConstantValue].
100    pub fn into_value(self) -> Result<ConstantValue, Self> {
101        match self {
102            Self::Int(value) => Ok(ConstantValue::Int(value)),
103            Self::String(value) => Ok(ConstantValue::String(value)),
104            Self::Word(value) => Ok(ConstantValue::Word(value)),
105            Self::Hash(kind, value) => Ok(ConstantValue::Hash(kind, value)),
106            expr @ (Self::BinaryOp { .. } | Self::Var(_)) => Err(expr),
107        }
108    }
109
110    /// Get the [ConstantValue] representation of this expression, if it is a value.
111    ///
112    /// Returns `None` if the expression cannot be represented as a [ConstantValue].
113    pub fn as_value(&self) -> Option<ConstantValue> {
114        match self {
115            Self::Int(value) => Some(ConstantValue::Int(*value)),
116            Self::String(value) => Some(ConstantValue::String(value.clone())),
117            Self::Word(value) => Some(ConstantValue::Word(*value)),
118            Self::Hash(kind, value) => Some(ConstantValue::Hash(*kind, value.clone())),
119            Self::BinaryOp { .. } | Self::Var(_) => None,
120        }
121    }
122
123    /// Attempt to fold to a single value.
124    ///
125    /// This will only succeed if the expression has no references to other constants.
126    ///
127    /// # Errors
128    /// Returns an error if an invalid expression is found while folding, such as division by zero.
129    pub fn try_fold(self) -> Result<Self, ParsingError> {
130        match self {
131            Self::String(_) | Self::Word(_) | Self::Int(_) | Self::Var(_) | Self::Hash(..) => {
132                Ok(self)
133            },
134            Self::BinaryOp { span, op, lhs, rhs } => {
135                if rhs.is_literal() {
136                    let rhs = Self::into_inner(rhs).try_fold()?;
137                    match rhs {
138                        Self::String(ident) => {
139                            Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
140                        },
141                        Self::Int(rhs) => {
142                            let lhs = Self::into_inner(lhs).try_fold()?;
143                            match lhs {
144                                Self::String(ident) => {
145                                    Err(ParsingError::StringInArithmeticExpression {
146                                        span: ident.span(),
147                                    })
148                                },
149                                Self::Int(lhs) => {
150                                    let lhs = lhs.into_inner();
151                                    let rhs = rhs.into_inner();
152                                    let is_division =
153                                        matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
154                                    let is_division_by_zero = is_division && rhs == Felt::ZERO;
155                                    if is_division_by_zero {
156                                        return Err(ParsingError::DivisionByZero { span });
157                                    }
158                                    match op {
159                                        ConstantOp::Add => {
160                                            Ok(Self::Int(Span::new(span, lhs + rhs)))
161                                        },
162                                        ConstantOp::Sub => {
163                                            Ok(Self::Int(Span::new(span, lhs - rhs)))
164                                        },
165                                        ConstantOp::Mul => {
166                                            Ok(Self::Int(Span::new(span, lhs * rhs)))
167                                        },
168                                        ConstantOp::Div => {
169                                            Ok(Self::Int(Span::new(span, lhs / rhs)))
170                                        },
171                                        ConstantOp::IntDiv => {
172                                            Ok(Self::Int(Span::new(span, lhs / rhs)))
173                                        },
174                                    }
175                                },
176                                lhs => Ok(Self::BinaryOp {
177                                    span,
178                                    op,
179                                    lhs: Box::new(lhs),
180                                    rhs: Box::new(Self::Int(rhs)),
181                                }),
182                            }
183                        },
184                        rhs => {
185                            let lhs = Self::into_inner(lhs).try_fold()?;
186                            Ok(Self::BinaryOp {
187                                span,
188                                op,
189                                lhs: Box::new(lhs),
190                                rhs: Box::new(rhs),
191                            })
192                        },
193                    }
194                } else {
195                    let lhs = Self::into_inner(lhs).try_fold()?;
196                    Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
197                }
198            },
199        }
200    }
201
202    /// Get any references to other symbols present in this expression
203    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
204        use alloc::collections::BTreeSet;
205
206        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
207        let mut references = BTreeSet::new();
208
209        while let Some(ty) = worklist.pop() {
210            match ty {
211                Self::Int(_) | Self::Word(_) | Self::String(_) | Self::Hash(..) => continue,
212                Self::Var(path) => {
213                    references.insert(path.clone());
214                },
215                Self::BinaryOp { lhs, rhs, .. } => {
216                    worklist.push(lhs);
217                    worklist.push(rhs);
218                },
219            }
220        }
221
222        references.into_iter().collect()
223    }
224
225    fn is_literal(&self) -> bool {
226        match self {
227            Self::Int(_) | Self::String(_) | Self::Word(_) | Self::Hash(..) => true,
228            Self::Var(_) => false,
229            Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
230        }
231    }
232
233    #[inline(always)]
234    #[expect(clippy::boxed_local)]
235    fn into_inner(self: Box<Self>) -> Self {
236        *self
237    }
238}
239
240impl Eq for ConstantExpr {}
241
242impl PartialEq for ConstantExpr {
243    fn eq(&self, other: &Self) -> bool {
244        match (self, other) {
245            (Self::Int(x), Self::Int(y)) => x == y,
246            (Self::Int(_), _) => false,
247            (Self::Word(x), Self::Word(y)) => x == y,
248            (Self::Word(_), _) => false,
249            (Self::Var(x), Self::Var(y)) => x == y,
250            (Self::Var(_), _) => false,
251            (Self::String(x), Self::String(y)) => x == y,
252            (Self::String(_), _) => false,
253            (Self::Hash(x_hk, x_i), Self::Hash(y_hk, y_i)) => x_i == y_i && x_hk == y_hk,
254            (Self::Hash(..), _) => false,
255            (
256                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
257                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
258            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
259            (Self::BinaryOp { .. }, _) => false,
260        }
261    }
262}
263
264impl core::hash::Hash for ConstantExpr {
265    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
266        core::mem::discriminant(self).hash(state);
267        match self {
268            Self::Int(value) => value.hash(state),
269            Self::Word(value) => value.hash(state),
270            Self::String(value) => value.hash(state),
271            Self::Var(value) => value.hash(state),
272            Self::Hash(hash_kind, string) => {
273                hash_kind.hash(state);
274                string.hash(state);
275            },
276            Self::BinaryOp { op, lhs, rhs, .. } => {
277                op.hash(state);
278                lhs.hash(state);
279                rhs.hash(state);
280            },
281        }
282    }
283}
284
285impl fmt::Debug for ConstantExpr {
286    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
287        match self {
288            Self::Int(lit) => fmt::Debug::fmt(&**lit, f),
289            Self::Word(lit) => fmt::Debug::fmt(&**lit, f),
290            Self::Var(path) => fmt::Debug::fmt(path, f),
291            Self::String(name) => fmt::Debug::fmt(&**name, f),
292            Self::Hash(hash_kind, str) => {
293                f.debug_tuple("Hash").field(hash_kind).field(str).finish()
294            },
295            Self::BinaryOp { op, lhs, rhs, .. } => {
296                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
297            },
298        }
299    }
300}
301
302impl crate::prettier::PrettyPrint for ConstantExpr {
303    fn render(&self) -> crate::prettier::Document {
304        use crate::prettier::*;
305
306        match self {
307            Self::Int(literal) => literal.render(),
308            Self::Word(literal) => literal.render(),
309            Self::Var(path) => display(path),
310            Self::String(ident) => text(format!("\"{}\"", ident.as_str().escape_debug())),
311            Self::Hash(hash_kind, str) => flatten(
312                display(hash_kind)
313                    + const_text("(")
314                    + text(format!("\"{}\"", str.as_str().escape_debug()))
315                    + const_text(")"),
316            ),
317            Self::BinaryOp { op, lhs, rhs, .. } => {
318                let single_line = lhs.render() + display(op) + rhs.render();
319                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
320                single_line | multi_line
321            },
322        }
323    }
324}
325
326impl Spanned for ConstantExpr {
327    fn span(&self) -> SourceSpan {
328        match self {
329            Self::Int(spanned) => spanned.span(),
330            Self::Word(spanned) => spanned.span(),
331            Self::Hash(_, spanned) => spanned.span(),
332            Self::Var(spanned) => spanned.span(),
333            Self::String(spanned) => spanned.span(),
334            Self::BinaryOp { span, .. } => *span,
335        }
336    }
337}
338
339#[cfg(feature = "arbitrary")]
340impl proptest::arbitrary::Arbitrary for ConstantExpr {
341    type Parameters = ();
342
343    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
344        use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
345
346        prop_oneof![
347            any::<IntValue>().prop_map(|n| Self::Int(Span::unknown(n))),
348            crate::arbitrary::path::constant_path_random_length(0)
349                .prop_map(|p| Self::Var(Span::unknown(p))),
350            any::<(ConstantOp, IntValue, IntValue)>().prop_map(|(op, lhs, rhs)| Self::BinaryOp {
351                span: SourceSpan::UNKNOWN,
352                op,
353                lhs: Box::new(ConstantExpr::Int(Span::unknown(lhs))),
354                rhs: Box::new(ConstantExpr::Int(Span::unknown(rhs))),
355            }),
356            any::<Ident>().prop_map(Self::String),
357            any::<WordValue>().prop_map(|word| Self::Word(Span::unknown(word))),
358            any::<(HashKind, Ident)>().prop_map(|(kind, s)| Self::Hash(kind, s)),
359        ]
360        .boxed()
361    }
362
363    type Strategy = proptest::prelude::BoxedStrategy<Self>;
364}
365
366// CONSTANT OPERATION
367// ================================================================================================
368
369/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
370#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
371#[repr(u8)]
372#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
373pub enum ConstantOp {
374    Add,
375    Sub,
376    Mul,
377    Div,
378    IntDiv,
379}
380
381impl ConstantOp {
382    const fn name(&self) -> &'static str {
383        match self {
384            Self::Add => "Add",
385            Self::Sub => "Sub",
386            Self::Mul => "Mul",
387            Self::Div => "Div",
388            Self::IntDiv => "IntDiv",
389        }
390    }
391}
392
393impl fmt::Display for ConstantOp {
394    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
395        match self {
396            Self::Add => f.write_str("+"),
397            Self::Sub => f.write_str("-"),
398            Self::Mul => f.write_str("*"),
399            Self::Div => f.write_str("/"),
400            Self::IntDiv => f.write_str("//"),
401        }
402    }
403}
404
405impl ConstantOp {
406    const fn tag(&self) -> u8 {
407        // SAFETY: This is safe because we have given this enum a
408        // primitive representation with #[repr(u8)], with the first
409        // field of the underlying union-of-structs the discriminant
410        //
411        // See the section on "accessing the numeric value of the discriminant"
412        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
413        unsafe { *(self as *const Self).cast::<u8>() }
414    }
415}
416
417impl Serializable for ConstantOp {
418    fn write_into<W: ByteWriter>(&self, target: &mut W) {
419        target.write_u8(self.tag());
420    }
421}
422
423impl Deserializable for ConstantOp {
424    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
425        const ADD: u8 = ConstantOp::Add.tag();
426        const SUB: u8 = ConstantOp::Sub.tag();
427        const MUL: u8 = ConstantOp::Mul.tag();
428        const DIV: u8 = ConstantOp::Div.tag();
429        const INT_DIV: u8 = ConstantOp::IntDiv.tag();
430
431        match source.read_u8()? {
432            ADD => Ok(Self::Add),
433            SUB => Ok(Self::Sub),
434            MUL => Ok(Self::Mul),
435            DIV => Ok(Self::Div),
436            INT_DIV => Ok(Self::IntDiv),
437            invalid => Err(DeserializationError::InvalidValue(format!(
438                "unexpected ConstantOp tag: '{invalid}'"
439            ))),
440        }
441    }
442}
443
444#[cfg(feature = "arbitrary")]
445impl proptest::arbitrary::Arbitrary for ConstantOp {
446    type Parameters = ();
447
448    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
449        use proptest::{
450            prop_oneof,
451            strategy::{Just, Strategy},
452        };
453
454        prop_oneof![
455            Just(Self::Add),
456            Just(Self::Sub),
457            Just(Self::Mul),
458            Just(Self::Div),
459            Just(Self::IntDiv),
460        ]
461        .boxed()
462    }
463
464    type Strategy = proptest::prelude::BoxedStrategy<Self>;
465}
466
467// HASH KIND
468// ================================================================================================
469
470/// Represents the type of the final value to which some string value should be converted.
471#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
472#[repr(u8)]
473#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
474pub enum HashKind {
475    /// Reduce a string to a word using Blake3 hash function
476    Word,
477    /// Reduce a string to a felt using Blake3 hash function (via 64-bit reduction)
478    Event,
479}
480
481impl HashKind {
482    const fn tag(&self) -> u8 {
483        // SAFETY: This is safe because we have given this enum a
484        // primitive representation with #[repr(u8)], with the first
485        // field of the underlying union-of-structs the discriminant
486        //
487        // See the section on "accessing the numeric value of the discriminant"
488        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
489        unsafe { *(self as *const Self).cast::<u8>() }
490    }
491}
492
493impl fmt::Display for HashKind {
494    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
495        match self {
496            Self::Word => f.write_str("word"),
497            Self::Event => f.write_str("event"),
498        }
499    }
500}
501
502#[cfg(feature = "arbitrary")]
503impl proptest::arbitrary::Arbitrary for HashKind {
504    type Parameters = ();
505
506    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
507        use proptest::{
508            prop_oneof,
509            strategy::{Just, Strategy},
510        };
511
512        prop_oneof![Just(Self::Word), Just(Self::Event),].boxed()
513    }
514
515    type Strategy = proptest::prelude::BoxedStrategy<Self>;
516}
517
518impl Serializable for HashKind {
519    fn write_into<W: ByteWriter>(&self, target: &mut W) {
520        target.write_u8(self.tag());
521    }
522}
523
524impl Deserializable for HashKind {
525    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
526        const WORD: u8 = HashKind::Word.tag();
527        const EVENT: u8 = HashKind::Event.tag();
528
529        match source.read_u8()? {
530            WORD => Ok(Self::Word),
531            EVENT => Ok(Self::Event),
532            invalid => Err(DeserializationError::InvalidValue(format!(
533                "unexpected HashKind tag: '{invalid}'"
534            ))),
535        }
536    }
537}