1use alloc::boxed::Box;
2use core::{
3 fmt::{Display, Formatter},
4 ops::{Add, AddAssign, Bound, Index, IndexMut, Mul, RangeBounds, Sub, SubAssign},
5};
6
7use vm_core::Felt;
8
9#[derive(Debug, thiserror::Error)]
12pub enum RowIndexError {
13 #[error("value {0} is larger than u32::MAX so it cannot be converted into a RowIndex")]
15 InvalidSize(Box<str>),
16}
17
18#[derive(Debug, Copy, Clone, Eq, Ord, PartialOrd)]
23pub struct RowIndex(u32);
24
25impl RowIndex {
26 pub fn as_usize(&self) -> usize {
27 self.0 as usize
28 }
29
30 pub fn as_u32(&self) -> u32 {
31 self.0
32 }
33}
34
35impl Display for RowIndex {
36 fn fmt(&self, f: &mut Formatter) -> core::fmt::Result {
37 write!(f, "{}", self.0)
38 }
39}
40
41impl From<RowIndex> for u32 {
45 fn from(step: RowIndex) -> u32 {
46 step.0
47 }
48}
49
50impl From<RowIndex> for u64 {
51 fn from(step: RowIndex) -> u64 {
52 step.0 as u64
53 }
54}
55
56impl From<RowIndex> for usize {
57 fn from(step: RowIndex) -> usize {
58 step.0 as usize
59 }
60}
61
62impl From<RowIndex> for Felt {
63 fn from(step: RowIndex) -> Felt {
64 Felt::from(step.0)
65 }
66}
67
68impl From<usize> for RowIndex {
78 fn from(value: usize) -> Self {
79 let value = u32::try_from(value)
80 .map_err(|_| RowIndexError::InvalidSize(format!("{}_usize", value).into()))
81 .unwrap();
82 value.into()
83 }
84}
85
86impl TryFrom<u64> for RowIndex {
93 type Error = RowIndexError;
94
95 fn try_from(value: u64) -> Result<Self, Self::Error> {
96 let value = u32::try_from(value)
97 .map_err(|_| RowIndexError::InvalidSize(format!("{}_u64", value).into()))?;
98 Ok(RowIndex::from(value))
99 }
100}
101
102impl From<u32> for RowIndex {
103 fn from(value: u32) -> Self {
104 Self(value)
105 }
106}
107
108impl From<i32> for RowIndex {
114 fn from(value: i32) -> Self {
115 let value = u32::try_from(value)
116 .map_err(|_| RowIndexError::InvalidSize(format!("{}_i32", value).into()))
117 .unwrap();
118 RowIndex(value)
119 }
120}
121
122impl Sub<usize> for RowIndex {
132 type Output = RowIndex;
133
134 fn sub(self, rhs: usize) -> Self::Output {
135 let rhs = u32::try_from(rhs)
136 .map_err(|_| RowIndexError::InvalidSize(format!("{}_usize", rhs).into()))
137 .unwrap();
138 RowIndex(self.0 - rhs)
139 }
140}
141
142impl SubAssign<u32> for RowIndex {
143 fn sub_assign(&mut self, rhs: u32) {
144 self.0 -= rhs;
145 }
146}
147
148impl Sub<RowIndex> for RowIndex {
149 type Output = usize;
150
151 fn sub(self, rhs: RowIndex) -> Self::Output {
152 (self.0 - rhs.0) as usize
153 }
154}
155
156impl RowIndex {
157 pub fn saturating_sub(self, rhs: u32) -> Self {
158 RowIndex(self.0.saturating_sub(rhs))
159 }
160
161 pub fn max(self, other: RowIndex) -> Self {
162 RowIndex(self.0.max(other.0))
163 }
164}
165
166impl Add<usize> for RowIndex {
173 type Output = RowIndex;
174
175 fn add(self, rhs: usize) -> Self::Output {
176 let rhs = u32::try_from(rhs)
177 .map_err(|_| RowIndexError::InvalidSize(format!("{}_usize", rhs).into()))
178 .unwrap();
179 RowIndex(self.0 + rhs)
180 }
181}
182
183impl Add<RowIndex> for u32 {
184 type Output = RowIndex;
185
186 fn add(self, rhs: RowIndex) -> Self::Output {
187 RowIndex(self + rhs.0)
188 }
189}
190
191impl AddAssign<usize> for RowIndex {
198 fn add_assign(&mut self, rhs: usize) {
199 let rhs: RowIndex = rhs.into();
200 self.0 += rhs.0;
201 }
202}
203
204impl Mul<RowIndex> for usize {
205 type Output = RowIndex;
206
207 fn mul(self, rhs: RowIndex) -> Self::Output {
208 (self * rhs.0 as usize).into()
209 }
210}
211
212impl PartialEq<RowIndex> for RowIndex {
216 fn eq(&self, rhs: &RowIndex) -> bool {
217 self.0 == rhs.0
218 }
219}
220
221impl PartialEq<usize> for RowIndex {
222 fn eq(&self, rhs: &usize) -> bool {
223 self.0
224 == u32::try_from(*rhs)
225 .map_err(|_| RowIndexError::InvalidSize(format!("{}_usize", *rhs).into()))
226 .unwrap()
227 }
228}
229
230impl PartialEq<RowIndex> for i32 {
231 fn eq(&self, rhs: &RowIndex) -> bool {
232 *self as u32 == u32::from(*rhs)
233 }
234}
235
236impl PartialOrd<usize> for RowIndex {
237 fn partial_cmp(&self, rhs: &usize) -> Option<core::cmp::Ordering> {
238 let rhs = u32::try_from(*rhs)
239 .map_err(|_| RowIndexError::InvalidSize(format!("{}_usize", *rhs).into()))
240 .unwrap();
241 self.0.partial_cmp(&rhs)
242 }
243}
244
245impl<T> Index<RowIndex> for [T] {
246 type Output = T;
247 fn index(&self, i: RowIndex) -> &Self::Output {
248 &self[i.0 as usize]
249 }
250}
251
252impl<T> IndexMut<RowIndex> for [T] {
253 fn index_mut(&mut self, i: RowIndex) -> &mut Self::Output {
254 &mut self[i.0 as usize]
255 }
256}
257
258impl RangeBounds<RowIndex> for RowIndex {
259 fn start_bound(&self) -> Bound<&Self> {
260 Bound::Included(self)
261 }
262 fn end_bound(&self) -> Bound<&Self> {
263 Bound::Included(self)
264 }
265}
266
267#[cfg(test)]
270mod tests {
271 use alloc::collections::BTreeMap;
272
273 #[test]
274 fn row_index_conversions() {
275 use super::RowIndex;
276 let _: RowIndex = 5.into();
278 let _: RowIndex = 5u32.into();
279 let _: RowIndex = (5usize).into();
280
281 let _: u32 = RowIndex(5).into();
283 let _: u64 = RowIndex(5).into();
284 let _: usize = RowIndex(5).into();
285 }
286
287 #[test]
288 fn row_index_ops() {
289 use super::RowIndex;
290
291 assert_eq!(RowIndex(5), 5);
293 assert_eq!(RowIndex(5), RowIndex(5));
294 assert!(RowIndex(5) == RowIndex(5));
295 assert!(RowIndex(5) >= RowIndex(5));
296 assert!(RowIndex(6) >= RowIndex(5));
297 assert!(RowIndex(5) > RowIndex(4));
298 assert!(RowIndex(5) <= RowIndex(5));
299 assert!(RowIndex(4) <= RowIndex(5));
300 assert!(RowIndex(5) < RowIndex(6));
301
302 assert_eq!(RowIndex(5) + 3, 8);
304 assert_eq!(RowIndex(5) - 3, 2);
305 assert_eq!(3 + RowIndex(5), 8);
306 assert_eq!(2 * RowIndex(5), 10);
307
308 let mut step = RowIndex(5);
310 step += 5;
311 assert_eq!(step, 10);
312 }
313
314 #[test]
315 fn row_index_range() {
316 use super::RowIndex;
317 let mut tree: BTreeMap<RowIndex, usize> = BTreeMap::new();
318 tree.insert(RowIndex(0), 0);
319 tree.insert(RowIndex(1), 1);
320 tree.insert(RowIndex(2), 2);
321 let acc =
322 tree.range(RowIndex::from(0)..RowIndex::from(tree.len()))
323 .fold(0, |acc, (key, val)| {
324 assert_eq!(*key, RowIndex::from(acc));
325 assert_eq!(*val, acc);
326 acc + 1
327 });
328 assert_eq!(acc, 3);
329 }
330
331 #[test]
332 fn row_index_display() {
333 assert_eq!(format!("{}", super::RowIndex(5)), "5");
334 }
335}