1use alloc::{boxed::Box, sync::Arc, vec::Vec};
2use core::fmt;
3
4use miden_core::serde::{
5 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
6};
7use miden_debug_types::{SourceSpan, Span, Spanned};
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 Felt, Path,
13 ast::{ConstantValue, Ident},
14 parser::{IntValue, LiteralErrorKind, ParsingError, WordValue},
15};
16
17const MAX_CONST_EXPR_FOLD_DEPTH: usize = 256;
22
23#[derive(Clone)]
28#[repr(u8)]
29pub enum ConstantExpr {
30 Int(Span<IntValue>),
32 Var(Span<Arc<Path>>),
34 BinaryOp {
36 span: SourceSpan,
37 op: ConstantOp,
38 lhs: Box<ConstantExpr>,
39 rhs: Box<ConstantExpr>,
40 },
41 String(Ident),
43 Word(Span<WordValue>),
45 Hash(HashKind, Ident),
48}
49
50impl ConstantExpr {
51 pub fn is_value(&self) -> bool {
53 matches!(self, Self::Int(_) | Self::Word(_) | Self::Hash(_, _) | Self::String(_))
54 }
55
56 #[track_caller]
61 pub fn expect_int(&self) -> IntValue {
62 match self {
63 Self::Int(spanned) => spanned.into_inner(),
64 other => panic!("expected constant expression to be a literal, got {other:#?}"),
65 }
66 }
67
68 #[track_caller]
73 pub fn expect_felt(&self) -> Felt {
74 match self {
75 Self::Int(spanned) => Felt::new(spanned.inner().as_int()),
76 other => panic!("expected constant expression to be a literal, got {other:#?}"),
77 }
78 }
79
80 #[track_caller]
85 pub fn expect_string(&self) -> Arc<str> {
86 match self {
87 Self::String(spanned) => spanned.clone().into_inner(),
88 other => panic!("expected constant expression to be a string, got {other:#?}"),
89 }
90 }
91
92 #[track_caller]
97 pub fn expect_value(&self) -> ConstantValue {
98 self.as_value().unwrap_or_else(|| {
99 panic!("expected constant expression to be a value, got {:#?}", self)
100 })
101 }
102
103 pub fn into_value(self) -> Result<ConstantValue, Self> {
107 match self {
108 Self::Int(value) => Ok(ConstantValue::Int(value)),
109 Self::String(value) => Ok(ConstantValue::String(value)),
110 Self::Word(value) => Ok(ConstantValue::Word(value)),
111 Self::Hash(kind, value) => Ok(ConstantValue::Hash(kind, value)),
112 expr @ (Self::BinaryOp { .. } | Self::Var(_)) => Err(expr),
113 }
114 }
115
116 pub fn as_value(&self) -> Option<ConstantValue> {
120 match self {
121 Self::Int(value) => Some(ConstantValue::Int(*value)),
122 Self::String(value) => Some(ConstantValue::String(value.clone())),
123 Self::Word(value) => Some(ConstantValue::Word(*value)),
124 Self::Hash(kind, value) => Some(ConstantValue::Hash(*kind, value.clone())),
125 Self::BinaryOp { .. } | Self::Var(_) => None,
126 }
127 }
128
129 pub fn try_fold(self) -> Result<Self, ParsingError> {
136 self.try_fold_with_depth(0)
137 }
138
139 fn try_fold_with_depth(self, depth: usize) -> Result<Self, ParsingError> {
140 if depth > MAX_CONST_EXPR_FOLD_DEPTH {
141 return Err(ParsingError::ConstExprDepthExceeded {
142 span: self.span(),
143 max_depth: MAX_CONST_EXPR_FOLD_DEPTH,
144 });
145 }
146
147 match self {
148 Self::String(_) | Self::Word(_) | Self::Int(_) | Self::Var(_) | Self::Hash(..) => {
149 Ok(self)
150 },
151 Self::BinaryOp { span, op, lhs, rhs } => {
152 if rhs.is_literal() {
153 let rhs = Self::into_inner(rhs).try_fold_with_depth(depth + 1)?;
154 match rhs {
155 Self::String(ident) => {
156 Err(ParsingError::StringInArithmeticExpression { span: ident.span() })
157 },
158 Self::Int(rhs) => {
159 let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
160 match lhs {
161 Self::String(ident) => {
162 Err(ParsingError::StringInArithmeticExpression {
163 span: ident.span(),
164 })
165 },
166 Self::Int(lhs) => {
167 let lhs = lhs.into_inner();
168 let rhs = rhs.into_inner();
169 let is_division =
170 matches!(op, ConstantOp::Div | ConstantOp::IntDiv);
171 let is_division_by_zero = is_division && rhs.as_int() == 0;
172 if is_division_by_zero {
173 return Err(ParsingError::DivisionByZero { span });
174 }
175 match op {
176 ConstantOp::Add => {
177 let value = lhs
178 .checked_add(rhs)
179 .ok_or(ParsingError::ConstantOverflow { span })?;
180 Some(Self::Int(Span::new(span, value)))
181 },
182 ConstantOp::Sub => {
183 let value = lhs
184 .checked_sub(rhs)
185 .ok_or(ParsingError::ConstantOverflow { span })?;
186 Some(Self::Int(Span::new(span, value)))
187 },
188 ConstantOp::Mul => {
189 let value = lhs
190 .checked_mul(rhs)
191 .ok_or(ParsingError::ConstantOverflow { span })?;
192 Some(Self::Int(Span::new(span, value)))
193 },
194 ConstantOp::IntDiv => {
195 let value = lhs
196 .checked_div(rhs)
197 .ok_or(ParsingError::ConstantOverflow { span })?;
198 Some(Self::Int(Span::new(span, value)))
199 },
200 ConstantOp::Div => {
201 let lhs = Felt::new(lhs.as_int());
202 let rhs = Felt::new(rhs.as_int());
203 let value = IntValue::from(lhs / rhs);
204 Some(Self::Int(Span::new(span, value)))
205 },
206 }
207 .ok_or(
208 ParsingError::InvalidLiteral {
209 span,
210 kind: LiteralErrorKind::FeltOverflow,
211 },
212 )
213 },
214 lhs => Ok(Self::BinaryOp {
215 span,
216 op,
217 lhs: Box::new(lhs),
218 rhs: Box::new(Self::Int(rhs)),
219 }),
220 }
221 },
222 rhs => {
223 let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
224 Ok(Self::BinaryOp {
225 span,
226 op,
227 lhs: Box::new(lhs),
228 rhs: Box::new(rhs),
229 })
230 },
231 }
232 } else {
233 let lhs = Self::into_inner(lhs).try_fold_with_depth(depth + 1)?;
234 Ok(Self::BinaryOp { span, op, lhs: Box::new(lhs), rhs })
235 }
236 },
237 }
238 }
239
240 pub fn references(&self) -> Vec<Span<Arc<Path>>> {
242 use alloc::collections::BTreeSet;
243
244 let mut worklist = smallvec::SmallVec::<[_; 4]>::from_slice(&[self]);
245 let mut references = BTreeSet::new();
246
247 while let Some(ty) = worklist.pop() {
248 match ty {
249 Self::Int(_) | Self::Word(_) | Self::String(_) | Self::Hash(..) => {},
250 Self::Var(path) => {
251 references.insert(path.clone());
252 },
253 Self::BinaryOp { lhs, rhs, .. } => {
254 worklist.push(lhs);
255 worklist.push(rhs);
256 },
257 }
258 }
259
260 references.into_iter().collect()
261 }
262
263 fn is_literal(&self) -> bool {
264 match self {
265 Self::Int(_) | Self::String(_) | Self::Word(_) | Self::Hash(..) => true,
266 Self::Var(_) => false,
267 Self::BinaryOp { lhs, rhs, .. } => lhs.is_literal() && rhs.is_literal(),
268 }
269 }
270
271 #[inline(always)]
272 #[expect(clippy::boxed_local)]
273 fn into_inner(self: Box<Self>) -> Self {
274 *self
275 }
276}
277
278impl Eq for ConstantExpr {}
279
280impl PartialEq for ConstantExpr {
281 fn eq(&self, other: &Self) -> bool {
282 match (self, other) {
283 (Self::Int(x), Self::Int(y)) => x == y,
284 (Self::Int(_), _) => false,
285 (Self::Word(x), Self::Word(y)) => x == y,
286 (Self::Word(_), _) => false,
287 (Self::Var(x), Self::Var(y)) => x == y,
288 (Self::Var(_), _) => false,
289 (Self::String(x), Self::String(y)) => x == y,
290 (Self::String(_), _) => false,
291 (Self::Hash(x_hk, x_i), Self::Hash(y_hk, y_i)) => x_i == y_i && x_hk == y_hk,
292 (Self::Hash(..), _) => false,
293 (
294 Self::BinaryOp { op: lop, lhs: llhs, rhs: lrhs, .. },
295 Self::BinaryOp { op: rop, lhs: rlhs, rhs: rrhs, .. },
296 ) => lop == rop && llhs == rlhs && lrhs == rrhs,
297 (Self::BinaryOp { .. }, _) => false,
298 }
299 }
300}
301
302impl core::hash::Hash for ConstantExpr {
303 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
304 core::mem::discriminant(self).hash(state);
305 match self {
306 Self::Int(value) => value.hash(state),
307 Self::Word(value) => value.hash(state),
308 Self::String(value) => value.hash(state),
309 Self::Var(value) => value.hash(state),
310 Self::Hash(hash_kind, string) => {
311 hash_kind.hash(state);
312 string.hash(state);
313 },
314 Self::BinaryOp { op, lhs, rhs, .. } => {
315 op.hash(state);
316 lhs.hash(state);
317 rhs.hash(state);
318 },
319 }
320 }
321}
322
323impl fmt::Debug for ConstantExpr {
324 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
325 match self {
326 Self::Int(lit) => fmt::Debug::fmt(&**lit, f),
327 Self::Word(lit) => fmt::Debug::fmt(&**lit, f),
328 Self::Var(path) => fmt::Debug::fmt(path, f),
329 Self::String(name) => fmt::Debug::fmt(&**name, f),
330 Self::Hash(hash_kind, str) => {
331 f.debug_tuple("Hash").field(hash_kind).field(str).finish()
332 },
333 Self::BinaryOp { op, lhs, rhs, .. } => {
334 f.debug_tuple(op.name()).field(lhs).field(rhs).finish()
335 },
336 }
337 }
338}
339
340impl crate::prettier::PrettyPrint for ConstantExpr {
341 fn render(&self) -> crate::prettier::Document {
342 use crate::prettier::*;
343
344 match self {
345 Self::Int(literal) => literal.render(),
346 Self::Word(literal) => literal.render(),
347 Self::Var(path) => display(path),
348 Self::String(ident) => text(format!("\"{}\"", ident.as_str().escape_debug())),
349 Self::Hash(hash_kind, str) => flatten(
350 display(hash_kind)
351 + const_text("(")
352 + text(format!("\"{}\"", str.as_str().escape_debug()))
353 + const_text(")"),
354 ),
355 Self::BinaryOp { op, lhs, rhs, .. } => {
356 let single_line = lhs.render() + display(op) + rhs.render();
357 let multi_line = lhs.render() + nl() + (display(op)) + rhs.render();
358 single_line | multi_line
359 },
360 }
361 }
362}
363
364impl Spanned for ConstantExpr {
365 fn span(&self) -> SourceSpan {
366 match self {
367 Self::Int(spanned) => spanned.span(),
368 Self::Word(spanned) => spanned.span(),
369 Self::Hash(_, spanned) => spanned.span(),
370 Self::Var(spanned) => spanned.span(),
371 Self::String(spanned) => spanned.span(),
372 Self::BinaryOp { span, .. } => *span,
373 }
374 }
375}
376
377#[cfg(feature = "arbitrary")]
378impl proptest::arbitrary::Arbitrary for ConstantExpr {
379 type Parameters = ();
380
381 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
382 use proptest::{arbitrary::any, prop_oneof, strategy::Strategy};
383
384 prop_oneof![
385 any::<IntValue>().prop_map(|n| Self::Int(Span::unknown(n))),
386 crate::arbitrary::path::constant_path_random_length(0)
387 .prop_map(|p| Self::Var(Span::unknown(p))),
388 any::<(ConstantOp, IntValue, IntValue)>().prop_map(|(op, lhs, rhs)| Self::BinaryOp {
389 span: SourceSpan::UNKNOWN,
390 op,
391 lhs: Box::new(ConstantExpr::Int(Span::unknown(lhs))),
392 rhs: Box::new(ConstantExpr::Int(Span::unknown(rhs))),
393 }),
394 any::<Ident>().prop_map(Self::String),
395 any::<WordValue>().prop_map(|word| Self::Word(Span::unknown(word))),
396 any::<(HashKind, Ident)>().prop_map(|(kind, s)| Self::Hash(kind, s)),
397 ]
398 .boxed()
399 }
400
401 type Strategy = proptest::prelude::BoxedStrategy<Self>;
402}
403
404#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
409#[repr(u8)]
410#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
411pub enum ConstantOp {
412 Add,
413 Sub,
414 Mul,
415 Div,
416 IntDiv,
417}
418
419impl ConstantOp {
420 const fn name(&self) -> &'static str {
421 match self {
422 Self::Add => "Add",
423 Self::Sub => "Sub",
424 Self::Mul => "Mul",
425 Self::Div => "Div",
426 Self::IntDiv => "IntDiv",
427 }
428 }
429}
430
431impl fmt::Display for ConstantOp {
432 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
433 match self {
434 Self::Add => f.write_str("+"),
435 Self::Sub => f.write_str("-"),
436 Self::Mul => f.write_str("*"),
437 Self::Div => f.write_str("/"),
438 Self::IntDiv => f.write_str("//"),
439 }
440 }
441}
442
443impl ConstantOp {
444 const fn tag(&self) -> u8 {
445 unsafe { *(self as *const Self).cast::<u8>() }
452 }
453}
454
455impl Serializable for ConstantOp {
456 fn write_into<W: ByteWriter>(&self, target: &mut W) {
457 target.write_u8(self.tag());
458 }
459}
460
461impl Deserializable for ConstantOp {
462 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
463 const ADD: u8 = ConstantOp::Add.tag();
464 const SUB: u8 = ConstantOp::Sub.tag();
465 const MUL: u8 = ConstantOp::Mul.tag();
466 const DIV: u8 = ConstantOp::Div.tag();
467 const INT_DIV: u8 = ConstantOp::IntDiv.tag();
468
469 match source.read_u8()? {
470 ADD => Ok(Self::Add),
471 SUB => Ok(Self::Sub),
472 MUL => Ok(Self::Mul),
473 DIV => Ok(Self::Div),
474 INT_DIV => Ok(Self::IntDiv),
475 invalid => Err(DeserializationError::InvalidValue(format!(
476 "unexpected ConstantOp tag: '{invalid}'"
477 ))),
478 }
479 }
480}
481
482#[cfg(feature = "arbitrary")]
483impl proptest::arbitrary::Arbitrary for ConstantOp {
484 type Parameters = ();
485
486 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
487 use proptest::{
488 prop_oneof,
489 strategy::{Just, Strategy},
490 };
491
492 prop_oneof![
493 Just(Self::Add),
494 Just(Self::Sub),
495 Just(Self::Mul),
496 Just(Self::Div),
497 Just(Self::IntDiv),
498 ]
499 .boxed()
500 }
501
502 type Strategy = proptest::prelude::BoxedStrategy<Self>;
503}
504
505#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
510#[repr(u8)]
511#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
512pub enum HashKind {
513 Word,
515 Event,
517}
518
519impl HashKind {
520 const fn tag(&self) -> u8 {
521 unsafe { *(self as *const Self).cast::<u8>() }
528 }
529}
530
531impl fmt::Display for HashKind {
532 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
533 match self {
534 Self::Word => f.write_str("word"),
535 Self::Event => f.write_str("event"),
536 }
537 }
538}
539
540#[cfg(feature = "arbitrary")]
541impl proptest::arbitrary::Arbitrary for HashKind {
542 type Parameters = ();
543
544 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
545 use proptest::{
546 prop_oneof,
547 strategy::{Just, Strategy},
548 };
549
550 prop_oneof![Just(Self::Word), Just(Self::Event),].boxed()
551 }
552
553 type Strategy = proptest::prelude::BoxedStrategy<Self>;
554}
555
556impl Serializable for HashKind {
557 fn write_into<W: ByteWriter>(&self, target: &mut W) {
558 target.write_u8(self.tag());
559 }
560}
561
562impl Deserializable for HashKind {
563 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
564 const WORD: u8 = HashKind::Word.tag();
565 const EVENT: u8 = HashKind::Event.tag();
566
567 match source.read_u8()? {
568 WORD => Ok(Self::Word),
569 EVENT => Ok(Self::Event),
570 invalid => Err(DeserializationError::InvalidValue(format!(
571 "unexpected HashKind tag: '{invalid}'"
572 ))),
573 }
574 }
575}
576
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 fn nested_add_expr(depth: usize) -> ConstantExpr {
582 let mut expr = ConstantExpr::Int(Span::unknown(IntValue::from(1u8)));
583 for _ in 0..depth {
584 expr = ConstantExpr::BinaryOp {
585 span: SourceSpan::UNKNOWN,
586 op: ConstantOp::Add,
587 lhs: Box::new(expr),
588 rhs: Box::new(ConstantExpr::Int(Span::unknown(IntValue::from(1u8)))),
589 };
590 }
591 expr
592 }
593
594 #[test]
595 fn const_expr_fold_depth_boundary() {
596 let ok_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH);
597 assert!(ok_expr.try_fold().is_ok());
598
599 let err_expr = nested_add_expr(MAX_CONST_EXPR_FOLD_DEPTH + 1);
600 let err = err_expr.try_fold().expect_err("expected depth-exceeded error");
601 assert!(matches!(err, ParsingError::ConstExprDepthExceeded { max_depth, .. }
602 if max_depth == MAX_CONST_EXPR_FOLD_DEPTH));
603 }
604}