air_parser/ast/
declarations.rs

1//! This module provides AST structures which represent declarations permitted at module scope.
2//!
3//! There are no expressions/statements permitted in the top-level of a module, only declarations.
4//! These declarations define named items which are used by functions/constraints during evaluation.
5//!
6//! Some declarations introduce identifiers at global scope, i.e. they are implicitly defined in all
7//! modules regardless of imports.
8//!
9//! Certain declarations are only permitted in the root module of an AirScript program, as they are
10//! also effectively global:
11//!
12//! * `trace_columns`
13//! * `public_inputs`
14//! * `boundary_constraints`
15//! * `integrity_constraints`
16//!
17//! All other declarations are module-scoped, and must be explicitly imported by a module which wishes
18//! to reference them. Not all items are importable however, only the following:
19//!
20//! * constants
21//! * evaluators
22//! * pure functions
23//!
24//! There is no notion of public/private visiblity, so any declaration of the above types may be
25//! imported into another module, and "wildcard" imports will import all importable items.
26use std::{collections::HashSet, fmt};
27
28use miden_diagnostics::{SourceSpan, Spanned};
29
30use super::*;
31
32/// Represents all of the top-level items permitted at module scope.
33#[derive(Debug, PartialEq, Eq, Spanned)]
34pub enum Declaration {
35    /// Import one or more items from the specified AirScript module to the current module
36    Import(Span<Import>),
37    /// A Bus section declaration
38    Buses(Span<Vec<Bus>>),
39    /// A constant value declaration
40    Constant(Constant),
41    /// An evaluator function definition
42    ///
43    /// Evaluator functions can be defined in any module of the program
44    EvaluatorFunction(EvaluatorFunction),
45    /// A pure function definition
46    ///
47    /// Pure functions can be defined in any module of the program
48    Function(Function),
49    /// A `periodic_columns` section declaration
50    ///
51    /// This may appear any number of times in the program, and may be declared in any module.
52    PeriodicColumns(Span<Vec<PeriodicColumn>>),
53    /// A `public_inputs` section declaration
54    ///
55    /// There may only be one of these in the entire program, and it must
56    /// appear in the root AirScript module, i.e. in a module declared with `def`
57    PublicInputs(Span<Vec<PublicInput>>),
58    /// A `trace_bindings` section declaration
59    ///
60    /// There may only be one of these in the entire program, and it must
61    /// appear in the root AirScript module, i.e. in a module declared with `def`
62    Trace(Span<Vec<TraceSegment>>),
63    /// A `boundary_constraints` section declaration
64    ///
65    /// There may only be one of these in the entire program, and it must
66    /// appear in the root AirScript module, i.e. in a module declared with `def`
67    BoundaryConstraints(Span<Vec<Statement>>),
68    /// A `integrity_constraints` section declaration
69    ///
70    /// There may only be one of these in the entire program, and it must
71    /// appear in the root AirScript module, i.e. in a module declared with `def`
72    IntegrityConstraints(Span<Vec<Statement>>),
73}
74
75/// Represents a bus declaration in an AirScript module.
76#[derive(Debug, Clone, Spanned)]
77pub struct Bus {
78    #[span]
79    pub span: SourceSpan,
80    pub name: Identifier,
81    pub bus_type: BusType,
82}
83impl Bus {
84    /// Creates a new bus declaration
85    pub const fn new(span: SourceSpan, name: Identifier, bus_type: BusType) -> Self {
86        Self {
87            span,
88            name,
89            bus_type,
90        }
91    }
92}
93#[derive(Default, Copy, Hash, Debug, Clone, PartialEq, Eq)]
94pub enum BusType {
95    /// A multiset bus
96    #[default]
97    Multiset,
98    /// A logup bus
99    Logup,
100}
101
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum BusOperator {
104    /// Insert a tuple to the bus
105    Insert,
106    /// Remove a tuple from the bus
107    Remove,
108}
109impl std::fmt::Display for BusOperator {
110    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
111        match self {
112            Self::Insert => write!(f, "insert"),
113            Self::Remove => write!(f, "remove"),
114        }
115    }
116}
117
118impl Eq for Bus {}
119impl PartialEq for Bus {
120    fn eq(&self, other: &Self) -> bool {
121        self.name == other.name && self.bus_type == other.bus_type
122    }
123}
124
125/// Stores a constant's name and value. There are three types of constants:
126///
127/// * Scalar: 123
128/// * Vector: \[1, 2, 3\]
129/// * Matrix: \[\[1, 2, 3\], \[4, 5, 6\]\]
130#[derive(Debug, Clone, Spanned)]
131pub struct Constant {
132    #[span]
133    pub span: SourceSpan,
134    pub name: Identifier,
135    pub value: ConstantExpr,
136}
137impl Constant {
138    /// Returns a new instance of a [Constant]
139    pub const fn new(span: SourceSpan, name: Identifier, value: ConstantExpr) -> Self {
140        Self { span, name, value }
141    }
142
143    /// Gets the type of the value associated with this constant
144    pub fn ty(&self) -> Type {
145        self.value.ty()
146    }
147}
148impl Eq for Constant {}
149impl PartialEq for Constant {
150    fn eq(&self, other: &Self) -> bool {
151        self.name == other.name && self.value == other.value
152    }
153}
154
155/// Value of a constant. Constants can be of 3 value types:
156///
157/// * Scalar: 123
158/// * Vector: \[1, 2, 3\]
159/// * Matrix: \[\[1, 2, 3\], \[4, 5, 6\]\]
160#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
161pub enum ConstantExpr {
162    Scalar(u64),
163    Vector(Vec<u64>),
164    Matrix(Vec<Vec<u64>>),
165}
166impl ConstantExpr {
167    /// Gets the type of this expression
168    pub fn ty(&self) -> Type {
169        match self {
170            Self::Scalar(_) => Type::Felt,
171            Self::Vector(elems) => Type::Vector(elems.len()),
172            Self::Matrix(rows) => {
173                let num_rows = rows.len();
174                let num_cols = rows.first().unwrap().len();
175                Type::Matrix(num_rows, num_cols)
176            }
177        }
178    }
179
180    /// Returns true if this expression is of aggregate type
181    pub fn is_aggregate(&self) -> bool {
182        matches!(self, Self::Vector(_) | Self::Matrix(_))
183    }
184}
185impl fmt::Display for ConstantExpr {
186    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187        match self {
188            Self::Scalar(value) => write!(f, "{value}"),
189            Self::Vector(values) => {
190                write!(f, "{}", DisplayList(values.as_slice()))
191            }
192            Self::Matrix(values) => write!(
193                f,
194                "{}",
195                DisplayBracketed(DisplayCsv::new(
196                    values.iter().map(|vs| DisplayList(vs.as_slice()))
197                ))
198            ),
199        }
200    }
201}
202
203/// An import declaration
204///
205/// There can be multiple of these in a given module
206#[derive(Debug, Clone)]
207pub enum Import {
208    /// Imports all items from `module`
209    All { module: ModuleId },
210    /// Imports `items` from `module`
211    Partial {
212        module: ModuleId,
213        items: HashSet<Identifier>,
214    },
215}
216impl Import {
217    pub fn module(&self) -> ModuleId {
218        match self {
219            Self::All { module } | Self::Partial { module, .. } => *module,
220        }
221    }
222}
223impl Eq for Import {}
224impl PartialEq for Import {
225    fn eq(&self, other: &Self) -> bool {
226        match (self, other) {
227            (Self::All { module: l }, Self::All { module: r }) => l == r,
228            (
229                Self::Partial {
230                    module: l,
231                    items: ls,
232                },
233                Self::Partial {
234                    module: r,
235                    items: rs,
236                },
237            ) if l == r => ls.difference(rs).next().is_none(),
238            _ => false,
239        }
240    }
241}
242
243/// Represents an item exported from a module
244///
245/// Currently, only constants and functions are exported.
246#[derive(Debug, Copy, Clone, PartialEq, Eq)]
247pub enum Export<'a> {
248    Constant(&'a crate::ast::Constant),
249    Evaluator(&'a EvaluatorFunction),
250}
251impl Export<'_> {
252    pub fn name(&self) -> Identifier {
253        match self {
254            Self::Constant(item) => item.name,
255            Self::Evaluator(item) => item.name,
256        }
257    }
258
259    /// Returns the type of the value associated with this export
260    ///
261    /// NOTE: Evaluator functions have no return value, so they have no type associated.
262    /// For this reason, this function returns `Option<Type>` rather than `Type`.
263    pub fn ty(&self) -> Option<Type> {
264        match self {
265            Self::Constant(item) => Some(item.ty()),
266            Self::Evaluator(_) => None,
267        }
268    }
269}
270
271/// Declaration of a periodic column in an AirScript module.
272///
273/// Periodic columns are columns with repeating cycles of values. The values declared
274/// for the periodic column should be the cycle of values that will be repeated. The
275/// length of the values vector is expected to be a power of 2 with a minimum length of 2,
276/// which is enforced during semantic analysis.
277#[derive(Debug, Clone, Spanned)]
278pub struct PeriodicColumn {
279    #[span]
280    pub span: SourceSpan,
281    pub name: Identifier,
282    pub values: Vec<u64>,
283}
284impl PeriodicColumn {
285    pub const fn new(span: SourceSpan, name: Identifier, values: Vec<u64>) -> Self {
286        Self { span, name, values }
287    }
288
289    pub fn period(&self) -> usize {
290        self.values.len()
291    }
292}
293impl Eq for PeriodicColumn {}
294impl PartialEq for PeriodicColumn {
295    fn eq(&self, other: &Self) -> bool {
296        self.name == other.name && self.values == other.values
297    }
298}
299
300/// Declaration of a public input for an AirScript program.
301///
302/// This declaration is only permitted in the root module.
303///
304/// Public inputs are represented by a named identifier which is used to identify a fixed
305/// size array of length `size`.
306#[derive(Debug, Clone, Spanned)]
307pub enum PublicInput {
308    Vector {
309        #[span]
310        span: SourceSpan,
311        name: Identifier,
312        size: usize,
313    },
314    Table {
315        #[span]
316        span: SourceSpan,
317        name: Identifier,
318        size: usize,
319    },
320}
321impl PublicInput {
322    #[inline]
323    pub fn new_vector(span: SourceSpan, name: Identifier, size: u64) -> Self {
324        Self::Vector {
325            span,
326            name,
327            size: size.try_into().unwrap(),
328        }
329    }
330    #[inline]
331    pub fn new_table(span: SourceSpan, name: Identifier, size: u64) -> Self {
332        Self::Table {
333            span,
334            name,
335            size: size.try_into().unwrap(),
336        }
337    }
338    #[inline]
339    pub fn name(&self) -> Identifier {
340        match self {
341            Self::Vector { name, .. } | Self::Table { name, .. } => *name,
342        }
343    }
344    #[inline]
345    pub fn size(&self) -> usize {
346        match self {
347            Self::Vector { size, .. } | Self::Table { size, .. } => *size,
348        }
349    }
350}
351impl Eq for PublicInput {}
352impl PartialEq for PublicInput {
353    fn eq(&self, other: &Self) -> bool {
354        match (self, other) {
355            (
356                Self::Vector {
357                    name: l, size: ls, ..
358                },
359                Self::Vector {
360                    name: r, size: rs, ..
361                },
362            ) => l == r && ls == rs,
363            (
364                Self::Table {
365                    name: l, size: lc, ..
366                },
367                Self::Table {
368                    name: r, size: rc, ..
369                },
370            ) => l == r && lc == rc,
371            _ => false,
372        }
373    }
374}
375
376/// Evaluator functions take a vector of trace bindings as parameters where each trace binding
377/// represents one or a group of columns in the execution trace that are passed to the evaluator
378/// function, and enforce integrity constraints on those trace columns.
379#[derive(Debug, Clone, Spanned)]
380pub struct EvaluatorFunction {
381    #[span]
382    pub span: SourceSpan,
383    pub name: Identifier,
384    pub params: Vec<TraceSegment>,
385    pub body: Vec<Statement>,
386}
387impl EvaluatorFunction {
388    /// Creates a new function.
389    pub const fn new(
390        span: SourceSpan,
391        name: Identifier,
392        params: Vec<TraceSegment>,
393        body: Vec<Statement>,
394    ) -> Self {
395        Self {
396            span,
397            name,
398            params,
399            body,
400        }
401    }
402}
403impl Eq for EvaluatorFunction {}
404impl PartialEq for EvaluatorFunction {
405    fn eq(&self, other: &Self) -> bool {
406        self.name == other.name && self.params == other.params && self.body == other.body
407    }
408}
409
410/// Functions take a group of expressions as parameters and returns a value.
411///
412/// The result value of a function may be a felt, vector, or a matrix.
413///
414/// NOTE: Functions do not take trace bindings as parameters.
415#[derive(Debug, Clone, Spanned)]
416pub struct Function {
417    #[span]
418    pub span: SourceSpan,
419    pub name: Identifier,
420    pub params: Vec<(Identifier, Type)>,
421    pub return_type: Type,
422    pub body: Vec<Statement>,
423}
424impl Function {
425    /// Creates a new function.
426    pub const fn new(
427        span: SourceSpan,
428        name: Identifier,
429        params: Vec<(Identifier, Type)>,
430        return_type: Type,
431        body: Vec<Statement>,
432    ) -> Self {
433        Self {
434            span,
435            name,
436            params,
437            return_type,
438            body,
439        }
440    }
441
442    pub fn param_types(&self) -> Vec<Type> {
443        self.params.iter().map(|(_, ty)| *ty).collect::<Vec<_>>()
444    }
445}
446
447impl Eq for Function {}
448impl PartialEq for Function {
449    fn eq(&self, other: &Self) -> bool {
450        self.name == other.name
451            && self.params == other.params
452            && self.return_type == other.return_type
453            && self.body == other.body
454    }
455}