1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::fmt;
3
4use miden_core::serde::{
5 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
6};
7use miden_debug_types::{SourceSpan, Span, Spanned};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 Felt, Path,
13 ast::{ConstantValue, Ident},
14 parser::{IntValue, LiteralErrorKind, ParsingError, WordValue},
15};
16
17const MAX_CONST_EXPR_FOLD_DEPTH: usize = 256;
22
23#[derive(Clone)]
28#[repr(u8)]
29pub enum ConstantExpr {
30 Int(Span<IntValue>),
32 Var(Span<Arc<Path>>),
34 BinaryOp {
36 span: SourceSpan,
37 op: ConstantOp,
38 lhs: Box<ConstantExpr>,
39 rhs: Box<ConstantExpr>,
40 },
41 String(Ident),
43 Word(Span<WordValue>),
45 Hash(HashKind, Ident),
48}
49
50impl ConstantExpr {
51 pub fn is_value(&self) -> bool {
53 matches!(self, Self::Int(_) | Self::Word(_) | Self::Hash(_, _) | Self::String(_))
54 }
55
56 #[track_caller]
61 pub fn expect_int(&self) -> IntValue {
62 match self {
63 Self::Int(spanned) => spanned.into_inner(),
64 other => panic!("expected constant expression to be a literal, got {other:#?}"),
65 }
66 }
67
68 #[track_caller]
73 pub fn expect_felt(&self) -> Felt {
74 match self {
75 Self::Int(spanned) => Felt::new_unchecked(spanned.inner().as_int()),
76 other => panic!("expected constant expression to be a literal, got {other:#?}"),
77 }
78 }
79
80 #[track_caller]
85 pub fn expect_string(&self) -> Arc<str> {
86 match self {
87 Self::String(spanned) => spanned.clone().into_inner(),
88 other => panic!("expected constant expression to be a string, got {other:#?}"),
89 }
90 }
91
92 #[track_caller]
97 pub fn expect_value(&self) -> ConstantValue {
98 self.as_value()
99 .unwrap_or_else(|| panic!("expected constant expression to be a value, got {self:#?}"))
100 }
101
102 pub fn into_value(self) -> Result<ConstantValue, Self> {
106 match self {
107 Self::Int(value) => Ok(ConstantValue::Int(value)),
108 Self::String(value) => Ok(ConstantValue::String(value)),
109 Self::Word(value) => Ok(ConstantValue::Word(value)),
110 Self::Hash(kind, value) => Ok(ConstantValue::Hash(kind, value)),
111 expr @ (Self::BinaryOp { .. } | Self::Var(_)) => Err(expr),
112 }
113 }
114
115 pub fn as_value(&self) -> Option<ConstantValue> {
119 match self {
120 Self::Int(value) => Some(ConstantValue::Int(*value)),
121 Self::String(value) => Some(ConstantValue::String(value.clone())),
122 Self::Word(value) => Some(ConstantValue::Word(*value)),
123 Self::Hash(kind, value) => Some(ConstantValue::Hash(*kind, value.clone())),
124 Self::BinaryOp { .. } | Self::Var(_) => None,
125 }
126 }
127
128 pub fn try_fold(self) -> Result<Self, ParsingError> {
135 self.try_fold_with_depth(0)
136 }
137
138 fn try_fold_with_depth(self, depth: usize) -> Result<Self, ParsingError> {
139 if depth > MAX_CONST_EXPR_FOLD_DEPTH {
140 return Err(ParsingError::ConstExprDepthExceeded {
141 span: self.span(),
142 max_depth: MAX_CONST_EXPR_FOLD_DEPTH,
143 });
144 }
145
146 match self {
147 Self::String(_) | Self::Word(_) | Self::Int(_) | Self::Var(_) | Self::Hash(..) => {
148 Ok(self)
149 },
150 Self::BinaryOp { span, op, lhs, rhs } => {
151 if rhs.is_literal() {
152 let rhs = Self::into_inner(rhs).try_fold_with_depth(depth + 1)?;
153 match rhs {
154 Self::String(ident) => {
155 Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
156 },
157 Self::Int(rhs) => {
158 let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
159 match lhs {
160 Self::String(ident) => {
161 Err(ParsingError::StringInArithmeticExpression {
162 span: ident.span(),
163 })
164 },
165 Self::Int(lhs) => {
166 let lhs = lhs.into_inner();
167 let rhs = rhs.into_inner();
168 let is_division =
169 matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
170 let is_division_by_zero = is_division && rhs.as_int() == 0;
171 if is_division_by_zero {
172 return Err(ParsingError::DivisionByZero { span });
173 }
174 match op {
175 ConstantOp::Add => {
176 let value = lhs
177 .checked_add(rhs)
178 .ok_or(ParsingError::ConstantOverflow { span })?;
179 Some(Self::Int(Span::new(span, value)))
180 },
181 ConstantOp::Sub => {
182 let value = lhs
183 .checked_sub(rhs)
184 .ok_or(ParsingError::ConstantOverflow { span })?;
185 Some(Self::Int(Span::new(span, value)))
186 },
187 ConstantOp::Mul => {
188 let value = lhs
189 .checked_mul(rhs)
190 .ok_or(ParsingError::ConstantOverflow { span })?;
191 Some(Self::Int(Span::new(span, value)))
192 },
193 ConstantOp::IntDiv => {
194 let value = lhs
195 .checked_div(rhs)
196 .ok_or(ParsingError::ConstantOverflow { span })?;
197 Some(Self::Int(Span::new(span, value)))
198 },
199 ConstantOp::Div => {
200 let lhs = Felt::new_unchecked(lhs.as_int());
201 let rhs = Felt::new_unchecked(rhs.as_int());
202 let value = IntValue::from(lhs / rhs);
203 Some(Self::Int(Span::new(span, value)))
204 },
205 }
206 .ok_or(
207 ParsingError::InvalidLiteral {
208 span,
209 kind: LiteralErrorKind::FeltOverflow,
210 },
211 )
212 },
213 lhs => Ok(Self::BinaryOp {
214 span,
215 op,
216 lhs: Box::new(lhs),
217 rhs: Box::new(Self::Int(rhs)),
218 }),
219 }
220 },
221 rhs => {
222 let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
223 Ok(Self::BinaryOp {
224 span,
225 op,
226 lhs: Box::new(lhs),
227 rhs: Box::new(rhs),
228 })
229 },
230 }
231 } else {
232 let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
233 Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
234 }
235 },
236 }
237 }
238
239 pub fn references(&self) -> Vec<Span<Arc<Path>>> {
241 use alloc::collections::BTreeSet;
242
243 let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
244 let mut references = BTreeSet::new();
245
246 while let Some(ty) = worklist.pop() {
247 match ty {
248 Self::Int(_) | Self::Word(_) | Self::String(_) | Self::Hash(..) => {},
249 Self::Var(path) => {
250 references.insert(path.clone());
251 },
252 Self::BinaryOp { lhs, rhs, .. } => {
253 worklist.push(lhs);
254 worklist.push(rhs);
255 },
256 }
257 }
258
259 references.into_iter().collect()
260 }
261
262 fn is_literal(&self) -> bool {
263 match self {
264 Self::Int(_) | Self::String(_) | Self::Word(_) | Self::Hash(..) => true,
265 Self::Var(_) => false,
266 Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
267 }
268 }
269
270 #[inline(always)]
271 #[expect(clippy::boxed_local)]
272 fn into_inner(self: Box<Self>) -> Self {
273 *self
274 }
275}
276
277impl Eq for ConstantExpr {}
278
279impl PartialEq for ConstantExpr {
280 fn eq(&self, other: &Self) -> bool {
281 match (self, other) {
282 (Self::Int(x), Self::Int(y)) => x == y,
283 (Self::Int(_), _) => false,
284 (Self::Word(x), Self::Word(y)) => x == y,
285 (Self::Word(_), _) => false,
286 (Self::Var(x), Self::Var(y)) => x == y,
287 (Self::Var(_), _) => false,
288 (Self::String(x), Self::String(y)) => x == y,
289 (Self::String(_), _) => false,
290 (Self::Hash(x_hk, x_i), Self::Hash(y_hk, y_i)) => x_i == y_i && x_hk == y_hk,
291 (Self::Hash(..), _) => false,
292 (
293 Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
294 Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
295 ) => lop == rop && llhs == rlhs && lrhs == rrhs,
296 (Self::BinaryOp { .. }, _) => false,
297 }
298 }
299}
300
301impl core::hash::Hash for ConstantExpr {
302 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
303 core::mem::discriminant(self).hash(state);
304 match self {
305 Self::Int(value) => value.hash(state),
306 Self::Word(value) => value.hash(state),
307 Self::String(value) => value.hash(state),
308 Self::Var(value) => value.hash(state),
309 Self::Hash(hash_kind, string) => {
310 hash_kind.hash(state);
311 string.hash(state);
312 },
313 Self::BinaryOp { op, lhs, rhs, .. } => {
314 op.hash(state);
315 lhs.hash(state);
316 rhs.hash(state);
317 },
318 }
319 }
320}
321
322impl fmt::Debug for ConstantExpr {
323 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
324 match self {
325 Self::Int(lit) => fmt::Debug::fmt(&**lit, f),
326 Self::Word(lit) => fmt::Debug::fmt(&**lit, f),
327 Self::Var(path) => fmt::Debug::fmt(path, f),
328 Self::String(name) => fmt::Debug::fmt(&**name, f),
329 Self::Hash(hash_kind, str) => {
330 f.debug_tuple("Hash").field(hash_kind).field(str).finish()
331 },
332 Self::BinaryOp { op, lhs, rhs, .. } => {
333 f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
334 },
335 }
336 }
337}
338
339impl crate::prettier::PrettyPrint for ConstantExpr {
340 fn render(&self) -> crate::prettier::Document {
341 use crate::prettier::*;
342
343 match self {
344 Self::Int(literal) => literal.render(),
345 Self::Word(literal) => literal.render(),
346 Self::Var(path) => display(path),
347 Self::String(ident) => text(format!("\"{}\"", ident.as_str().escape_debug())),
348 Self::Hash(hash_kind, str) => flatten(
349 display(hash_kind)
350 + const_text("(")
351 + text(format!("\"{}\"", str.as_str().escape_debug()))
352 + const_text(")"),
353 ),
354 Self::BinaryOp { op, lhs, rhs, .. } => {
355 let single_line = lhs.render() + display(op) + rhs.render();
356 let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
357 single_line | multi_line
358 },
359 }
360 }
361}
362
363impl Spanned for ConstantExpr {
364 fn span(&self) -> SourceSpan {
365 match self {
366 Self::Int(spanned) => spanned.span(),
367 Self::Word(spanned) => spanned.span(),
368 Self::Hash(_, spanned) => spanned.span(),
369 Self::Var(spanned) => spanned.span(),
370 Self::String(spanned) => spanned.span(),
371 Self::BinaryOp { span, .. } => *span,
372 }
373 }
374}
375
376#[cfg(feature = "arbitrary")]
377impl proptest::arbitrary::Arbitrary for ConstantExpr {
378 type Parameters = ();
379
380 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
381 use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
382
383 prop_oneof![
384 any::<IntValue>().prop_map(|n| Self::Int(Span::unknown(n))),
385 crate::arbitrary::path::constant_path_random_length(0)
386 .prop_map(|p| Self::Var(Span::unknown(p))),
387 any::<(ConstantOp, IntValue, IntValue)>().prop_map(|(op, lhs, rhs)| Self::BinaryOp {
388 span: SourceSpan::UNKNOWN,
389 op,
390 lhs: Box::new(ConstantExpr::Int(Span::unknown(lhs))),
391 rhs: Box::new(ConstantExpr::Int(Span::unknown(rhs))),
392 }),
393 any::<Ident>().prop_map(Self::String),
394 any::<WordValue>().prop_map(|word| Self::Word(Span::unknown(word))),
395 any::<(HashKind, Ident)>().prop_map(|(kind, s)| Self::Hash(kind, s)),
396 ]
397 .boxed()
398 }
399
400 type Strategy = proptest::prelude::BoxedStrategy<Self>;
401}
402
403#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
408#[repr(u8)]
409#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
410#[cfg_attr(
411 all(feature = "arbitrary", test),
412 miden_test_serde_macros::serde_test(binary_serde(true))
413)]
414pub enum ConstantOp {
415 Add,
416 Sub,
417 Mul,
418 Div,
419 IntDiv,
420}
421
422impl ConstantOp {
423 const fn name(self) -> &'static str {
424 match self {
425 Self::Add => "Add",
426 Self::Sub => "Sub",
427 Self::Mul => "Mul",
428 Self::Div => "Div",
429 Self::IntDiv => "IntDiv",
430 }
431 }
432}
433
434impl fmt::Display for ConstantOp {
435 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
436 match self {
437 Self::Add => f.write_str("+"),
438 Self::Sub => f.write_str("-"),
439 Self::Mul => f.write_str("*"),
440 Self::Div => f.write_str("/"),
441 Self::IntDiv => f.write_str("//"),
442 }
443 }
444}
445
446impl ConstantOp {
447 const fn tag(&self) -> u8 {
448 unsafe { *(self as *const Self).cast::<u8>() }
455 }
456}
457
458impl Serializable for ConstantOp {
459 fn write_into<W: ByteWriter>(&self, target: &mut W) {
460 target.write_u8(self.tag());
461 }
462}
463
464impl Deserializable for ConstantOp {
465 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
466 const ADD: u8 = ConstantOp::Add.tag();
467 const SUB: u8 = ConstantOp::Sub.tag();
468 const MUL: u8 = ConstantOp::Mul.tag();
469 const DIV: u8 = ConstantOp::Div.tag();
470 const INT_DIV: u8 = ConstantOp::IntDiv.tag();
471
472 match source.read_u8()? {
473 ADD => Ok(Self::Add),
474 SUB => Ok(Self::Sub),
475 MUL => Ok(Self::Mul),
476 DIV => Ok(Self::Div),
477 INT_DIV => Ok(Self::IntDiv),
478 invalid => Err(DeserializationError::InvalidValue(format!(
479 "unexpected ConstantOp tag: '{invalid}'"
480 ))),
481 }
482 }
483}
484
485#[cfg(feature = "arbitrary")]
486impl proptest::arbitrary::Arbitrary for ConstantOp {
487 type Parameters = ();
488
489 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
490 use proptest::{
491 prop_oneof,
492 strategy::{Just, Strategy},
493 };
494
495 prop_oneof![
496 Just(Self::Add),
497 Just(Self::Sub),
498 Just(Self::Mul),
499 Just(Self::Div),
500 Just(Self::IntDiv),
501 ]
502 .boxed()
503 }
504
505 type Strategy = proptest::prelude::BoxedStrategy<Self>;
506}
507
508#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
513#[repr(u8)]
514#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
515#[cfg_attr(
516 all(feature = "arbitrary", test),
517 miden_test_serde_macros::serde_test(binary_serde(true))
518)]
519pub enum HashKind {
520 Word,
522 Event,
524}
525
526impl HashKind {
527 const fn tag(&self) -> u8 {
528 unsafe { *(self as *const Self).cast::<u8>() }
535 }
536}
537
538impl fmt::Display for HashKind {
539 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
540 match self {
541 Self::Word => f.write_str("word"),
542 Self::Event => f.write_str("event"),
543 }
544 }
545}
546
547#[cfg(feature = "arbitrary")]
548impl proptest::arbitrary::Arbitrary for HashKind {
549 type Parameters = ();
550
551 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
552 use proptest::{
553 prop_oneof,
554 strategy::{Just, Strategy},
555 };
556
557 prop_oneof![Just(Self::Word), Just(Self::Event),].boxed()
558 }
559
560 type Strategy = proptest::prelude::BoxedStrategy<Self>;
561}
562
563impl Serializable for HashKind {
564 fn write_into<W: ByteWriter>(&self, target: &mut W) {
565 target.write_u8(self.tag());
566 }
567}
568
569impl Deserializable for HashKind {
570 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
571 const WORD: u8 = HashKind::Word.tag();
572 const EVENT: u8 = HashKind::Event.tag();
573
574 match source.read_u8()? {
575 WORD => Ok(Self::Word),
576 EVENT => Ok(Self::Event),
577 invalid => Err(DeserializationError::InvalidValue(format!(
578 "unexpected HashKind tag: '{invalid}'"
579 ))),
580 }
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587
588 fn nested_add_expr(depth: usize) -> ConstantExpr {
589 let mut expr = ConstantExpr::Int(Span::unknown(IntValue::from(1u8)));
590 for _ in 0..depth {
591 expr = ConstantExpr::BinaryOp {
592 span: SourceSpan::UNKNOWN,
593 op: ConstantOp::Add,
594 lhs: Box::new(expr),
595 rhs: Box::new(ConstantExpr::Int(Span::unknown(IntValue::from(1u8)))),
596 };
597 }
598 expr
599 }
600
601 #[test]
602 fn const_expr_fold_depth_boundary() {
603 let ok_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH);
604 assert!(ok_expr.try_fold().is_ok());
605
606 let err_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH + 1);
607 let err = err_expr.try_fold().expect_err("expected depth-exceeded error");
608 assert!(matches!(err, ParsingError::ConstExprDepthExceeded { max_depth, .. }
609 if max_depth == MAX_CONST_EXPR_FOLD_DEPTH));
610 }
611}