use alloc::{format, string::ToString};
use core::fmt;
use midenc_hir_type::PointerType;
use midenc_session::diagnostics::Severity;
use crate::{
CompactString, Context, Op, Operation, Report, Type, derive::operation_trait, ir::value::Value,
};
pub trait InferTypeOpInterface: Op {
fn infer_return_types(&mut self, context: &Context) -> Result<(), Report>;
fn are_compatible_return_types(&self, lhs: &[Type], rhs: &[Type]) -> bool {
lhs == rhs
}
}
#[operation_trait]
pub trait SameTypeOperands {
#[verifier]
fn operands_are_the_same_type(op: &Operation, context: &Context) -> Result<(), Report> {
let mut operands = op.operands().iter();
let Some(first_operand) = operands.next() else {
return Ok(());
};
let (expected_ty, set_by) = {
let operand = first_operand.borrow();
let value = operand.value();
(value.ty().clone(), value.span())
};
for operand in operands {
let operand = operand.borrow();
let value = operand.value();
let value_ty = value.ty();
if value_ty != &expected_ty {
return Err(context
.session()
.diagnostics
.diagnostic(Severity::Error)
.with_message(::alloc::format!("invalid operation {}", op.name()))
.with_primary_label(
op.span,
"this operation expects all operands to be of the same type",
)
.with_secondary_label(set_by, "inferred the expected type from this value")
.with_secondary_label(value.span(), "which differs from this value's type")
.with_help(format!("expected '{expected_ty}', got '{value_ty}'"))
.into_report());
}
}
Ok(())
}
}
#[operation_trait]
pub trait SameOperandsAndResultType: SameTypeOperands {
#[verifier]
fn operands_and_result_are_the_same_type(
op: &Operation,
context: &Context,
) -> Result<(), Report> {
let mut operands = op.operands().iter();
let Some(first_operand) = operands.next() else {
return Ok(());
};
let (expected_ty, set_by) = {
let operand = first_operand.borrow();
let value = operand.value();
(value.ty().clone(), value.span())
};
let results = op.results();
assert!(
!results.is_empty(),
"Operation: {} was marked as having SameOperandsAndResultType, however it has no \
results.",
op.name()
);
for result in results.iter() {
let result = result.borrow();
let value = result.as_value_ref().borrow();
let result_ty = result.ty();
if result_ty != &expected_ty {
return Err(context
.session()
.diagnostics
.diagnostic(Severity::Error)
.with_message(::alloc::format!("invalid operation result {}", op.name()))
.with_primary_label(
op.span,
"this operation expects the operands and the results to be of the same \
type",
)
.with_secondary_label(set_by, "inferred the expected type from this value")
.with_secondary_label(value.span(), "which differs from this value's type")
.with_help(format!("expected '{expected_ty}', got '{result_ty}'"))
.into_report());
}
}
Ok(())
}
}
#[operation_trait]
pub trait Variadic<T: TypeConstraint> {
#[verifier]
fn all_operands_match_constraint<T: TypeConstraint>(
op: &Operation,
context: &Context,
) -> Result<(), Report> {
for operand in op.operands().iter() {
let operand = operand.borrow();
let value = operand.value();
let ty = value.ty();
let constraint = <T as TypeConstraint>::get();
if constraint.matches(ty) {
continue;
} else {
let description = constraint.description();
return Err(context
.diagnostics()
.diagnostic(Severity::Error)
.with_message("invalid operand")
.with_primary_label(
value.span(),
format!("expected operand type to be {description}, but got {ty}"),
)
.into_report());
}
}
Ok(())
}
}
pub trait TypeConstraint: 'static {
fn get() -> Self
where
Self: Sized;
fn description(&self) -> CompactString;
fn matches(&self, ty: &crate::Type) -> bool;
}
pub trait BuildableTypeConstraint: TypeConstraint {
fn build(context: &Context) -> crate::Type;
}
macro_rules! type_constraint {
($Constraint:ident, $description:literal, $matcher:literal) => {
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct $Constraint;
impl TypeConstraint for $Constraint {
#[inline(always)]
fn get() -> Self {
Self
}
#[inline(always)]
fn description(&self) -> $crate::CompactString {
$crate::CompactString::const_new($description)
}
#[inline]
fn matches(&self, _ty: &$crate::Type) -> bool {
$matcher
}
}
};
($Constraint:ident, $description:literal, $matcher:path) => {
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct $Constraint;
impl TypeConstraint for $Constraint {
#[inline(always)]
fn get() -> Self {
Self
}
#[inline(always)]
fn description(&self) -> $crate::CompactString {
$crate::CompactString::const_new($description)
}
#[inline]
fn matches(&self, ty: &$crate::Type) -> bool {
$matcher(ty)
}
}
};
($Constraint:ident, $description:literal, |$matcher_input:ident| $matcher:expr) => {
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct $Constraint;
impl TypeConstraint for $Constraint {
#[inline(always)]
fn get() -> Self {
Self
}
#[inline(always)]
fn description(&self) -> $crate::CompactString {
$crate::CompactString::const_new($description)
}
#[inline]
fn matches(&self, $matcher_input: &$crate::Type) -> bool {
$matcher
}
}
};
}
type_constraint!(AnyType, "any type", true);
type_constraint!(AnyList, "any list type", crate::Type::is_list);
type_constraint!(AnyArray, "any array type", crate::Type::is_array);
type_constraint!(AnyStruct, "any struct type", crate::Type::is_struct);
type_constraint!(AnyPointer, "a pointer type", crate::Type::is_pointer);
type_constraint!(AnyInteger, "an integral type", crate::Type::is_integer);
type_constraint!(AnyPointerOrInteger, "an integral or pointer type", |ty| ty.is_pointer()
|| ty.is_integer());
type_constraint!(AnySignedInteger, "a signed integral type", crate::Type::is_signed_integer);
type_constraint!(
AnyUnsignedInteger,
"an unsigned integral type",
crate::Type::is_unsigned_integer
);
type_constraint!(IntFelt, "a field element", crate::Type::is_felt);
pub type Bool = SizedInt<1>;
pub type Int8 = SizedInt<8>;
pub type SInt8 = And<AnySignedInteger, SizedInt<8>>;
pub type UInt8 = And<AnyUnsignedInteger, SizedInt<8>>;
pub type Int16 = SizedInt<16>;
pub type SInt16 = And<AnySignedInteger, SizedInt<16>>;
pub type UInt16 = And<AnyUnsignedInteger, SizedInt<16>>;
pub type Int32 = SizedInt<32>;
pub type SInt32 = And<AnySignedInteger, SizedInt<32>>;
pub type UInt32 = And<AnyUnsignedInteger, SizedInt<32>>;
pub type Int64 = SizedInt<64>;
pub type SInt64 = And<AnySignedInteger, SizedInt<64>>;
pub type UInt64 = And<AnyUnsignedInteger, SizedInt<64>>;
pub type Int128 = SizedInt<128>;
pub type SInt128 = And<AnySignedInteger, SizedInt<128>>;
pub type UInt128 = And<AnyUnsignedInteger, SizedInt<128>>;
impl BuildableTypeConstraint for IntFelt {
fn build(_context: &Context) -> crate::Type {
crate::Type::Felt
}
}
impl BuildableTypeConstraint for Bool {
fn build(_context: &Context) -> crate::Type {
crate::Type::I1
}
}
impl BuildableTypeConstraint for UInt8 {
fn build(_context: &Context) -> crate::Type {
crate::Type::U8
}
}
impl BuildableTypeConstraint for SInt8 {
fn build(_context: &Context) -> crate::Type {
crate::Type::I8
}
}
impl BuildableTypeConstraint for UInt16 {
fn build(_context: &Context) -> crate::Type {
crate::Type::U16
}
}
impl BuildableTypeConstraint for SInt16 {
fn build(_context: &Context) -> crate::Type {
crate::Type::I16
}
}
impl BuildableTypeConstraint for UInt32 {
fn build(_context: &Context) -> crate::Type {
crate::Type::U32
}
}
impl BuildableTypeConstraint for SInt32 {
fn build(_context: &Context) -> crate::Type {
crate::Type::I32
}
}
impl BuildableTypeConstraint for UInt64 {
fn build(_context: &Context) -> crate::Type {
crate::Type::U64
}
}
impl BuildableTypeConstraint for SInt64 {
fn build(_context: &Context) -> crate::Type {
crate::Type::I64
}
}
impl BuildableTypeConstraint for UInt128 {
fn build(_context: &Context) -> crate::Type {
crate::Type::U128
}
}
impl BuildableTypeConstraint for SInt128 {
fn build(_context: &Context) -> crate::Type {
crate::Type::I128
}
}
pub struct SizedInt<const N: usize>(core::marker::PhantomData<[(); N]>);
impl<const N: usize> Copy for SizedInt<N> {}
impl<const N: usize> Clone for SizedInt<N> {
fn clone(&self) -> Self {
*self
}
}
impl<const N: usize> fmt::Debug for SizedInt<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(core::any::type_name::<Self>())
}
}
impl<const N: usize> fmt::Display for SizedInt<N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{N}-bit integral type")
}
}
impl<const N: usize> TypeConstraint for SizedInt<N> {
#[inline(always)]
fn get() -> Self {
Self(core::marker::PhantomData)
}
fn description(&self) -> CompactString {
CompactString::from(self.to_string())
}
fn matches(&self, ty: &crate::Type) -> bool {
ty.is_integer()
}
}
impl BuildableTypeConstraint for SizedInt<8> {
fn build(_context: &Context) -> crate::Type {
crate::Type::I8
}
}
impl BuildableTypeConstraint for SizedInt<16> {
fn build(_context: &Context) -> crate::Type {
crate::Type::I16
}
}
impl BuildableTypeConstraint for SizedInt<32> {
fn build(_context: &Context) -> crate::Type {
crate::Type::I32
}
}
impl BuildableTypeConstraint for SizedInt<64> {
fn build(_context: &Context) -> crate::Type {
crate::Type::I64
}
}
pub struct PointerOf<T>(core::marker::PhantomData<T>);
impl<T> Copy for PointerOf<T> {}
impl<T> Clone for PointerOf<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> fmt::Debug for PointerOf<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(core::any::type_name::<Self>())
}
}
impl<T: TypeConstraint> fmt::Display for PointerOf<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let pointee = <T as TypeConstraint>::get().description();
write!(f, "a pointer to {pointee}")
}
}
impl<T: TypeConstraint> TypeConstraint for PointerOf<T> {
#[inline(always)]
fn get() -> Self {
Self(core::marker::PhantomData)
}
fn description(&self) -> CompactString {
CompactString::from(self.to_string())
}
fn matches(&self, ty: &crate::Type) -> bool {
ty.pointee()
.is_some_and(|pointee| <T as TypeConstraint>::get().matches(pointee))
}
}
impl<T: BuildableTypeConstraint> BuildableTypeConstraint for PointerOf<T> {
fn build(context: &Context) -> crate::Type {
crate::Type::from(PointerType::new(<T as BuildableTypeConstraint>::build(context)))
}
}
pub struct AnyArrayOf<T>(core::marker::PhantomData<T>);
impl<T> Copy for AnyArrayOf<T> {}
impl<T> Clone for AnyArrayOf<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> fmt::Debug for AnyArrayOf<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(core::any::type_name::<Self>())
}
}
impl<T: TypeConstraint> fmt::Display for AnyArrayOf<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let element = <T as TypeConstraint>::get().description();
write!(f, "an array of {element}")
}
}
impl<T: TypeConstraint> TypeConstraint for AnyArrayOf<T> {
#[inline(always)]
fn get() -> Self {
Self(core::marker::PhantomData)
}
fn description(&self) -> CompactString {
CompactString::from(self.to_string())
}
fn matches(&self, ty: &crate::Type) -> bool {
match ty {
crate::Type::Array(ty) => <T as TypeConstraint>::get().matches(ty.element_type()),
_ => false,
}
}
}
pub struct ArrayOf<const N: usize, T>(core::marker::PhantomData<[T; N]>);
impl<const N: usize, T> Copy for ArrayOf<N, T> {}
impl<const N: usize, T> Clone for ArrayOf<N, T> {
fn clone(&self) -> Self {
*self
}
}
impl<const N: usize, T> fmt::Debug for ArrayOf<N, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(core::any::type_name::<Self>())
}
}
impl<const N: usize, T: TypeConstraint> fmt::Display for ArrayOf<N, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let element = <T as TypeConstraint>::get().description();
write!(f, "an array of {N} {element}")
}
}
impl<const N: usize, T: TypeConstraint> TypeConstraint for ArrayOf<N, T> {
#[inline(always)]
fn get() -> Self {
Self(core::marker::PhantomData)
}
fn description(&self) -> CompactString {
CompactString::from(self.to_string())
}
fn matches(&self, ty: &crate::Type) -> bool {
match ty {
crate::Type::Array(ty) if ty.len() == N => {
<T as TypeConstraint>::get().matches(ty.element_type())
}
_ => false,
}
}
}
impl<const N: usize, T: BuildableTypeConstraint> BuildableTypeConstraint for ArrayOf<N, T> {
fn build(context: &Context) -> crate::Type {
let element = <T as BuildableTypeConstraint>::build(context);
crate::Type::from(crate::ArrayType::new(element, N))
}
}
pub struct And<T, U> {
_left: core::marker::PhantomData<T>,
_right: core::marker::PhantomData<U>,
}
impl<T, U> Copy for And<T, U> {}
impl<T, U> Clone for And<T, U> {
fn clone(&self) -> Self {
*self
}
}
impl<T, U> fmt::Debug for And<T, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(core::any::type_name::<Self>())
}
}
impl<T: TypeConstraint, U: TypeConstraint> TypeConstraint for And<T, U> {
#[inline(always)]
fn get() -> Self {
Self {
_left: core::marker::PhantomData,
_right: core::marker::PhantomData,
}
}
fn description(&self) -> CompactString {
let left = <T as TypeConstraint>::get().description();
let right = <U as TypeConstraint>::get().description();
CompactString::from(format!("both {left} and {right}"))
}
#[inline]
fn matches(&self, ty: &crate::Type) -> bool {
<T as TypeConstraint>::get().matches(ty) && <U as TypeConstraint>::get().matches(ty)
}
}
pub struct Or<T, U> {
_left: core::marker::PhantomData<T>,
_right: core::marker::PhantomData<U>,
}
impl<T, U> Copy for Or<T, U> {}
impl<T, U> Clone for Or<T, U> {
fn clone(&self) -> Self {
*self
}
}
impl<T, U> fmt::Debug for Or<T, U> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(core::any::type_name::<Self>())
}
}
impl<T: TypeConstraint, U: TypeConstraint> TypeConstraint for Or<T, U> {
#[inline(always)]
fn get() -> Self {
Self {
_left: core::marker::PhantomData,
_right: core::marker::PhantomData,
}
}
fn description(&self) -> CompactString {
let left = <T as TypeConstraint>::get().description();
let right = <U as TypeConstraint>::get().description();
CompactString::from(format!("either {left} or {right}"))
}
#[inline]
fn matches(&self, ty: &crate::Type) -> bool {
<T as TypeConstraint>::get().matches(ty) || <U as TypeConstraint>::get().matches(ty)
}
}