1use alloc::vec::Vec;
8
9use miden_crypto::utils::{
10 ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
11};
12#[cfg(feature = "arbitrary")]
13use proptest::prelude::*;
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16use thiserror::Error;
17
18use crate::{Idx, IndexVec, IndexedVecError};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50#[cfg_attr(
51 all(feature = "arbitrary", test),
52 miden_test_serde_macros::serde_test(binary_serde(true), types(crate::SerdeTestId, u32))
53)]
54pub struct CsrMatrix<I: Idx, D> {
55 data: Vec<D>,
57 indptr: IndexVec<I, usize>,
59}
60
61#[cfg(feature = "arbitrary")]
62impl<I, D> Arbitrary for CsrMatrix<I, D>
63where
64 I: Idx + 'static,
65 D: Arbitrary + 'static,
66 D::Strategy: 'static,
67{
68 type Parameters = D::Parameters;
69 type Strategy = BoxedStrategy<Self>;
70
71 fn arbitrary_with(args: Self::Parameters) -> Self::Strategy {
72 let row = proptest::collection::vec(any_with::<D>(args), 0..8);
73
74 proptest::collection::vec(row, 0..16)
75 .prop_map(|rows| {
76 let mut matrix = Self::new();
77 for row in rows {
78 matrix.push_row(row).expect("generated row count fits in u32");
79 }
80 matrix
81 })
82 .boxed()
83 }
84}
85
86impl<I: Idx, D> Default for CsrMatrix<I, D> {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl<I: Idx, D> CsrMatrix<I, D> {
93 pub fn new() -> Self {
98 Self {
99 data: Vec::new(),
100 indptr: IndexVec::new(),
101 }
102 }
103
104 pub fn with_capacity(num_rows: usize, num_elements: usize) -> Self {
111 Self {
112 data: Vec::with_capacity(num_elements),
113 indptr: IndexVec::with_capacity(num_rows + 1),
114 }
115 }
116
117 pub fn push_row(&mut self, values: impl IntoIterator<Item = D>) -> Result<I, IndexedVecError> {
128 if self.indptr.is_empty() {
130 self.indptr.push(0)?;
131 }
132
133 let row_idx = self.num_rows();
135
136 self.data.extend(values);
138
139 self.indptr.push(self.data.len())?;
141
142 Ok(I::from(row_idx as u32))
143 }
144
145 pub fn push_empty_row(&mut self) -> Result<I, IndexedVecError> {
151 self.push_row(core::iter::empty())
152 }
153
154 pub fn fill_to_row(&mut self, target_row: I) -> Result<(), IndexedVecError> {
162 let target = target_row.to_usize();
163 while self.num_rows() < target {
164 self.push_empty_row()?;
165 }
166 Ok(())
167 }
168
169 pub fn is_empty(&self) -> bool {
174 self.indptr.is_empty()
175 }
176
177 pub fn num_rows(&self) -> usize {
179 if self.indptr.is_empty() {
180 0
181 } else {
182 self.indptr.len() - 1
183 }
184 }
185
186 pub fn num_elements(&self) -> usize {
188 self.data.len()
189 }
190
191 pub fn row(&self, row: I) -> Option<&[D]> {
193 let row_idx = row.to_usize();
194 if row_idx >= self.num_rows() {
195 return None;
196 }
197
198 let start = self.indptr[row];
199 let end = self.indptr[I::from((row_idx + 1) as u32)];
200 Some(&self.data[start..end])
201 }
202
203 pub fn row_expect(&self, row: I) -> &[D] {
209 self.row(row).expect("row index out of bounds")
210 }
211
212 pub fn iter(&self) -> impl Iterator<Item = (I, &[D])> {
214 (0..self.num_rows()).map(move |i| {
215 let row = I::from(i as u32);
216 (row, self.row_expect(row))
217 })
218 }
219
220 pub fn iter_enumerated(&self) -> impl Iterator<Item = (I, usize, &D)> {
222 self.iter()
223 .flat_map(|(row, data)| data.iter().enumerate().map(move |(pos, d)| (row, pos, d)))
224 }
225
226 pub fn data(&self) -> &[D] {
228 &self.data
229 }
230
231 pub fn indptr(&self) -> &IndexVec<I, usize> {
233 &self.indptr
234 }
235
236 pub fn validate(&self) -> Result<(), CsrValidationError> {
248 self.validate_with(|_| true)
249 }
250
251 pub fn validate_with<F>(&self, f: F) -> Result<(), CsrValidationError>
260 where
261 F: Fn(&D) -> bool,
262 {
263 let indptr = self.indptr.as_slice();
264
265 if indptr.is_empty() {
267 return Ok(());
268 }
269
270 if indptr[0] != 0 {
272 return Err(CsrValidationError::IndptrStartNotZero(indptr[0]));
273 }
274
275 for i in 1..indptr.len() {
277 if indptr[i - 1] > indptr[i] {
278 return Err(CsrValidationError::IndptrNotMonotonic {
279 index: i,
280 prev: indptr[i - 1],
281 curr: indptr[i],
282 });
283 }
284 }
285
286 let last = *indptr.last().expect("indptr is non-empty");
288 if last != self.data.len() {
289 return Err(CsrValidationError::IndptrDataMismatch {
290 indptr_end: last,
291 data_len: self.data.len(),
292 });
293 }
294
295 for (row, data) in self.iter() {
297 for (pos, d) in data.iter().enumerate() {
298 if !f(d) {
299 return Err(CsrValidationError::InvalidData {
300 row: row.to_usize(),
301 position: pos,
302 });
303 }
304 }
305 }
306
307 Ok(())
308 }
309}
310
311#[derive(Debug, Clone, PartialEq, Eq, Error)]
316pub enum CsrValidationError {
317 #[error("indptr must start at 0, got {0}")]
319 IndptrStartNotZero(usize),
320
321 #[error("indptr not monotonic at index {index}: {prev} > {curr}")]
323 IndptrNotMonotonic { index: usize, prev: usize, curr: usize },
324
325 #[error("indptr ends at {indptr_end}, but data.len() is {data_len}")]
327 IndptrDataMismatch { indptr_end: usize, data_len: usize },
328
329 #[error("invalid data value at row {row}, position {position}")]
331 InvalidData { row: usize, position: usize },
332}
333
334impl<I, D> Serializable for CsrMatrix<I, D>
338where
339 I: Idx,
340 D: Serializable,
341{
342 fn write_into<W: ByteWriter>(&self, target: &mut W) {
343 target.write_usize(self.data.len());
345 for item in &self.data {
346 item.write_into(target);
347 }
348
349 target.write_usize(self.indptr.len());
351 for &ptr in self.indptr.as_slice() {
352 target.write_usize(ptr);
353 }
354 }
355}
356
357impl<I, D> Deserializable for CsrMatrix<I, D>
358where
359 I: Idx,
360 D: Deserializable,
361{
362 fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
363 let data_len = source.read_usize()?;
365 let data: Vec<D> = source.read_many_iter(data_len)?.collect::<Result<_, _>>()?;
366
367 let indptr_len = source.read_usize()?;
369 let indptr_vec: Vec<usize> =
370 source.read_many_iter(indptr_len)?.collect::<Result<_, _>>()?;
371 let indptr = IndexVec::try_from(indptr_vec).map_err(|_| {
372 DeserializationError::InvalidValue("indptr too large for IndexVec".into())
373 })?;
374
375 Ok(Self { data, indptr })
376 }
377
378 fn min_serialized_size() -> usize {
388 2
389 }
390}
391
392#[cfg(test)]
396mod tests {
397 use alloc::vec;
398
399 use super::*;
400 use crate::newtype_id;
401
402 newtype_id!(TestRowId);
403
404 #[test]
405 fn test_new_is_empty() {
406 let csr = CsrMatrix::<TestRowId, u32>::new();
407 assert!(csr.is_empty());
408 assert_eq!(csr.num_rows(), 0);
409 assert_eq!(csr.num_elements(), 0);
410 }
411
412 #[test]
413 fn test_push_row() {
414 let mut csr = CsrMatrix::<TestRowId, u32>::new();
415
416 let id0 = csr.push_row([1, 2, 3]).unwrap();
417 assert_eq!(id0, TestRowId::from(0));
418 assert_eq!(csr.num_rows(), 1);
419 assert_eq!(csr.num_elements(), 3);
420 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2, 3][..]));
421
422 let id1 = csr.push_row([4, 5]).unwrap();
423 assert_eq!(id1, TestRowId::from(1));
424 assert_eq!(csr.num_rows(), 2);
425 assert_eq!(csr.num_elements(), 5);
426 assert_eq!(csr.row(TestRowId::from(1)), Some(&[4, 5][..]));
427 }
428
429 #[test]
430 fn test_push_empty_row() {
431 let mut csr = CsrMatrix::<TestRowId, u32>::new();
432
433 csr.push_row([1, 2]).unwrap();
434 csr.push_empty_row().unwrap();
435 csr.push_row([3]).unwrap();
436
437 assert_eq!(csr.num_rows(), 3);
438 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1, 2][..]));
439 assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
440 assert_eq!(csr.row(TestRowId::from(2)), Some(&[3][..]));
441 }
442
443 #[test]
444 fn test_fill_to_row() {
445 let mut csr = CsrMatrix::<TestRowId, u32>::new();
446
447 csr.push_row([1]).unwrap();
448 csr.fill_to_row(TestRowId::from(3)).unwrap();
449 csr.push_row([2]).unwrap();
450
451 assert_eq!(csr.num_rows(), 4);
452 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
453 assert_eq!(csr.row(TestRowId::from(1)), Some(&[][..]));
454 assert_eq!(csr.row(TestRowId::from(2)), Some(&[][..]));
455 assert_eq!(csr.row(TestRowId::from(3)), Some(&[2][..]));
456 }
457
458 #[test]
459 fn test_row_out_of_bounds() {
460 let mut csr = CsrMatrix::<TestRowId, u32>::new();
461 csr.push_row([1]).unwrap();
462
463 assert_eq!(csr.row(TestRowId::from(0)), Some(&[1][..]));
464 assert_eq!(csr.row(TestRowId::from(1)), None);
465 assert_eq!(csr.row(TestRowId::from(100)), None);
466 }
467
468 #[test]
469 fn test_iter() {
470 let mut csr = CsrMatrix::<TestRowId, u32>::new();
471 csr.push_row([1, 2]).unwrap();
472 csr.push_empty_row().unwrap();
473 csr.push_row([3]).unwrap();
474
475 let items: Vec<_> = csr.iter().collect();
476 assert_eq!(items.len(), 3);
477 assert_eq!(items[0], (TestRowId::from(0), &[1, 2][..]));
478 assert_eq!(items[1], (TestRowId::from(1), &[][..]));
479 assert_eq!(items[2], (TestRowId::from(2), &[3][..]));
480 }
481
482 #[test]
483 fn test_iter_enumerated() {
484 let mut csr = CsrMatrix::<TestRowId, u32>::new();
485 csr.push_row([10, 20]).unwrap();
486 csr.push_row([30]).unwrap();
487
488 let items: Vec<_> = csr.iter_enumerated().collect();
489 assert_eq!(items.len(), 3);
490 assert_eq!(items[0], (TestRowId::from(0), 0, &10));
491 assert_eq!(items[1], (TestRowId::from(0), 1, &20));
492 assert_eq!(items[2], (TestRowId::from(1), 0, &30));
493 }
494
495 #[test]
496 fn test_validate_empty() {
497 let csr = CsrMatrix::<TestRowId, u32>::new();
498 assert!(csr.validate().is_ok());
499 }
500
501 #[test]
502 fn test_validate_valid() {
503 let mut csr = CsrMatrix::<TestRowId, u32>::new();
504 csr.push_row([1, 2, 3]).unwrap();
505 csr.push_empty_row().unwrap();
506 csr.push_row([4]).unwrap();
507
508 assert!(csr.validate().is_ok());
509 }
510
511 #[test]
512 fn test_validate_with_callback() {
513 let mut csr = CsrMatrix::<TestRowId, u32>::new();
514 csr.push_row([1, 2, 3]).unwrap();
515 csr.push_row([4, 5]).unwrap();
516
517 assert!(csr.validate_with(|&v| v < 10).is_ok());
519
520 let result = csr.validate_with(|&v| v < 4);
522 assert!(matches!(result, Err(CsrValidationError::InvalidData { row: 1, position: 0 })));
523 }
524
525 #[test]
526 fn test_serialization_roundtrip() {
527 let mut csr = CsrMatrix::<TestRowId, u32>::new();
528 csr.push_row([1, 2, 3]).unwrap();
529 csr.push_empty_row().unwrap();
530 csr.push_row([4, 5]).unwrap();
531
532 let mut bytes = vec![];
534 csr.write_into(&mut bytes);
535
536 let mut reader = miden_crypto::utils::SliceReader::new(&bytes);
538 let restored: CsrMatrix<TestRowId, u32> = CsrMatrix::read_from(&mut reader).unwrap();
539
540 assert_eq!(csr, restored);
541 }
542}