use core::{
fmt,
ops::{Index, IndexMut},
};
use miden_core::{Felt, FieldElement};
use midenc_hir::{AttributeValue, Immediate, Type, ValueRef};
use smallvec::{SmallVec, smallvec};
#[derive(Debug, Copy, Clone)]
pub enum Constraint {
Move,
Copy,
}
pub enum OperandType {
Const(Box<dyn AttributeValue>),
Value(ValueRef),
Type(Type),
}
impl Clone for OperandType {
fn clone(&self) -> Self {
match self {
Self::Const(value) => Self::Const(value.clone_value()),
Self::Value(value) => Self::Value(*value),
Self::Type(ty) => Self::Type(ty.clone()),
}
}
}
impl OperandType {
pub fn ty(&self) -> Type {
match self {
Self::Const(imm) => {
imm.downcast_ref::<Immediate>().expect("unexpected constant value type").ty()
}
Self::Value(value) => value.borrow().ty().clone(),
Self::Type(ty) => ty.clone(),
}
}
}
impl fmt::Debug for OperandType {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Const(value) => write!(f, "Const({value:?})"),
Self::Value(value) => write!(f, "Value({value})"),
Self::Type(ty) => write!(f, "Type({ty})"),
}
}
}
impl Eq for OperandType {}
impl PartialEq for OperandType {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Value(a), Self::Value(b)) => a == b,
(Self::Value(_), _) | (_, Self::Value(_)) => false,
(Self::Const(a), Self::Const(b)) => a == b,
(Self::Const(_), _) | (_, Self::Const(_)) => false,
(Self::Type(a), Self::Type(b)) => a == b,
}
}
}
impl PartialEq<Type> for OperandType {
fn eq(&self, other: &Type) -> bool {
match self {
Self::Type(a) => a == other,
_ => false,
}
}
}
impl PartialEq<Immediate> for OperandType {
fn eq(&self, other: &Immediate) -> bool {
match self {
Self::Const(a) => a.downcast_ref::<Immediate>().is_some_and(|a| a == other),
_ => false,
}
}
}
impl PartialEq<dyn AttributeValue> for OperandType {
fn eq(&self, other: &dyn AttributeValue) -> bool {
match self {
Self::Const(a) => a.as_ref() == other,
_ => false,
}
}
}
impl PartialEq<ValueRef> for OperandType {
fn eq(&self, other: &ValueRef) -> bool {
match self {
Self::Value(this) => this == other,
_ => false,
}
}
}
impl From<ValueRef> for OperandType {
fn from(value: ValueRef) -> Self {
Self::Value(value)
}
}
impl From<Type> for OperandType {
fn from(ty: Type) -> Self {
Self::Type(ty)
}
}
impl From<Immediate> for OperandType {
fn from(value: Immediate) -> Self {
Self::Const(Box::new(value))
}
}
impl From<Box<dyn AttributeValue>> for OperandType {
fn from(value: Box<dyn AttributeValue>) -> Self {
Self::Const(value)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Operand {
word: SmallVec<[Type; 4]>,
operand: OperandType,
}
impl Default for Operand {
fn default() -> Self {
Self {
word: smallvec![Type::Felt],
operand: OperandType::Const(Box::new(Immediate::Felt(Felt::ZERO))),
}
}
}
impl PartialEq<ValueRef> for Operand {
#[inline(always)]
fn eq(&self, other: &ValueRef) -> bool {
self.operand.eq(other)
}
}
impl PartialEq<dyn AttributeValue> for Operand {
#[inline(always)]
fn eq(&self, other: &dyn AttributeValue) -> bool {
self.operand.eq(other)
}
}
impl PartialEq<Immediate> for Operand {
#[inline(always)]
fn eq(&self, other: &Immediate) -> bool {
self.operand.eq(other)
}
}
impl PartialEq<Immediate> for &Operand {
#[inline(always)]
fn eq(&self, other: &Immediate) -> bool {
self.operand.eq(other)
}
}
impl PartialEq<Type> for Operand {
#[inline(always)]
fn eq(&self, other: &Type) -> bool {
self.operand.eq(other)
}
}
impl PartialEq<Type> for &Operand {
#[inline(always)]
fn eq(&self, other: &Type) -> bool {
self.operand.eq(other)
}
}
impl From<Immediate> for Operand {
#[inline]
fn from(imm: Immediate) -> Self {
Self::new(imm.into())
}
}
impl From<u32> for Operand {
#[inline]
fn from(imm: u32) -> Self {
Self::new(Immediate::U32(imm).into())
}
}
impl TryFrom<&Operand> for ValueRef {
type Error = ();
fn try_from(operand: &Operand) -> Result<Self, Self::Error> {
match operand.operand {
OperandType::Value(value) => Ok(value),
_ => Err(()),
}
}
}
#[cfg(test)]
impl TryFrom<&Operand> for Immediate {
type Error = ();
fn try_from(operand: &Operand) -> Result<Self, Self::Error> {
match &operand.operand {
OperandType::Const(value) => value.downcast_ref::<Immediate>().copied().ok_or(()),
_ => Err(()),
}
}
}
#[cfg(test)]
impl TryFrom<&Operand> for Type {
type Error = ();
fn try_from(operand: &Operand) -> Result<Self, Self::Error> {
match operand.operand {
OperandType::Type(ref ty) => Ok(ty.clone()),
_ => Err(()),
}
}
}
#[cfg(test)]
impl TryFrom<Operand> for Type {
type Error = ();
fn try_from(operand: Operand) -> Result<Self, Self::Error> {
match operand.operand {
OperandType::Type(ty) => Ok(ty),
_ => Err(()),
}
}
}
impl From<Type> for Operand {
#[inline]
fn from(ty: Type) -> Self {
Self::new(OperandType::Type(ty))
}
}
impl From<ValueRef> for Operand {
#[inline]
fn from(value: ValueRef) -> Self {
Self::new(OperandType::Value(value))
}
}
impl Operand {
pub fn new(operand: OperandType) -> Self {
let ty = operand.ty();
let mut word = ty.to_raw_parts().expect("invalid operand type");
assert!(!word.is_empty(), "invalid operand: must be a sized type");
assert!(word.len() <= 4, "invalid operand: must be smaller than or equal to a word");
if word.len() > 1 {
word.reverse();
}
Self { word, operand }
}
pub fn size(&self) -> usize {
self.word.len()
}
#[inline(always)]
pub fn value(&self) -> &OperandType {
&self.operand
}
#[inline]
pub fn as_value(&self) -> Option<ValueRef> {
self.try_into().ok()
}
#[inline]
pub fn ty(&self) -> Type {
self.operand.ty()
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct OperandStack {
stack: Vec<Operand>,
}
impl Default for OperandStack {
fn default() -> Self {
Self {
stack: Vec::with_capacity(16),
}
}
}
impl OperandStack {
pub fn rename(&mut self, n: usize, value: ValueRef) {
let len = self.stack.len();
assert!(n < len, "invalid operand stack index ({n}), only {len} operands are available");
let index = len - n - 1;
match &mut self.stack[index].operand {
OperandType::Value(prev_value) => {
*prev_value = value;
}
prev => {
*prev = OperandType::Value(value);
}
}
}
pub fn find(&self, value: &ValueRef) -> Option<usize> {
self.stack.iter().rev().position(|v| v == value)
}
#[allow(unused)]
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.stack.is_empty()
}
#[inline]
pub fn raw_len(&self) -> usize {
self.stack.iter().map(|operand| operand.size()).sum()
}
#[track_caller]
pub fn effective_index(&self, index: usize) -> usize {
assert!(
index < self.stack.len(),
"expected {} to be less than {}",
index,
self.stack.len()
);
self.stack.iter().rev().take(index).map(|o| o.size()).sum()
}
#[track_caller]
pub fn effective_index_inclusive(&self, index: usize) -> usize {
assert!(index < self.stack.len());
self.stack.iter().rev().take(index + 1).map(|o| o.size()).sum::<usize>() - 1
}
#[inline]
pub fn len(&self) -> usize {
self.stack.len()
}
#[inline]
pub fn get(&self, index: usize) -> Option<&Operand> {
let effective_len: usize = self.stack.iter().rev().take(index + 1).map(|o| o.size()).sum();
assert!(
effective_len <= 16,
"invalid operand stack index ({index}): requires access to more than 16 elements, \
which is not supported in Miden"
);
let len = self.stack.len();
if index >= len {
return None;
}
self.stack.get(len - index - 1)
}
#[inline]
pub fn peek(&self) -> Option<&Operand> {
self.stack.last()
}
#[inline]
pub fn push<V: Into<Operand>>(&mut self, value: V) {
self.stack.push(value.into());
}
#[inline]
#[track_caller]
pub fn pop(&mut self) -> Option<Operand> {
self.stack.pop()
}
#[allow(clippy::should_implement_trait)]
#[track_caller]
pub fn drop(&mut self) {
self.stack.pop().expect("operand stack is empty");
}
#[inline]
#[track_caller]
pub fn dropn(&mut self, n: usize) {
let len = self.stack.len();
assert!(n <= len, "unable to drop {n} operands, operand stack only has {len}");
self.stack.truncate(len - n);
}
#[track_caller]
pub fn dup(&mut self, n: usize) {
let operand = self[n].clone();
self.stack.push(operand);
}
#[track_caller]
pub fn swap(&mut self, n: usize) {
assert_ne!(n, 0, "invalid swap, index must be in the range 1..=15");
let len = self.stack.len();
assert!(n < len, "invalid operand stack index ({n}), only {len} operands are available");
let a = len - 1;
let b = a - n;
self.stack.swap(a, b);
}
pub fn movup(&mut self, n: usize) {
assert_ne!(n, 0, "invalid move, index must be in the range 1..=15");
let len = self.stack.len();
assert!(n < len, "invalid operand stack index ({n}), only {len} operands are available");
let mid = len - (n + 1);
let (_, r) = self.stack.split_at_mut(mid);
r.rotate_left(1);
}
pub fn movdn(&mut self, n: usize) {
assert_ne!(n, 0, "invalid move, index must be in the range 1..=15");
let len = self.stack.len();
assert!(n < len, "invalid operand stack index ({n}), only {len} operands are available");
let mid = len - (n + 1);
let (_, r) = self.stack.split_at_mut(mid);
r.rotate_right(1);
}
#[allow(unused)]
#[inline(always)]
pub fn iter(&self) -> impl DoubleEndedIterator<Item = &Operand> {
self.stack.iter()
}
}
impl Index<usize> for OperandStack {
type Output = Operand;
#[track_caller]
fn index(&self, index: usize) -> &Self::Output {
let len = self.stack.len();
assert!(
index < len,
"invalid operand stack index ({index}): only {len} operands are available"
);
let effective_len: usize = self.stack.iter().rev().take(index + 1).map(|o| o.size()).sum();
assert!(
effective_len <= 16,
"invalid operand stack index ({index}): requires access to more than 16 elements, \
which is not supported in Miden"
);
&self.stack[len - index - 1]
}
}
impl IndexMut<usize> for OperandStack {
#[track_caller]
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
let len = self.stack.len();
assert!(
index < len,
"invalid operand stack index ({index}): only {len} elements are available"
);
let effective_len: usize = self.stack.iter().rev().take(index + 1).map(|o| o.size()).sum();
assert!(
effective_len <= 16,
"invalid operand stack index ({index}): requires access to more than 16 elements, \
which is not supported in Miden"
);
&mut self.stack[len - index - 1]
}
}
impl fmt::Debug for OperandStack {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut builder = f.debug_list();
for (index, value) in self.stack.iter().rev().enumerate() {
builder.entry_with(|f| write!(f, "{index} => {value:?}"));
}
builder.finish()
}
}
#[cfg(test)]
mod tests {
use alloc::rc::Rc;
use midenc_hir::{ArrayType, BuilderExt, Context, PointerType, StructType};
use super::*;
#[test]
fn operand_stack_homogenous_operand_sizes_test() {
let mut stack = OperandStack::default();
let zero = Immediate::U32(0);
let one = Immediate::U32(1);
let two = Immediate::U32(2);
let three = Immediate::U32(3);
#[inline]
#[allow(unused)]
fn as_imms(word: [Operand; 4]) -> [Immediate; 4] {
[
(&word[0]).try_into().unwrap(),
(&word[1]).try_into().unwrap(),
(&word[2]).try_into().unwrap(),
(&word[3]).try_into().unwrap(),
]
}
#[inline]
fn as_imm(operand: Operand) -> Immediate {
(&operand).try_into().unwrap()
}
stack.push(zero);
stack.push(one);
stack.push(two);
stack.push(three);
assert_eq!(stack.len(), 4);
assert_eq!(stack[0], three);
assert_eq!(stack[1], two);
assert_eq!(stack[2], one);
assert_eq!(stack[3], zero);
assert_eq!(stack.peek().unwrap(), three);
stack.dup(0);
assert_eq!(stack.len(), 5);
assert_eq!(stack[0], three);
assert_eq!(stack[1], three);
assert_eq!(stack[2], two);
assert_eq!(stack[3], one);
assert_eq!(stack[4], zero);
stack.dup(3);
assert_eq!(stack.len(), 6);
assert_eq!(stack[0], one);
assert_eq!(stack[1], three);
assert_eq!(stack[2], three);
assert_eq!(stack[3], two);
assert_eq!(stack[4], one);
assert_eq!(stack[5], zero);
stack.drop();
assert_eq!(stack.len(), 5);
assert_eq!(stack[0], three);
assert_eq!(stack[1], three);
assert_eq!(stack[2], two);
assert_eq!(stack[3], one);
assert_eq!(stack[4], zero);
stack.swap(2);
assert_eq!(stack.len(), 5);
assert_eq!(stack[0], two);
assert_eq!(stack[1], three);
assert_eq!(stack[2], three);
assert_eq!(stack[3], one);
assert_eq!(stack[4], zero);
stack.swap(1);
assert_eq!(stack.len(), 5);
assert_eq!(stack[0], three);
assert_eq!(stack[1], two);
assert_eq!(stack[2], three);
assert_eq!(stack[3], one);
assert_eq!(stack[4], zero);
stack.movup(2);
assert_eq!(stack.len(), 5);
assert_eq!(stack[0], three);
assert_eq!(stack[1], three);
assert_eq!(stack[2], two);
assert_eq!(stack[3], one);
assert_eq!(stack[4], zero);
stack.movdn(3);
assert_eq!(stack.len(), 5);
assert_eq!(stack[0], three);
assert_eq!(stack[1], two);
assert_eq!(stack[2], one);
assert_eq!(stack[3], three);
assert_eq!(stack[4], zero);
assert_eq!(stack.pop().map(as_imm), Some(three));
assert_eq!(stack.len(), 4);
assert_eq!(stack[0], two);
assert_eq!(stack[1], one);
assert_eq!(stack[2], three);
assert_eq!(stack[3], zero);
stack.dropn(2);
assert_eq!(stack.len(), 2);
assert_eq!(stack[0], three);
assert_eq!(stack[1], zero);
}
#[test]
fn operand_stack_values_test() {
use midenc_dialect_hir::Load;
let mut stack = OperandStack::default();
let context = Rc::new(Context::default());
let ptr_u8 = Type::from(PointerType::new(Type::U8));
let array_u8 = Type::from(ArrayType::new(Type::U8, 4));
let struct_ty = Type::from(StructType::new([Type::U64, Type::U8]));
let block = context.create_block_with_params([ptr_u8, array_u8, Type::U32, struct_ty]);
let block = block.borrow();
let values = block.arguments();
let zero = values[0] as ValueRef;
let one = values[1] as ValueRef;
let two = values[2] as ValueRef;
let three = values[3] as ValueRef;
drop(block);
stack.push(zero);
stack.push(one);
stack.push(two);
stack.push(three);
assert_eq!(stack.len(), 4);
assert_eq!(stack.raw_len(), 6);
assert_eq!(stack.find(&zero), Some(3));
assert_eq!(stack.find(&one), Some(2));
assert_eq!(stack.find(&two), Some(1));
assert_eq!(stack.find(&three), Some(0));
stack.dup(0);
assert_eq!(stack.find(&three), Some(0));
stack.dup(3);
assert_eq!(stack.find(&one), Some(0));
stack.drop();
assert_eq!(stack.find(&one), Some(3));
assert_eq!(stack.find(&three), Some(0));
assert_eq!(stack[1], three);
stack.push(Immediate::Felt(Felt::ZERO));
stack.push(Immediate::Felt(Felt::ZERO));
stack.push(Immediate::Felt(Felt::ZERO));
stack.push(Immediate::Felt(Felt::ZERO));
assert_eq!(stack.find(&one), Some(7));
assert_eq!(stack.find(&three), Some(4));
let four = {
let mut builder = midenc_hir::OpBuilder::new(context.clone());
let load_builder = builder.create::<Load, _>(Default::default());
let load = load_builder(zero).unwrap();
load.borrow().result().as_value_ref()
};
stack.rename(1, four);
assert_eq!(stack.find(&four), Some(1));
assert_eq!(stack.find(&three), Some(4));
let top = stack.pop().unwrap();
assert_eq!((&top).try_into(), Ok(Immediate::Felt(Felt::ZERO)));
assert_eq!(stack.find(&four), Some(0));
assert_eq!(stack[1], Immediate::Felt(Felt::ZERO));
assert_eq!(stack[2], Immediate::Felt(Felt::ZERO));
assert_eq!(stack.find(&three), Some(3));
stack.dropn(3);
assert_eq!(stack.find(&four), None);
assert_eq!(stack.find(&three), Some(0));
assert_eq!(stack[1], three);
assert_eq!(stack.find(&two), Some(2));
assert_eq!(stack.find(&one), Some(3));
assert_eq!(stack.find(&zero), Some(4));
stack.swap(3);
assert_eq!(stack.find(&one), Some(0));
assert_eq!(stack.find(&three), Some(1));
assert_eq!(stack.find(&two), Some(2));
assert_eq!(stack[3], three);
stack.swap(1);
assert_eq!(stack.find(&three), Some(0));
assert_eq!(stack.find(&one), Some(1));
assert_eq!(stack.find(&two), Some(2));
assert_eq!(stack.find(&zero), Some(4));
stack.movup(2);
assert_eq!(stack.find(&two), Some(0));
assert_eq!(stack.find(&three), Some(1));
assert_eq!(stack.find(&one), Some(2));
assert_eq!(stack.find(&zero), Some(4));
stack.movdn(3);
assert_eq!(stack.find(&three), Some(0));
assert_eq!(stack.find(&one), Some(1));
assert_eq!(stack[2], three);
assert_eq!(stack.find(&two), Some(3));
assert_eq!(stack.find(&zero), Some(4));
}
#[test]
fn operand_stack_heterogenous_operand_sizes_test() {
let mut stack = OperandStack::default();
let zero = Immediate::U32(0);
let one = Immediate::U32(1);
let two = Type::U64;
let three = Type::U64;
let struct_a = Type::from(StructType::new([
Type::from(PointerType::new(Type::U8)),
Type::U16,
Type::U32,
]));
stack.push(zero);
stack.push(one);
stack.push(two.clone());
stack.push(three.clone());
stack.push(struct_a.clone());
assert_eq!(stack.len(), 5);
assert_eq!(stack.raw_len(), 9);
assert_eq!(stack[0], struct_a);
assert_eq!(stack[1], three);
assert_eq!(stack[2], two);
assert_eq!(stack[3], one);
assert_eq!(stack[4], zero);
assert_eq!(stack.peek().unwrap(), struct_a);
stack.dup(0);
assert_eq!(stack.len(), 6);
assert_eq!(stack.raw_len(), 12);
assert_eq!(stack[0], struct_a);
assert_eq!(stack[1], struct_a);
assert_eq!(stack[2], three);
assert_eq!(stack[3], two);
assert_eq!(stack[4], one);
assert_eq!(stack[5], zero);
assert_eq!(stack.effective_index(3), 8);
stack.dup(3);
assert_eq!(stack.len(), 7);
assert_eq!(stack.raw_len(), 14);
assert_eq!(stack[0], two);
assert_eq!(stack[1], struct_a);
assert_eq!(stack[2], struct_a);
stack.drop();
assert_eq!(stack.len(), 6);
assert_eq!(stack.raw_len(), 12);
assert_eq!(stack[0], struct_a);
assert_eq!(stack[1], struct_a);
assert_eq!(stack[2], three);
assert_eq!(stack[3], two);
assert_eq!(stack[4], one);
assert_eq!(stack[5], zero);
stack.swap(2);
assert_eq!(stack.len(), 6);
assert_eq!(stack.raw_len(), 12);
assert_eq!(stack[0], three);
assert_eq!(stack[1], struct_a);
assert_eq!(stack[2], struct_a);
assert_eq!(stack[3], two);
assert_eq!(stack[4], one);
stack.swap(1);
assert_eq!(stack.len(), 6);
assert_eq!(stack.raw_len(), 12);
assert_eq!(stack[0], struct_a);
assert_eq!(stack[1], three);
assert_eq!(stack[2], struct_a);
assert_eq!(stack[3], two);
assert_eq!(stack[4], one);
assert_eq!(stack[5], zero);
stack.movup(4);
assert_eq!(stack.len(), 6);
assert_eq!(stack.raw_len(), 12);
assert_eq!(stack[0], one);
assert_eq!(stack[1], struct_a);
assert_eq!(stack[2], three);
assert_eq!(stack[3], struct_a);
assert_eq!(stack[4], two);
assert_eq!(stack[5], zero);
stack.movdn(3);
assert_eq!(stack.len(), 6);
assert_eq!(stack.raw_len(), 12);
assert_eq!(stack[0], struct_a);
assert_eq!(stack[1], three);
assert_eq!(stack[2], struct_a);
assert_eq!(stack[3], one);
assert_eq!(stack[4], two);
let operand: Type = stack.pop().unwrap().try_into().unwrap();
assert_eq!(operand, struct_a);
assert_eq!(stack.len(), 5);
assert_eq!(stack.raw_len(), 9);
assert_eq!(stack[0], three);
assert_eq!(stack[1], struct_a);
assert_eq!(stack[2], one);
assert_eq!(stack[3], two);
stack.dropn(2);
assert_eq!(stack.len(), 3);
assert_eq!(stack.raw_len(), 4);
assert_eq!(stack[0], one);
assert_eq!(stack[1], two);
assert_eq!(stack[2], zero);
}
}