use alloc::vec::Vec;
use miden_crypto::utils::{
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::{Idx, IndexVec, IndexedVecError};
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct CsrMatrix<I: Idx, D> {
data: Vec<D>,
indptr: IndexVec<I, usize>,
}
impl<I: Idx, D> Default for CsrMatrix<I, D> {
fn default() -> Self {
Self::new()
}
}
impl<I: Idx, D> CsrMatrix<I, D> {
pub fn new() -> Self {
Self {
data: Vec::new(),
indptr: IndexVec::new(),
}
}
pub fn with_capacity(num_rows: usize, num_elements: usize) -> Self {
Self {
data: Vec::with_capacity(num_elements),
indptr: IndexVec::with_capacity(num_rows + 1),
}
}
pub fn push_row(&mut self, values: impl IntoIterator<Item = D>) -> Result<I, IndexedVecError> {
if self.indptr.is_empty() {
self.indptr.push(0)?;
}
let row_idx = self.num_rows();
self.data.extend(values);
self.indptr.push(self.data.len())?;
Ok(I::from(row_idx as u32))
}
pub fn push_empty_row(&mut self) -> Result<I, IndexedVecError> {
self.push_row(core::iter::empty())
}
pub fn fill_to_row(&mut self, target_row: I) -> Result<(), IndexedVecError> {
let target = target_row.to_usize();
while self.num_rows() < target {
self.push_empty_row()?;
}
Ok(())
}
pub fn is_empty(&self) -> bool {
self.indptr.is_empty()
}
pub fn num_rows(&self) -> usize {
if self.indptr.is_empty() {
0
} else {
self.indptr.len() - 1
}
}
pub fn num_elements(&self) -> usize {
self.data.len()
}
pub fn row(&self, row: I) -> Option<&[D]> {
let row_idx = row.to_usize();
if row_idx >= self.num_rows() {
return None;
}
let start = self.indptr[row];
let end = self.indptr[I::from((row_idx + 1) as u32)];
Some(&self.data[start..end])
}
pub fn row_expect(&self, row: I) -> &[D] {
self.row(row).expect("row index out of bounds")
}
pub fn iter(&self) -> impl Iterator<Item = (I, &[D])> {
(0..self.num_rows()).map(move |i| {
let row = I::from(i as u32);
(row, self.row_expect(row))
})
}
pub fn iter_enumerated(&self) -> impl Iterator<Item = (I, usize, &D)> {
self.iter()
.flat_map(|(row, data)| data.iter().enumerate().map(move |(pos, d)| (row, pos, d)))
}
pub fn data(&self) -> &[D] {
&self.data
}
pub fn indptr(&self) -> &IndexVec<I, usize> {
&self.indptr
}
pub fn validate(&self) -> Result<(), CsrValidationError> {
self.validate_with(|_| true)
}
pub fn validate_with<F>(&self, f: F) -> Result<(), CsrValidationError>
where
F: Fn(&D) -> bool,
{
let indptr = self.indptr.as_slice();
if indptr.is_empty() {
return Ok(());
}
if indptr[0] != 0 {
return Err(CsrValidationError::IndptrStartNotZero(indptr[0]));
}
for i in 1..indptr.len() {
if indptr[i - 1] > indptr[i] {
return Err(CsrValidationError::IndptrNotMonotonic {
index: i,
prev: indptr[i - 1],
curr: indptr[i],
});
}
}
let last = *indptr.last().expect("indptr is non-empty");
if last != self.data.len() {
return Err(CsrValidationError::IndptrDataMismatch {
indptr_end: last,
data_len: self.data.len(),
});
}
for (row, data) in self.iter() {
for (pos, d) in data.iter().enumerate() {
if !f(d) {
return Err(CsrValidationError::InvalidData {
row: row.to_usize(),
position: pos,
});
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum CsrValidationError {
#[error("indptr must start at 0, got {0}")]
IndptrStartNotZero(usize),
#[error("indptr not monotonic at index {index}: {prev} > {curr}")]
IndptrNotMonotonic { index: usize, prev: usize, curr: usize },
#[error("indptr ends at {indptr_end}, but data.len() is {data_len}")]
IndptrDataMismatch { indptr_end: usize, data_len: usize },
#[error("invalid data value at row {row}, position {position}")]
InvalidData { row: usize, position: usize },
}
impl<I, D> Serializable for CsrMatrix<I, D>
where
I: Idx,
D: Serializable,
{
fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write_usize(self.data.len());
for item in &self.data {
item.write_into(target);
}
target.write_usize(self.indptr.len());
for &ptr in self.indptr.as_slice() {
target.write_usize(ptr);
}
}
}
impl<I, D> Deserializable for CsrMatrix<I, D>
where
I: Idx,
D: Deserializable,
{
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let data_len = source.read_usize()?;
let data: Vec<D> = source.read_many_iter(data_len)?.collect::<Result<_, _>>()?;
let indptr_len = source.read_usize()?;
let indptr_vec: Vec<usize> =
source.read_many_iter(indptr_len)?.collect::<Result<_, _>>()?;
let indptr = IndexVec::try_from(indptr_vec).map_err(|_| {
DeserializationError::InvalidValue("indptr too large for IndexVec".into())
})?;
Ok(Self { data, indptr })
}
fn min_serialized_size() -> usize {
2
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use super::*;
use crate::newtype_id;
newtype_id!(TestRowId);
#[test]
fn test_new_is_empty() {
let csr = CsrMatrix::<TestRowId, u32>::new();
assert!(csr.is_empty());
assert_eq!(csr.num_rows(), 0);
assert_eq!(csr.num_elements(), 0);
}
#[test]
fn test_push_row() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
let id0 = csr.push_row([1, 2, 3]).unwrap();
assert_eq!(id0, TestRowId::from(0));
assert_eq!(csr.num_rows(), 1);
assert_eq!(csr.num_elements(), 3);
assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2, 3][..]));
let id1 = csr.push_row([4, 5]).unwrap();
assert_eq!(id1, TestRowId::from(1));
assert_eq!(csr.num_rows(), 2);
assert_eq!(csr.num_elements(), 5);
assert_eq!(csr.row(TestRowId::from(1)), Some(&[4, 5][..]));
}
#[test]
fn test_push_empty_row() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([1, 2]).unwrap();
csr.push_empty_row().unwrap();
csr.push_row([3]).unwrap();
assert_eq!(csr.num_rows(), 3);
assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2][..]));
assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
assert_eq!(csr.row(TestRowId::from(2)), Some(&[3][..]));
}
#[test]
fn test_fill_to_row() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([1]).unwrap();
csr.fill_to_row(TestRowId::from(3)).unwrap();
csr.push_row([2]).unwrap();
assert_eq!(csr.num_rows(), 4);
assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
assert_eq!(csr.row(TestRowId::from(2)), Some(&[][..]));
assert_eq!(csr.row(TestRowId::from(3)), Some(&[2][..]));
}
#[test]
fn test_row_out_of_bounds() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([1]).unwrap();
assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
assert_eq!(csr.row(TestRowId::from(1)), None);
assert_eq!(csr.row(TestRowId::from(100)), None);
}
#[test]
fn test_iter() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([1, 2]).unwrap();
csr.push_empty_row().unwrap();
csr.push_row([3]).unwrap();
let items: alloc::vec::Vec<_> = csr.iter().collect();
assert_eq!(items.len(), 3);
assert_eq!(items[0], (TestRowId::from(0), &[1, 2][..]));
assert_eq!(items[1], (TestRowId::from(1), &[][..]));
assert_eq!(items[2], (TestRowId::from(2), &[3][..]));
}
#[test]
fn test_iter_enumerated() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([10, 20]).unwrap();
csr.push_row([30]).unwrap();
let items: alloc::vec::Vec<_> = csr.iter_enumerated().collect();
assert_eq!(items.len(), 3);
assert_eq!(items[0], (TestRowId::from(0), 0, &10));
assert_eq!(items[1], (TestRowId::from(0), 1, &20));
assert_eq!(items[2], (TestRowId::from(1), 0, &30));
}
#[test]
fn test_validate_empty() {
let csr = CsrMatrix::<TestRowId, u32>::new();
assert!(csr.validate().is_ok());
}
#[test]
fn test_validate_valid() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([1, 2, 3]).unwrap();
csr.push_empty_row().unwrap();
csr.push_row([4]).unwrap();
assert!(csr.validate().is_ok());
}
#[test]
fn test_validate_with_callback() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([1, 2, 3]).unwrap();
csr.push_row([4, 5]).unwrap();
assert!(csr.validate_with(|&v| v < 10).is_ok());
let result = csr.validate_with(|&v| v < 4);
assert!(matches!(result, Err(CsrValidationError::InvalidData { row: 1, position: 0 })));
}
#[test]
fn test_serialization_roundtrip() {
let mut csr = CsrMatrix::<TestRowId, u32>::new();
csr.push_row([1, 2, 3]).unwrap();
csr.push_empty_row().unwrap();
csr.push_row([4, 5]).unwrap();
let mut bytes = vec![];
csr.write_into(&mut bytes);
let mut reader = miden_crypto::utils::SliceReader::new(&bytes);
let restored: CsrMatrix<TestRowId, u32> = CsrMatrix::read_from(&mut reader).unwrap();
assert_eq!(csr, restored);
}
}