use core::{
fmt::{Display, Formatter},
ops::{Add, AddAssign, Bound, Index, IndexMut, Mul, RangeBounds, Sub, SubAssign},
};
use vm_core::Felt;
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum RowIndexError<T> {
#[error("value is too large to be converted into RowIndex: {0}")]
InvalidSize(T),
}
#[derive(Debug, Copy, Clone, Eq, Ord, PartialOrd)]
pub struct RowIndex(u32);
impl RowIndex {
pub fn as_usize(&self) -> usize {
self.0 as usize
}
pub fn as_u32(&self) -> u32 {
self.0
}
}
impl Display for RowIndex {
fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<RowIndex> for u32 {
fn from(step: RowIndex) -> u32 {
step.0
}
}
impl From<RowIndex> for u64 {
fn from(step: RowIndex) -> u64 {
step.0 as u64
}
}
impl From<RowIndex> for usize {
fn from(step: RowIndex) -> usize {
step.0 as usize
}
}
impl From<RowIndex> for Felt {
fn from(step: RowIndex) -> Felt {
Felt::from(step.0)
}
}
impl From<usize> for RowIndex {
fn from(value: usize) -> Self {
let value = u32::try_from(value).map_err(|_| RowIndexError::InvalidSize(value)).unwrap();
value.into()
}
}
impl TryFrom<u64> for RowIndex {
type Error = RowIndexError<u64>;
fn try_from(value: u64) -> Result<Self, Self::Error> {
let value = u32::try_from(value).map_err(|_| RowIndexError::InvalidSize(value))?;
Ok(RowIndex::from(value))
}
}
impl From<u32> for RowIndex {
fn from(value: u32) -> Self {
Self(value)
}
}
impl From<i32> for RowIndex {
fn from(value: i32) -> Self {
let value = u32::try_from(value).map_err(|_| RowIndexError::InvalidSize(value)).unwrap();
RowIndex(value)
}
}
impl Sub<usize> for RowIndex {
type Output = RowIndex;
fn sub(self, rhs: usize) -> Self::Output {
let rhs = u32::try_from(rhs).map_err(|_| RowIndexError::InvalidSize(rhs)).unwrap();
RowIndex(self.0 - rhs)
}
}
impl SubAssign<u32> for RowIndex {
fn sub_assign(&mut self, rhs: u32) {
self.0 -= rhs;
}
}
impl Sub<RowIndex> for RowIndex {
type Output = usize;
fn sub(self, rhs: RowIndex) -> Self::Output {
(self.0 - rhs.0) as usize
}
}
impl RowIndex {
pub fn saturating_sub(self, rhs: u32) -> Self {
RowIndex(self.0.saturating_sub(rhs))
}
pub fn max(self, other: RowIndex) -> Self {
RowIndex(self.0.max(other.0))
}
}
impl Add<usize> for RowIndex {
type Output = RowIndex;
fn add(self, rhs: usize) -> Self::Output {
let rhs = u32::try_from(rhs).map_err(|_| RowIndexError::InvalidSize(rhs)).unwrap();
RowIndex(self.0 + rhs)
}
}
impl Add<RowIndex> for u32 {
type Output = RowIndex;
fn add(self, rhs: RowIndex) -> Self::Output {
RowIndex(self + rhs.0)
}
}
impl AddAssign<usize> for RowIndex {
fn add_assign(&mut self, rhs: usize) {
let rhs: RowIndex = rhs.into();
self.0 += rhs.0;
}
}
impl Mul<RowIndex> for usize {
type Output = RowIndex;
fn mul(self, rhs: RowIndex) -> Self::Output {
(self * rhs.0 as usize).into()
}
}
impl PartialEq<RowIndex> for RowIndex {
fn eq(&self, rhs: &RowIndex) -> bool {
self.0 == rhs.0
}
}
impl PartialEq<usize> for RowIndex {
fn eq(&self, rhs: &usize) -> bool {
self.0 == u32::try_from(*rhs).map_err(|_| RowIndexError::InvalidSize(*rhs)).unwrap()
}
}
impl PartialEq<RowIndex> for i32 {
fn eq(&self, rhs: &RowIndex) -> bool {
*self as u32 == u32::from(*rhs)
}
}
impl PartialOrd<usize> for RowIndex {
fn partial_cmp(&self, rhs: &usize) -> Option<core::cmp::Ordering> {
let rhs = u32::try_from(*rhs).map_err(|_| RowIndexError::InvalidSize(*rhs)).unwrap();
self.0.partial_cmp(&rhs)
}
}
impl<T> Index<RowIndex> for [T] {
type Output = T;
fn index(&self, i: RowIndex) -> &Self::Output {
&self[i.0 as usize]
}
}
impl<T> IndexMut<RowIndex> for [T] {
fn index_mut(&mut self, i: RowIndex) -> &mut Self::Output {
&mut self[i.0 as usize]
}
}
impl RangeBounds<RowIndex> for RowIndex {
fn start_bound(&self) -> Bound<&Self> {
Bound::Included(self)
}
fn end_bound(&self) -> Bound<&Self> {
Bound::Included(self)
}
}
#[cfg(test)]
mod tests {
use alloc::collections::BTreeMap;
#[test]
fn row_index_conversions() {
use super::RowIndex;
let _: RowIndex = 5.into();
let _: RowIndex = 5u32.into();
let _: RowIndex = (5usize).into();
let _: u32 = RowIndex(5).into();
let _: u64 = RowIndex(5).into();
let _: usize = RowIndex(5).into();
}
#[test]
fn row_index_ops() {
use super::RowIndex;
assert_eq!(RowIndex(5), 5);
assert_eq!(RowIndex(5), RowIndex(5));
assert!(RowIndex(5) == RowIndex(5));
assert!(RowIndex(5) >= RowIndex(5));
assert!(RowIndex(6) >= RowIndex(5));
assert!(RowIndex(5) > RowIndex(4));
assert!(RowIndex(5) <= RowIndex(5));
assert!(RowIndex(4) <= RowIndex(5));
assert!(RowIndex(5) < RowIndex(6));
assert_eq!(RowIndex(5) + 3, 8);
assert_eq!(RowIndex(5) - 3, 2);
assert_eq!(3 + RowIndex(5), 8);
assert_eq!(2 * RowIndex(5), 10);
let mut step = RowIndex(5);
step += 5;
assert_eq!(step, 10);
}
#[test]
fn row_index_range() {
use super::RowIndex;
let mut tree: BTreeMap<RowIndex, usize> = BTreeMap::new();
tree.insert(RowIndex(0), 0);
tree.insert(RowIndex(1), 1);
tree.insert(RowIndex(2), 2);
let acc =
tree.range(RowIndex::from(0)..RowIndex::from(tree.len()))
.fold(0, |acc, (key, val)| {
assert_eq!(*key, RowIndex::from(acc));
assert_eq!(*val, acc);
acc + 1
});
assert_eq!(acc, 3);
}
#[test]
fn row_index_display() {
assert_eq!(format!("{}", super::RowIndex(5)), "5");
}
}