use std::{collections::HashSet, fmt};
use miden_diagnostics::{SourceSpan, Spanned};
use super::*;
#[derive(Debug, PartialEq, Eq, Spanned)]
pub enum Declaration {
Import(Span<Import>),
Buses(Span<Vec<Bus>>),
Constant(Constant),
EvaluatorFunction(EvaluatorFunction),
Function(Function),
PeriodicColumns(Span<Vec<PeriodicColumn>>),
PublicInputs(Span<Vec<PublicInput>>),
Trace(Span<Vec<TraceSegment>>),
BoundaryConstraints(Span<Vec<Statement>>),
IntegrityConstraints(Span<Vec<Statement>>),
}
#[derive(Debug, Clone, Spanned)]
pub struct Bus {
#[span]
pub span: SourceSpan,
pub name: Identifier,
pub bus_type: BusType,
}
impl Bus {
pub const fn new(span: SourceSpan, name: Identifier, bus_type: BusType) -> Self {
Self {
span,
name,
bus_type,
}
}
}
#[derive(Default, Copy, Hash, Debug, Clone, PartialEq, Eq)]
pub enum BusType {
#[default]
Multiset,
Logup,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BusOperator {
Insert,
Remove,
}
impl std::fmt::Display for BusOperator {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Insert => write!(f, "insert"),
Self::Remove => write!(f, "remove"),
}
}
}
impl Eq for Bus {}
impl PartialEq for Bus {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.bus_type == other.bus_type
}
}
#[derive(Debug, Clone, Spanned)]
pub struct Constant {
#[span]
pub span: SourceSpan,
pub name: Identifier,
pub value: ConstantExpr,
}
impl Constant {
pub const fn new(span: SourceSpan, name: Identifier, value: ConstantExpr) -> Self {
Self { span, name, value }
}
pub fn ty(&self) -> Type {
self.value.ty()
}
}
impl Eq for Constant {}
impl PartialEq for Constant {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.value == other.value
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum ConstantExpr {
Scalar(u64),
Vector(Vec<u64>),
Matrix(Vec<Vec<u64>>),
}
impl ConstantExpr {
pub fn ty(&self) -> Type {
match self {
Self::Scalar(_) => Type::Felt,
Self::Vector(elems) => Type::Vector(elems.len()),
Self::Matrix(rows) => {
let num_rows = rows.len();
let num_cols = rows.first().unwrap().len();
Type::Matrix(num_rows, num_cols)
}
}
}
pub fn is_aggregate(&self) -> bool {
matches!(self, Self::Vector(_) | Self::Matrix(_))
}
}
impl fmt::Display for ConstantExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Scalar(value) => write!(f, "{value}"),
Self::Vector(values) => {
write!(f, "{}", DisplayList(values.as_slice()))
}
Self::Matrix(values) => write!(
f,
"{}",
DisplayBracketed(DisplayCsv::new(
values.iter().map(|vs| DisplayList(vs.as_slice()))
))
),
}
}
}
#[derive(Debug, Clone)]
pub enum Import {
All { module: ModuleId },
Partial {
module: ModuleId,
items: HashSet<Identifier>,
},
}
impl Import {
pub fn module(&self) -> ModuleId {
match self {
Self::All { module } | Self::Partial { module, .. } => *module,
}
}
}
impl Eq for Import {}
impl PartialEq for Import {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::All { module: l }, Self::All { module: r }) => l == r,
(
Self::Partial {
module: l,
items: ls,
},
Self::Partial {
module: r,
items: rs,
},
) if l == r => ls.difference(rs).next().is_none(),
_ => false,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum Export<'a> {
Constant(&'a crate::ast::Constant),
Evaluator(&'a EvaluatorFunction),
}
impl Export<'_> {
pub fn name(&self) -> Identifier {
match self {
Self::Constant(item) => item.name,
Self::Evaluator(item) => item.name,
}
}
pub fn ty(&self) -> Option<Type> {
match self {
Self::Constant(item) => Some(item.ty()),
Self::Evaluator(_) => None,
}
}
}
#[derive(Debug, Clone, Spanned)]
pub struct PeriodicColumn {
#[span]
pub span: SourceSpan,
pub name: Identifier,
pub values: Vec<u64>,
}
impl PeriodicColumn {
pub const fn new(span: SourceSpan, name: Identifier, values: Vec<u64>) -> Self {
Self { span, name, values }
}
pub fn period(&self) -> usize {
self.values.len()
}
}
impl Eq for PeriodicColumn {}
impl PartialEq for PeriodicColumn {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.values == other.values
}
}
#[derive(Debug, Clone, Spanned)]
pub enum PublicInput {
Vector {
#[span]
span: SourceSpan,
name: Identifier,
size: usize,
},
Table {
#[span]
span: SourceSpan,
name: Identifier,
size: usize,
},
}
impl PublicInput {
#[inline]
pub fn new_vector(span: SourceSpan, name: Identifier, size: u64) -> Self {
Self::Vector {
span,
name,
size: size.try_into().unwrap(),
}
}
#[inline]
pub fn new_table(span: SourceSpan, name: Identifier, size: u64) -> Self {
Self::Table {
span,
name,
size: size.try_into().unwrap(),
}
}
#[inline]
pub fn name(&self) -> Identifier {
match self {
Self::Vector { name, .. } | Self::Table { name, .. } => *name,
}
}
#[inline]
pub fn size(&self) -> usize {
match self {
Self::Vector { size, .. } | Self::Table { size, .. } => *size,
}
}
}
impl Eq for PublicInput {}
impl PartialEq for PublicInput {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(
Self::Vector {
name: l, size: ls, ..
},
Self::Vector {
name: r, size: rs, ..
},
) => l == r && ls == rs,
(
Self::Table {
name: l, size: lc, ..
},
Self::Table {
name: r, size: rc, ..
},
) => l == r && lc == rc,
_ => false,
}
}
}
#[derive(Debug, Clone, Spanned)]
pub struct EvaluatorFunction {
#[span]
pub span: SourceSpan,
pub name: Identifier,
pub params: Vec<TraceSegment>,
pub body: Vec<Statement>,
}
impl EvaluatorFunction {
pub const fn new(
span: SourceSpan,
name: Identifier,
params: Vec<TraceSegment>,
body: Vec<Statement>,
) -> Self {
Self {
span,
name,
params,
body,
}
}
}
impl Eq for EvaluatorFunction {}
impl PartialEq for EvaluatorFunction {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.params == other.params && self.body == other.body
}
}
#[derive(Debug, Clone, Spanned)]
pub struct Function {
#[span]
pub span: SourceSpan,
pub name: Identifier,
pub params: Vec<(Identifier, Type)>,
pub return_type: Type,
pub body: Vec<Statement>,
}
impl Function {
pub const fn new(
span: SourceSpan,
name: Identifier,
params: Vec<(Identifier, Type)>,
return_type: Type,
body: Vec<Statement>,
) -> Self {
Self {
span,
name,
params,
return_type,
body,
}
}
pub fn param_types(&self) -> Vec<Type> {
self.params.iter().map(|(_, ty)| *ty).collect::<Vec<_>>()
}
}
impl Eq for Function {}
impl PartialEq for Function {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.params == other.params
&& self.return_type == other.return_type
&& self.body == other.body
}
}