use std::{convert::AsRef, fmt};
use miden_diagnostics::{SourceSpan, Span, Spanned};
use crate::symbols::Symbol;
use super::*;
pub type Range = std::ops::Range<usize>;
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Spanned)]
pub struct Identifier(pub Span<Symbol>);
impl Identifier {
pub fn new(span: SourceSpan, name: Symbol) -> Self {
Self(Span::new(span, name))
}
pub fn name(&self) -> Symbol {
self.0.item
}
#[inline]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
pub fn is_uppercase(&self) -> bool {
self.0.as_str().chars().all(char::is_uppercase)
}
pub fn is_generated(&self) -> bool {
self.0.as_str().starts_with('%')
}
pub fn is_special(&self) -> bool {
self.0.as_str().starts_with('$')
}
}
impl PartialEq<&str> for Identifier {
#[inline]
fn eq(&self, other: &&str) -> bool {
self.0.item == *other
}
}
impl PartialEq<&Identifier> for Identifier {
#[inline]
fn eq(&self, other: &&Self) -> bool {
self == *other
}
}
impl fmt::Debug for Identifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("Identifier")
.field(&format!("{}", &self.0.item))
.finish()
}
}
impl fmt::Display for Identifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", &self.0)
}
}
impl From<ResolvableIdentifier> for Identifier {
fn from(id: ResolvableIdentifier) -> Self {
match id {
ResolvableIdentifier::Local(id) => id,
ResolvableIdentifier::Global(id) => id,
ResolvableIdentifier::Resolved(qid) => qid.item.id(),
ResolvableIdentifier::Unresolved(nid) => nid.id(),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)]
pub enum NamespacedIdentifier {
Function(#[span] Identifier),
Binding(#[span] Identifier),
}
impl NamespacedIdentifier {
pub fn id(&self) -> Identifier {
match self {
Self::Function(ident) | Self::Binding(ident) => *ident,
}
}
}
impl AsRef<Identifier> for NamespacedIdentifier {
fn as_ref(&self) -> &Identifier {
match self {
Self::Function(ident) | Self::Binding(ident) => ident,
}
}
}
impl From<ResolvableIdentifier> for NamespacedIdentifier {
fn from(id: ResolvableIdentifier) -> Self {
match id {
ResolvableIdentifier::Local(id) => Self::Binding(id),
ResolvableIdentifier::Global(id) => Self::Binding(id),
ResolvableIdentifier::Resolved(qid) => qid.item,
ResolvableIdentifier::Unresolved(nid) => nid,
}
}
}
impl fmt::Display for NamespacedIdentifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(self.as_ref(), f)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)]
pub struct QualifiedIdentifier {
pub module: ModuleId,
#[span]
pub item: NamespacedIdentifier,
}
impl QualifiedIdentifier {
pub const fn new(module: ModuleId, item: NamespacedIdentifier) -> Self {
Self { module, item }
}
pub const fn id(&self) -> NamespacedIdentifier {
self.item
}
#[inline]
pub fn name(&self) -> Symbol {
self.as_ref().name()
}
pub fn is_builtin(&self) -> bool {
use crate::symbols;
if self.module.name() == "$builtin" {
match self.item {
NamespacedIdentifier::Function(id) => {
matches!(id.name(), symbols::Sum | symbols::Prod)
}
_ => false,
}
} else {
false
}
}
}
impl AsRef<Identifier> for QualifiedIdentifier {
#[inline]
fn as_ref(&self) -> &Identifier {
self.item.as_ref()
}
}
impl fmt::Display for QualifiedIdentifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}::{}", &self.module, &self.item)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Spanned)]
pub enum ResolvableIdentifier {
Local(#[span] Identifier),
Global(#[span] Identifier),
Resolved(#[span] QualifiedIdentifier),
Unresolved(#[span] NamespacedIdentifier),
}
impl ResolvableIdentifier {
#[inline]
pub fn is_resolved(&self) -> bool {
matches!(self, Self::Local(_) | Self::Global(_) | Self::Resolved(_))
}
pub fn is_local(&self) -> bool {
matches!(self, Self::Local(_))
}
pub fn is_global(&self) -> bool {
matches!(self, Self::Global(_))
}
pub fn is_builtin(&self) -> bool {
match self {
Self::Resolved(qid) => qid.is_builtin(),
_ => false,
}
}
pub fn module(&self) -> Option<ModuleId> {
match self {
Self::Resolved(qid) => Some(*qid.as_ref()),
_ => None,
}
}
#[inline]
pub fn namespaced(&self) -> NamespacedIdentifier {
(*self).into()
}
#[inline]
pub fn resolved(&self) -> Option<QualifiedIdentifier> {
match self {
Self::Resolved(qid) => Some(*qid),
_ => None,
}
}
}
impl AsRef<Identifier> for ResolvableIdentifier {
#[inline]
fn as_ref(&self) -> &Identifier {
match self {
Self::Local(id) => id,
Self::Global(id) => id,
Self::Resolved(qid) => qid.item.as_ref(),
Self::Unresolved(nid) => nid.as_ref(),
}
}
}
impl fmt::Display for ResolvableIdentifier {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Local(id) => write!(f, "{id}"),
Self::Global(id) => write!(f, "{id}"),
Self::Resolved(qid) => write!(f, "{qid}"),
Self::Unresolved(nid) => write!(f, "{nid}"),
}
}
}
#[derive(Clone, PartialEq, Eq, Spanned)]
pub enum Expr {
Const(Span<ConstantExpr>),
Range(RangeExpr),
Vector(Span<Vec<Expr>>),
Matrix(Span<Vec<Vec<ScalarExpr>>>),
SymbolAccess(SymbolAccess),
Binary(BinaryExpr),
Call(Call),
ListComprehension(ListComprehension),
Let(Box<Let>),
BusOperation(BusOperation),
Null(Span<()>),
Unconstrained(Span<()>),
}
impl Expr {
pub fn is_constant(&self) -> bool {
match self {
Self::Const(_) => true,
Self::Range(range) => range.is_constant(),
_ => false,
}
}
pub fn ty(&self) -> Option<Type> {
match self {
Self::Const(constant) => Some(constant.ty()),
Self::Range(range) => range.ty(),
Self::Vector(vector) => match vector.first().and_then(|e| e.ty()) {
Some(Type::Felt) => Some(Type::Vector(vector.len())),
Some(Type::Vector(n)) => Some(Type::Matrix(vector.len(), n)),
Some(_) => None,
None => Some(Type::Vector(0)),
},
Self::Matrix(matrix) => {
let rows = matrix.len();
let cols = matrix[0].len();
Some(Type::Matrix(rows, cols))
}
Self::SymbolAccess(access) => access.ty,
Self::Binary(_) => Some(Type::Felt),
Self::Call(call) => call.ty,
Self::ListComprehension(lc) => lc.ty,
Self::Let(let_expr) => let_expr.ty(),
Self::BusOperation(_) | Self::Null(_) | Self::Unconstrained(_) => Some(Type::Felt),
}
}
}
impl fmt::Debug for Expr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Const(expr) => f.debug_tuple("Const").field(&expr.item).finish(),
Self::Range(expr) => f.debug_tuple("Range").field(&expr).finish(),
Self::Vector(expr) => f.debug_tuple("Vector").field(&expr.item).finish(),
Self::Matrix(expr) => f.debug_tuple("Matrix").field(&expr.item).finish(),
Self::SymbolAccess(expr) => f.debug_tuple("SymbolAccess").field(expr).finish(),
Self::Binary(expr) => f.debug_tuple("Binary").field(expr).finish(),
Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(),
Self::ListComprehension(expr) => {
f.debug_tuple("ListComprehension").field(expr).finish()
}
Self::Let(let_expr) => write!(f, "{let_expr:#?}"),
Self::BusOperation(expr) => f.debug_tuple("BusOp").field(expr).finish(),
Self::Null(expr) => f.debug_tuple("Null").field(expr).finish(),
Self::Unconstrained(expr) => f.debug_tuple("Unconstrained").field(expr).finish(),
}
}
}
impl fmt::Display for Expr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Const(expr) => write!(f, "{}", &expr),
Self::Range(range) => write!(f, "{range}"),
Self::Vector(expr) => write!(f, "{}", DisplayList(expr.as_slice())),
Self::Matrix(expr) => {
f.write_str("[")?;
for (i, col) in expr.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "{}", DisplayList(col.as_slice()))?;
}
f.write_str("]")
}
Self::SymbolAccess(expr) => write!(f, "{expr}"),
Self::Binary(expr) => write!(f, "{expr}"),
Self::Call(expr) => write!(f, "{expr}"),
Self::ListComprehension(expr) => write!(f, "{}", DisplayBracketed(expr)),
Self::Let(let_expr) => {
let display = DisplayLet {
let_expr,
indent: 0,
in_expr_position: true,
};
write!(f, "{display}")
}
Self::BusOperation(expr) => write!(f, "{expr}"),
Self::Null(_expr) => write!(f, "null"),
Self::Unconstrained(_expr) => write!(f, "unconstrained"),
}
}
}
impl From<SymbolAccess> for Expr {
#[inline]
fn from(expr: SymbolAccess) -> Self {
Self::SymbolAccess(expr)
}
}
impl From<BinaryExpr> for Expr {
#[inline]
fn from(expr: BinaryExpr) -> Self {
Self::Binary(expr)
}
}
impl From<Call> for Expr {
#[inline]
fn from(expr: Call) -> Self {
Self::Call(expr)
}
}
impl From<BusOperation> for Expr {
#[inline]
fn from(expr: BusOperation) -> Self {
Self::BusOperation(expr)
}
}
impl From<ListComprehension> for Expr {
#[inline]
fn from(expr: ListComprehension) -> Self {
Self::ListComprehension(expr)
}
}
impl TryFrom<Let> for Expr {
type Error = InvalidExprError;
fn try_from(expr: Let) -> Result<Self, Self::Error> {
if expr.ty().is_some() {
Ok(Self::Let(Box::new(expr)))
} else {
Err(InvalidExprError::InvalidLetExpr(expr.span()))
}
}
}
impl TryFrom<ScalarExpr> for Expr {
type Error = InvalidExprError;
#[inline]
fn try_from(expr: ScalarExpr) -> Result<Self, Self::Error> {
match expr {
ScalarExpr::Const(spanned) => Ok(Self::Const(Span::new(
spanned.span(),
ConstantExpr::Scalar(spanned.item),
))),
ScalarExpr::SymbolAccess(access) => Ok(Self::SymbolAccess(access)),
ScalarExpr::Binary(expr) => Ok(Self::Binary(expr)),
ScalarExpr::Call(expr) => Ok(Self::Call(expr)),
ScalarExpr::BoundedSymbolAccess(_) => {
Err(InvalidExprError::BoundedSymbolAccess(expr.span()))
}
ScalarExpr::Let(expr) => Ok(Self::Let(expr)),
ScalarExpr::BusOperation(expr) => Ok(Self::BusOperation(expr)),
ScalarExpr::Null(spanned) => Ok(Self::Null(spanned)),
ScalarExpr::Unconstrained(spanned) => Ok(Self::Unconstrained(spanned)),
}
}
}
impl TryFrom<Statement> for Expr {
type Error = InvalidExprError;
fn try_from(stmt: Statement) -> Result<Self, Self::Error> {
match stmt {
Statement::Let(let_expr) => Ok(Self::Let(Box::new(let_expr))),
Statement::Expr(expr) => Ok(expr),
_ => Err(InvalidExprError::NotAnExpr(stmt.span())),
}
}
}
#[derive(Clone, PartialEq, Eq, Spanned)]
pub enum ScalarExpr {
Const(Span<u64>),
SymbolAccess(SymbolAccess),
BoundedSymbolAccess(BoundedSymbolAccess),
Binary(BinaryExpr),
Call(Call),
Let(Box<Let>),
BusOperation(BusOperation),
Null(Span<()>),
Unconstrained(Span<()>),
}
impl ScalarExpr {
pub fn is_constant(&self) -> bool {
matches!(self, Self::Const(_))
}
pub fn has_block_like_expansion(&self) -> bool {
match self {
Self::Binary(expr) => expr.has_block_like_expansion(),
Self::Call(_) | Self::Let(_) => true,
_ => false,
}
}
pub fn ty(&self) -> Result<Option<Type>, SourceSpan> {
match self {
Self::Const(_) => Ok(Some(Type::Felt)),
Self::SymbolAccess(sym) => Ok(sym.ty),
Self::BoundedSymbolAccess(sym) => Ok(sym.column.ty),
Self::Binary(expr) => match (expr.lhs.ty()?, expr.rhs.ty()?) {
(None, _) | (_, None) => Ok(None),
(Some(lty), Some(rty)) if lty == rty => Ok(Some(lty)),
_ => Err(expr.span()),
},
Self::Call(expr) => Ok(expr.ty),
Self::Let(expr) => Ok(expr.ty()),
Self::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => {
Ok(Some(Type::Felt))
}
}
}
}
impl TryFrom<Expr> for ScalarExpr {
type Error = InvalidExprError;
fn try_from(expr: Expr) -> Result<Self, Self::Error> {
match expr {
Expr::Const(constant) => {
let span = constant.span();
match constant.item {
ConstantExpr::Scalar(v) => Ok(Self::Const(Span::new(span, v))),
_ => Err(InvalidExprError::InvalidScalarExpr(span)),
}
}
Expr::SymbolAccess(sym) => Ok(Self::SymbolAccess(sym)),
Expr::Binary(bin) => Ok(Self::Binary(bin)),
Expr::Call(call) => Ok(Self::Call(call)),
Expr::Let(let_expr) => {
if let_expr.ty().is_none() {
Err(InvalidExprError::InvalidScalarExpr(let_expr.span()))
} else {
Ok(Self::Let(let_expr))
}
}
invalid => Err(InvalidExprError::InvalidScalarExpr(invalid.span())),
}
}
}
impl TryFrom<Statement> for ScalarExpr {
type Error = InvalidExprError;
fn try_from(stmt: Statement) -> Result<Self, Self::Error> {
match stmt {
Statement::Let(let_expr) => Self::try_from(Expr::Let(Box::new(let_expr))),
Statement::Expr(expr) => Self::try_from(expr),
stmt => Err(InvalidExprError::InvalidScalarExpr(stmt.span())),
}
}
}
impl From<u64> for ScalarExpr {
fn from(value: u64) -> Self {
Self::Const(Span::new(SourceSpan::UNKNOWN, value))
}
}
impl fmt::Debug for ScalarExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Const(i) => f.debug_tuple("Const").field(&i.item).finish(),
Self::SymbolAccess(expr) => f.debug_tuple("SymbolAccess").field(expr).finish(),
Self::BoundedSymbolAccess(expr) => {
f.debug_tuple("BoundedSymbolAccess").field(expr).finish()
}
Self::Binary(expr) => f.debug_tuple("Binary").field(expr).finish(),
Self::Call(expr) => f.debug_tuple("Call").field(expr).finish(),
Self::Let(expr) => write!(f, "{expr:#?}"),
Self::BusOperation(expr) => f.debug_tuple("BusOp").field(expr).finish(),
Self::Null(expr) => f.debug_tuple("Null").field(expr).finish(),
Self::Unconstrained(expr) => f.debug_tuple("Unconstrained").field(expr).finish(),
}
}
}
impl fmt::Display for ScalarExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Const(value) => write!(f, "{value}"),
Self::SymbolAccess(expr) => write!(f, "{expr}"),
Self::BoundedSymbolAccess(expr) => write!(f, "{}.{}", &expr.column, &expr.boundary),
Self::Binary(expr) => write!(f, "{expr}"),
Self::Call(call) => write!(f, "{call}"),
Self::Let(let_expr) => {
let display = DisplayLet {
let_expr,
indent: 0,
in_expr_position: true,
};
write!(f, "{display}")
}
Self::BusOperation(expr) => write!(f, "{expr}"),
Self::Null(_value) => write!(f, "null"),
Self::Unconstrained(_value) => write!(f, "unconstrained"),
}
}
}
#[derive(Clone, Spanned, Debug)]
pub struct ConstSymbolAccess {
#[span]
pub span: SourceSpan,
pub name: ResolvableIdentifier,
pub ty: Option<Type>,
}
impl ConstSymbolAccess {
pub fn new(span: SourceSpan, name: Identifier) -> Self {
Self {
span,
name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(name)),
ty: None,
}
}
}
impl Eq for ConstSymbolAccess {}
impl PartialEq for ConstSymbolAccess {
fn eq(&self, other: &Self) -> bool {
self.name.eq(&other.name) && self.ty.eq(&other.ty)
}
}
impl std::hash::Hash for ConstSymbolAccess {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.ty.hash(state);
}
}
impl fmt::Display for ConstSymbolAccess {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", &self.name)
}
}
#[derive(Debug, Clone, Spanned)]
pub struct RangeExpr {
#[span]
pub span: SourceSpan,
pub start: RangeBound,
pub end: RangeBound,
}
impl TryFrom<&RangeExpr> for Range {
type Error = InvalidExprError;
#[inline]
fn try_from(expr: &RangeExpr) -> Result<Self, InvalidExprError> {
match (&expr.start, &expr.end) {
(RangeBound::Const(lhs), RangeBound::Const(rhs)) => Ok(lhs.item..rhs.item),
_ => Err(InvalidExprError::NonConstantRangeExpr(expr.span)),
}
}
}
impl RangeExpr {
pub fn is_constant(&self) -> bool {
self.start.is_constant() && self.end.is_constant()
}
pub fn to_slice_range(&self) -> Range {
self.try_into()
.expect("attempted to convert non-constant range expression to constant")
}
pub fn ty(&self) -> Option<Type> {
match (&self.start, &self.end) {
(RangeBound::Const(start), RangeBound::Const(end)) => {
Some(Type::Vector(end.item.abs_diff(start.item)))
}
_ => None,
}
}
}
impl From<Range> for RangeExpr {
fn from(range: Range) -> Self {
Self {
span: SourceSpan::default(),
start: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.start)),
end: RangeBound::Const(Span::new(SourceSpan::UNKNOWN, range.end)),
}
}
}
impl Eq for RangeExpr {}
impl PartialEq for RangeExpr {
fn eq(&self, other: &Self) -> bool {
self.start.eq(&other.start) && self.end.eq(&other.end)
}
}
impl std::hash::Hash for RangeExpr {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.start.hash(state);
self.end.hash(state);
}
}
impl fmt::Display for RangeExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}..{}", &self.start, &self.end)
}
}
#[derive(Hash, Clone, Spanned, PartialEq, Eq, Debug)]
pub enum RangeBound {
SymbolAccess(ConstSymbolAccess),
Const(Span<usize>),
}
impl RangeBound {
pub fn is_constant(&self) -> bool {
matches!(self, Self::Const(_))
}
}
impl From<Identifier> for RangeBound {
fn from(name: Identifier) -> Self {
Self::SymbolAccess(ConstSymbolAccess::new(name.span(), name))
}
}
impl From<usize> for RangeBound {
fn from(constant: usize) -> Self {
Self::Const(Span::new(SourceSpan::UNKNOWN, constant))
}
}
impl fmt::Display for RangeBound {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::SymbolAccess(sym) => write!(f, "{sym}"),
Self::Const(constant) => write!(f, "{constant}"),
}
}
}
#[derive(Clone, Spanned)]
pub struct BinaryExpr {
#[span]
pub span: SourceSpan,
pub op: BinaryOp,
pub lhs: Box<ScalarExpr>,
pub rhs: Box<ScalarExpr>,
}
impl BinaryExpr {
pub fn new(span: SourceSpan, op: BinaryOp, lhs: ScalarExpr, rhs: ScalarExpr) -> Self {
Self {
span,
op,
lhs: Box::new(lhs),
rhs: Box::new(rhs),
}
}
#[inline]
pub fn has_block_like_expansion(&self) -> bool {
self.lhs.has_block_like_expansion() || self.rhs.has_block_like_expansion()
}
}
impl Eq for BinaryExpr {}
impl PartialEq for BinaryExpr {
fn eq(&self, other: &Self) -> bool {
self.op == other.op && self.lhs == other.lhs && self.rhs == other.rhs
}
}
impl fmt::Debug for BinaryExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BinaryExpr")
.field("op", &self.op)
.field("lhs", self.lhs.as_ref())
.field("rhs", self.rhs.as_ref())
.finish()
}
}
impl fmt::Display for BinaryExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{} {} {}", &self.lhs, &self.op, &self.rhs)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum BinaryOp {
Add,
Sub,
Mul,
Exp,
Eq,
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Add => f.write_str("+"),
Self::Sub => f.write_str("-"),
Self::Mul => f.write_str("*"),
Self::Exp => f.write_str("^"),
Self::Eq => f.write_str("="),
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Default, Eq)]
pub enum Boundary {
#[default]
First,
Last,
}
impl fmt::Display for Boundary {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self {
Self::First => write!(f, "first"),
Self::Last => write!(f, "last"),
}
}
}
#[derive(Hash, Debug, Clone, Eq, PartialEq, Default)]
pub enum AccessType {
#[default]
Default,
Slice(RangeExpr),
Index(usize),
Matrix(usize, usize),
}
impl fmt::Display for AccessType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Default => write!(f, "direct reference by name"),
Self::Slice(range) => write!(
f,
"slice of elements at indices {}..{}",
range.start, range.end
),
Self::Index(idx) => write!(f, "reference to element at index {idx}"),
Self::Matrix(row, col) => write!(f, "reference to value in matrix at [{row}][{col}]"),
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum InvalidAccessError {
#[error("attempted to access undefined variable")]
UndefinedVariable,
#[error("attempted to access a function as a variable")]
InvalidBinding,
#[error("attempted to take a slice of a scalar value")]
SliceOfScalar,
#[error("attempted to take a slice of a matrix value")]
SliceOfMatrix,
#[error("attempted to index into a scalar value")]
IndexIntoScalar,
#[error("attempted to access an index which is out of bounds")]
IndexOutOfBounds,
}
#[derive(Clone, Spanned)]
pub struct SymbolAccess {
#[span]
pub span: SourceSpan,
pub name: ResolvableIdentifier,
pub access_type: AccessType,
pub offset: usize,
pub ty: Option<Type>,
}
impl SymbolAccess {
pub const fn new(
span: SourceSpan,
name: Identifier,
access_type: AccessType,
offset: usize,
) -> Self {
Self {
span,
name: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(name)),
access_type,
offset,
ty: None,
}
}
pub fn access(&self, access_type: AccessType) -> Result<Self, InvalidAccessError> {
match &self.access_type {
AccessType::Default => self.access_default(access_type),
AccessType::Slice(base_range) => {
self.access_slice(base_range.to_slice_range(), access_type)
}
AccessType::Index(base_idx) => self.access_index(*base_idx, access_type),
AccessType::Matrix(_, _) => match access_type {
AccessType::Default => Ok(self.clone()),
_ => Err(InvalidAccessError::IndexIntoScalar),
},
}
}
fn access_default(&self, access_type: AccessType) -> Result<Self, InvalidAccessError> {
let ty = self.ty.unwrap();
match access_type {
AccessType::Default => Ok(self.clone()),
AccessType::Index(idx) => match ty {
Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
Type::Vector(_) => Ok(Self {
access_type: AccessType::Index(idx),
ty: Some(Type::Felt),
..self.clone()
}),
Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
Type::Matrix(_, cols) => Ok(Self {
access_type: AccessType::Index(idx),
ty: Some(Type::Vector(cols)),
..self.clone()
}),
},
AccessType::Slice(range) => {
let slice_range = range.to_slice_range();
let rlen = slice_range.end - slice_range.start;
match ty {
Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
Type::Vector(len) if slice_range.end > len => {
Err(InvalidAccessError::IndexOutOfBounds)
}
Type::Vector(_) => Ok(Self {
access_type: AccessType::Slice(range),
ty: Some(Type::Vector(rlen)),
..self.clone()
}),
Type::Matrix(rows, _) if slice_range.end > rows => {
Err(InvalidAccessError::IndexOutOfBounds)
}
Type::Matrix(_, cols) => Ok(Self {
access_type: AccessType::Slice(range),
ty: Some(Type::Matrix(rlen, cols)),
..self.clone()
}),
}
}
AccessType::Matrix(row, col) => match ty {
Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar),
Type::Matrix(rows, cols) if row >= rows || col >= cols => {
Err(InvalidAccessError::IndexOutOfBounds)
}
Type::Matrix(_, _) => Ok(Self {
access_type: AccessType::Matrix(row, col),
ty: Some(Type::Felt),
..self.clone()
}),
},
}
}
fn access_slice(
&self,
base_range: Range,
access_type: AccessType,
) -> Result<Self, InvalidAccessError> {
let ty = self.ty.unwrap();
match access_type {
AccessType::Default => Ok(self.clone()),
AccessType::Index(idx) => match ty {
Type::Felt => unreachable!(),
Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
Type::Vector(_) => Ok(Self {
access_type: AccessType::Index(base_range.start + idx),
ty: Some(Type::Felt),
..self.clone()
}),
Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
Type::Matrix(_, cols) => Ok(Self {
access_type: AccessType::Index(base_range.start + idx),
ty: Some(Type::Vector(cols)),
..self.clone()
}),
},
AccessType::Slice(range) => {
let slice_range = range.to_slice_range();
let blen = base_range.end - base_range.start;
let rlen = slice_range.len();
let start = base_range.start + slice_range.start;
let end = slice_range.start + slice_range.end;
let shifted = RangeExpr {
span: range.span,
start: RangeBound::Const(Span::new(range.start.span(), start)),
end: RangeBound::Const(Span::new(range.end.span(), end)),
};
match ty {
Type::Felt => unreachable!(),
Type::Vector(_) if slice_range.end > blen => {
Err(InvalidAccessError::IndexOutOfBounds)
}
Type::Vector(_) => Ok(Self {
access_type: AccessType::Slice(shifted),
ty: Some(Type::Vector(rlen)),
..self.clone()
}),
Type::Matrix(rows, _) if slice_range.end > rows => {
Err(InvalidAccessError::IndexOutOfBounds)
}
Type::Matrix(_, cols) => Ok(Self {
access_type: AccessType::Slice(shifted),
ty: Some(Type::Matrix(rlen, cols)),
..self.clone()
}),
}
}
AccessType::Matrix(row, col) => match ty {
Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar),
Type::Matrix(rows, cols) if row >= rows || col >= cols => {
Err(InvalidAccessError::IndexOutOfBounds)
}
Type::Matrix(_, _) => Ok(Self {
access_type: AccessType::Matrix(row, col),
ty: Some(Type::Felt),
..self.clone()
}),
},
}
}
fn access_index(
&self,
base_idx: usize,
access_type: AccessType,
) -> Result<Self, InvalidAccessError> {
let ty = self.ty.unwrap();
match access_type {
AccessType::Default => Ok(self.clone()),
AccessType::Index(idx) => match ty {
Type::Felt => Err(InvalidAccessError::IndexIntoScalar),
Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds),
Type::Vector(_) => Ok(Self {
access_type: AccessType::Matrix(base_idx, idx),
ty: Some(Type::Felt),
..self.clone()
}),
Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds),
Type::Matrix(_, cols) => Ok(Self {
access_type: AccessType::Matrix(base_idx, idx),
ty: Some(Type::Vector(cols)),
..self.clone()
}),
},
AccessType::Slice(_) => Err(InvalidAccessError::SliceOfMatrix),
AccessType::Matrix(_, _) => Err(InvalidAccessError::IndexIntoScalar),
}
}
}
impl Eq for SymbolAccess {}
impl PartialEq for SymbolAccess {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.access_type == other.access_type
&& self.offset == other.offset
&& self.ty == other.ty
}
}
impl fmt::Debug for SymbolAccess {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("SymbolAccess")
.field("name", &self.name)
.field("access_type", &self.access_type)
.field("offset", &self.offset)
.field("ty", &self.ty)
.finish()
}
}
impl fmt::Display for SymbolAccess {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.name)?;
match &self.access_type {
AccessType::Default => (),
AccessType::Index(idx) => write!(f, "[{idx}]")?,
AccessType::Slice(range) => write!(f, "[{}..{}]", range.start, range.end)?,
AccessType::Matrix(row, col) => write!(f, "[{row}][{col}]")?,
}
for _ in 0..self.offset {
f.write_str("'")?;
}
Ok(())
}
}
#[derive(Clone, Spanned)]
pub struct BoundedSymbolAccess {
#[span]
pub span: SourceSpan,
pub boundary: Boundary,
pub column: SymbolAccess,
}
impl BoundedSymbolAccess {
pub const fn new(span: SourceSpan, column: SymbolAccess, boundary: Boundary) -> Self {
Self {
span,
boundary,
column,
}
}
}
impl Eq for BoundedSymbolAccess {}
impl PartialEq for BoundedSymbolAccess {
fn eq(&self, other: &Self) -> bool {
self.boundary == other.boundary && self.column == other.column
}
}
impl fmt::Debug for BoundedSymbolAccess {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BoundedSymbolAccess")
.field("boundary", &self.boundary)
.field("column", &self.column)
.finish()
}
}
pub type ComprehensionContext = Vec<(Identifier, Expr)>;
#[derive(Clone, Spanned)]
pub struct ListComprehension {
#[span]
pub span: SourceSpan,
pub bindings: Vec<Identifier>,
pub iterables: Vec<Expr>,
pub body: Box<ScalarExpr>,
pub selector: Option<ScalarExpr>,
pub ty: Option<Type>,
}
impl ListComprehension {
pub fn new(
span: SourceSpan,
body: ScalarExpr,
mut context: ComprehensionContext,
selector: Option<ScalarExpr>,
) -> Self {
let bindings = context.iter().map(|(name, _)| name).copied().collect();
let iterables = context.drain(..).map(|(_, iterable)| iterable).collect();
Self {
span,
bindings,
iterables,
body: Box::new(body),
selector,
ty: None,
}
}
}
impl Eq for ListComprehension {}
impl PartialEq for ListComprehension {
fn eq(&self, other: &Self) -> bool {
self.bindings == other.bindings
&& self.iterables == other.iterables
&& self.body == other.body
&& self.selector == other.selector
}
}
impl fmt::Debug for ListComprehension {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ListComprehension")
.field("bindings", &self.bindings)
.field("iterables", &self.iterables)
.field("body", self.body.as_ref())
.field("selector", &self.selector)
.field("ty", &self.ty)
.finish()
}
}
impl fmt::Display for ListComprehension {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self.bindings.len() == 1 {
write!(
f,
"{} for {} in {}",
&self.body, &self.bindings[0], &self.iterables[0]
)?;
} else {
write!(
f,
"{} for {} in {}",
&self.body,
DisplayTuple(self.bindings.as_slice()),
DisplayTuple(self.iterables.as_slice())
)?;
}
if let Some(selector) = self.selector.as_ref() {
write!(f, " when {selector}")
} else {
Ok(())
}
}
}
#[derive(Clone, Spanned)]
pub struct BusOperation {
#[span]
pub span: SourceSpan,
pub bus: ResolvableIdentifier,
pub op: BusOperator,
pub args: Vec<Expr>,
}
impl BusOperation {
pub fn new(span: SourceSpan, bus: Identifier, op: BusOperator, args: Vec<Expr>) -> Self {
Self {
span,
bus: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(bus)),
op,
args,
}
}
}
impl Eq for BusOperation {}
impl PartialEq for BusOperation {
fn eq(&self, other: &Self) -> bool {
self.bus == other.bus && self.args == other.args && self.op == other.op
}
}
impl fmt::Debug for BusOperation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("BusOperation")
.field("bus", &self.bus)
.field("op", &self.op)
.field("args", &self.args)
.finish()
}
}
impl fmt::Display for BusOperation {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{}{}{}",
self.bus,
self.op,
DisplayTuple(self.args.as_slice())
)
}
}
#[derive(Clone, Spanned)]
pub struct Call {
#[span]
pub span: SourceSpan,
pub callee: ResolvableIdentifier,
pub args: Vec<Expr>,
pub ty: Option<Type>,
}
impl Call {
pub fn new(span: SourceSpan, callee: Identifier, args: Vec<Expr>) -> Self {
use crate::symbols;
match callee.name() {
symbols::Sum => Self::sum(span, args),
symbols::Prod => Self::prod(span, args),
_ => Self {
span,
callee: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Function(callee)),
args,
ty: None,
},
}
}
#[inline]
pub fn is_builtin(&self) -> bool {
self.callee.is_builtin()
}
#[inline]
pub fn sum(span: SourceSpan, args: Vec<Expr>) -> Self {
Self::new_builtin(span, "sum", args, Type::Felt)
}
#[inline]
pub fn prod(span: SourceSpan, args: Vec<Expr>) -> Self {
Self::new_builtin(span, "prod", args, Type::Felt)
}
fn new_builtin(span: SourceSpan, name: &str, args: Vec<Expr>, ty: Type) -> Self {
let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin"));
let name = Identifier::new(span, Symbol::intern(name));
let id = QualifiedIdentifier::new(builtin_module, NamespacedIdentifier::Function(name));
Self {
span,
callee: ResolvableIdentifier::Resolved(id),
args,
ty: Some(ty),
}
}
}
impl Eq for Call {}
impl PartialEq for Call {
fn eq(&self, other: &Self) -> bool {
self.callee == other.callee && self.args == other.args && self.ty == other.ty
}
}
impl fmt::Debug for Call {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Call")
.field("callee", &self.callee)
.field("args", &self.args)
.field("ty", &self.ty)
.finish()
}
}
impl fmt::Display for Call {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}{}", self.callee, DisplayTuple(self.args.as_slice()))
}
}