Skip to main content

air_parser/ast/
expression.rs

1//! This module provides AST structures which represent the various types of
2//! expressions, or types which are used primarily in expressions.
3//!
4//! Expressions always evaluate to a value, unlike statements, and as a result
5//! they can generally be arbitrarily nested to represent complex computations.
6//!
7//! However, in AirScript, the evaluation of constraints places limits on the types
8//! of values on which they can be enforced. Correspondingly, certain expressions are
9//! only usable in intermediate contexts (e.g. those which produce vectors/matrices), and
10//! must be reduced to scalars in constraints. As a result, we distinguish between scalar
11//! and non-scalar expression types.
12use std::{convert::AsRef, fmt};
13
14use miden_diagnostics::{SourceSpan, Span, Spanned};
15
16use crate::symbols::Symbol;
17
18use super::*;
19
20/// A range literal, equivalent to the interval `[start, end)`.
21pub type Range = std::ops::Range<usize>;
22
23/// Represents any type of identifier in AirScript
24#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Spanned)]
25pub struct Identifier(pub Span<Symbol>);
26impl Identifier {
27    pub fn new(span: SourceSpan, name: Symbol) -> Self {
28        Self(Span::new(span, name))
29    }
30
31    /// Returns the underlying symbol of the identifier.
32    pub fn name(&self) -> Symbol {
33        self.0.item
34    }
35
36    #[inline]
37    pub fn as_str(&self) -> &str {
38        self.0.as_str()
39    }
40
41    /// Returns true if all characters of this identifier are uppercase
42    pub fn is_uppercase(&self) -> bool {
43        self.0.as_str().chars().all(char::is_uppercase)
44    }
45
46    /// Returns true if this identifier was generated by the compiler
47    pub fn is_generated(&self) -> bool {
48        self.0.as_str().starts_with('%')
49    }
50
51    /// Returns true if this identifier has the `$` prefix associated with special identifiers
52    pub fn is_special(&self) -> bool {
53        self.0.as_str().starts_with('$')
54    }
55}
56impl PartialEq<&str> for Identifier {
57    #[inline]
58    fn eq(&self, other: &&str) -> bool {
59        self.0.item == *other
60    }
61}
62impl PartialEq<&Identifier> for Identifier {
63    #[inline]
64    fn eq(&self, other: &&Self) -> bool {
65        self == *other
66    }
67}
68impl fmt::Debug for Identifier {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_tuple("Identifier")
71            .field(&format!("{}", &self.0.item))
72            .finish()
73    }
74}
75impl fmt::Display for Identifier {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        write!(f, "{}", &self.0)
78    }
79}
80impl From<ResolvableIdentifier> for Identifier {
81    fn from(id: ResolvableIdentifier) -> Self {
82        match id {
83            ResolvableIdentifier::Local(id) => id,
84            ResolvableIdentifier::Global(id) => id,
85            ResolvableIdentifier::Resolved(qid) => qid.item.id(),
86            ResolvableIdentifier::Unresolved(nid) => nid.id(),
87        }
88    }
89}
90
91/// Represents an identifier qualified with its namespace.
92///
93/// Identifiers in AirScript are separated into two namespaces: one for functions,
94/// and one for buses and bindings. This is because functions cannot be bound, added to or remove from,
95/// while buses and bindings cannot be called.
96/// So we can always disambiguate identifiers based on its usage.
97///
98/// It is still probably best practice to avoid having name conflicts between functions,
99/// buses and bindings, but that is a matter of style rather than one of necessity.
100#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)]
101pub enum NamespacedIdentifier {
102    Function(#[span] Identifier),
103    Binding(#[span] Identifier),
104}
105impl NamespacedIdentifier {
106    pub fn id(&self) -> Identifier {
107        match self {
108            Self::Function(ident) | Self::Binding(ident) => *ident,
109        }
110    }
111}
112impl AsRef<Identifier> for NamespacedIdentifier {
113    fn as_ref(&self) -> &Identifier {
114        match self {
115            Self::Function(ident) | Self::Binding(ident) => ident,
116        }
117    }
118}
119impl From<ResolvableIdentifier> for NamespacedIdentifier {
120    fn from(id: ResolvableIdentifier) -> Self {
121        match id {
122            ResolvableIdentifier::Local(id) => Self::Binding(id),
123            ResolvableIdentifier::Global(id) => Self::Binding(id),
124            ResolvableIdentifier::Resolved(qid) => qid.item,
125            ResolvableIdentifier::Unresolved(nid) => nid,
126        }
127    }
128}
129impl fmt::Display for NamespacedIdentifier {
130    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
131        fmt::Display::fmt(self.as_ref(), f)
132    }
133}
134
135/// Represents an identifier qualified with both its parent module and namespace.
136///
137/// This represents a globally-unique identity for a declaration
138#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)]
139pub struct QualifiedIdentifier {
140    pub module: ModuleId,
141    #[span]
142    pub item: NamespacedIdentifier,
143}
144impl QualifiedIdentifier {
145    pub const fn new(module: ModuleId, item: NamespacedIdentifier) -> Self {
146        Self { module, item }
147    }
148
149    pub const fn id(&self) -> NamespacedIdentifier {
150        self.item
151    }
152
153    /// Returns the name of the item in its [Symbol] form
154    #[inline]
155    pub fn name(&self) -> Symbol {
156        self.as_ref().name()
157    }
158
159    /// Returns true if this identifier refers to a known builtin function
160    pub fn is_builtin(&self) -> bool {
161        use crate::symbols;
162
163        if self.module.name() == "$builtin" {
164            match self.item {
165                NamespacedIdentifier::Function(id) => {
166                    matches!(id.name(), symbols::Sum | symbols::Prod)
167                }
168                _ => false,
169            }
170        } else {
171            false
172        }
173    }
174}
175impl AsRef<Identifier> for QualifiedIdentifier {
176    #[inline]
177    fn as_ref(&self) -> &Identifier {
178        self.item.as_ref()
179    }
180}
181impl fmt::Display for QualifiedIdentifier {
182    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
183        write!(f, "{}::{}", &self.module, &self.item)
184    }
185}
186
187/// Represents an identifier which requires name resolution at some stage during lowering.
188#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Spanned)]
189pub enum ResolvableIdentifier {
190    /// This identifier is resolved to a local binding (i.e. function parameter or let-bound var)
191    Local(#[span] Identifier),
192    /// This identifier is resolved to a global binding
193    Global(#[span] Identifier),
194    /// This identifier is resolved to a non-local item (i.e. module-level declaration or imported item)
195    Resolved(#[span] QualifiedIdentifier),
196    /// This identifier is not yet resolved or is undefined in the current scope
197    Unresolved(#[span] NamespacedIdentifier),
198}
199impl ResolvableIdentifier {
200    /// Returns true if this identifier has been resolved locally or otherwise
201    #[inline]
202    pub fn is_resolved(&self) -> bool {
203        matches!(self, Self::Local(_) | Self::Global(_) | Self::Resolved(_))
204    }
205
206    /// Returns true if this is a locally-resolved identifier
207    pub fn is_local(&self) -> bool {
208        matches!(self, Self::Local(_))
209    }
210
211    /// Returns true if this is a globally-resolved identifier
212    pub fn is_global(&self) -> bool {
213        matches!(self, Self::Global(_))
214    }
215
216    /// Returns true if this identifier refers to a known builtin function
217    pub fn is_builtin(&self) -> bool {
218        match self {
219            Self::Resolved(qid) => qid.is_builtin(),
220            _ => false,
221        }
222    }
223
224    /// The module to which this identifier is resolved
225    ///
226    /// For locally-resolved identifiers, this returns `None`, same as
227    /// unresolved identifiers, check `is_resolved` to distinguish between
228    /// resolved/unresolved states
229    pub fn module(&self) -> Option<ModuleId> {
230        match self {
231            Self::Resolved(qid) => Some(*qid.as_ref()),
232            _ => None,
233        }
234    }
235
236    /// Obtains a [NamespacedIdentifier] from this identifier
237    #[inline]
238    pub fn namespaced(&self) -> NamespacedIdentifier {
239        (*self).into()
240    }
241
242    /// Gets the [QualifiedIdentifier] if this identifier is of type `Resolved`
243    #[inline]
244    pub fn resolved(&self) -> Option<QualifiedIdentifier> {
245        match self {
246            Self::Resolved(qid) => Some(*qid),
247            _ => None,
248        }
249    }
250}
251impl AsRef<Identifier> for ResolvableIdentifier {
252    #[inline]
253    fn as_ref(&self) -> &Identifier {
254        match self {
255            Self::Local(id) => id,
256            Self::Global(id) => id,
257            Self::Resolved(qid) => qid.item.as_ref(),
258            Self::Unresolved(nid) => nid.as_ref(),
259        }
260    }
261}
262impl fmt::Display for ResolvableIdentifier {
263    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
264        match self {
265            Self::Local(id) => write!(f, "{id}"),
266            Self::Global(id) => write!(f, "{id}"),
267            Self::Resolved(qid) => write!(f, "{qid}"),
268            Self::Unresolved(nid) => write!(f, "{nid}"),
269        }
270    }
271}
272
273/// Expressions which are valid in the body of a `let` statement, or in a function call.
274#[derive(Clone, PartialEq, Eq, Spanned)]
275pub enum Expr {
276    /// A constant expression
277    Const(Span<ConstantExpr>),
278    /// An expression which evaluates to a vector of integers in the given range
279    Range(RangeExpr),
280    /// A vector of expressions
281    ///
282    /// A vector may be used to represent matrices in some situations, but such matrices
283    /// must always be composed of scalar values. It is not permitted to have arbitrarily
284    /// deep vectors.
285    Vector(Span<Vec<Expr>>),
286    /// A matrix of scalar expressions
287    Matrix(Span<Vec<Vec<ScalarExpr>>>),
288    /// A reference to a named value of any type
289    SymbolAccess(SymbolAccess),
290    /// A binary operator over scalar values
291    Binary(BinaryExpr),
292    /// A call to a pure function
293    ///
294    /// NOTE: This expression is only valid when the call is a pure function;
295    /// calls to evaluators are not permitted in an `Expr` context, as they do
296    /// not produce a value.
297    Call(Call),
298    /// A generator expression which produces a vector or matrix of values
299    ListComprehension(ListComprehension),
300    /// A `let` expression, used to bind temporaries in expression position during compilation.
301    ///
302    /// NOTE: The AirScript syntax only permits `let` in statement position, so this variant
303    /// is only present in the AST as the result of an explicit transformation.
304    Let(Box<Let>),
305    /// A bus operation (`p.insert(...)` or `p.remove(...)`)
306    BusOperation(BusOperation),
307    /// An empty bus
308    Null(Span<()>),
309    /// An unconstrained bus
310    Unconstrained(Span<()>),
311}
312impl Expr {
313    /// Returns true if this expression is constant
314    ///
315    /// NOTE: This only returns true for the `Const` and `Range` variants
316    pub fn is_constant(&self) -> bool {
317        match self {
318            Self::Const(_) => true,
319            Self::Range(range) => range.is_constant(),
320            _ => false,
321        }
322    }
323
324    /// Returns the resolved type of this expression, if known
325    pub fn ty(&self) -> Option<Type> {
326        match self {
327            Self::Const(constant) => Some(constant.ty()),
328            Self::Range(range) => range.ty(),
329            Self::Vector(vector) => match vector.first().and_then(|e| e.ty()) {
330                Some(Type::Felt) => Some(Type::Vector(vector.len())),
331                Some(Type::Vector(n)) => Some(Type::Matrix(vector.len(), n)),
332                Some(_) => None,
333                None => Some(Type::Vector(0)),
334            },
335            Self::Matrix(matrix) => {
336                let rows = matrix.len();
337                let cols = matrix[0].len();
338                Some(Type::Matrix(rows, cols))
339            }
340            Self::SymbolAccess(access) => access.ty,
341            Self::Binary(_) => Some(Type::Felt),
342            Self::Call(call) => call.ty,
343            Self::ListComprehension(lc) => lc.ty,
344            Self::Let(let_expr) => let_expr.ty(),
345            Self::BusOperation(_) | Self::Null(_) | Self::Unconstrained(_) => Some(Type::Felt),
346        }
347    }
348}
349impl fmt::Debug for Expr {
350    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
351        match self {
352            Self::Const(expr) => f.debug_tuple("Const").field(&expr.item).finish(),
353            Self::Range(expr) => f.debug_tuple("Range").field(&expr).finish(),
354            Self::Vector(expr) => f.debug_tuple("Vector").field(&expr.item).finish(),
355            Self::Matrix(expr) => f.debug_tuple("Matrix").field(&expr.item).finish(),
356            Self::SymbolAccess(expr) => f.debug_tuple("SymbolAccess").field(expr).finish(),
357            Self::Binary(expr) => f.debug_tuple("Binary").field(expr).finish(),
358            Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(),
359            Self::ListComprehension(expr) => {
360                f.debug_tuple("ListComprehension").field(expr).finish()
361            }
362            Self::Let(let_expr) => write!(f, "{let_expr:#?}"),
363            Self::BusOperation(expr) => f.debug_tuple("BusOp").field(expr).finish(),
364            Self::Null(expr) => f.debug_tuple("Null").field(expr).finish(),
365            Self::Unconstrained(expr) => f.debug_tuple("Unconstrained").field(expr).finish(),
366        }
367    }
368}
369impl fmt::Display for Expr {
370    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
371        match self {
372            Self::Const(expr) => write!(f, "{}", &expr),
373            Self::Range(range) => write!(f, "{range}"),
374            Self::Vector(expr) => write!(f, "{}", DisplayList(expr.as_slice())),
375            Self::Matrix(expr) => {
376                f.write_str("[")?;
377                for (i, col) in expr.iter().enumerate() {
378                    if i > 0 {
379                        f.write_str(", ")?;
380                    }
381                    write!(f, "{}", DisplayList(col.as_slice()))?;
382                }
383                f.write_str("]")
384            }
385            Self::SymbolAccess(expr) => write!(f, "{expr}"),
386            Self::Binary(expr) => write!(f, "{expr}"),
387            Self::Call(expr) => write!(f, "{expr}"),
388            Self::ListComprehension(expr) => write!(f, "{}", DisplayBracketed(expr)),
389            Self::Let(let_expr) => {
390                let display = DisplayLet {
391                    let_expr,
392                    indent: 0,
393                    in_expr_position: true,
394                };
395                write!(f, "{display}")
396            }
397            Self::BusOperation(expr) => write!(f, "{expr}"),
398            Self::Null(_expr) => write!(f, "null"),
399            Self::Unconstrained(_expr) => write!(f, "unconstrained"),
400        }
401    }
402}
403impl From<SymbolAccess> for Expr {
404    #[inline]
405    fn from(expr: SymbolAccess) -> Self {
406        Self::SymbolAccess(expr)
407    }
408}
409impl From<BinaryExpr> for Expr {
410    #[inline]
411    fn from(expr: BinaryExpr) -> Self {
412        Self::Binary(expr)
413    }
414}
415impl From<Call> for Expr {
416    #[inline]
417    fn from(expr: Call) -> Self {
418        Self::Call(expr)
419    }
420}
421impl From<BusOperation> for Expr {
422    #[inline]
423    fn from(expr: BusOperation) -> Self {
424        Self::BusOperation(expr)
425    }
426}
427impl From<ListComprehension> for Expr {
428    #[inline]
429    fn from(expr: ListComprehension) -> Self {
430        Self::ListComprehension(expr)
431    }
432}
433impl TryFrom<Let> for Expr {
434    type Error = InvalidExprError;
435
436    fn try_from(expr: Let) -> Result<Self, Self::Error> {
437        if expr.ty().is_some() {
438            Ok(Self::Let(Box::new(expr)))
439        } else {
440            Err(InvalidExprError::InvalidLetExpr(expr.span()))
441        }
442    }
443}
444impl TryFrom<ScalarExpr> for Expr {
445    type Error = InvalidExprError;
446
447    #[inline]
448    fn try_from(expr: ScalarExpr) -> Result<Self, Self::Error> {
449        match expr {
450            ScalarExpr::Const(spanned) => Ok(Self::Const(Span::new(
451                spanned.span(),
452                ConstantExpr::Scalar(spanned.item),
453            ))),
454            ScalarExpr::SymbolAccess(access) => Ok(Self::SymbolAccess(access)),
455            ScalarExpr::Binary(expr) => Ok(Self::Binary(expr)),
456            ScalarExpr::Call(expr) => Ok(Self::Call(expr)),
457            ScalarExpr::BoundedSymbolAccess(_) => {
458                Err(InvalidExprError::BoundedSymbolAccess(expr.span()))
459            }
460            ScalarExpr::Let(expr) => Ok(Self::Let(expr)),
461            ScalarExpr::BusOperation(expr) => Ok(Self::BusOperation(expr)),
462            ScalarExpr::Null(spanned) => Ok(Self::Null(spanned)),
463            ScalarExpr::Unconstrained(spanned) => Ok(Self::Unconstrained(spanned)),
464        }
465    }
466}
467impl TryFrom<Statement> for Expr {
468    type Error = InvalidExprError;
469
470    fn try_from(stmt: Statement) -> Result<Self, Self::Error> {
471        match stmt {
472            Statement::Let(let_expr) => Ok(Self::Let(Box::new(let_expr))),
473            Statement::Expr(expr) => Ok(expr),
474            _ => Err(InvalidExprError::NotAnExpr(stmt.span())),
475        }
476    }
477}
478
479/// Scalar expressions are expressions which evaluate to a single scalar value,
480/// i.e. they have no vector or matrix elements. Only scalar expressions are valid
481/// in a constraint statement.
482#[derive(Clone, PartialEq, Eq, Spanned)]
483pub enum ScalarExpr {
484    /// A constant scalar value, i.e. integer
485    Const(Span<u64>),
486    /// A reference to a named value
487    ///
488    /// NOTE: Symbol accesses in a `ScalarExpr` context must produce scalar values.
489    SymbolAccess(SymbolAccess),
490    /// A reference to a trace column on a particular boundary of the trace, which must produce a scalar
491    ///
492    /// NOTE: This is only a valid expression in boundary constraints
493    BoundedSymbolAccess(BoundedSymbolAccess),
494    /// A binary operator over scalar values
495    Binary(BinaryExpr),
496    /// A call to a pure function or evaluator
497    ///
498    /// NOTE: This is only a valid scalar expression when one of the following hold:
499    ///
500    /// 1. The call is the top-level expression of a constraint, and is to an evaluator function
501    /// 2. The call is not the top-level expression of a constraint, and is to a pure function
502    ///    that produces a scalar value type.
503    ///
504    /// If neither of the above are true, the call is invalid in a `ScalarExpr` context
505    Call(Call),
506    /// An expression that binds a local variable to a temporary value during evaluation.
507    ///
508    /// NOTE: This is only a valid scalar expression during the inlining phase, when we expand
509    /// binary expressions or function calls to a block of statements, and only when the result
510    /// of evaluating the `let` produces a valid scalar expression.
511    Let(Box<Let>),
512    /// A bus operation
513    BusOperation(BusOperation),
514    /// An empty bus
515    Null(Span<()>),
516    /// An unconstrained bus
517    Unconstrained(Span<()>),
518}
519impl ScalarExpr {
520    /// Returns true if this is a constant value
521    pub fn is_constant(&self) -> bool {
522        matches!(self, Self::Const(_))
523    }
524
525    /// Returns true if this scalar expression could expand to a block, e.g. due to a function call being inlined.
526    pub fn has_block_like_expansion(&self) -> bool {
527        match self {
528            Self::Binary(expr) => expr.has_block_like_expansion(),
529            Self::Call(_) | Self::Let(_) => true,
530            _ => false,
531        }
532    }
533
534    /// Returns the resolved type of this expression, if known.
535    ///
536    /// Returns `Ok(Some)` if the type could be resolved without conflict.
537    /// Returns `Ok(None)` if type information was missing.
538    /// Returns `Err` if the type could not be resolved due to a conflict,
539    /// with a span covering the source of the conflict.
540    pub fn ty(&self) -> Result<Option<Type>, SourceSpan> {
541        match self {
542            Self::Const(_) => Ok(Some(Type::Felt)),
543            Self::SymbolAccess(sym) => Ok(sym.ty),
544            Self::BoundedSymbolAccess(sym) => Ok(sym.column.ty),
545            Self::Binary(expr) => match (expr.lhs.ty()?, expr.rhs.ty()?) {
546                (None, _) | (_, None) => Ok(None),
547                (Some(lty), Some(rty)) if lty == rty => Ok(Some(lty)),
548                _ => Err(expr.span()),
549            },
550            Self::Call(expr) => Ok(expr.ty),
551            Self::Let(expr) => Ok(expr.ty()),
552            Self::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
553                Ok(Some(Type::Felt))
554            }
555        }
556    }
557}
558impl TryFrom<Expr> for ScalarExpr {
559    type Error = InvalidExprError;
560
561    fn try_from(expr: Expr) -> Result<Self, Self::Error> {
562        match expr {
563            Expr::Const(constant) => {
564                let span = constant.span();
565                match constant.item {
566                    ConstantExpr::Scalar(v) => Ok(Self::Const(Span::new(span, v))),
567                    _ => Err(InvalidExprError::InvalidScalarExpr(span)),
568                }
569            }
570            Expr::SymbolAccess(sym) => Ok(Self::SymbolAccess(sym)),
571            Expr::Binary(bin) => Ok(Self::Binary(bin)),
572            Expr::Call(call) => Ok(Self::Call(call)),
573            Expr::Let(let_expr) => {
574                if let_expr.ty().is_none() {
575                    Err(InvalidExprError::InvalidScalarExpr(let_expr.span()))
576                } else {
577                    Ok(Self::Let(let_expr))
578                }
579            }
580            invalid => Err(InvalidExprError::InvalidScalarExpr(invalid.span())),
581        }
582    }
583}
584impl TryFrom<Statement> for ScalarExpr {
585    type Error = InvalidExprError;
586
587    fn try_from(stmt: Statement) -> Result<Self, Self::Error> {
588        match stmt {
589            Statement::Let(let_expr) => Self::try_from(Expr::Let(Box::new(let_expr))),
590            Statement::Expr(expr) => Self::try_from(expr),
591            stmt => Err(InvalidExprError::InvalidScalarExpr(stmt.span())),
592        }
593    }
594}
595impl From<u64> for ScalarExpr {
596    fn from(value: u64) -> Self {
597        Self::Const(Span::new(SourceSpan::UNKNOWN, value))
598    }
599}
600impl fmt::Debug for ScalarExpr {
601    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
602        match self {
603            Self::Const(i) => f.debug_tuple("Const").field(&i.item).finish(),
604            Self::SymbolAccess(expr) => f.debug_tuple("SymbolAccess").field(expr).finish(),
605            Self::BoundedSymbolAccess(expr) => {
606                f.debug_tuple("BoundedSymbolAccess").field(expr).finish()
607            }
608            Self::Binary(expr) => f.debug_tuple("Binary").field(expr).finish(),
609            Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(),
610            Self::Let(expr) => write!(f, "{expr:#?}"),
611            Self::BusOperation(expr) => f.debug_tuple("BusOp").field(expr).finish(),
612            Self::Null(expr) => f.debug_tuple("Null").field(expr).finish(),
613            Self::Unconstrained(expr) => f.debug_tuple("Unconstrained").field(expr).finish(),
614        }
615    }
616}
617impl fmt::Display for ScalarExpr {
618    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
619        match self {
620            Self::Const(value) => write!(f, "{value}"),
621            Self::SymbolAccess(expr) => write!(f, "{expr}"),
622            Self::BoundedSymbolAccess(expr) => write!(f, "{}.{}", &expr.column, &expr.boundary),
623            Self::Binary(expr) => write!(f, "{expr}"),
624            Self::Call(call) => write!(f, "{call}"),
625            Self::Let(let_expr) => {
626                let display = DisplayLet {
627                    let_expr,
628                    indent: 0,
629                    in_expr_position: true,
630                };
631                write!(f, "{display}")
632            }
633            Self::BusOperation(expr) => write!(f, "{expr}"),
634            Self::Null(_value) => write!(f, "null"),
635            Self::Unconstrained(_value) => write!(f, "unconstrained"),
636        }
637    }
638}
639
640/// Represents a symbol access to a named constant.
641#[derive(Clone, Spanned, Debug)]
642pub struct ConstSymbolAccess {
643    #[span]
644    pub span: SourceSpan,
645    pub name: ResolvableIdentifier,
646    pub ty: Option<Type>,
647}
648impl ConstSymbolAccess {
649    pub fn new(span: SourceSpan, name: Identifier) -> Self {
650        Self {
651            span,
652            name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(name)),
653            ty: None,
654        }
655    }
656}
657impl Eq for ConstSymbolAccess {}
658impl PartialEq for ConstSymbolAccess {
659    fn eq(&self, other: &Self) -> bool {
660        self.name.eq(&other.name) && self.ty.eq(&other.ty)
661    }
662}
663impl std::hash::Hash for ConstSymbolAccess {
664    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
665        self.name.hash(state);
666        self.ty.hash(state);
667    }
668}
669impl fmt::Display for ConstSymbolAccess {
670    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
671        write!(f, "{}", &self.name)
672    }
673}
674
675#[derive(Debug, Clone, Spanned)]
676pub struct RangeExpr {
677    #[span]
678    pub span: SourceSpan,
679    pub start: RangeBound,
680    pub end: RangeBound,
681}
682
683impl TryFrom<&RangeExpr> for Range {
684    type Error = InvalidExprError;
685
686    #[inline]
687    fn try_from(expr: &RangeExpr) -> Result<Self, InvalidExprError> {
688        match (&expr.start, &expr.end) {
689            (RangeBound::Const(lhs), RangeBound::Const(rhs)) => Ok(lhs.item..rhs.item),
690            _ => Err(InvalidExprError::NonConstantRangeExpr(expr.span)),
691        }
692    }
693}
694
695impl RangeExpr {
696    pub fn is_constant(&self) -> bool {
697        self.start.is_constant() && self.end.is_constant()
698    }
699
700    /// Converts this range expression to a `Range` type, assuming it is constant.
701    /// Panics if the range is not constant.
702    pub fn to_slice_range(&self) -> Range {
703        self.try_into()
704            .expect("attempted to convert non-constant range expression to constant")
705    }
706
707    pub fn ty(&self) -> Option<Type> {
708        match (&self.start, &self.end) {
709            (RangeBound::Const(start), RangeBound::Const(end)) => {
710                Some(Type::Vector(end.item.abs_diff(start.item)))
711            }
712            _ => None,
713        }
714    }
715}
716impl From<Range> for RangeExpr {
717    fn from(range: Range) -> Self {
718        Self {
719            span: SourceSpan::default(),
720            start: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.start)),
721            end: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.end)),
722        }
723    }
724}
725impl Eq for RangeExpr {}
726impl PartialEq for RangeExpr {
727    fn eq(&self, other: &Self) -> bool {
728        self.start.eq(&other.start) && self.end.eq(&other.end)
729    }
730}
731impl std::hash::Hash for RangeExpr {
732    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
733        self.start.hash(state);
734        self.end.hash(state);
735    }
736}
737impl fmt::Display for RangeExpr {
738    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
739        write!(f, "{}..{}", &self.start, &self.end)
740    }
741}
742
743#[derive(Hash, Clone, Spanned, PartialEq, Eq, Debug)]
744pub enum RangeBound {
745    SymbolAccess(ConstSymbolAccess),
746    Const(Span<usize>),
747}
748impl RangeBound {
749    pub fn is_constant(&self) -> bool {
750        matches!(self, Self::Const(_))
751    }
752}
753impl From<Identifier> for RangeBound {
754    fn from(name: Identifier) -> Self {
755        Self::SymbolAccess(ConstSymbolAccess::new(name.span(), name))
756    }
757}
758impl From<usize> for RangeBound {
759    fn from(constant: usize) -> Self {
760        Self::Const(Span::new(SourceSpan::UNKNOWN, constant))
761    }
762}
763impl fmt::Display for RangeBound {
764    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
765        match self {
766            Self::SymbolAccess(sym) => write!(f, "{sym}"),
767            Self::Const(constant) => write!(f, "{constant}"),
768        }
769    }
770}
771
772/// Represents an expression requiring evaluation of a binary operator
773#[derive(Clone, Spanned)]
774pub struct BinaryExpr {
775    #[span]
776    pub span: SourceSpan,
777    pub op: BinaryOp,
778    pub lhs: Box<ScalarExpr>,
779    pub rhs: Box<ScalarExpr>,
780}
781impl BinaryExpr {
782    pub fn new(span: SourceSpan, op: BinaryOp, lhs: ScalarExpr, rhs: ScalarExpr) -> Self {
783        Self {
784            span,
785            op,
786            lhs: Box::new(lhs),
787            rhs: Box::new(rhs),
788        }
789    }
790
791    /// Returns true if this binary expression could expand to a block, e.g. due to a function call being inlined.
792    #[inline]
793    pub fn has_block_like_expansion(&self) -> bool {
794        self.lhs.has_block_like_expansion() || self.rhs.has_block_like_expansion()
795    }
796}
797impl Eq for BinaryExpr {}
798impl PartialEq for BinaryExpr {
799    fn eq(&self, other: &Self) -> bool {
800        self.op == other.op && self.lhs == other.lhs && self.rhs == other.rhs
801    }
802}
803impl fmt::Debug for BinaryExpr {
804    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
805        f.debug_struct("BinaryExpr")
806            .field("op", &self.op)
807            .field("lhs", self.lhs.as_ref())
808            .field("rhs", self.rhs.as_ref())
809            .finish()
810    }
811}
812impl fmt::Display for BinaryExpr {
813    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
814        write!(f, "{} {} {}", &self.lhs, &self.op, &self.rhs)
815    }
816}
817
818#[derive(Debug, Copy, Clone, PartialEq, Eq)]
819pub enum BinaryOp {
820    /// Addition
821    Add,
822    /// Subtraction
823    Sub,
824    /// Multiplication
825    Mul,
826    /// Exponentiation
827    Exp,
828    /// Equality
829    ///
830    /// NOTE: This is only used in constraints to assert equality, it is invalid in other contexts
831    Eq,
832}
833impl fmt::Display for BinaryOp {
834    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
835        match self {
836            Self::Add => f.write_str("+"),
837            Self::Sub => f.write_str("-"),
838            Self::Mul => f.write_str("*"),
839            Self::Exp => f.write_str("^"),
840            Self::Eq => f.write_str("="),
841        }
842    }
843}
844
845/// Describes the type of boundary in the boundary constraint.
846#[derive(Debug, Copy, Clone, PartialEq, Default, Eq)]
847pub enum Boundary {
848    #[default]
849    First,
850    Last,
851}
852impl fmt::Display for Boundary {
853    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
854        match &self {
855            Self::First => write!(f, "first"),
856            Self::Last => write!(f, "last"),
857        }
858    }
859}
860
861/// Represents the way an identifier is accessed/referenced in the source.
862#[derive(Hash, Debug, Clone, Eq, PartialEq, Default)]
863pub enum AccessType {
864    /// Access refers to the entire bound value
865    #[default]
866    Default,
867    /// Access binds a sub-slice of a vector
868    Slice(RangeExpr),
869    /// Access binds the value at a specific index of an aggregate value (i.e. vector or matrix)
870    ///
871    /// The result type may be either a scalar or a vector, depending on the type of the aggregate
872    Index(usize),
873    /// Access binds the value at a specific row and column of a matrix value
874    Matrix(usize, usize),
875}
876impl fmt::Display for AccessType {
877    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
878        match self {
879            Self::Default => write!(f, "direct reference by name"),
880            Self::Slice(range) => write!(
881                f,
882                "slice of elements at indices {}..{}",
883                range.start, range.end
884            ),
885            Self::Index(idx) => write!(f, "reference to element at index {idx}"),
886            Self::Matrix(row, col) => write!(f, "reference to value in matrix at [{row}][{col}]"),
887        }
888    }
889}
890
891#[derive(Debug, Clone, thiserror::Error)]
892pub enum InvalidAccessError {
893    #[error("attempted to access undefined variable")]
894    UndefinedVariable,
895    #[error("attempted to access a function as a variable")]
896    InvalidBinding,
897    #[error("attempted to take a slice of a scalar value")]
898    SliceOfScalar,
899    #[error("attempted to take a slice of a matrix value")]
900    SliceOfMatrix,
901    #[error("attempted to index into a scalar value")]
902    IndexIntoScalar,
903    #[error("attempted to access an index which is out of bounds")]
904    IndexOutOfBounds,
905}
906
907/// [SymbolAccess] represents access to a named item in the source code; one of the following:
908///
909/// * A global name associated with trace columns or public inputs
910/// * A named constant
911/// * A module-local name associated with periodic columns
912/// * A evaluator/function parameter
913/// * A let-bound variable
914#[derive(Clone, Spanned)]
915pub struct SymbolAccess {
916    #[span]
917    pub span: SourceSpan,
918    /// The symbol being accessed
919    pub name: ResolvableIdentifier,
920    /// The type of access
921    pub access_type: AccessType,
922    /// Used when the accessing a trace column with `'`, indicates the offset from
923    /// the current row in the trace. Defaults to zero.
924    ///
925    /// NOTE: When accessed with an offset, trace columns are treated as scalar values,
926    /// not as trace columns proper. What this means is that such an access cannot be
927    /// used in a context where a trace column is expected, only where a scalar value
928    /// is expected.
929    pub offset: usize,
930    /// Used during name resolution/type checking to store the type associated with
931    /// the value produced by the symbol access. If unset, it simply means that the
932    /// type has not been checked/resolved.
933    pub ty: Option<Type>,
934}
935impl SymbolAccess {
936    pub const fn new(
937        span: SourceSpan,
938        name: Identifier,
939        access_type: AccessType,
940        offset: usize,
941    ) -> Self {
942        Self {
943            span,
944            name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(name)),
945            access_type,
946            offset,
947            ty: None,
948        }
949    }
950
951    /// Generates a new [SymbolAccess] that represents accessing this access, i.e.
952    /// nesting accesses. For example, if called with `AccessType::Index`, and
953    /// the current access type is `Default`, a new [SymbolAccess] is returned which
954    /// has an access type of `Index`. However, if the current access type was `Index`,
955    /// then the resulting [SymbolAccess] will have an access type of `Matrix`
956    ///
957    /// It is expected that the type of this access has been resolved already, and this
958    /// function will panic if that is not the case.
959    pub fn access(&self, access_type: AccessType) -> Result<Self, InvalidAccessError> {
960        match &self.access_type {
961            AccessType::Default => self.access_default(access_type),
962            AccessType::Slice(base_range) => {
963                self.access_slice(base_range.to_slice_range(), access_type)
964            }
965            AccessType::Index(base_idx) => self.access_index(*base_idx, access_type),
966            AccessType::Matrix(_, _) => match access_type {
967                AccessType::Default => Ok(self.clone()),
968                _ => Err(InvalidAccessError::IndexIntoScalar),
969            },
970        }
971    }
972
973    fn access_default(&self, access_type: AccessType) -> Result<Self, InvalidAccessError> {
974        let ty = self.ty.unwrap();
975        match access_type {
976            AccessType::Default => Ok(self.clone()),
977            AccessType::Index(idx) => match ty {
978                Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
979                Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
980                Type::Vector(_) => Ok(Self {
981                    access_type: AccessType::Index(idx),
982                    ty: Some(Type::Felt),
983                    ..self.clone()
984                }),
985                Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
986                Type::Matrix(_, cols) => Ok(Self {
987                    access_type: AccessType::Index(idx),
988                    ty: Some(Type::Vector(cols)),
989                    ..self.clone()
990                }),
991            },
992            AccessType::Slice(range) => {
993                let slice_range = range.to_slice_range();
994                let rlen = slice_range.end - slice_range.start;
995                match ty {
996                    Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
997                    Type::Vector(len) if slice_range.end > len => {
998                        Err(InvalidAccessError::IndexOutOfBounds)
999                    }
1000                    Type::Vector(_) => Ok(Self {
1001                        access_type: AccessType::Slice(range),
1002                        ty: Some(Type::Vector(rlen)),
1003                        ..self.clone()
1004                    }),
1005                    Type::Matrix(rows, _) if slice_range.end > rows => {
1006                        Err(InvalidAccessError::IndexOutOfBounds)
1007                    }
1008                    Type::Matrix(_, cols) => Ok(Self {
1009                        access_type: AccessType::Slice(range),
1010                        ty: Some(Type::Matrix(rlen, cols)),
1011                        ..self.clone()
1012                    }),
1013                }
1014            }
1015            AccessType::Matrix(row, col) => match ty {
1016                Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar),
1017                Type::Matrix(rows, cols) if row >= rows || col >= cols => {
1018                    Err(InvalidAccessError::IndexOutOfBounds)
1019                }
1020                Type::Matrix(_, _) => Ok(Self {
1021                    access_type: AccessType::Matrix(row, col),
1022                    ty: Some(Type::Felt),
1023                    ..self.clone()
1024                }),
1025            },
1026        }
1027    }
1028
1029    fn access_slice(
1030        &self,
1031        base_range: Range,
1032        access_type: AccessType,
1033    ) -> Result<Self, InvalidAccessError> {
1034        let ty = self.ty.unwrap();
1035        match access_type {
1036            AccessType::Default => Ok(self.clone()),
1037            AccessType::Index(idx) => match ty {
1038                Type::Felt => unreachable!(),
1039                Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
1040                Type::Vector(_) => Ok(Self {
1041                    access_type: AccessType::Index(base_range.start + idx),
1042                    ty: Some(Type::Felt),
1043                    ..self.clone()
1044                }),
1045                Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
1046                Type::Matrix(_, cols) => Ok(Self {
1047                    access_type: AccessType::Index(base_range.start + idx),
1048                    ty: Some(Type::Vector(cols)),
1049                    ..self.clone()
1050                }),
1051            },
1052            AccessType::Slice(range) => {
1053                let slice_range = range.to_slice_range();
1054                let blen = base_range.end - base_range.start;
1055                let rlen = slice_range.len();
1056                let start = base_range.start + slice_range.start;
1057                let end = slice_range.start + slice_range.end;
1058                let shifted = RangeExpr {
1059                    span: range.span,
1060                    start: RangeBound::Const(Span::new(range.start.span(), start)),
1061                    end: RangeBound::Const(Span::new(range.end.span(), end)),
1062                };
1063                match ty {
1064                    Type::Felt => unreachable!(),
1065                    Type::Vector(_) if slice_range.end > blen => {
1066                        Err(InvalidAccessError::IndexOutOfBounds)
1067                    }
1068                    Type::Vector(_) => Ok(Self {
1069                        access_type: AccessType::Slice(shifted),
1070                        ty: Some(Type::Vector(rlen)),
1071                        ..self.clone()
1072                    }),
1073                    Type::Matrix(rows, _) if slice_range.end > rows => {
1074                        Err(InvalidAccessError::IndexOutOfBounds)
1075                    }
1076                    Type::Matrix(_, cols) => Ok(Self {
1077                        access_type: AccessType::Slice(shifted),
1078                        ty: Some(Type::Matrix(rlen, cols)),
1079                        ..self.clone()
1080                    }),
1081                }
1082            }
1083            AccessType::Matrix(row, col) => match ty {
1084                Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar),
1085                Type::Matrix(rows, cols) if row >= rows || col >= cols => {
1086                    Err(InvalidAccessError::IndexOutOfBounds)
1087                }
1088                Type::Matrix(_, _) => Ok(Self {
1089                    access_type: AccessType::Matrix(row, col),
1090                    ty: Some(Type::Felt),
1091                    ..self.clone()
1092                }),
1093            },
1094        }
1095    }
1096
1097    fn access_index(
1098        &self,
1099        base_idx: usize,
1100        access_type: AccessType,
1101    ) -> Result<Self, InvalidAccessError> {
1102        let ty = self.ty.unwrap();
1103        match access_type {
1104            AccessType::Default => Ok(self.clone()),
1105            AccessType::Index(idx) => match ty {
1106                Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
1107                Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
1108                Type::Vector(_) => Ok(Self {
1109                    access_type: AccessType::Matrix(base_idx, idx),
1110                    ty: Some(Type::Felt),
1111                    ..self.clone()
1112                }),
1113                Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
1114                Type::Matrix(_, cols) => Ok(Self {
1115                    access_type: AccessType::Matrix(base_idx, idx),
1116                    ty: Some(Type::Vector(cols)),
1117                    ..self.clone()
1118                }),
1119            },
1120            AccessType::Slice(_) => Err(InvalidAccessError::SliceOfMatrix),
1121            AccessType::Matrix(_, _) => Err(InvalidAccessError::IndexIntoScalar),
1122        }
1123    }
1124}
1125impl Eq for SymbolAccess {}
1126impl PartialEq for SymbolAccess {
1127    fn eq(&self, other: &Self) -> bool {
1128        self.name == other.name
1129            && self.access_type == other.access_type
1130            && self.offset == other.offset
1131            && self.ty == other.ty
1132    }
1133}
1134impl fmt::Debug for SymbolAccess {
1135    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1136        f.debug_struct("SymbolAccess")
1137            .field("name", &self.name)
1138            .field("access_type", &self.access_type)
1139            .field("offset", &self.offset)
1140            .field("ty", &self.ty)
1141            .finish()
1142    }
1143}
1144impl fmt::Display for SymbolAccess {
1145    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1146        write!(f, "{}", self.name)?;
1147        match &self.access_type {
1148            AccessType::Default => (),
1149            AccessType::Index(idx) => write!(f, "[{idx}]")?,
1150            AccessType::Slice(range) => write!(f, "[{}..{}]", range.start, range.end)?,
1151            AccessType::Matrix(row, col) => write!(f, "[{row}][{col}]")?,
1152        }
1153        // TODO: When we change the syntax to support arbitrary offsets, we'll need to update this
1154        for _ in 0..self.offset {
1155            f.write_str("'")?;
1156        }
1157        Ok(())
1158    }
1159}
1160
1161/// A [SymbolAccess] on a specific [Boundary] of a trace column.
1162///
1163/// The underlying symbol must refer to a trace column, or the access is invalid.
1164#[derive(Clone, Spanned)]
1165pub struct BoundedSymbolAccess {
1166    #[span]
1167    pub span: SourceSpan,
1168    /// The boundary on which this access will be evaluated
1169    pub boundary: Boundary,
1170    /// The column access metadata
1171    pub column: SymbolAccess,
1172}
1173impl BoundedSymbolAccess {
1174    pub const fn new(span: SourceSpan, column: SymbolAccess, boundary: Boundary) -> Self {
1175        Self {
1176            span,
1177            boundary,
1178            column,
1179        }
1180    }
1181}
1182impl Eq for BoundedSymbolAccess {}
1183impl PartialEq for BoundedSymbolAccess {
1184    fn eq(&self, other: &Self) -> bool {
1185        self.boundary == other.boundary && self.column == other.column
1186    }
1187}
1188impl fmt::Debug for BoundedSymbolAccess {
1189    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1190        f.debug_struct("BoundedSymbolAccess")
1191            .field("boundary", &self.boundary)
1192            .field("column", &self.column)
1193            .finish()
1194    }
1195}
1196
1197/// Represents the bindings in a comprehension expression.
1198///
1199/// Each element consists of a name, and an expression which evaluates to an iterable,
1200/// where the name will be bound to each element of the iterable in the comprehension body.
1201pub type ComprehensionContext = Vec<(Identifier, Expr)>;
1202
1203#[derive(Clone, Spanned)]
1204pub struct ListComprehension {
1205    #[span]
1206    pub span: SourceSpan,
1207    /// The names to be bound to each element of their corresponding iterable in `iterables`
1208    ///
1209    /// NOTE: There must be the same number of bindings as iterables.
1210    pub bindings: Vec<Identifier>,
1211    /// The generators for this comprehension.
1212    ///
1213    /// Each iterable must produce the same number of elements as the others.
1214    ///
1215    /// NOTE: There must be the same number of iterables as bindings.
1216    pub iterables: Vec<Expr>,
1217    /// The expression which will be evaluated at each step of the comprehension
1218    pub body: Box<ScalarExpr>,
1219    /// An optional filter applied to the generator expression at each iteration, which
1220    /// skips values for which the selector evaluates to zero (false).
1221    ///
1222    /// When the comprehension is used as a constraint, this field is only valid for
1223    /// use in integrity constraints.
1224    pub selector: Option<ScalarExpr>,
1225    /// The type of the result of this list comprehension, e.g. `vector[5]`
1226    ///
1227    /// This is set during semantic analysis
1228    pub ty: Option<Type>,
1229}
1230impl ListComprehension {
1231    /// Creates a new list comprehension.
1232    pub fn new(
1233        span: SourceSpan,
1234        body: ScalarExpr,
1235        mut context: ComprehensionContext,
1236        selector: Option<ScalarExpr>,
1237    ) -> Self {
1238        let bindings = context.iter().map(|(name, _)| name).copied().collect();
1239        let iterables = context.drain(..).map(|(_, iterable)| iterable).collect();
1240        Self {
1241            span,
1242            bindings,
1243            iterables,
1244            body: Box::new(body),
1245            selector,
1246            ty: None,
1247        }
1248    }
1249}
1250impl Eq for ListComprehension {}
1251impl PartialEq for ListComprehension {
1252    fn eq(&self, other: &Self) -> bool {
1253        self.bindings == other.bindings
1254            && self.iterables == other.iterables
1255            && self.body == other.body
1256            && self.selector == other.selector
1257    }
1258}
1259impl fmt::Debug for ListComprehension {
1260    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1261        f.debug_struct("ListComprehension")
1262            .field("bindings", &self.bindings)
1263            .field("iterables", &self.iterables)
1264            .field("body", self.body.as_ref())
1265            .field("selector", &self.selector)
1266            .field("ty", &self.ty)
1267            .finish()
1268    }
1269}
1270impl fmt::Display for ListComprehension {
1271    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1272        if self.bindings.len() == 1 {
1273            write!(
1274                f,
1275                "{} for {} in {}",
1276                &self.body, &self.bindings[0], &self.iterables[0]
1277            )?;
1278        } else {
1279            write!(
1280                f,
1281                "{} for {} in {}",
1282                &self.body,
1283                DisplayTuple(self.bindings.as_slice()),
1284                DisplayTuple(self.iterables.as_slice())
1285            )?;
1286        }
1287
1288        if let Some(selector) = self.selector.as_ref() {
1289            write!(f, " when {selector}")
1290        } else {
1291            Ok(())
1292        }
1293    }
1294}
1295
1296#[derive(Clone, Spanned)]
1297pub struct BusOperation {
1298    #[span]
1299    pub span: SourceSpan,
1300    pub bus: ResolvableIdentifier,
1301    pub op: BusOperator,
1302    pub args: Vec<Expr>,
1303}
1304
1305impl BusOperation {
1306    pub fn new(span: SourceSpan, bus: Identifier, op: BusOperator, args: Vec<Expr>) -> Self {
1307        Self {
1308            span,
1309            bus: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(bus)),
1310            op,
1311            args,
1312        }
1313    }
1314}
1315
1316impl Eq for BusOperation {}
1317impl PartialEq for BusOperation {
1318    fn eq(&self, other: &Self) -> bool {
1319        self.bus == other.bus && self.args == other.args && self.op == other.op
1320    }
1321}
1322impl fmt::Debug for BusOperation {
1323    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1324        f.debug_struct("BusOperation")
1325            .field("bus", &self.bus)
1326            .field("op", &self.op)
1327            .field("args", &self.args)
1328            .finish()
1329    }
1330}
1331impl fmt::Display for BusOperation {
1332    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1333        write!(
1334            f,
1335            "{}{}{}",
1336            self.bus,
1337            self.op,
1338            DisplayTuple(self.args.as_slice())
1339        )
1340    }
1341}
1342
1343/// Represents a function call (either a pure function or an evaluator).
1344///
1345/// Calls are permitted in a scalar expression context, but arguments to the
1346/// callee may be non-scalar expressions - it is expected that the callee produces
1347/// a scalar in such contexts.
1348///
1349/// Calls to pure functions return scalar or non-scalar values, and so they may be
1350/// used any place other scalar expressions are supported.
1351///
1352/// Calls to evaluators are restricted, and have different semantics, though they appear
1353/// much the same in the language syntax. In particular, evaluators are effectivley callable
1354/// constraints, and may only appear as the sole expression of a constraint, e.g. `enf foo([a, b])`.
1355///
1356/// Because evaluators behave like constraints, they produce no value, so the "type" of a call
1357/// expression which invokes an evaluator is void. For this reason, calls to evaluators
1358/// always have a type of `None`. Pure functions on the other hand must always produce a value,
1359/// so such calls will always have a valid type. The only time when calls to pure functions will
1360/// have a `None` type is prior to name resolution in the semantic analysis pass.
1361#[derive(Clone, Spanned)]
1362pub struct Call {
1363    #[span]
1364    pub span: SourceSpan,
1365    pub callee: ResolvableIdentifier,
1366    pub args: Vec<Expr>,
1367    /// Used to store the type produced by a call to a pure function
1368    ///
1369    /// The reason this field is an `Option` is two-fold:
1370    ///
1371    /// * Calls to evaluators produce no value, and thus have no type
1372    /// * When parsed, the callee has not yet been resolved, so we don't know the
1373    ///   type of the function being called. During semantic analysis, the callee is
1374    ///   resolved and this field is set to the result type of that function.
1375    pub ty: Option<Type>,
1376}
1377impl Call {
1378    pub fn new(span: SourceSpan, callee: Identifier, args: Vec<Expr>) -> Self {
1379        use crate::symbols;
1380
1381        match callee.name() {
1382            symbols::Sum => Self::sum(span, args),
1383            symbols::Prod => Self::prod(span, args),
1384            _ => Self {
1385                span,
1386                callee: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Function(callee)),
1387                args,
1388                ty: None,
1389            },
1390        }
1391    }
1392
1393    /// Returns true if the callee is a builtin function, e.g. `sum`
1394    #[inline]
1395    pub fn is_builtin(&self) -> bool {
1396        self.callee.is_builtin()
1397    }
1398
1399    /// Constructs a function call for the `sum` reducer/fold
1400    #[inline]
1401    pub fn sum(span: SourceSpan, args: Vec<Expr>) -> Self {
1402        Self::new_builtin(span, "sum", args, Type::Felt)
1403    }
1404
1405    /// Constructs a function call for the `prod` reducer/fold
1406    #[inline]
1407    pub fn prod(span: SourceSpan, args: Vec<Expr>) -> Self {
1408        Self::new_builtin(span, "prod", args, Type::Felt)
1409    }
1410
1411    fn new_builtin(span: SourceSpan, name: &str, args: Vec<Expr>, ty: Type) -> Self {
1412        let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin"));
1413        let name = Identifier::new(span, Symbol::intern(name));
1414        let id = QualifiedIdentifier::new(builtin_module, NamespacedIdentifier::Function(name));
1415        Self {
1416            span,
1417            callee: ResolvableIdentifier::Resolved(id),
1418            args,
1419            ty: Some(ty),
1420        }
1421    }
1422}
1423impl Eq for Call {}
1424impl PartialEq for Call {
1425    fn eq(&self, other: &Self) -> bool {
1426        self.callee == other.callee && self.args == other.args && self.ty == other.ty
1427    }
1428}
1429impl fmt::Debug for Call {
1430    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1431        f.debug_struct("Call")
1432            .field("callee", &self.callee)
1433            .field("args", &self.args)
1434            .field("ty", &self.ty)
1435            .finish()
1436    }
1437}
1438impl fmt::Display for Call {
1439    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1440        write!(f, "{}{}", self.callee, DisplayTuple(self.args.as_slice()))
1441    }
1442}