1use std::{convert::AsRef, fmt};
13
14use miden_diagnostics::{SourceSpan, Span, Spanned};
15
16use crate::symbols::Symbol;
17
18use super::*;
19
20pub type Range = std::ops::Range<usize>;
22
23#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Spanned)]
25pub struct Identifier(pub Span<Symbol>);
26impl Identifier {
27 pub fn new(span: SourceSpan, name: Symbol) -> Self {
28 Self(Span::new(span, name))
29 }
30
31 pub fn name(&self) -> Symbol {
33 self.0.item
34 }
35
36 #[inline]
37 pub fn as_str(&self) -> &str {
38 self.0.as_str()
39 }
40
41 pub fn is_uppercase(&self) -> bool {
43 self.0.as_str().chars().all(char::is_uppercase)
44 }
45
46 pub fn is_generated(&self) -> bool {
48 self.0.as_str().starts_with('%')
49 }
50
51 pub fn is_special(&self) -> bool {
53 self.0.as_str().starts_with('$')
54 }
55}
56impl PartialEq<&str> for Identifier {
57 #[inline]
58 fn eq(&self, other: &&str) -> bool {
59 self.0.item == *other
60 }
61}
62impl PartialEq<&Identifier> for Identifier {
63 #[inline]
64 fn eq(&self, other: &&Self) -> bool {
65 self == *other
66 }
67}
68impl fmt::Debug for Identifier {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_tuple("Identifier")
71 .field(&format!("{}", &self.0.item))
72 .finish()
73 }
74}
75impl fmt::Display for Identifier {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 write!(f, "{}", &self.0)
78 }
79}
80impl From<ResolvableIdentifier> for Identifier {
81 fn from(id: ResolvableIdentifier) -> Self {
82 match id {
83 ResolvableIdentifier::Local(id) => id,
84 ResolvableIdentifier::Global(id) => id,
85 ResolvableIdentifier::Resolved(qid) => qid.item.id(),
86 ResolvableIdentifier::Unresolved(nid) => nid.id(),
87 }
88 }
89}
90
91#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)]
101pub enum NamespacedIdentifier {
102 Function(#[span] Identifier),
103 Binding(#[span] Identifier),
104}
105impl NamespacedIdentifier {
106 pub fn id(&self) -> Identifier {
107 match self {
108 Self::Function(ident) | Self::Binding(ident) => *ident,
109 }
110 }
111}
112impl AsRef<Identifier> for NamespacedIdentifier {
113 fn as_ref(&self) -> &Identifier {
114 match self {
115 Self::Function(ident) | Self::Binding(ident) => ident,
116 }
117 }
118}
119impl From<ResolvableIdentifier> for NamespacedIdentifier {
120 fn from(id: ResolvableIdentifier) -> Self {
121 match id {
122 ResolvableIdentifier::Local(id) => Self::Binding(id),
123 ResolvableIdentifier::Global(id) => Self::Binding(id),
124 ResolvableIdentifier::Resolved(qid) => qid.item,
125 ResolvableIdentifier::Unresolved(nid) => nid,
126 }
127 }
128}
129impl fmt::Display for NamespacedIdentifier {
130 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
131 fmt::Display::fmt(self.as_ref(), f)
132 }
133}
134
135#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)]
139pub struct QualifiedIdentifier {
140 pub module: ModuleId,
141 #[span]
142 pub item: NamespacedIdentifier,
143}
144impl QualifiedIdentifier {
145 pub const fn new(module: ModuleId, item: NamespacedIdentifier) -> Self {
146 Self { module, item }
147 }
148
149 pub const fn id(&self) -> NamespacedIdentifier {
150 self.item
151 }
152
153 #[inline]
155 pub fn name(&self) -> Symbol {
156 self.as_ref().name()
157 }
158
159 pub fn is_builtin(&self) -> bool {
161 use crate::symbols;
162
163 if self.module.name() == "$builtin" {
164 match self.item {
165 NamespacedIdentifier::Function(id) => {
166 matches!(id.name(), symbols::Sum | symbols::Prod)
167 }
168 _ => false,
169 }
170 } else {
171 false
172 }
173 }
174}
175impl AsRef<Identifier> for QualifiedIdentifier {
176 #[inline]
177 fn as_ref(&self) -> &Identifier {
178 self.item.as_ref()
179 }
180}
181impl fmt::Display for QualifiedIdentifier {
182 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 write!(f, "{}::{}", &self.module, &self.item)
184 }
185}
186
187#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Spanned)]
189pub enum ResolvableIdentifier {
190 Local(#[span] Identifier),
192 Global(#[span] Identifier),
194 Resolved(#[span] QualifiedIdentifier),
196 Unresolved(#[span] NamespacedIdentifier),
198}
199impl ResolvableIdentifier {
200 #[inline]
202 pub fn is_resolved(&self) -> bool {
203 matches!(self, Self::Local(_) | Self::Global(_) | Self::Resolved(_))
204 }
205
206 pub fn is_local(&self) -> bool {
208 matches!(self, Self::Local(_))
209 }
210
211 pub fn is_global(&self) -> bool {
213 matches!(self, Self::Global(_))
214 }
215
216 pub fn is_builtin(&self) -> bool {
218 match self {
219 Self::Resolved(qid) => qid.is_builtin(),
220 _ => false,
221 }
222 }
223
224 pub fn module(&self) -> Option<ModuleId> {
230 match self {
231 Self::Resolved(qid) => Some(*qid.as_ref()),
232 _ => None,
233 }
234 }
235
236 #[inline]
238 pub fn namespaced(&self) -> NamespacedIdentifier {
239 (*self).into()
240 }
241
242 #[inline]
244 pub fn resolved(&self) -> Option<QualifiedIdentifier> {
245 match self {
246 Self::Resolved(qid) => Some(*qid),
247 _ => None,
248 }
249 }
250}
251impl AsRef<Identifier> for ResolvableIdentifier {
252 #[inline]
253 fn as_ref(&self) -> &Identifier {
254 match self {
255 Self::Local(id) => id,
256 Self::Global(id) => id,
257 Self::Resolved(qid) => qid.item.as_ref(),
258 Self::Unresolved(nid) => nid.as_ref(),
259 }
260 }
261}
262impl fmt::Display for ResolvableIdentifier {
263 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
264 match self {
265 Self::Local(id) => write!(f, "{id}"),
266 Self::Global(id) => write!(f, "{id}"),
267 Self::Resolved(qid) => write!(f, "{qid}"),
268 Self::Unresolved(nid) => write!(f, "{nid}"),
269 }
270 }
271}
272
273#[derive(Clone, PartialEq, Eq, Spanned)]
275pub enum Expr {
276 Const(Span<ConstantExpr>),
278 Range(RangeExpr),
280 Vector(Span<Vec<Expr>>),
286 Matrix(Span<Vec<Vec<ScalarExpr>>>),
288 SymbolAccess(SymbolAccess),
290 Binary(BinaryExpr),
292 Call(Call),
298 ListComprehension(ListComprehension),
300 Let(Box<Let>),
305 BusOperation(BusOperation),
307 Null(Span<()>),
309 Unconstrained(Span<()>),
311}
312impl Expr {
313 pub fn is_constant(&self) -> bool {
317 match self {
318 Self::Const(_) => true,
319 Self::Range(range) => range.is_constant(),
320 _ => false,
321 }
322 }
323
324 pub fn ty(&self) -> Option<Type> {
326 match self {
327 Self::Const(constant) => Some(constant.ty()),
328 Self::Range(range) => range.ty(),
329 Self::Vector(vector) => match vector.first().and_then(|e| e.ty()) {
330 Some(Type::Felt) => Some(Type::Vector(vector.len())),
331 Some(Type::Vector(n)) => Some(Type::Matrix(vector.len(), n)),
332 Some(_) => None,
333 None => Some(Type::Vector(0)),
334 },
335 Self::Matrix(matrix) => {
336 let rows = matrix.len();
337 let cols = matrix[0].len();
338 Some(Type::Matrix(rows, cols))
339 }
340 Self::SymbolAccess(access) => access.ty,
341 Self::Binary(_) => Some(Type::Felt),
342 Self::Call(call) => call.ty,
343 Self::ListComprehension(lc) => lc.ty,
344 Self::Let(let_expr) => let_expr.ty(),
345 Self::BusOperation(_) | Self::Null(_) | Self::Unconstrained(_) => Some(Type::Felt),
346 }
347 }
348}
349impl fmt::Debug for Expr {
350 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
351 match self {
352 Self::Const(expr) => f.debug_tuple("Const").field(&expr.item).finish(),
353 Self::Range(expr) => f.debug_tuple("Range").field(&expr).finish(),
354 Self::Vector(expr) => f.debug_tuple("Vector").field(&expr.item).finish(),
355 Self::Matrix(expr) => f.debug_tuple("Matrix").field(&expr.item).finish(),
356 Self::SymbolAccess(expr) => f.debug_tuple("SymbolAccess").field(expr).finish(),
357 Self::Binary(expr) => f.debug_tuple("Binary").field(expr).finish(),
358 Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(),
359 Self::ListComprehension(expr) => {
360 f.debug_tuple("ListComprehension").field(expr).finish()
361 }
362 Self::Let(let_expr) => write!(f, "{let_expr:#?}"),
363 Self::BusOperation(expr) => f.debug_tuple("BusOp").field(expr).finish(),
364 Self::Null(expr) => f.debug_tuple("Null").field(expr).finish(),
365 Self::Unconstrained(expr) => f.debug_tuple("Unconstrained").field(expr).finish(),
366 }
367 }
368}
369impl fmt::Display for Expr {
370 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
371 match self {
372 Self::Const(expr) => write!(f, "{}", &expr),
373 Self::Range(range) => write!(f, "{range}"),
374 Self::Vector(expr) => write!(f, "{}", DisplayList(expr.as_slice())),
375 Self::Matrix(expr) => {
376 f.write_str("[")?;
377 for (i, col) in expr.iter().enumerate() {
378 if i > 0 {
379 f.write_str(", ")?;
380 }
381 write!(f, "{}", DisplayList(col.as_slice()))?;
382 }
383 f.write_str("]")
384 }
385 Self::SymbolAccess(expr) => write!(f, "{expr}"),
386 Self::Binary(expr) => write!(f, "{expr}"),
387 Self::Call(expr) => write!(f, "{expr}"),
388 Self::ListComprehension(expr) => write!(f, "{}", DisplayBracketed(expr)),
389 Self::Let(let_expr) => {
390 let display = DisplayLet {
391 let_expr,
392 indent: 0,
393 in_expr_position: true,
394 };
395 write!(f, "{display}")
396 }
397 Self::BusOperation(expr) => write!(f, "{expr}"),
398 Self::Null(_expr) => write!(f, "null"),
399 Self::Unconstrained(_expr) => write!(f, "unconstrained"),
400 }
401 }
402}
403impl From<SymbolAccess> for Expr {
404 #[inline]
405 fn from(expr: SymbolAccess) -> Self {
406 Self::SymbolAccess(expr)
407 }
408}
409impl From<BinaryExpr> for Expr {
410 #[inline]
411 fn from(expr: BinaryExpr) -> Self {
412 Self::Binary(expr)
413 }
414}
415impl From<Call> for Expr {
416 #[inline]
417 fn from(expr: Call) -> Self {
418 Self::Call(expr)
419 }
420}
421impl From<BusOperation> for Expr {
422 #[inline]
423 fn from(expr: BusOperation) -> Self {
424 Self::BusOperation(expr)
425 }
426}
427impl From<ListComprehension> for Expr {
428 #[inline]
429 fn from(expr: ListComprehension) -> Self {
430 Self::ListComprehension(expr)
431 }
432}
433impl TryFrom<Let> for Expr {
434 type Error = InvalidExprError;
435
436 fn try_from(expr: Let) -> Result<Self, Self::Error> {
437 if expr.ty().is_some() {
438 Ok(Self::Let(Box::new(expr)))
439 } else {
440 Err(InvalidExprError::InvalidLetExpr(expr.span()))
441 }
442 }
443}
444impl TryFrom<ScalarExpr> for Expr {
445 type Error = InvalidExprError;
446
447 #[inline]
448 fn try_from(expr: ScalarExpr) -> Result<Self, Self::Error> {
449 match expr {
450 ScalarExpr::Const(spanned) => Ok(Self::Const(Span::new(
451 spanned.span(),
452 ConstantExpr::Scalar(spanned.item),
453 ))),
454 ScalarExpr::SymbolAccess(access) => Ok(Self::SymbolAccess(access)),
455 ScalarExpr::Binary(expr) => Ok(Self::Binary(expr)),
456 ScalarExpr::Call(expr) => Ok(Self::Call(expr)),
457 ScalarExpr::BoundedSymbolAccess(_) => {
458 Err(InvalidExprError::BoundedSymbolAccess(expr.span()))
459 }
460 ScalarExpr::Let(expr) => Ok(Self::Let(expr)),
461 ScalarExpr::BusOperation(expr) => Ok(Self::BusOperation(expr)),
462 ScalarExpr::Null(spanned) => Ok(Self::Null(spanned)),
463 ScalarExpr::Unconstrained(spanned) => Ok(Self::Unconstrained(spanned)),
464 }
465 }
466}
467impl TryFrom<Statement> for Expr {
468 type Error = InvalidExprError;
469
470 fn try_from(stmt: Statement) -> Result<Self, Self::Error> {
471 match stmt {
472 Statement::Let(let_expr) => Ok(Self::Let(Box::new(let_expr))),
473 Statement::Expr(expr) => Ok(expr),
474 _ => Err(InvalidExprError::NotAnExpr(stmt.span())),
475 }
476 }
477}
478
479#[derive(Clone, PartialEq, Eq, Spanned)]
483pub enum ScalarExpr {
484 Const(Span<u64>),
486 SymbolAccess(SymbolAccess),
490 BoundedSymbolAccess(BoundedSymbolAccess),
494 Binary(BinaryExpr),
496 Call(Call),
506 Let(Box<Let>),
512 BusOperation(BusOperation),
514 Null(Span<()>),
516 Unconstrained(Span<()>),
518}
519impl ScalarExpr {
520 pub fn is_constant(&self) -> bool {
522 matches!(self, Self::Const(_))
523 }
524
525 pub fn has_block_like_expansion(&self) -> bool {
527 match self {
528 Self::Binary(expr) => expr.has_block_like_expansion(),
529 Self::Call(_) | Self::Let(_) => true,
530 _ => false,
531 }
532 }
533
534 pub fn ty(&self) -> Result<Option<Type>, SourceSpan> {
541 match self {
542 Self::Const(_) => Ok(Some(Type::Felt)),
543 Self::SymbolAccess(sym) => Ok(sym.ty),
544 Self::BoundedSymbolAccess(sym) => Ok(sym.column.ty),
545 Self::Binary(expr) => match (expr.lhs.ty()?, expr.rhs.ty()?) {
546 (None, _) | (_, None) => Ok(None),
547 (Some(lty), Some(rty)) if lty == rty => Ok(Some(lty)),
548 _ => Err(expr.span()),
549 },
550 Self::Call(expr) => Ok(expr.ty),
551 Self::Let(expr) => Ok(expr.ty()),
552 Self::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
553 Ok(Some(Type::Felt))
554 }
555 }
556 }
557}
558impl TryFrom<Expr> for ScalarExpr {
559 type Error = InvalidExprError;
560
561 fn try_from(expr: Expr) -> Result<Self, Self::Error> {
562 match expr {
563 Expr::Const(constant) => {
564 let span = constant.span();
565 match constant.item {
566 ConstantExpr::Scalar(v) => Ok(Self::Const(Span::new(span, v))),
567 _ => Err(InvalidExprError::InvalidScalarExpr(span)),
568 }
569 }
570 Expr::SymbolAccess(sym) => Ok(Self::SymbolAccess(sym)),
571 Expr::Binary(bin) => Ok(Self::Binary(bin)),
572 Expr::Call(call) => Ok(Self::Call(call)),
573 Expr::Let(let_expr) => {
574 if let_expr.ty().is_none() {
575 Err(InvalidExprError::InvalidScalarExpr(let_expr.span()))
576 } else {
577 Ok(Self::Let(let_expr))
578 }
579 }
580 invalid => Err(InvalidExprError::InvalidScalarExpr(invalid.span())),
581 }
582 }
583}
584impl TryFrom<Statement> for ScalarExpr {
585 type Error = InvalidExprError;
586
587 fn try_from(stmt: Statement) -> Result<Self, Self::Error> {
588 match stmt {
589 Statement::Let(let_expr) => Self::try_from(Expr::Let(Box::new(let_expr))),
590 Statement::Expr(expr) => Self::try_from(expr),
591 stmt => Err(InvalidExprError::InvalidScalarExpr(stmt.span())),
592 }
593 }
594}
595impl From<u64> for ScalarExpr {
596 fn from(value: u64) -> Self {
597 Self::Const(Span::new(SourceSpan::UNKNOWN, value))
598 }
599}
600impl fmt::Debug for ScalarExpr {
601 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
602 match self {
603 Self::Const(i) => f.debug_tuple("Const").field(&i.item).finish(),
604 Self::SymbolAccess(expr) => f.debug_tuple("SymbolAccess").field(expr).finish(),
605 Self::BoundedSymbolAccess(expr) => {
606 f.debug_tuple("BoundedSymbolAccess").field(expr).finish()
607 }
608 Self::Binary(expr) => f.debug_tuple("Binary").field(expr).finish(),
609 Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(),
610 Self::Let(expr) => write!(f, "{expr:#?}"),
611 Self::BusOperation(expr) => f.debug_tuple("BusOp").field(expr).finish(),
612 Self::Null(expr) => f.debug_tuple("Null").field(expr).finish(),
613 Self::Unconstrained(expr) => f.debug_tuple("Unconstrained").field(expr).finish(),
614 }
615 }
616}
617impl fmt::Display for ScalarExpr {
618 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
619 match self {
620 Self::Const(value) => write!(f, "{value}"),
621 Self::SymbolAccess(expr) => write!(f, "{expr}"),
622 Self::BoundedSymbolAccess(expr) => write!(f, "{}.{}", &expr.column, &expr.boundary),
623 Self::Binary(expr) => write!(f, "{expr}"),
624 Self::Call(call) => write!(f, "{call}"),
625 Self::Let(let_expr) => {
626 let display = DisplayLet {
627 let_expr,
628 indent: 0,
629 in_expr_position: true,
630 };
631 write!(f, "{display}")
632 }
633 Self::BusOperation(expr) => write!(f, "{expr}"),
634 Self::Null(_value) => write!(f, "null"),
635 Self::Unconstrained(_value) => write!(f, "unconstrained"),
636 }
637 }
638}
639
640#[derive(Clone, Spanned, Debug)]
642pub struct ConstSymbolAccess {
643 #[span]
644 pub span: SourceSpan,
645 pub name: ResolvableIdentifier,
646 pub ty: Option<Type>,
647}
648impl ConstSymbolAccess {
649 pub fn new(span: SourceSpan, name: Identifier) -> Self {
650 Self {
651 span,
652 name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(name)),
653 ty: None,
654 }
655 }
656}
657impl Eq for ConstSymbolAccess {}
658impl PartialEq for ConstSymbolAccess {
659 fn eq(&self, other: &Self) -> bool {
660 self.name.eq(&other.name) && self.ty.eq(&other.ty)
661 }
662}
663impl std::hash::Hash for ConstSymbolAccess {
664 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
665 self.name.hash(state);
666 self.ty.hash(state);
667 }
668}
669impl fmt::Display for ConstSymbolAccess {
670 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
671 write!(f, "{}", &self.name)
672 }
673}
674
675#[derive(Debug, Clone, Spanned)]
676pub struct RangeExpr {
677 #[span]
678 pub span: SourceSpan,
679 pub start: RangeBound,
680 pub end: RangeBound,
681}
682
683impl TryFrom<&RangeExpr> for Range {
684 type Error = InvalidExprError;
685
686 #[inline]
687 fn try_from(expr: &RangeExpr) -> Result<Self, InvalidExprError> {
688 match (&expr.start, &expr.end) {
689 (RangeBound::Const(lhs), RangeBound::Const(rhs)) => Ok(lhs.item..rhs.item),
690 _ => Err(InvalidExprError::NonConstantRangeExpr(expr.span)),
691 }
692 }
693}
694
695impl RangeExpr {
696 pub fn is_constant(&self) -> bool {
697 self.start.is_constant() && self.end.is_constant()
698 }
699
700 pub fn to_slice_range(&self) -> Range {
703 self.try_into()
704 .expect("attempted to convert non-constant range expression to constant")
705 }
706
707 pub fn ty(&self) -> Option<Type> {
708 match (&self.start, &self.end) {
709 (RangeBound::Const(start), RangeBound::Const(end)) => {
710 Some(Type::Vector(end.item.abs_diff(start.item)))
711 }
712 _ => None,
713 }
714 }
715}
716impl From<Range> for RangeExpr {
717 fn from(range: Range) -> Self {
718 Self {
719 span: SourceSpan::default(),
720 start: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.start)),
721 end: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.end)),
722 }
723 }
724}
725impl Eq for RangeExpr {}
726impl PartialEq for RangeExpr {
727 fn eq(&self, other: &Self) -> bool {
728 self.start.eq(&other.start) && self.end.eq(&other.end)
729 }
730}
731impl std::hash::Hash for RangeExpr {
732 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
733 self.start.hash(state);
734 self.end.hash(state);
735 }
736}
737impl fmt::Display for RangeExpr {
738 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
739 write!(f, "{}..{}", &self.start, &self.end)
740 }
741}
742
743#[derive(Hash, Clone, Spanned, PartialEq, Eq, Debug)]
744pub enum RangeBound {
745 SymbolAccess(ConstSymbolAccess),
746 Const(Span<usize>),
747}
748impl RangeBound {
749 pub fn is_constant(&self) -> bool {
750 matches!(self, Self::Const(_))
751 }
752}
753impl From<Identifier> for RangeBound {
754 fn from(name: Identifier) -> Self {
755 Self::SymbolAccess(ConstSymbolAccess::new(name.span(), name))
756 }
757}
758impl From<usize> for RangeBound {
759 fn from(constant: usize) -> Self {
760 Self::Const(Span::new(SourceSpan::UNKNOWN, constant))
761 }
762}
763impl fmt::Display for RangeBound {
764 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
765 match self {
766 Self::SymbolAccess(sym) => write!(f, "{sym}"),
767 Self::Const(constant) => write!(f, "{constant}"),
768 }
769 }
770}
771
772#[derive(Clone, Spanned)]
774pub struct BinaryExpr {
775 #[span]
776 pub span: SourceSpan,
777 pub op: BinaryOp,
778 pub lhs: Box<ScalarExpr>,
779 pub rhs: Box<ScalarExpr>,
780}
781impl BinaryExpr {
782 pub fn new(span: SourceSpan, op: BinaryOp, lhs: ScalarExpr, rhs: ScalarExpr) -> Self {
783 Self {
784 span,
785 op,
786 lhs: Box::new(lhs),
787 rhs: Box::new(rhs),
788 }
789 }
790
791 #[inline]
793 pub fn has_block_like_expansion(&self) -> bool {
794 self.lhs.has_block_like_expansion() || self.rhs.has_block_like_expansion()
795 }
796}
797impl Eq for BinaryExpr {}
798impl PartialEq for BinaryExpr {
799 fn eq(&self, other: &Self) -> bool {
800 self.op == other.op && self.lhs == other.lhs && self.rhs == other.rhs
801 }
802}
803impl fmt::Debug for BinaryExpr {
804 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
805 f.debug_struct("BinaryExpr")
806 .field("op", &self.op)
807 .field("lhs", self.lhs.as_ref())
808 .field("rhs", self.rhs.as_ref())
809 .finish()
810 }
811}
812impl fmt::Display for BinaryExpr {
813 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
814 write!(f, "{} {} {}", &self.lhs, &self.op, &self.rhs)
815 }
816}
817
818#[derive(Debug, Copy, Clone, PartialEq, Eq)]
819pub enum BinaryOp {
820 Add,
822 Sub,
824 Mul,
826 Exp,
828 Eq,
832}
833impl fmt::Display for BinaryOp {
834 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
835 match self {
836 Self::Add => f.write_str("+"),
837 Self::Sub => f.write_str("-"),
838 Self::Mul => f.write_str("*"),
839 Self::Exp => f.write_str("^"),
840 Self::Eq => f.write_str("="),
841 }
842 }
843}
844
845#[derive(Debug, Copy, Clone, PartialEq, Default, Eq)]
847pub enum Boundary {
848 #[default]
849 First,
850 Last,
851}
852impl fmt::Display for Boundary {
853 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
854 match &self {
855 Self::First => write!(f, "first"),
856 Self::Last => write!(f, "last"),
857 }
858 }
859}
860
861#[derive(Hash, Debug, Clone, Eq, PartialEq, Default)]
863pub enum AccessType {
864 #[default]
866 Default,
867 Slice(RangeExpr),
869 Index(usize),
873 Matrix(usize, usize),
875}
876impl fmt::Display for AccessType {
877 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
878 match self {
879 Self::Default => write!(f, "direct reference by name"),
880 Self::Slice(range) => write!(
881 f,
882 "slice of elements at indices {}..{}",
883 range.start, range.end
884 ),
885 Self::Index(idx) => write!(f, "reference to element at index {idx}"),
886 Self::Matrix(row, col) => write!(f, "reference to value in matrix at [{row}][{col}]"),
887 }
888 }
889}
890
891#[derive(Debug, Clone, thiserror::Error)]
892pub enum InvalidAccessError {
893 #[error("attempted to access undefined variable")]
894 UndefinedVariable,
895 #[error("attempted to access a function as a variable")]
896 InvalidBinding,
897 #[error("attempted to take a slice of a scalar value")]
898 SliceOfScalar,
899 #[error("attempted to take a slice of a matrix value")]
900 SliceOfMatrix,
901 #[error("attempted to index into a scalar value")]
902 IndexIntoScalar,
903 #[error("attempted to access an index which is out of bounds")]
904 IndexOutOfBounds,
905}
906
907#[derive(Clone, Spanned)]
915pub struct SymbolAccess {
916 #[span]
917 pub span: SourceSpan,
918 pub name: ResolvableIdentifier,
920 pub access_type: AccessType,
922 pub offset: usize,
930 pub ty: Option<Type>,
934}
935impl SymbolAccess {
936 pub const fn new(
937 span: SourceSpan,
938 name: Identifier,
939 access_type: AccessType,
940 offset: usize,
941 ) -> Self {
942 Self {
943 span,
944 name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(name)),
945 access_type,
946 offset,
947 ty: None,
948 }
949 }
950
951 pub fn access(&self, access_type: AccessType) -> Result<Self, InvalidAccessError> {
960 match &self.access_type {
961 AccessType::Default => self.access_default(access_type),
962 AccessType::Slice(base_range) => {
963 self.access_slice(base_range.to_slice_range(), access_type)
964 }
965 AccessType::Index(base_idx) => self.access_index(*base_idx, access_type),
966 AccessType::Matrix(_, _) => match access_type {
967 AccessType::Default => Ok(self.clone()),
968 _ => Err(InvalidAccessError::IndexIntoScalar),
969 },
970 }
971 }
972
973 fn access_default(&self, access_type: AccessType) -> Result<Self, InvalidAccessError> {
974 let ty = self.ty.unwrap();
975 match access_type {
976 AccessType::Default => Ok(self.clone()),
977 AccessType::Index(idx) => match ty {
978 Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
979 Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
980 Type::Vector(_) => Ok(Self {
981 access_type: AccessType::Index(idx),
982 ty: Some(Type::Felt),
983 ..self.clone()
984 }),
985 Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
986 Type::Matrix(_, cols) => Ok(Self {
987 access_type: AccessType::Index(idx),
988 ty: Some(Type::Vector(cols)),
989 ..self.clone()
990 }),
991 },
992 AccessType::Slice(range) => {
993 let slice_range = range.to_slice_range();
994 let rlen = slice_range.end - slice_range.start;
995 match ty {
996 Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
997 Type::Vector(len) if slice_range.end > len => {
998 Err(InvalidAccessError::IndexOutOfBounds)
999 }
1000 Type::Vector(_) => Ok(Self {
1001 access_type: AccessType::Slice(range),
1002 ty: Some(Type::Vector(rlen)),
1003 ..self.clone()
1004 }),
1005 Type::Matrix(rows, _) if slice_range.end > rows => {
1006 Err(InvalidAccessError::IndexOutOfBounds)
1007 }
1008 Type::Matrix(_, cols) => Ok(Self {
1009 access_type: AccessType::Slice(range),
1010 ty: Some(Type::Matrix(rlen, cols)),
1011 ..self.clone()
1012 }),
1013 }
1014 }
1015 AccessType::Matrix(row, col) => match ty {
1016 Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar),
1017 Type::Matrix(rows, cols) if row >= rows || col >= cols => {
1018 Err(InvalidAccessError::IndexOutOfBounds)
1019 }
1020 Type::Matrix(_, _) => Ok(Self {
1021 access_type: AccessType::Matrix(row, col),
1022 ty: Some(Type::Felt),
1023 ..self.clone()
1024 }),
1025 },
1026 }
1027 }
1028
1029 fn access_slice(
1030 &self,
1031 base_range: Range,
1032 access_type: AccessType,
1033 ) -> Result<Self, InvalidAccessError> {
1034 let ty = self.ty.unwrap();
1035 match access_type {
1036 AccessType::Default => Ok(self.clone()),
1037 AccessType::Index(idx) => match ty {
1038 Type::Felt => unreachable!(),
1039 Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
1040 Type::Vector(_) => Ok(Self {
1041 access_type: AccessType::Index(base_range.start + idx),
1042 ty: Some(Type::Felt),
1043 ..self.clone()
1044 }),
1045 Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
1046 Type::Matrix(_, cols) => Ok(Self {
1047 access_type: AccessType::Index(base_range.start + idx),
1048 ty: Some(Type::Vector(cols)),
1049 ..self.clone()
1050 }),
1051 },
1052 AccessType::Slice(range) => {
1053 let slice_range = range.to_slice_range();
1054 let blen = base_range.end - base_range.start;
1055 let rlen = slice_range.len();
1056 let start = base_range.start + slice_range.start;
1057 let end = slice_range.start + slice_range.end;
1058 let shifted = RangeExpr {
1059 span: range.span,
1060 start: RangeBound::Const(Span::new(range.start.span(), start)),
1061 end: RangeBound::Const(Span::new(range.end.span(), end)),
1062 };
1063 match ty {
1064 Type::Felt => unreachable!(),
1065 Type::Vector(_) if slice_range.end > blen => {
1066 Err(InvalidAccessError::IndexOutOfBounds)
1067 }
1068 Type::Vector(_) => Ok(Self {
1069 access_type: AccessType::Slice(shifted),
1070 ty: Some(Type::Vector(rlen)),
1071 ..self.clone()
1072 }),
1073 Type::Matrix(rows, _) if slice_range.end > rows => {
1074 Err(InvalidAccessError::IndexOutOfBounds)
1075 }
1076 Type::Matrix(_, cols) => Ok(Self {
1077 access_type: AccessType::Slice(shifted),
1078 ty: Some(Type::Matrix(rlen, cols)),
1079 ..self.clone()
1080 }),
1081 }
1082 }
1083 AccessType::Matrix(row, col) => match ty {
1084 Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar),
1085 Type::Matrix(rows, cols) if row >= rows || col >= cols => {
1086 Err(InvalidAccessError::IndexOutOfBounds)
1087 }
1088 Type::Matrix(_, _) => Ok(Self {
1089 access_type: AccessType::Matrix(row, col),
1090 ty: Some(Type::Felt),
1091 ..self.clone()
1092 }),
1093 },
1094 }
1095 }
1096
1097 fn access_index(
1098 &self,
1099 base_idx: usize,
1100 access_type: AccessType,
1101 ) -> Result<Self, InvalidAccessError> {
1102 let ty = self.ty.unwrap();
1103 match access_type {
1104 AccessType::Default => Ok(self.clone()),
1105 AccessType::Index(idx) => match ty {
1106 Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
1107 Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
1108 Type::Vector(_) => Ok(Self {
1109 access_type: AccessType::Matrix(base_idx, idx),
1110 ty: Some(Type::Felt),
1111 ..self.clone()
1112 }),
1113 Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
1114 Type::Matrix(_, cols) => Ok(Self {
1115 access_type: AccessType::Matrix(base_idx, idx),
1116 ty: Some(Type::Vector(cols)),
1117 ..self.clone()
1118 }),
1119 },
1120 AccessType::Slice(_) => Err(InvalidAccessError::SliceOfMatrix),
1121 AccessType::Matrix(_, _) => Err(InvalidAccessError::IndexIntoScalar),
1122 }
1123 }
1124}
1125impl Eq for SymbolAccess {}
1126impl PartialEq for SymbolAccess {
1127 fn eq(&self, other: &Self) -> bool {
1128 self.name == other.name
1129 && self.access_type == other.access_type
1130 && self.offset == other.offset
1131 && self.ty == other.ty
1132 }
1133}
1134impl fmt::Debug for SymbolAccess {
1135 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1136 f.debug_struct("SymbolAccess")
1137 .field("name", &self.name)
1138 .field("access_type", &self.access_type)
1139 .field("offset", &self.offset)
1140 .field("ty", &self.ty)
1141 .finish()
1142 }
1143}
1144impl fmt::Display for SymbolAccess {
1145 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1146 write!(f, "{}", self.name)?;
1147 match &self.access_type {
1148 AccessType::Default => (),
1149 AccessType::Index(idx) => write!(f, "[{idx}]")?,
1150 AccessType::Slice(range) => write!(f, "[{}..{}]", range.start, range.end)?,
1151 AccessType::Matrix(row, col) => write!(f, "[{row}][{col}]")?,
1152 }
1153 for _ in 0..self.offset {
1155 f.write_str("'")?;
1156 }
1157 Ok(())
1158 }
1159}
1160
1161#[derive(Clone, Spanned)]
1165pub struct BoundedSymbolAccess {
1166 #[span]
1167 pub span: SourceSpan,
1168 pub boundary: Boundary,
1170 pub column: SymbolAccess,
1172}
1173impl BoundedSymbolAccess {
1174 pub const fn new(span: SourceSpan, column: SymbolAccess, boundary: Boundary) -> Self {
1175 Self {
1176 span,
1177 boundary,
1178 column,
1179 }
1180 }
1181}
1182impl Eq for BoundedSymbolAccess {}
1183impl PartialEq for BoundedSymbolAccess {
1184 fn eq(&self, other: &Self) -> bool {
1185 self.boundary == other.boundary && self.column == other.column
1186 }
1187}
1188impl fmt::Debug for BoundedSymbolAccess {
1189 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1190 f.debug_struct("BoundedSymbolAccess")
1191 .field("boundary", &self.boundary)
1192 .field("column", &self.column)
1193 .finish()
1194 }
1195}
1196
1197pub type ComprehensionContext = Vec<(Identifier, Expr)>;
1202
1203#[derive(Clone, Spanned)]
1204pub struct ListComprehension {
1205 #[span]
1206 pub span: SourceSpan,
1207 pub bindings: Vec<Identifier>,
1211 pub iterables: Vec<Expr>,
1217 pub body: Box<ScalarExpr>,
1219 pub selector: Option<ScalarExpr>,
1225 pub ty: Option<Type>,
1229}
1230impl ListComprehension {
1231 pub fn new(
1233 span: SourceSpan,
1234 body: ScalarExpr,
1235 mut context: ComprehensionContext,
1236 selector: Option<ScalarExpr>,
1237 ) -> Self {
1238 let bindings = context.iter().map(|(name, _)| name).copied().collect();
1239 let iterables = context.drain(..).map(|(_, iterable)| iterable).collect();
1240 Self {
1241 span,
1242 bindings,
1243 iterables,
1244 body: Box::new(body),
1245 selector,
1246 ty: None,
1247 }
1248 }
1249}
1250impl Eq for ListComprehension {}
1251impl PartialEq for ListComprehension {
1252 fn eq(&self, other: &Self) -> bool {
1253 self.bindings == other.bindings
1254 && self.iterables == other.iterables
1255 && self.body == other.body
1256 && self.selector == other.selector
1257 }
1258}
1259impl fmt::Debug for ListComprehension {
1260 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1261 f.debug_struct("ListComprehension")
1262 .field("bindings", &self.bindings)
1263 .field("iterables", &self.iterables)
1264 .field("body", self.body.as_ref())
1265 .field("selector", &self.selector)
1266 .field("ty", &self.ty)
1267 .finish()
1268 }
1269}
1270impl fmt::Display for ListComprehension {
1271 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1272 if self.bindings.len() == 1 {
1273 write!(
1274 f,
1275 "{} for {} in {}",
1276 &self.body, &self.bindings[0], &self.iterables[0]
1277 )?;
1278 } else {
1279 write!(
1280 f,
1281 "{} for {} in {}",
1282 &self.body,
1283 DisplayTuple(self.bindings.as_slice()),
1284 DisplayTuple(self.iterables.as_slice())
1285 )?;
1286 }
1287
1288 if let Some(selector) = self.selector.as_ref() {
1289 write!(f, " when {selector}")
1290 } else {
1291 Ok(())
1292 }
1293 }
1294}
1295
1296#[derive(Clone, Spanned)]
1297pub struct BusOperation {
1298 #[span]
1299 pub span: SourceSpan,
1300 pub bus: ResolvableIdentifier,
1301 pub op: BusOperator,
1302 pub args: Vec<Expr>,
1303}
1304
1305impl BusOperation {
1306 pub fn new(span: SourceSpan, bus: Identifier, op: BusOperator, args: Vec<Expr>) -> Self {
1307 Self {
1308 span,
1309 bus: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(bus)),
1310 op,
1311 args,
1312 }
1313 }
1314}
1315
1316impl Eq for BusOperation {}
1317impl PartialEq for BusOperation {
1318 fn eq(&self, other: &Self) -> bool {
1319 self.bus == other.bus && self.args == other.args && self.op == other.op
1320 }
1321}
1322impl fmt::Debug for BusOperation {
1323 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1324 f.debug_struct("BusOperation")
1325 .field("bus", &self.bus)
1326 .field("op", &self.op)
1327 .field("args", &self.args)
1328 .finish()
1329 }
1330}
1331impl fmt::Display for BusOperation {
1332 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1333 write!(
1334 f,
1335 "{}{}{}",
1336 self.bus,
1337 self.op,
1338 DisplayTuple(self.args.as_slice())
1339 )
1340 }
1341}
1342
1343#[derive(Clone, Spanned)]
1362pub struct Call {
1363 #[span]
1364 pub span: SourceSpan,
1365 pub callee: ResolvableIdentifier,
1366 pub args: Vec<Expr>,
1367 pub ty: Option<Type>,
1376}
1377impl Call {
1378 pub fn new(span: SourceSpan, callee: Identifier, args: Vec<Expr>) -> Self {
1379 use crate::symbols;
1380
1381 match callee.name() {
1382 symbols::Sum => Self::sum(span, args),
1383 symbols::Prod => Self::prod(span, args),
1384 _ => Self {
1385 span,
1386 callee: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Function(callee)),
1387 args,
1388 ty: None,
1389 },
1390 }
1391 }
1392
1393 #[inline]
1395 pub fn is_builtin(&self) -> bool {
1396 self.callee.is_builtin()
1397 }
1398
1399 #[inline]
1401 pub fn sum(span: SourceSpan, args: Vec<Expr>) -> Self {
1402 Self::new_builtin(span, "sum", args, Type::Felt)
1403 }
1404
1405 #[inline]
1407 pub fn prod(span: SourceSpan, args: Vec<Expr>) -> Self {
1408 Self::new_builtin(span, "prod", args, Type::Felt)
1409 }
1410
1411 fn new_builtin(span: SourceSpan, name: &str, args: Vec<Expr>, ty: Type) -> Self {
1412 let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin"));
1413 let name = Identifier::new(span, Symbol::intern(name));
1414 let id = QualifiedIdentifier::new(builtin_module, NamespacedIdentifier::Function(name));
1415 Self {
1416 span,
1417 callee: ResolvableIdentifier::Resolved(id),
1418 args,
1419 ty: Some(ty),
1420 }
1421 }
1422}
1423impl Eq for Call {}
1424impl PartialEq for Call {
1425 fn eq(&self, other: &Self) -> bool {
1426 self.callee == other.callee && self.args == other.args && self.ty == other.ty
1427 }
1428}
1429impl fmt::Debug for Call {
1430 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1431 f.debug_struct("Call")
1432 .field("callee", &self.callee)
1433 .field("args", &self.args)
1434 .field("ty", &self.ty)
1435 .finish()
1436 }
1437}
1438impl fmt::Display for Call {
1439 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1440 write!(f, "{}{}", self.callee, DisplayTuple(self.args.as_slice()))
1441 }
1442}