Skip to main content

miden_utils_indexing/
csr.rs

1//! Compressed Sparse Row (CSR) matrix for efficient sparse data storage.
2//!
3//! This module provides a generic [`CsrMatrix`] type that maps row indices to variable-length
4//! data. It's commonly used for storing decorator IDs, assembly operation IDs, and similar
5//! sparse mappings in the Miden VM.
6
7use alloc::vec::Vec;
8
9use miden_crypto::utils::{
10    ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
11};
12#[cfg(feature = "arbitrary")]
13use proptest::prelude::*;
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17
18use crate::{Idx, IndexVec, IndexedVecError};
19
20// CSR MATRIX
21// ================================================================================================
22
23/// Compressed Sparse Row matrix mapping row indices to variable-length data.
24///
25/// For row `i`, its data is at `data[indptr[i]..indptr[i+1]]`.
26///
27/// # Type Parameters
28///
29/// - `I`: The row index type, must implement [`Idx`].
30/// - `D`: The data type stored in each row.
31///
32/// # Example
33///
34/// ```ignore
35/// use miden_utils_indexing::{CsrMatrix, newtype_id};
36///
37/// newtype_id!(NodeId);
38///
39/// let mut csr = CsrMatrix::<NodeId, u32>::new();
40/// csr.push_row([1, 2, 3]);       // Row 0: [1, 2, 3]
41/// csr.push_empty_row();          // Row 1: []
42/// csr.push_row([4, 5]);          // Row 2: [4, 5]
43///
44/// assert_eq!(csr.row(NodeId::from(0)), Some(&[1, 2, 3][..]));
45/// assert_eq!(csr.row(NodeId::from(1)), Some(&[][..]));
46/// assert_eq!(csr.row(NodeId::from(2)), Some(&[4, 5][..]));
47/// ```
48#[derive(Debug, Clone, PartialEq, Eq)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50#[cfg_attr(
51    all(feature = "arbitrary", test),
52    miden_test_serde_macros::serde_test(binary_serde(true), types(crate::SerdeTestId, u32))
53)]
54pub struct CsrMatrix<I: Idx, D> {
55    /// Flat storage of all data values.
56    data: Vec<D>,
57    /// Row pointers: row i's data is at `data[indptr[i]..indptr[i+1]]`.
58    indptr: IndexVec<I, usize>,
59}
60
61#[cfg(feature = "arbitrary")]
62impl<I, D> Arbitrary for CsrMatrix<I, D>
63where
64    I: Idx + 'static,
65    D: Arbitrary + 'static,
66    D::Strategy: 'static,
67{
68    type Parameters = D::Parameters;
69    type Strategy = BoxedStrategy<Self>;
70
71    fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
72        let row = proptest::collection::vec(any_with::<D>(args), 0..8);
73
74        proptest::collection::vec(row, 0..16)
75            .prop_map(|rows| {
76                let mut matrix = Self::new();
77                for row in rows {
78                    matrix.push_row(row).expect("generated row count fits in u32");
79                }
80                matrix
81            })
82            .boxed()
83    }
84}
85
86impl<I: Idx, D> Default for CsrMatrix<I, D> {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl<I: Idx, D> CsrMatrix<I, D> {
93    // CONSTRUCTORS
94    // --------------------------------------------------------------------------------------------
95
96    /// Creates a new empty [`CsrMatrix`].
97    pub fn new() -> Self {
98        Self {
99            data: Vec::new(),
100            indptr: IndexVec::new(),
101        }
102    }
103
104    /// Creates a [`CsrMatrix`] with pre-allocated capacity.
105    ///
106    /// # Arguments
107    ///
108    /// - `num_rows`: Expected number of rows.
109    /// - `num_elements`: Expected total number of data elements across all rows.
110    pub fn with_capacity(num_rows: usize, num_elements: usize) -> Self {
111        Self {
112            data: Vec::with_capacity(num_elements),
113            indptr: IndexVec::with_capacity(num_rows + 1),
114        }
115    }
116
117    // MUTATION
118    // --------------------------------------------------------------------------------------------
119
120    /// Appends a new row with the given data values and returns the index of the newly added row.
121    ///
122    /// Rows must be added in sequential order starting from row 0.
123    ///
124    /// # Errors
125    ///
126    /// Returns an error if the number of rows would exceed `u32::MAX`.
127    pub fn push_row(&mut self, values: impl IntoIterator<Item = D>) -> Result<I, IndexedVecError> {
128        // Initialize indptr with 0 if this is the first row
129        if self.indptr.is_empty() {
130            self.indptr.push(0)?;
131        }
132
133        // The row index is the current number of rows (before adding)
134        let row_idx = self.num_rows();
135
136        // Add data
137        self.data.extend(values);
138
139        // Add end pointer for this row
140        self.indptr.push(self.data.len())?;
141
142        Ok(I::from(row_idx as u32))
143    }
144
145    /// Appends an empty row (no data for this row index).
146    ///
147    /// # Errors
148    ///
149    /// Returns an error if the number of rows would exceed `u32::MAX`.
150    pub fn push_empty_row(&mut self) -> Result<I, IndexedVecError> {
151        self.push_row(core::iter::empty())
152    }
153
154    /// Appends empty rows to fill gaps up to (but not including) `target_row`.
155    ///
156    /// If `target_row` is less than or equal to the current number of rows, this is a no-op.
157    ///
158    /// # Errors
159    ///
160    /// Returns an error if the number of rows would exceed `u32::MAX`.
161    pub fn fill_to_row(&mut self, target_row: I) -> Result<(), IndexedVecError> {
162        let target = target_row.to_usize();
163        while self.num_rows() < target {
164            self.push_empty_row()?;
165        }
166        Ok(())
167    }
168
169    // ACCESSORS
170    // --------------------------------------------------------------------------------------------
171
172    /// Returns `true` if this matrix has no rows.
173    pub fn is_empty(&self) -> bool {
174        self.indptr.is_empty()
175    }
176
177    /// Returns the number of rows in this matrix.
178    pub fn num_rows(&self) -> usize {
179        if self.indptr.is_empty() {
180            0
181        } else {
182            self.indptr.len() - 1
183        }
184    }
185
186    /// Returns the total number of data elements across all rows.
187    pub fn num_elements(&self) -> usize {
188        self.data.len()
189    }
190
191    /// Returns the data slice for the given row, or `None` if the row doesn't exist.
192    pub fn row(&self, row: I) -> Option<&[D]> {
193        let row_idx = row.to_usize();
194        if row_idx >= self.num_rows() {
195            return None;
196        }
197
198        let start = self.indptr[row];
199        let end = self.indptr[I::from((row_idx + 1) as u32)];
200        Some(&self.data[start..end])
201    }
202
203    /// Returns the data slice for the given row, panicking if the row doesn't exist.
204    ///
205    /// # Panics
206    ///
207    /// Panics if `row` is out of bounds.
208    pub fn row_expect(&self, row: I) -> &[D] {
209        self.row(row).expect("row index out of bounds")
210    }
211
212    /// Returns an iterator over all `(row_index, data_slice)` pairs.
213    pub fn iter(&self) -> impl Iterator<Item = (I, &[D])> {
214        (0..self.num_rows()).map(move |i| {
215            let row = I::from(i as u32);
216            (row, self.row_expect(row))
217        })
218    }
219
220    /// Returns an iterator over all data elements with their `(row_index, position_in_row, &data)`.
221    pub fn iter_enumerated(&self) -> impl Iterator<Item = (I, usize, &D)> {
222        self.iter()
223            .flat_map(|(row, data)| data.iter().enumerate().map(move |(pos, d)| (row, pos, d)))
224    }
225
226    /// Returns the underlying data slice.
227    pub fn data(&self) -> &[D] {
228        &self.data
229    }
230
231    /// Returns the underlying indptr.
232    pub fn indptr(&self) -> &IndexVec<I, usize> {
233        &self.indptr
234    }
235
236    // VALIDATION
237    // --------------------------------------------------------------------------------------------
238
239    /// Validates the CSR structural invariants.
240    ///
241    /// Checks:
242    /// - `indptr` starts at 0 (if non-empty)
243    /// - `indptr` is monotonically increasing
244    /// - `indptr` ends at `data.len()`
245    ///
246    /// For domain-specific validation of data values, use [`validate_with`](Self::validate_with).
247    pub fn validate(&self) -> Result<(), CsrValidationError> {
248        self.validate_with(|_| true)
249    }
250
251    /// Validates structural invariants plus domain-specific data constraints.
252    ///
253    /// The callback is invoked for each data element. Return `false` to indicate
254    /// an invalid value.
255    ///
256    /// # Arguments
257    ///
258    /// - `f`: A function that returns `true` if the data value is valid.
259    pub fn validate_with<F>(&self, f: F) -> Result<(), CsrValidationError>
260    where
261        F: Fn(&D) -> bool,
262    {
263        let indptr = self.indptr.as_slice();
264
265        // Empty matrix is valid
266        if indptr.is_empty() {
267            return Ok(());
268        }
269
270        // Check indptr starts at 0
271        if indptr[0] != 0 {
272            return Err(CsrValidationError::IndptrStartNotZero(indptr[0]));
273        }
274
275        // Check indptr is monotonic
276        for i in 1..indptr.len() {
277            if indptr[i - 1] > indptr[i] {
278                return Err(CsrValidationError::IndptrNotMonotonic {
279                    index: i,
280                    prev: indptr[i - 1],
281                    curr: indptr[i],
282                });
283            }
284        }
285
286        // Check indptr ends at data.len()
287        let last = *indptr.last().expect("indptr is non-empty");
288        if last != self.data.len() {
289            return Err(CsrValidationError::IndptrDataMismatch {
290                indptr_end: last,
291                data_len: self.data.len(),
292            });
293        }
294
295        // Validate data values
296        for (row, data) in self.iter() {
297            for (pos, d) in data.iter().enumerate() {
298                if !f(d) {
299                    return Err(CsrValidationError::InvalidData {
300                        row: row.to_usize(),
301                        position: pos,
302                    });
303                }
304            }
305        }
306
307        Ok(())
308    }
309}
310
311// CSR VALIDATION ERROR
312// ================================================================================================
313
314/// Errors that can occur during CSR validation.
315#[derive(Debug, Clone, PartialEq, Eq, Error)]
316pub enum CsrValidationError {
317    /// The indptr array must start at 0.
318    #[error("indptr must start at 0, got {0}")]
319    IndptrStartNotZero(usize),
320
321    /// The indptr array must be monotonically increasing.
322    #[error("indptr not monotonic at index {index}: {prev} > {curr}")]
323    IndptrNotMonotonic { index: usize, prev: usize, curr: usize },
324
325    /// The last indptr value must equal data.len().
326    #[error("indptr ends at {indptr_end}, but data.len() is {data_len}")]
327    IndptrDataMismatch { indptr_end: usize, data_len: usize },
328
329    /// A data value failed domain-specific validation.
330    #[error("invalid data value at row {row}, position {position}")]
331    InvalidData { row: usize, position: usize },
332}
333
334// SERIALIZATION
335// ================================================================================================
336
337impl<I, D> Serializable for CsrMatrix<I, D>
338where
339    I: Idx,
340    D: Serializable,
341{
342    fn write_into<W: ByteWriter>(&self, target: &mut W) {
343        // Write data
344        target.write_usize(self.data.len());
345        for item in &self.data {
346            item.write_into(target);
347        }
348
349        // Write indptr
350        target.write_usize(self.indptr.len());
351        for &ptr in self.indptr.as_slice() {
352            target.write_usize(ptr);
353        }
354    }
355}
356
357impl<I, D> Deserializable for CsrMatrix<I, D>
358where
359    I: Idx,
360    D: Deserializable,
361{
362    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
363        // Read data using read_many_iter for BudgetedReader integration
364        let data_len = source.read_usize()?;
365        let data: Vec<D> = source.read_many_iter(data_len)?.collect::<Result<_, _>>()?;
366
367        // Read indptr using read_many_iter for BudgetedReader integration
368        let indptr_len = source.read_usize()?;
369        let indptr_vec: Vec<usize> =
370            source.read_many_iter(indptr_len)?.collect::<Result<_, _>>()?;
371        let indptr = IndexVec::try_from(indptr_vec).map_err(|_| {
372            DeserializationError::InvalidValue("indptr too large for IndexVec".into())
373        })?;
374
375        Ok(Self { data, indptr })
376    }
377
378    /// Returns the minimum serialized size for a CsrMatrix.
379    ///
380    /// A CsrMatrix serializes as:
381    /// - data_len (vint, minimum 1 byte)
382    /// - data elements (minimum 0 if empty)
383    /// - indptr_len (vint, minimum 1 byte)
384    /// - indptr elements (minimum 0 if empty)
385    ///
386    /// Total minimum: 2 bytes (two vint length prefixes for empty matrix)
387    fn min_serialized_size() -> usize {
388        2
389    }
390}
391
392// TESTS
393// ================================================================================================
394
395#[cfg(test)]
396mod tests {
397    use alloc::vec;
398
399    use super::*;
400    use crate::newtype_id;
401
402    newtype_id!(TestRowId);
403
404    #[test]
405    fn test_new_is_empty() {
406        let csr = CsrMatrix::<TestRowId, u32>::new();
407        assert!(csr.is_empty());
408        assert_eq!(csr.num_rows(), 0);
409        assert_eq!(csr.num_elements(), 0);
410    }
411
412    #[test]
413    fn test_push_row() {
414        let mut csr = CsrMatrix::<TestRowId, u32>::new();
415
416        let id0 = csr.push_row([1, 2, 3]).unwrap();
417        assert_eq!(id0, TestRowId::from(0));
418        assert_eq!(csr.num_rows(), 1);
419        assert_eq!(csr.num_elements(), 3);
420        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2, 3][..]));
421
422        let id1 = csr.push_row([4, 5]).unwrap();
423        assert_eq!(id1, TestRowId::from(1));
424        assert_eq!(csr.num_rows(), 2);
425        assert_eq!(csr.num_elements(), 5);
426        assert_eq!(csr.row(TestRowId::from(1)), Some(&[4, 5][..]));
427    }
428
429    #[test]
430    fn test_push_empty_row() {
431        let mut csr = CsrMatrix::<TestRowId, u32>::new();
432
433        csr.push_row([1, 2]).unwrap();
434        csr.push_empty_row().unwrap();
435        csr.push_row([3]).unwrap();
436
437        assert_eq!(csr.num_rows(), 3);
438        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2][..]));
439        assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
440        assert_eq!(csr.row(TestRowId::from(2)), Some(&[3][..]));
441    }
442
443    #[test]
444    fn test_fill_to_row() {
445        let mut csr = CsrMatrix::<TestRowId, u32>::new();
446
447        csr.push_row([1]).unwrap();
448        csr.fill_to_row(TestRowId::from(3)).unwrap();
449        csr.push_row([2]).unwrap();
450
451        assert_eq!(csr.num_rows(), 4);
452        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
453        assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
454        assert_eq!(csr.row(TestRowId::from(2)), Some(&[][..]));
455        assert_eq!(csr.row(TestRowId::from(3)), Some(&[2][..]));
456    }
457
458    #[test]
459    fn test_row_out_of_bounds() {
460        let mut csr = CsrMatrix::<TestRowId, u32>::new();
461        csr.push_row([1]).unwrap();
462
463        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
464        assert_eq!(csr.row(TestRowId::from(1)), None);
465        assert_eq!(csr.row(TestRowId::from(100)), None);
466    }
467
468    #[test]
469    fn test_iter() {
470        let mut csr = CsrMatrix::<TestRowId, u32>::new();
471        csr.push_row([1, 2]).unwrap();
472        csr.push_empty_row().unwrap();
473        csr.push_row([3]).unwrap();
474
475        let items: Vec<_> = csr.iter().collect();
476        assert_eq!(items.len(), 3);
477        assert_eq!(items[0], (TestRowId::from(0), &[1, 2][..]));
478        assert_eq!(items[1], (TestRowId::from(1), &[][..]));
479        assert_eq!(items[2], (TestRowId::from(2), &[3][..]));
480    }
481
482    #[test]
483    fn test_iter_enumerated() {
484        let mut csr = CsrMatrix::<TestRowId, u32>::new();
485        csr.push_row([10, 20]).unwrap();
486        csr.push_row([30]).unwrap();
487
488        let items: Vec<_> = csr.iter_enumerated().collect();
489        assert_eq!(items.len(), 3);
490        assert_eq!(items[0], (TestRowId::from(0), 0, &10));
491        assert_eq!(items[1], (TestRowId::from(0), 1, &20));
492        assert_eq!(items[2], (TestRowId::from(1), 0, &30));
493    }
494
495    #[test]
496    fn test_validate_empty() {
497        let csr = CsrMatrix::<TestRowId, u32>::new();
498        assert!(csr.validate().is_ok());
499    }
500
501    #[test]
502    fn test_validate_valid() {
503        let mut csr = CsrMatrix::<TestRowId, u32>::new();
504        csr.push_row([1, 2, 3]).unwrap();
505        csr.push_empty_row().unwrap();
506        csr.push_row([4]).unwrap();
507
508        assert!(csr.validate().is_ok());
509    }
510
511    #[test]
512    fn test_validate_with_callback() {
513        let mut csr = CsrMatrix::<TestRowId, u32>::new();
514        csr.push_row([1, 2, 3]).unwrap();
515        csr.push_row([4, 5]).unwrap();
516
517        // All values < 10: valid
518        assert!(csr.validate_with(|&v| v < 10).is_ok());
519
520        // All values < 4: invalid (first failure is 4 at row 1, position 0)
521        let result = csr.validate_with(|&v| v < 4);
522        assert!(matches!(result, Err(CsrValidationError::InvalidData { row: 1, position: 0 })));
523    }
524
525    #[test]
526    fn test_serialization_roundtrip() {
527        let mut csr = CsrMatrix::<TestRowId, u32>::new();
528        csr.push_row([1, 2, 3]).unwrap();
529        csr.push_empty_row().unwrap();
530        csr.push_row([4, 5]).unwrap();
531
532        // Serialize
533        let mut bytes = vec![];
534        csr.write_into(&mut bytes);
535
536        // Deserialize
537        let mut reader = miden_crypto::utils::SliceReader::new(&bytes);
538        let restored: CsrMatrix<TestRowId, u32> = CsrMatrix::read_from(&mut reader).unwrap();
539
540        assert_eq!(csr, restored);
541    }
542}