1use std::{collections::HashSet, fmt};
27
28use miden_diagnostics::{SourceSpan, Spanned};
29
30use super::*;
31
32#[derive(Debug, PartialEq, Eq, Spanned)]
34pub enum Declaration {
35 Import(Span<Import>),
37 Buses(Span<Vec<Bus>>),
39 Constant(Constant),
41 EvaluatorFunction(EvaluatorFunction),
45 Function(Function),
49 PeriodicColumns(Span<Vec<PeriodicColumn>>),
53 PublicInputs(Span<Vec<PublicInput>>),
58 Trace(Span<Vec<TraceSegment>>),
63 BoundaryConstraints(Span<Vec<Statement>>),
68 IntegrityConstraints(Span<Vec<Statement>>),
73}
74
75#[derive(Debug, Clone, Spanned)]
77pub struct Bus {
78 #[span]
79 pub span: SourceSpan,
80 pub name: Identifier,
81 pub bus_type: BusType,
82}
83impl Bus {
84 pub const fn new(span: SourceSpan, name: Identifier, bus_type: BusType) -> Self {
86 Self {
87 span,
88 name,
89 bus_type,
90 }
91 }
92}
93#[derive(Default, Copy, Hash, Debug, Clone, PartialEq, Eq)]
94pub enum BusType {
95 #[default]
97 Multiset,
98 Logup,
100}
101
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum BusOperator {
104 Insert,
106 Remove,
108}
109impl std::fmt::Display for BusOperator {
110 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
111 match self {
112 Self::Insert => write!(f, "insert"),
113 Self::Remove => write!(f, "remove"),
114 }
115 }
116}
117
118impl Eq for Bus {}
119impl PartialEq for Bus {
120 fn eq(&self, other: &Self) -> bool {
121 self.name == other.name && self.bus_type == other.bus_type
122 }
123}
124
125#[derive(Debug, Clone, Spanned)]
131pub struct Constant {
132 #[span]
133 pub span: SourceSpan,
134 pub name: Identifier,
135 pub value: ConstantExpr,
136}
137impl Constant {
138 pub const fn new(span: SourceSpan, name: Identifier, value: ConstantExpr) -> Self {
140 Self { span, name, value }
141 }
142
143 pub fn ty(&self) -> Type {
145 self.value.ty()
146 }
147}
148impl Eq for Constant {}
149impl PartialEq for Constant {
150 fn eq(&self, other: &Self) -> bool {
151 self.name == other.name && self.value == other.value
152 }
153}
154
155#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
161pub enum ConstantExpr {
162 Scalar(u64),
163 Vector(Vec<u64>),
164 Matrix(Vec<Vec<u64>>),
165}
166impl ConstantExpr {
167 pub fn ty(&self) -> Type {
169 match self {
170 Self::Scalar(_) => Type::Felt,
171 Self::Vector(elems) => Type::Vector(elems.len()),
172 Self::Matrix(rows) => {
173 let num_rows = rows.len();
174 let num_cols = rows.first().unwrap().len();
175 Type::Matrix(num_rows, num_cols)
176 }
177 }
178 }
179
180 pub fn is_aggregate(&self) -> bool {
182 matches!(self, Self::Vector(_) | Self::Matrix(_))
183 }
184}
185impl fmt::Display for ConstantExpr {
186 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
187 match self {
188 Self::Scalar(value) => write!(f, "{value}"),
189 Self::Vector(values) => {
190 write!(f, "{}", DisplayList(values.as_slice()))
191 }
192 Self::Matrix(values) => write!(
193 f,
194 "{}",
195 DisplayBracketed(DisplayCsv::new(
196 values.iter().map(|vs| DisplayList(vs.as_slice()))
197 ))
198 ),
199 }
200 }
201}
202
203#[derive(Debug, Clone)]
207pub enum Import {
208 All { module: ModuleId },
210 Partial {
212 module: ModuleId,
213 items: HashSet<Identifier>,
214 },
215}
216impl Import {
217 pub fn module(&self) -> ModuleId {
218 match self {
219 Self::All { module } | Self::Partial { module, .. } => *module,
220 }
221 }
222}
223impl Eq for Import {}
224impl PartialEq for Import {
225 fn eq(&self, other: &Self) -> bool {
226 match (self, other) {
227 (Self::All { module: l }, Self::All { module: r }) => l == r,
228 (
229 Self::Partial {
230 module: l,
231 items: ls,
232 },
233 Self::Partial {
234 module: r,
235 items: rs,
236 },
237 ) if l == r => ls.difference(rs).next().is_none(),
238 _ => false,
239 }
240 }
241}
242
243#[derive(Debug, Copy, Clone, PartialEq, Eq)]
247pub enum Export<'a> {
248 Constant(&'a crate::ast::Constant),
249 Evaluator(&'a EvaluatorFunction),
250}
251impl Export<'_> {
252 pub fn name(&self) -> Identifier {
253 match self {
254 Self::Constant(item) => item.name,
255 Self::Evaluator(item) => item.name,
256 }
257 }
258
259 pub fn ty(&self) -> Option<Type> {
264 match self {
265 Self::Constant(item) => Some(item.ty()),
266 Self::Evaluator(_) => None,
267 }
268 }
269}
270
271#[derive(Debug, Clone, Spanned)]
278pub struct PeriodicColumn {
279 #[span]
280 pub span: SourceSpan,
281 pub name: Identifier,
282 pub values: Vec<u64>,
283}
284impl PeriodicColumn {
285 pub const fn new(span: SourceSpan, name: Identifier, values: Vec<u64>) -> Self {
286 Self { span, name, values }
287 }
288
289 pub fn period(&self) -> usize {
290 self.values.len()
291 }
292}
293impl Eq for PeriodicColumn {}
294impl PartialEq for PeriodicColumn {
295 fn eq(&self, other: &Self) -> bool {
296 self.name == other.name && self.values == other.values
297 }
298}
299
300#[derive(Debug, Clone, Spanned)]
307pub enum PublicInput {
308 Vector {
309 #[span]
310 span: SourceSpan,
311 name: Identifier,
312 size: usize,
313 },
314 Table {
315 #[span]
316 span: SourceSpan,
317 name: Identifier,
318 size: usize,
319 },
320}
321impl PublicInput {
322 #[inline]
323 pub fn new_vector(span: SourceSpan, name: Identifier, size: u64) -> Self {
324 Self::Vector {
325 span,
326 name,
327 size: size.try_into().unwrap(),
328 }
329 }
330 #[inline]
331 pub fn new_table(span: SourceSpan, name: Identifier, size: u64) -> Self {
332 Self::Table {
333 span,
334 name,
335 size: size.try_into().unwrap(),
336 }
337 }
338 #[inline]
339 pub fn name(&self) -> Identifier {
340 match self {
341 Self::Vector { name, .. } | Self::Table { name, .. } => *name,
342 }
343 }
344 #[inline]
345 pub fn size(&self) -> usize {
346 match self {
347 Self::Vector { size, .. } | Self::Table { size, .. } => *size,
348 }
349 }
350}
351impl Eq for PublicInput {}
352impl PartialEq for PublicInput {
353 fn eq(&self, other: &Self) -> bool {
354 match (self, other) {
355 (
356 Self::Vector {
357 name: l, size: ls, ..
358 },
359 Self::Vector {
360 name: r, size: rs, ..
361 },
362 ) => l == r && ls == rs,
363 (
364 Self::Table {
365 name: l, size: lc, ..
366 },
367 Self::Table {
368 name: r, size: rc, ..
369 },
370 ) => l == r && lc == rc,
371 _ => false,
372 }
373 }
374}
375
376#[derive(Debug, Clone, Spanned)]
380pub struct EvaluatorFunction {
381 #[span]
382 pub span: SourceSpan,
383 pub name: Identifier,
384 pub params: Vec<TraceSegment>,
385 pub body: Vec<Statement>,
386}
387impl EvaluatorFunction {
388 pub const fn new(
390 span: SourceSpan,
391 name: Identifier,
392 params: Vec<TraceSegment>,
393 body: Vec<Statement>,
394 ) -> Self {
395 Self {
396 span,
397 name,
398 params,
399 body,
400 }
401 }
402}
403impl Eq for EvaluatorFunction {}
404impl PartialEq for EvaluatorFunction {
405 fn eq(&self, other: &Self) -> bool {
406 self.name == other.name && self.params == other.params && self.body == other.body
407 }
408}
409
410#[derive(Debug, Clone, Spanned)]
416pub struct Function {
417 #[span]
418 pub span: SourceSpan,
419 pub name: Identifier,
420 pub params: Vec<(Identifier, Type)>,
421 pub return_type: Type,
422 pub body: Vec<Statement>,
423}
424impl Function {
425 pub const fn new(
427 span: SourceSpan,
428 name: Identifier,
429 params: Vec<(Identifier, Type)>,
430 return_type: Type,
431 body: Vec<Statement>,
432 ) -> Self {
433 Self {
434 span,
435 name,
436 params,
437 return_type,
438 body,
439 }
440 }
441
442 pub fn param_types(&self) -> Vec<Type> {
443 self.params.iter().map(|(_, ty)| *ty).collect::<Vec<_>>()
444 }
445}
446
447impl Eq for Function {}
448impl PartialEq for Function {
449 fn eq(&self, other: &Self) -> bool {
450 self.name == other.name
451 && self.params == other.params
452 && self.return_type == other.return_type
453 && self.body == other.body
454 }
455}