1use alloc::vec::Vec;
8
9use miden_crypto::utils::{
10 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
11};
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15
16use crate::{Idx, IndexVec, IndexedVecError};
17
18#[derive(Debug, Clone, PartialEq, Eq)]
47#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
48pub struct CsrMatrix<I: Idx, D> {
49 data: Vec<D>,
51 indptr: IndexVec<I, usize>,
53}
54
55impl<I: Idx, D> Default for CsrMatrix<I, D> {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl<I: Idx, D> CsrMatrix<I, D> {
62 pub fn new() -> Self {
67 Self {
68 data: Vec::new(),
69 indptr: IndexVec::new(),
70 }
71 }
72
73 pub fn with_capacity(num_rows: usize, num_elements: usize) -> Self {
80 Self {
81 data: Vec::with_capacity(num_elements),
82 indptr: IndexVec::with_capacity(num_rows + 1),
83 }
84 }
85
86 pub fn push_row(&mut self, values: impl IntoIterator<Item = D>) -> Result<I, IndexedVecError> {
97 if self.indptr.is_empty() {
99 self.indptr.push(0)?;
100 }
101
102 let row_idx = self.num_rows();
104
105 self.data.extend(values);
107
108 self.indptr.push(self.data.len())?;
110
111 Ok(I::from(row_idx as u32))
112 }
113
114 pub fn push_empty_row(&mut self) -> Result<I, IndexedVecError> {
120 self.push_row(core::iter::empty())
121 }
122
123 pub fn fill_to_row(&mut self, target_row: I) -> Result<(), IndexedVecError> {
131 let target = target_row.to_usize();
132 while self.num_rows() < target {
133 self.push_empty_row()?;
134 }
135 Ok(())
136 }
137
138 pub fn is_empty(&self) -> bool {
143 self.indptr.is_empty()
144 }
145
146 pub fn num_rows(&self) -> usize {
148 if self.indptr.is_empty() {
149 0
150 } else {
151 self.indptr.len() - 1
152 }
153 }
154
155 pub fn num_elements(&self) -> usize {
157 self.data.len()
158 }
159
160 pub fn row(&self, row: I) -> Option<&[D]> {
162 let row_idx = row.to_usize();
163 if row_idx >= self.num_rows() {
164 return None;
165 }
166
167 let start = self.indptr[row];
168 let end = self.indptr[I::from((row_idx + 1) as u32)];
169 Some(&self.data[start..end])
170 }
171
172 pub fn row_expect(&self, row: I) -> &[D] {
178 self.row(row).expect("row index out of bounds")
179 }
180
181 pub fn iter(&self) -> impl Iterator<Item = (I, &[D])> {
183 (0..self.num_rows()).map(move |i| {
184 let row = I::from(i as u32);
185 (row, self.row_expect(row))
186 })
187 }
188
189 pub fn iter_enumerated(&self) -> impl Iterator<Item = (I, usize, &D)> {
191 self.iter()
192 .flat_map(|(row, data)| data.iter().enumerate().map(move |(pos, d)| (row, pos, d)))
193 }
194
195 pub fn data(&self) -> &[D] {
197 &self.data
198 }
199
200 pub fn indptr(&self) -> &IndexVec<I, usize> {
202 &self.indptr
203 }
204
205 pub fn validate(&self) -> Result<(), CsrValidationError> {
217 self.validate_with(|_| true)
218 }
219
220 pub fn validate_with<F>(&self, f: F) -> Result<(), CsrValidationError>
229 where
230 F: Fn(&D) -> bool,
231 {
232 let indptr = self.indptr.as_slice();
233
234 if indptr.is_empty() {
236 return Ok(());
237 }
238
239 if indptr[0] != 0 {
241 return Err(CsrValidationError::IndptrStartNotZero(indptr[0]));
242 }
243
244 for i in 1..indptr.len() {
246 if indptr[i - 1] > indptr[i] {
247 return Err(CsrValidationError::IndptrNotMonotonic {
248 index: i,
249 prev: indptr[i - 1],
250 curr: indptr[i],
251 });
252 }
253 }
254
255 let last = *indptr.last().expect("indptr is non-empty");
257 if last != self.data.len() {
258 return Err(CsrValidationError::IndptrDataMismatch {
259 indptr_end: last,
260 data_len: self.data.len(),
261 });
262 }
263
264 for (row, data) in self.iter() {
266 for (pos, d) in data.iter().enumerate() {
267 if !f(d) {
268 return Err(CsrValidationError::InvalidData {
269 row: row.to_usize(),
270 position: pos,
271 });
272 }
273 }
274 }
275
276 Ok(())
277 }
278}
279
280#[derive(Debug, Clone, PartialEq, Eq, Error)]
285pub enum CsrValidationError {
286 #[error("indptr must start at 0, got {0}")]
288 IndptrStartNotZero(usize),
289
290 #[error("indptr not monotonic at index {index}: {prev} > {curr}")]
292 IndptrNotMonotonic { index: usize, prev: usize, curr: usize },
293
294 #[error("indptr ends at {indptr_end}, but data.len() is {data_len}")]
296 IndptrDataMismatch { indptr_end: usize, data_len: usize },
297
298 #[error("invalid data value at row {row}, position {position}")]
300 InvalidData { row: usize, position: usize },
301}
302
303impl<I, D> Serializable for CsrMatrix<I, D>
307where
308 I: Idx,
309 D: Serializable,
310{
311 fn write_into<W: ByteWriter>(&self, target: &mut W) {
312 target.write_usize(self.data.len());
314 for item in &self.data {
315 item.write_into(target);
316 }
317
318 target.write_usize(self.indptr.len());
320 for &ptr in self.indptr.as_slice() {
321 target.write_usize(ptr);
322 }
323 }
324}
325
326impl<I, D> Deserializable for CsrMatrix<I, D>
327where
328 I: Idx,
329 D: Deserializable,
330{
331 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
332 let data_len = source.read_usize()?;
334 let data: Vec<D> = source.read_many_iter(data_len)?.collect::<Result<_, _>>()?;
335
336 let indptr_len = source.read_usize()?;
338 let indptr_vec: Vec<usize> =
339 source.read_many_iter(indptr_len)?.collect::<Result<_, _>>()?;
340 let indptr = IndexVec::try_from(indptr_vec).map_err(|_| {
341 DeserializationError::InvalidValue("indptr too large for IndexVec".into())
342 })?;
343
344 Ok(Self { data, indptr })
345 }
346
347 fn min_serialized_size() -> usize {
357 2
358 }
359}
360
361#[cfg(test)]
365mod tests {
366 use alloc::vec;
367
368 use super::*;
369 use crate::newtype_id;
370
371 newtype_id!(TestRowId);
372
373 #[test]
374 fn test_new_is_empty() {
375 let csr = CsrMatrix::<TestRowId, u32>::new();
376 assert!(csr.is_empty());
377 assert_eq!(csr.num_rows(), 0);
378 assert_eq!(csr.num_elements(), 0);
379 }
380
381 #[test]
382 fn test_push_row() {
383 let mut csr = CsrMatrix::<TestRowId, u32>::new();
384
385 let id0 = csr.push_row([1, 2, 3]).unwrap();
386 assert_eq!(id0, TestRowId::from(0));
387 assert_eq!(csr.num_rows(), 1);
388 assert_eq!(csr.num_elements(), 3);
389 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2, 3][..]));
390
391 let id1 = csr.push_row([4, 5]).unwrap();
392 assert_eq!(id1, TestRowId::from(1));
393 assert_eq!(csr.num_rows(), 2);
394 assert_eq!(csr.num_elements(), 5);
395 assert_eq!(csr.row(TestRowId::from(1)), Some(&[4, 5][..]));
396 }
397
398 #[test]
399 fn test_push_empty_row() {
400 let mut csr = CsrMatrix::<TestRowId, u32>::new();
401
402 csr.push_row([1, 2]).unwrap();
403 csr.push_empty_row().unwrap();
404 csr.push_row([3]).unwrap();
405
406 assert_eq!(csr.num_rows(), 3);
407 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2][..]));
408 assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
409 assert_eq!(csr.row(TestRowId::from(2)), Some(&[3][..]));
410 }
411
412 #[test]
413 fn test_fill_to_row() {
414 let mut csr = CsrMatrix::<TestRowId, u32>::new();
415
416 csr.push_row([1]).unwrap();
417 csr.fill_to_row(TestRowId::from(3)).unwrap();
418 csr.push_row([2]).unwrap();
419
420 assert_eq!(csr.num_rows(), 4);
421 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
422 assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
423 assert_eq!(csr.row(TestRowId::from(2)), Some(&[][..]));
424 assert_eq!(csr.row(TestRowId::from(3)), Some(&[2][..]));
425 }
426
427 #[test]
428 fn test_row_out_of_bounds() {
429 let mut csr = CsrMatrix::<TestRowId, u32>::new();
430 csr.push_row([1]).unwrap();
431
432 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
433 assert_eq!(csr.row(TestRowId::from(1)), None);
434 assert_eq!(csr.row(TestRowId::from(100)), None);
435 }
436
437 #[test]
438 fn test_iter() {
439 let mut csr = CsrMatrix::<TestRowId, u32>::new();
440 csr.push_row([1, 2]).unwrap();
441 csr.push_empty_row().unwrap();
442 csr.push_row([3]).unwrap();
443
444 let items: alloc::vec::Vec<_> = csr.iter().collect();
445 assert_eq!(items.len(), 3);
446 assert_eq!(items[0], (TestRowId::from(0), &[1, 2][..]));
447 assert_eq!(items[1], (TestRowId::from(1), &[][..]));
448 assert_eq!(items[2], (TestRowId::from(2), &[3][..]));
449 }
450
451 #[test]
452 fn test_iter_enumerated() {
453 let mut csr = CsrMatrix::<TestRowId, u32>::new();
454 csr.push_row([10, 20]).unwrap();
455 csr.push_row([30]).unwrap();
456
457 let items: alloc::vec::Vec<_> = csr.iter_enumerated().collect();
458 assert_eq!(items.len(), 3);
459 assert_eq!(items[0], (TestRowId::from(0), 0, &10));
460 assert_eq!(items[1], (TestRowId::from(0), 1, &20));
461 assert_eq!(items[2], (TestRowId::from(1), 0, &30));
462 }
463
464 #[test]
465 fn test_validate_empty() {
466 let csr = CsrMatrix::<TestRowId, u32>::new();
467 assert!(csr.validate().is_ok());
468 }
469
470 #[test]
471 fn test_validate_valid() {
472 let mut csr = CsrMatrix::<TestRowId, u32>::new();
473 csr.push_row([1, 2, 3]).unwrap();
474 csr.push_empty_row().unwrap();
475 csr.push_row([4]).unwrap();
476
477 assert!(csr.validate().is_ok());
478 }
479
480 #[test]
481 fn test_validate_with_callback() {
482 let mut csr = CsrMatrix::<TestRowId, u32>::new();
483 csr.push_row([1, 2, 3]).unwrap();
484 csr.push_row([4, 5]).unwrap();
485
486 assert!(csr.validate_with(|&v| v < 10).is_ok());
488
489 let result = csr.validate_with(|&v| v < 4);
491 assert!(matches!(result, Err(CsrValidationError::InvalidData { row: 1, position: 0 })));
492 }
493
494 #[test]
495 fn test_serialization_roundtrip() {
496 let mut csr = CsrMatrix::<TestRowId, u32>::new();
497 csr.push_row([1, 2, 3]).unwrap();
498 csr.push_empty_row().unwrap();
499 csr.push_row([4, 5]).unwrap();
500
501 let mut bytes = vec![];
503 csr.write_into(&mut bytes);
504
505 let mut reader = miden_crypto::utils::SliceReader::new(&bytes);
507 let restored: CsrMatrix<TestRowId, u32> = CsrMatrix::read_from(&mut reader).unwrap();
508
509 assert_eq!(csr, restored);
510 }
511}