Skip to main content

miden_utils_indexing/
csr.rs

1//! Compressed Sparse Row (CSR) matrix for efficient sparse data storage.
2//!
3//! This module provides a generic [`CsrMatrix`] type that maps row indices to variable-length
4//! data. It's commonly used for storing decorator IDs, assembly operation IDs, and similar
5//! sparse mappings in the Miden VM.
6
7use alloc::vec::Vec;
8
9use miden_crypto::utils::{
10    ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
11};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15
16use crate::{Idx, IndexVec, IndexedVecError};
17
18// CSR MATRIX
19// ================================================================================================
20
21/// Compressed Sparse Row matrix mapping row indices to variable-length data.
22///
23/// For row `i`, its data is at `data[indptr[i]..indptr[i+1]]`.
24///
25/// # Type Parameters
26///
27/// - `I`: The row index type, must implement [`Idx`].
28/// - `D`: The data type stored in each row.
29///
30/// # Example
31///
32/// ```ignore
33/// use miden_utils_indexing::{CsrMatrix, newtype_id};
34///
35/// newtype_id!(NodeId);
36///
37/// let mut csr = CsrMatrix::<NodeId, u32>::new();
38/// csr.push_row([1, 2, 3]);       // Row 0: [1, 2, 3]
39/// csr.push_empty_row();          // Row 1: []
40/// csr.push_row([4, 5]);          // Row 2: [4, 5]
41///
42/// assert_eq!(csr.row(NodeId::from(0)), Some(&[1, 2, 3][..]));
43/// assert_eq!(csr.row(NodeId::from(1)), Some(&[][..]));
44/// assert_eq!(csr.row(NodeId::from(2)), Some(&[4, 5][..]));
45/// ```
46#[derive(Debug, Clone, PartialEq, Eq)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct CsrMatrix<I: Idx, D> {
49    /// Flat storage of all data values.
50    data: Vec<D>,
51    /// Row pointers: row i's data is at `data[indptr[i]..indptr[i+1]]`.
52    indptr: IndexVec<I, usize>,
53}
54
55impl<I: Idx, D> Default for CsrMatrix<I, D> {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl<I: Idx, D> CsrMatrix<I, D> {
62    // CONSTRUCTORS
63    // --------------------------------------------------------------------------------------------
64
65    /// Creates a new empty [`CsrMatrix`].
66    pub fn new() -> Self {
67        Self {
68            data: Vec::new(),
69            indptr: IndexVec::new(),
70        }
71    }
72
73    /// Creates a [`CsrMatrix`] with pre-allocated capacity.
74    ///
75    /// # Arguments
76    ///
77    /// - `num_rows`: Expected number of rows.
78    /// - `num_elements`: Expected total number of data elements across all rows.
79    pub fn with_capacity(num_rows: usize, num_elements: usize) -> Self {
80        Self {
81            data: Vec::with_capacity(num_elements),
82            indptr: IndexVec::with_capacity(num_rows + 1),
83        }
84    }
85
86    // MUTATION
87    // --------------------------------------------------------------------------------------------
88
89    /// Appends a new row with the given data values and returns the index of the newly added row.
90    ///
91    /// Rows must be added in sequential order starting from row 0.
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the number of rows would exceed `u32::MAX`.
96    pub fn push_row(&mut self, values: impl IntoIterator<Item = D>) -> Result<I, IndexedVecError> {
97        // Initialize indptr with 0 if this is the first row
98        if self.indptr.is_empty() {
99            self.indptr.push(0)?;
100        }
101
102        // The row index is the current number of rows (before adding)
103        let row_idx = self.num_rows();
104
105        // Add data
106        self.data.extend(values);
107
108        // Add end pointer for this row
109        self.indptr.push(self.data.len())?;
110
111        Ok(I::from(row_idx as u32))
112    }
113
114    /// Appends an empty row (no data for this row index).
115    ///
116    /// # Errors
117    ///
118    /// Returns an error if the number of rows would exceed `u32::MAX`.
119    pub fn push_empty_row(&mut self) -> Result<I, IndexedVecError> {
120        self.push_row(core::iter::empty())
121    }
122
123    /// Appends empty rows to fill gaps up to (but not including) `target_row`.
124    ///
125    /// If `target_row` is less than or equal to the current number of rows, this is a no-op.
126    ///
127    /// # Errors
128    ///
129    /// Returns an error if the number of rows would exceed `u32::MAX`.
130    pub fn fill_to_row(&mut self, target_row: I) -> Result<(), IndexedVecError> {
131        let target = target_row.to_usize();
132        while self.num_rows() < target {
133            self.push_empty_row()?;
134        }
135        Ok(())
136    }
137
138    // ACCESSORS
139    // --------------------------------------------------------------------------------------------
140
141    /// Returns `true` if this matrix has no rows.
142    pub fn is_empty(&self) -> bool {
143        self.indptr.is_empty()
144    }
145
146    /// Returns the number of rows in this matrix.
147    pub fn num_rows(&self) -> usize {
148        if self.indptr.is_empty() {
149            0
150        } else {
151            self.indptr.len() - 1
152        }
153    }
154
155    /// Returns the total number of data elements across all rows.
156    pub fn num_elements(&self) -> usize {
157        self.data.len()
158    }
159
160    /// Returns the data slice for the given row, or `None` if the row doesn't exist.
161    pub fn row(&self, row: I) -> Option<&[D]> {
162        let row_idx = row.to_usize();
163        if row_idx >= self.num_rows() {
164            return None;
165        }
166
167        let start = self.indptr[row];
168        let end = self.indptr[I::from((row_idx + 1) as u32)];
169        Some(&self.data[start..end])
170    }
171
172    /// Returns the data slice for the given row, panicking if the row doesn't exist.
173    ///
174    /// # Panics
175    ///
176    /// Panics if `row` is out of bounds.
177    pub fn row_expect(&self, row: I) -> &[D] {
178        self.row(row).expect("row index out of bounds")
179    }
180
181    /// Returns an iterator over all `(row_index, data_slice)` pairs.
182    pub fn iter(&self) -> impl Iterator<Item = (I, &[D])> {
183        (0..self.num_rows()).map(move |i| {
184            let row = I::from(i as u32);
185            (row, self.row_expect(row))
186        })
187    }
188
189    /// Returns an iterator over all data elements with their `(row_index, position_in_row, &data)`.
190    pub fn iter_enumerated(&self) -> impl Iterator<Item = (I, usize, &D)> {
191        self.iter()
192            .flat_map(|(row, data)| data.iter().enumerate().map(move |(pos, d)| (row, pos, d)))
193    }
194
195    /// Returns the underlying data slice.
196    pub fn data(&self) -> &[D] {
197        &self.data
198    }
199
200    /// Returns the underlying indptr.
201    pub fn indptr(&self) -> &IndexVec<I, usize> {
202        &self.indptr
203    }
204
205    // VALIDATION
206    // --------------------------------------------------------------------------------------------
207
208    /// Validates the CSR structural invariants.
209    ///
210    /// Checks:
211    /// - `indptr` starts at 0 (if non-empty)
212    /// - `indptr` is monotonically increasing
213    /// - `indptr` ends at `data.len()`
214    ///
215    /// For domain-specific validation of data values, use [`validate_with`](Self::validate_with).
216    pub fn validate(&self) -> Result<(), CsrValidationError> {
217        self.validate_with(|_| true)
218    }
219
220    /// Validates structural invariants plus domain-specific data constraints.
221    ///
222    /// The callback is invoked for each data element. Return `false` to indicate
223    /// an invalid value.
224    ///
225    /// # Arguments
226    ///
227    /// - `f`: A function that returns `true` if the data value is valid.
228    pub fn validate_with<F>(&self, f: F) -> Result<(), CsrValidationError>
229    where
230        F: Fn(&D) -> bool,
231    {
232        let indptr = self.indptr.as_slice();
233
234        // Empty matrix is valid
235        if indptr.is_empty() {
236            return Ok(());
237        }
238
239        // Check indptr starts at 0
240        if indptr[0] != 0 {
241            return Err(CsrValidationError::IndptrStartNotZero(indptr[0]));
242        }
243
244        // Check indptr is monotonic
245        for i in 1..indptr.len() {
246            if indptr[i - 1] > indptr[i] {
247                return Err(CsrValidationError::IndptrNotMonotonic {
248                    index: i,
249                    prev: indptr[i - 1],
250                    curr: indptr[i],
251                });
252            }
253        }
254
255        // Check indptr ends at data.len()
256        let last = *indptr.last().expect("indptr is non-empty");
257        if last != self.data.len() {
258            return Err(CsrValidationError::IndptrDataMismatch {
259                indptr_end: last,
260                data_len: self.data.len(),
261            });
262        }
263
264        // Validate data values
265        for (row, data) in self.iter() {
266            for (pos, d) in data.iter().enumerate() {
267                if !f(d) {
268                    return Err(CsrValidationError::InvalidData {
269                        row: row.to_usize(),
270                        position: pos,
271                    });
272                }
273            }
274        }
275
276        Ok(())
277    }
278}
279
280// CSR VALIDATION ERROR
281// ================================================================================================
282
283/// Errors that can occur during CSR validation.
284#[derive(Debug, Clone, PartialEq, Eq, Error)]
285pub enum CsrValidationError {
286    /// The indptr array must start at 0.
287    #[error("indptr must start at 0, got {0}")]
288    IndptrStartNotZero(usize),
289
290    /// The indptr array must be monotonically increasing.
291    #[error("indptr not monotonic at index {index}: {prev} > {curr}")]
292    IndptrNotMonotonic { index: usize, prev: usize, curr: usize },
293
294    /// The last indptr value must equal data.len().
295    #[error("indptr ends at {indptr_end}, but data.len() is {data_len}")]
296    IndptrDataMismatch { indptr_end: usize, data_len: usize },
297
298    /// A data value failed domain-specific validation.
299    #[error("invalid data value at row {row}, position {position}")]
300    InvalidData { row: usize, position: usize },
301}
302
303// SERIALIZATION
304// ================================================================================================
305
306impl<I, D> Serializable for CsrMatrix<I, D>
307where
308    I: Idx,
309    D: Serializable,
310{
311    fn write_into<W: ByteWriter>(&self, target: &mut W) {
312        // Write data
313        target.write_usize(self.data.len());
314        for item in &self.data {
315            item.write_into(target);
316        }
317
318        // Write indptr
319        target.write_usize(self.indptr.len());
320        for &ptr in self.indptr.as_slice() {
321            target.write_usize(ptr);
322        }
323    }
324}
325
326impl<I, D> Deserializable for CsrMatrix<I, D>
327where
328    I: Idx,
329    D: Deserializable,
330{
331    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
332        // Read data using read_many_iter for BudgetedReader integration
333        let data_len = source.read_usize()?;
334        let data: Vec<D> = source.read_many_iter(data_len)?.collect::<Result<_, _>>()?;
335
336        // Read indptr using read_many_iter for BudgetedReader integration
337        let indptr_len = source.read_usize()?;
338        let indptr_vec: Vec<usize> =
339            source.read_many_iter(indptr_len)?.collect::<Result<_, _>>()?;
340        let indptr = IndexVec::try_from(indptr_vec).map_err(|_| {
341            DeserializationError::InvalidValue("indptr too large for IndexVec".into())
342        })?;
343
344        Ok(Self { data, indptr })
345    }
346
347    /// Returns the minimum serialized size for a CsrMatrix.
348    ///
349    /// A CsrMatrix serializes as:
350    /// - data_len (vint, minimum 1 byte)
351    /// - data elements (minimum 0 if empty)
352    /// - indptr_len (vint, minimum 1 byte)
353    /// - indptr elements (minimum 0 if empty)
354    ///
355    /// Total minimum: 2 bytes (two vint length prefixes for empty matrix)
356    fn min_serialized_size() -> usize {
357        2
358    }
359}
360
361// TESTS
362// ================================================================================================
363
364#[cfg(test)]
365mod tests {
366    use alloc::vec;
367
368    use super::*;
369    use crate::newtype_id;
370
371    newtype_id!(TestRowId);
372
373    #[test]
374    fn test_new_is_empty() {
375        let csr = CsrMatrix::<TestRowId, u32>::new();
376        assert!(csr.is_empty());
377        assert_eq!(csr.num_rows(), 0);
378        assert_eq!(csr.num_elements(), 0);
379    }
380
381    #[test]
382    fn test_push_row() {
383        let mut csr = CsrMatrix::<TestRowId, u32>::new();
384
385        let id0 = csr.push_row([1, 2, 3]).unwrap();
386        assert_eq!(id0, TestRowId::from(0));
387        assert_eq!(csr.num_rows(), 1);
388        assert_eq!(csr.num_elements(), 3);
389        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2, 3][..]));
390
391        let id1 = csr.push_row([4, 5]).unwrap();
392        assert_eq!(id1, TestRowId::from(1));
393        assert_eq!(csr.num_rows(), 2);
394        assert_eq!(csr.num_elements(), 5);
395        assert_eq!(csr.row(TestRowId::from(1)), Some(&[4, 5][..]));
396    }
397
398    #[test]
399    fn test_push_empty_row() {
400        let mut csr = CsrMatrix::<TestRowId, u32>::new();
401
402        csr.push_row([1, 2]).unwrap();
403        csr.push_empty_row().unwrap();
404        csr.push_row([3]).unwrap();
405
406        assert_eq!(csr.num_rows(), 3);
407        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2][..]));
408        assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
409        assert_eq!(csr.row(TestRowId::from(2)), Some(&[3][..]));
410    }
411
412    #[test]
413    fn test_fill_to_row() {
414        let mut csr = CsrMatrix::<TestRowId, u32>::new();
415
416        csr.push_row([1]).unwrap();
417        csr.fill_to_row(TestRowId::from(3)).unwrap();
418        csr.push_row([2]).unwrap();
419
420        assert_eq!(csr.num_rows(), 4);
421        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
422        assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
423        assert_eq!(csr.row(TestRowId::from(2)), Some(&[][..]));
424        assert_eq!(csr.row(TestRowId::from(3)), Some(&[2][..]));
425    }
426
427    #[test]
428    fn test_row_out_of_bounds() {
429        let mut csr = CsrMatrix::<TestRowId, u32>::new();
430        csr.push_row([1]).unwrap();
431
432        assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
433        assert_eq!(csr.row(TestRowId::from(1)), None);
434        assert_eq!(csr.row(TestRowId::from(100)), None);
435    }
436
437    #[test]
438    fn test_iter() {
439        let mut csr = CsrMatrix::<TestRowId, u32>::new();
440        csr.push_row([1, 2]).unwrap();
441        csr.push_empty_row().unwrap();
442        csr.push_row([3]).unwrap();
443
444        let items: alloc::vec::Vec<_> = csr.iter().collect();
445        assert_eq!(items.len(), 3);
446        assert_eq!(items[0], (TestRowId::from(0), &[1, 2][..]));
447        assert_eq!(items[1], (TestRowId::from(1), &[][..]));
448        assert_eq!(items[2], (TestRowId::from(2), &[3][..]));
449    }
450
451    #[test]
452    fn test_iter_enumerated() {
453        let mut csr = CsrMatrix::<TestRowId, u32>::new();
454        csr.push_row([10, 20]).unwrap();
455        csr.push_row([30]).unwrap();
456
457        let items: alloc::vec::Vec<_> = csr.iter_enumerated().collect();
458        assert_eq!(items.len(), 3);
459        assert_eq!(items[0], (TestRowId::from(0), 0, &10));
460        assert_eq!(items[1], (TestRowId::from(0), 1, &20));
461        assert_eq!(items[2], (TestRowId::from(1), 0, &30));
462    }
463
464    #[test]
465    fn test_validate_empty() {
466        let csr = CsrMatrix::<TestRowId, u32>::new();
467        assert!(csr.validate().is_ok());
468    }
469
470    #[test]
471    fn test_validate_valid() {
472        let mut csr = CsrMatrix::<TestRowId, u32>::new();
473        csr.push_row([1, 2, 3]).unwrap();
474        csr.push_empty_row().unwrap();
475        csr.push_row([4]).unwrap();
476
477        assert!(csr.validate().is_ok());
478    }
479
480    #[test]
481    fn test_validate_with_callback() {
482        let mut csr = CsrMatrix::<TestRowId, u32>::new();
483        csr.push_row([1, 2, 3]).unwrap();
484        csr.push_row([4, 5]).unwrap();
485
486        // All values < 10: valid
487        assert!(csr.validate_with(|&v| v < 10).is_ok());
488
489        // All values < 4: invalid (first failure is 4 at row 1, position 0)
490        let result = csr.validate_with(|&v| v < 4);
491        assert!(matches!(result, Err(CsrValidationError::InvalidData { row: 1, position: 0 })));
492    }
493
494    #[test]
495    fn test_serialization_roundtrip() {
496        let mut csr = CsrMatrix::<TestRowId, u32>::new();
497        csr.push_row([1, 2, 3]).unwrap();
498        csr.push_empty_row().unwrap();
499        csr.push_row([4, 5]).unwrap();
500
501        // Serialize
502        let mut bytes = vec![];
503        csr.write_into(&mut bytes);
504
505        // Deserialize
506        let mut reader = miden_crypto::utils::SliceReader::new(&bytes);
507        let restored: CsrMatrix<TestRowId, u32> = CsrMatrix::read_from(&mut reader).unwrap();
508
509        assert_eq!(csr, restored);
510    }
511}