Skip to main content

miden_assembly_syntax/ast/constants/
expr.rs

1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::fmt;
3
4use miden_core::serde::{
5    ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
6};
7use miden_debug_types::{SourceSpan, Span, Spanned};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    Felt, Path,
13    ast::{ConstantValue, Ident},
14    parser::{IntValue, WordValue},
15};
16
17// CONSTANT EXPRESSION
18// ================================================================================================
19
20/// Represents a constant expression or value in Miden Assembly syntax.
21#[derive(Clone)]
22#[repr(u8)]
23pub enum ConstantExpr {
24    /// A literal [`Felt`] value.
25    Int(Span<IntValue>),
26    /// A reference to another constant.
27    Var(Span<Arc<Path>>),
28    /// An binary arithmetic operator.
29    BinaryOp {
30        span: SourceSpan,
31        op: ConstantOp,
32        lhs: Box<ConstantExpr>,
33        rhs: Box<ConstantExpr>,
34    },
35    /// A plain spanned string.
36    String(Ident),
37    /// A literal ['WordValue'].
38    Word(Span<WordValue>),
39    /// A spanned string with a [`HashKind`] showing to which type of value the given string should
40    /// be hashed.
41    Hash(HashKind, Ident),
42}
43
44impl ConstantExpr {
45    /// Returns true if this expression is already evaluated to a concrete value
46    pub fn is_value(&self) -> bool {
47        matches!(self, Self::Int(_) | Self::Word(_) | Self::Hash(_, _) | Self::String(_))
48    }
49
50    /// Unwrap an [`IntValue`] from this expression or panic.
51    ///
52    /// This is used in places where we expect the expression to have been folded to an integer,
53    /// otherwise a bug occurred.
54    #[track_caller]
55    pub fn expect_int(&self) -> IntValue {
56        match self {
57            Self::Int(spanned) => spanned.into_inner(),
58            other => panic!("expected constant expression to be a literal, got {other:#?}"),
59        }
60    }
61
62    /// Unwrap a [`Felt`] value from this expression or panic.
63    ///
64    /// This is used in places where we expect the expression to have been folded to a felt value,
65    /// otherwise a bug occurred.
66    #[track_caller]
67    pub fn expect_felt(&self) -> Felt {
68        match self {
69            Self::Int(spanned) => Felt::new_unchecked(spanned.inner().as_int()),
70            other => panic!("expected constant expression to be a literal, got {other:#?}"),
71        }
72    }
73
74    /// Unwrap a [`Arc<str>`] value from this expression or panic.
75    ///
76    /// This is used in places where we expect the expression to have been folded to a string value,
77    /// otherwise a bug occurred.
78    #[track_caller]
79    pub fn expect_string(&self) -> Arc<str> {
80        match self {
81            Self::String(spanned) => spanned.clone().into_inner(),
82            other => panic!("expected constant expression to be a string, got {other:#?}"),
83        }
84    }
85
86    /// Unwrap a [ConstantValue] from this expression or panic.
87    ///
88    /// This is used in places where we expect the expression to have been folded to a concrete
89    /// value, otherwise a bug occurred.
90    #[track_caller]
91    pub fn expect_value(&self) -> ConstantValue {
92        self.as_value()
93            .unwrap_or_else(|| panic!("expected constant expression to be a value, got {self:#?}"))
94    }
95
96    /// Try to convert this expression into a [ConstantValue], if the expression is a value.
97    ///
98    /// Returns `Err` if the expression cannot be represented as a [ConstantValue].
99    pub fn into_value(self) -> Result<ConstantValue, Self> {
100        match self {
101            Self::Int(value) => Ok(ConstantValue::Int(value)),
102            Self::String(value) => Ok(ConstantValue::String(value)),
103            Self::Word(value) => Ok(ConstantValue::Word(value)),
104            Self::Hash(kind, value) => Ok(ConstantValue::Hash(kind, value)),
105            expr @ (Self::BinaryOp { .. } | Self::Var(_)) => Err(expr),
106        }
107    }
108
109    /// Get the [ConstantValue] representation of this expression, if it is a value.
110    ///
111    /// Returns `None` if the expression cannot be represented as a [ConstantValue].
112    pub fn as_value(&self) -> Option<ConstantValue> {
113        match self {
114            Self::Int(value) => Some(ConstantValue::Int(*value)),
115            Self::String(value) => Some(ConstantValue::String(value.clone())),
116            Self::Word(value) => Some(ConstantValue::Word(*value)),
117            Self::Hash(kind, value) => Some(ConstantValue::Hash(*kind, value.clone())),
118            Self::BinaryOp { .. } | Self::Var(_) => None,
119        }
120    }
121
122    /// Get any references to other symbols present in this expression
123    pub fn references(&self) -> Vec<Span<Arc<Path>>> {
124        use alloc::collections::BTreeSet;
125
126        let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
127        let mut references = BTreeSet::new();
128
129        while let Some(ty) = worklist.pop() {
130            match ty {
131                Self::Int(_) | Self::Word(_) | Self::String(_) | Self::Hash(..) => {},
132                Self::Var(path) => {
133                    references.insert(path.clone());
134                },
135                Self::BinaryOp { lhs, rhs, .. } => {
136                    worklist.push(lhs);
137                    worklist.push(rhs);
138                },
139            }
140        }
141
142        references.into_iter().collect()
143    }
144}
145
146impl Eq for ConstantExpr {}
147
148impl PartialEq for ConstantExpr {
149    fn eq(&self, other: &Self) -> bool {
150        match (self, other) {
151            (Self::Int(x), Self::Int(y)) => x == y,
152            (Self::Int(_), _) => false,
153            (Self::Word(x), Self::Word(y)) => x == y,
154            (Self::Word(_), _) => false,
155            (Self::Var(x), Self::Var(y)) => x == y,
156            (Self::Var(_), _) => false,
157            (Self::String(x), Self::String(y)) => x == y,
158            (Self::String(_), _) => false,
159            (Self::Hash(x_hk, x_i), Self::Hash(y_hk, y_i)) => x_i == y_i && x_hk == y_hk,
160            (Self::Hash(..), _) => false,
161            (
162                Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
163                Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
164            ) => lop == rop && llhs == rlhs && lrhs == rrhs,
165            (Self::BinaryOp { .. }, _) => false,
166        }
167    }
168}
169
170impl core::hash::Hash for ConstantExpr {
171    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
172        core::mem::discriminant(self).hash(state);
173        match self {
174            Self::Int(value) => value.hash(state),
175            Self::Word(value) => value.hash(state),
176            Self::String(value) => value.hash(state),
177            Self::Var(value) => value.hash(state),
178            Self::Hash(hash_kind, string) => {
179                hash_kind.hash(state);
180                string.hash(state);
181            },
182            Self::BinaryOp { op, lhs, rhs, .. } => {
183                op.hash(state);
184                lhs.hash(state);
185                rhs.hash(state);
186            },
187        }
188    }
189}
190
191impl fmt::Debug for ConstantExpr {
192    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193        match self {
194            Self::Int(lit) => fmt::Debug::fmt(&**lit, f),
195            Self::Word(lit) => fmt::Debug::fmt(&**lit, f),
196            Self::Var(path) => fmt::Debug::fmt(path, f),
197            Self::String(name) => fmt::Debug::fmt(&**name, f),
198            Self::Hash(hash_kind, str) => {
199                f.debug_tuple("Hash").field(hash_kind).field(str).finish()
200            },
201            Self::BinaryOp { op, lhs, rhs, .. } => {
202                f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
203            },
204        }
205    }
206}
207
208impl crate::prettier::PrettyPrint for ConstantExpr {
209    fn render(&self) -> crate::prettier::Document {
210        use crate::prettier::*;
211
212        match self {
213            Self::Int(literal) => literal.render(),
214            Self::Word(literal) => literal.render(),
215            Self::Var(path) => display(path),
216            Self::String(ident) => text(format!("\"{}\"", ident.as_str().escape_debug())),
217            Self::Hash(hash_kind, str) => flatten(
218                display(hash_kind)
219                    + const_text("(")
220                    + text(format!("\"{}\"", str.as_str().escape_debug()))
221                    + const_text(")"),
222            ),
223            Self::BinaryOp { op, lhs, rhs, .. } => {
224                let single_line = lhs.render() + display(op) + rhs.render();
225                let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
226                single_line | multi_line
227            },
228        }
229    }
230}
231
232impl Spanned for ConstantExpr {
233    fn span(&self) -> SourceSpan {
234        match self {
235            Self::Int(spanned) => spanned.span(),
236            Self::Word(spanned) => spanned.span(),
237            Self::Hash(_, spanned) => spanned.span(),
238            Self::Var(spanned) => spanned.span(),
239            Self::String(spanned) => spanned.span(),
240            Self::BinaryOp { span, .. } => *span,
241        }
242    }
243}
244
245#[cfg(feature = "arbitrary")]
246impl proptest::arbitrary::Arbitrary for ConstantExpr {
247    type Parameters = ();
248
249    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
250        use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
251
252        prop_oneof![
253            any::<IntValue>().prop_map(|n| Self::Int(Span::unknown(n))),
254            crate::arbitrary::path::constant_path_random_length(0)
255                .prop_map(|p| Self::Var(Span::unknown(p))),
256            any::<(ConstantOp, IntValue, IntValue)>().prop_map(|(op, lhs, rhs)| Self::BinaryOp {
257                span: SourceSpan::UNKNOWN,
258                op,
259                lhs: Box::new(ConstantExpr::Int(Span::unknown(lhs))),
260                rhs: Box::new(ConstantExpr::Int(Span::unknown(rhs))),
261            }),
262            any::<Ident>().prop_map(Self::String),
263            any::<WordValue>().prop_map(|word| Self::Word(Span::unknown(word))),
264            any::<(HashKind, Ident)>().prop_map(|(kind, s)| Self::Hash(kind, s)),
265        ]
266        .boxed()
267    }
268
269    type Strategy = proptest::prelude::BoxedStrategy<Self>;
270}
271
272// CONSTANT OPERATION
273// ================================================================================================
274
275/// Represents the set of binary arithmetic operators supported in Miden Assembly syntax.
276#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
277#[repr(u8)]
278#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
279#[cfg_attr(
280    all(feature = "arbitrary", test),
281    miden_test_serde_macros::serde_test(binary_serde(true))
282)]
283pub enum ConstantOp {
284    Add,
285    Sub,
286    Mul,
287    Div,
288    IntDiv,
289}
290
291impl ConstantOp {
292    const fn name(self) -> &'static str {
293        match self {
294            Self::Add => "Add",
295            Self::Sub => "Sub",
296            Self::Mul => "Mul",
297            Self::Div => "Div",
298            Self::IntDiv => "IntDiv",
299        }
300    }
301}
302
303impl fmt::Display for ConstantOp {
304    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
305        match self {
306            Self::Add => f.write_str("+"),
307            Self::Sub => f.write_str("-"),
308            Self::Mul => f.write_str("*"),
309            Self::Div => f.write_str("/"),
310            Self::IntDiv => f.write_str("//"),
311        }
312    }
313}
314
315impl ConstantOp {
316    const fn tag(&self) -> u8 {
317        // SAFETY: This is safe because we have given this enum a
318        // primitive representation with #[repr(u8)], with the first
319        // field of the underlying union-of-structs the discriminant
320        //
321        // See the section on "accessing the numeric value of the discriminant"
322        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
323        unsafe { *(self as *const Self).cast::<u8>() }
324    }
325}
326
327impl Serializable for ConstantOp {
328    fn write_into<W: ByteWriter>(&self, target: &mut W) {
329        target.write_u8(self.tag());
330    }
331}
332
333impl Deserializable for ConstantOp {
334    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
335        const ADD: u8 = ConstantOp::Add.tag();
336        const SUB: u8 = ConstantOp::Sub.tag();
337        const MUL: u8 = ConstantOp::Mul.tag();
338        const DIV: u8 = ConstantOp::Div.tag();
339        const INT_DIV: u8 = ConstantOp::IntDiv.tag();
340
341        match source.read_u8()? {
342            ADD => Ok(Self::Add),
343            SUB => Ok(Self::Sub),
344            MUL => Ok(Self::Mul),
345            DIV => Ok(Self::Div),
346            INT_DIV => Ok(Self::IntDiv),
347            invalid => Err(DeserializationError::InvalidValue(format!(
348                "unexpected ConstantOp tag: '{invalid}'"
349            ))),
350        }
351    }
352}
353
354#[cfg(feature = "arbitrary")]
355impl proptest::arbitrary::Arbitrary for ConstantOp {
356    type Parameters = ();
357
358    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
359        use proptest::{
360            prop_oneof,
361            strategy::{Just, Strategy},
362        };
363
364        prop_oneof![
365            Just(Self::Add),
366            Just(Self::Sub),
367            Just(Self::Mul),
368            Just(Self::Div),
369            Just(Self::IntDiv),
370        ]
371        .boxed()
372    }
373
374    type Strategy = proptest::prelude::BoxedStrategy<Self>;
375}
376
377// HASH KIND
378// ================================================================================================
379
380/// Represents the type of the final value to which some string value should be converted.
381#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
382#[repr(u8)]
383#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
384#[cfg_attr(
385    all(feature = "arbitrary", test),
386    miden_test_serde_macros::serde_test(binary_serde(true))
387)]
388pub enum HashKind {
389    /// Reduce a string to a word using Blake3 hash function
390    Word,
391    /// Reduce a string to a felt using Blake3 hash function (via 64-bit reduction)
392    Event,
393}
394
395impl HashKind {
396    const fn tag(&self) -> u8 {
397        // SAFETY: This is safe because we have given this enum a
398        // primitive representation with #[repr(u8)], with the first
399        // field of the underlying union-of-structs the discriminant
400        //
401        // See the section on "accessing the numeric value of the discriminant"
402        // here: https://doc.rust-lang.org/std/mem/fn.discriminant.html
403        unsafe { *(self as *const Self).cast::<u8>() }
404    }
405}
406
407impl fmt::Display for HashKind {
408    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
409        match self {
410            Self::Word => f.write_str("word"),
411            Self::Event => f.write_str("event"),
412        }
413    }
414}
415
416#[cfg(feature = "arbitrary")]
417impl proptest::arbitrary::Arbitrary for HashKind {
418    type Parameters = ();
419
420    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
421        use proptest::{
422            prop_oneof,
423            strategy::{Just, Strategy},
424        };
425
426        prop_oneof![Just(Self::Word), Just(Self::Event),].boxed()
427    }
428
429    type Strategy = proptest::prelude::BoxedStrategy<Self>;
430}
431
432impl Serializable for HashKind {
433    fn write_into<W: ByteWriter>(&self, target: &mut W) {
434        target.write_u8(self.tag());
435    }
436}
437
438impl Deserializable for HashKind {
439    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
440        const WORD: u8 = HashKind::Word.tag();
441        const EVENT: u8 = HashKind::Event.tag();
442
443        match source.read_u8()? {
444            WORD => Ok(Self::Word),
445            EVENT => Ok(Self::Event),
446            invalid => Err(DeserializationError::InvalidValue(format!(
447                "unexpected HashKind tag: '{invalid}'"
448            ))),
449        }
450    }
451}