1use {
2 crate::{FieldElement, InternedFieldElement, Interner},
3 ark_std::Zero,
4 rayon::{
5 iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
6 slice::ParallelSliceMut,
7 },
8 serde::{
9 de::{SeqAccess, Visitor},
10 ser::SerializeStruct,
11 Deserialize, Deserializer, Serialize, Serializer,
12 },
13 std::{
14 fmt::{self, Debug},
15 ops::{Mul, Range},
16 },
17};
18
19#[derive(Debug, Clone, Copy)]
20pub struct DeltaEncodingStats {
21 pub total_entries: usize,
22 pub absolute_bytes: usize,
23 pub delta_bytes: usize,
24}
25
26impl DeltaEncodingStats {
27 pub const fn savings_bytes(&self) -> usize {
28 self.absolute_bytes.saturating_sub(self.delta_bytes)
29 }
30
31 pub fn savings_percent(&self) -> f64 {
32 if self.absolute_bytes == 0 {
33 0.0
34 } else {
35 self.savings_bytes() as f64 / self.absolute_bytes as f64 * 100.0
36 }
37 }
38}
39
40const fn varint_size(value: u32) -> usize {
41 match value {
42 0..=0x7f => 1,
43 0x80..=0x3fff => 2,
44 0x4000..=0x1f_ffff => 3,
45 0x20_0000..=0xfff_ffff => 4,
46 _ => 5,
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq)]
56pub struct SparseMatrix {
57 pub num_rows: usize,
59
60 pub num_cols: usize,
62
63 new_row_indices: Vec<u32>,
65
66 col_indices: Vec<u32>,
68
69 values: Vec<InternedFieldElement>,
71}
72
73fn encode_col_deltas(
75 col_indices: &[u32],
76 new_row_indices: &[u32],
77 total_entries: usize,
78) -> Vec<u32> {
79 let mut deltas = Vec::with_capacity(col_indices.len());
80 let num_rows = new_row_indices.len();
81
82 for row in 0..num_rows {
83 let start = new_row_indices[row] as usize;
84 let end = new_row_indices
85 .get(row + 1)
86 .map_or(total_entries, |&v| v as usize);
87
88 let row_cols = &col_indices[start..end];
89 if row_cols.is_empty() {
90 continue;
91 }
92
93 debug_assert!(
94 row_cols.windows(2).all(|w| w[0] <= w[1]),
95 "Column indices must be sorted within each row"
96 );
97
98 deltas.push(row_cols[0]);
100
101 for i in 1..row_cols.len() {
103 deltas.push(row_cols[i] - row_cols[i - 1]);
104 }
105 }
106
107 deltas
108}
109
110fn decode_col_deltas(deltas: &[u32], new_row_indices: &[u32], total_entries: usize) -> Vec<u32> {
112 let mut col_indices = Vec::with_capacity(deltas.len());
113 let num_rows = new_row_indices.len();
114
115 let mut delta_idx = 0;
116 for row in 0..num_rows {
117 let start = new_row_indices[row] as usize;
118 let end = new_row_indices
119 .get(row + 1)
120 .map_or(total_entries, |&v| v as usize);
121
122 let row_len = end - start;
123 if row_len == 0 {
124 continue;
125 }
126
127 let first_col = deltas[delta_idx];
129 col_indices.push(first_col);
130 delta_idx += 1;
131
132 let mut prev_col = first_col;
134 for _ in 1..row_len {
135 let col = prev_col + deltas[delta_idx];
136 col_indices.push(col);
137 prev_col = col;
138 delta_idx += 1;
139 }
140 }
141
142 col_indices
143}
144
145impl Serialize for SparseMatrix {
146 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
147 where
148 S: Serializer,
149 {
150 let col_deltas =
151 encode_col_deltas(&self.col_indices, &self.new_row_indices, self.values.len());
152
153 let mut state = serializer.serialize_struct("SparseMatrix", 5)?;
154 state.serialize_field("num_rows", &self.num_rows)?;
155 state.serialize_field("num_cols", &self.num_cols)?;
156 state.serialize_field("new_row_indices", &self.new_row_indices)?;
157 state.serialize_field("col_deltas", &col_deltas)?;
158 state.serialize_field("values", &self.values)?;
159 state.end()
160 }
161}
162
163impl<'de> Deserialize<'de> for SparseMatrix {
164 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
165 where
166 D: Deserializer<'de>,
167 {
168 #[derive(Deserialize)]
169 #[serde(field_identifier, rename_all = "snake_case")]
170 enum Field {
171 NumRows,
172 NumCols,
173 NewRowIndices,
174 ColDeltas,
175 Values,
176 }
177
178 struct SparseMatrixVisitor;
179
180 impl<'de> Visitor<'de> for SparseMatrixVisitor {
181 type Value = SparseMatrix;
182
183 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
184 formatter.write_str("struct SparseMatrix")
185 }
186
187 fn visit_seq<V>(self, mut seq: V) -> Result<SparseMatrix, V::Error>
188 where
189 V: SeqAccess<'de>,
190 {
191 let num_rows = seq
192 .next_element()?
193 .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
194 let num_cols = seq
195 .next_element()?
196 .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
197 let new_row_indices: Vec<u32> = seq
198 .next_element()?
199 .ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
200 let col_deltas: Vec<u32> = seq
201 .next_element()?
202 .ok_or_else(|| serde::de::Error::invalid_length(3, &self))?;
203 let values: Vec<InternedFieldElement> = seq
204 .next_element()?
205 .ok_or_else(|| serde::de::Error::invalid_length(4, &self))?;
206
207 let col_indices = decode_col_deltas(&col_deltas, &new_row_indices, values.len());
208
209 Ok(SparseMatrix {
210 num_rows,
211 num_cols,
212 new_row_indices,
213 col_indices,
214 values,
215 })
216 }
217
218 fn visit_map<V>(self, mut map: V) -> Result<SparseMatrix, V::Error>
219 where
220 V: serde::de::MapAccess<'de>,
221 {
222 let mut num_rows = None;
223 let mut num_cols = None;
224 let mut new_row_indices: Option<Vec<u32>> = None;
225 let mut col_deltas: Option<Vec<u32>> = None;
226 let mut values: Option<Vec<InternedFieldElement>> = None;
227
228 while let Some(key) = map.next_key()? {
229 match key {
230 Field::NumRows => {
231 if num_rows.is_some() {
232 return Err(serde::de::Error::duplicate_field("num_rows"));
233 }
234 num_rows = Some(map.next_value()?);
235 }
236 Field::NumCols => {
237 if num_cols.is_some() {
238 return Err(serde::de::Error::duplicate_field("num_cols"));
239 }
240 num_cols = Some(map.next_value()?);
241 }
242 Field::NewRowIndices => {
243 if new_row_indices.is_some() {
244 return Err(serde::de::Error::duplicate_field("new_row_indices"));
245 }
246 new_row_indices = Some(map.next_value()?);
247 }
248 Field::ColDeltas => {
249 if col_deltas.is_some() {
250 return Err(serde::de::Error::duplicate_field("col_deltas"));
251 }
252 col_deltas = Some(map.next_value()?);
253 }
254 Field::Values => {
255 if values.is_some() {
256 return Err(serde::de::Error::duplicate_field("values"));
257 }
258 values = Some(map.next_value()?);
259 }
260 }
261 }
262
263 let num_rows =
264 num_rows.ok_or_else(|| serde::de::Error::missing_field("num_rows"))?;
265 let num_cols =
266 num_cols.ok_or_else(|| serde::de::Error::missing_field("num_cols"))?;
267 let new_row_indices = new_row_indices
268 .ok_or_else(|| serde::de::Error::missing_field("new_row_indices"))?;
269 let col_deltas =
270 col_deltas.ok_or_else(|| serde::de::Error::missing_field("col_deltas"))?;
271 let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
272
273 let col_indices = decode_col_deltas(&col_deltas, &new_row_indices, values.len());
274
275 Ok(SparseMatrix {
276 num_rows,
277 num_cols,
278 new_row_indices,
279 col_indices,
280 values,
281 })
282 }
283 }
284
285 const FIELDS: &[&str] = &[
286 "num_rows",
287 "num_cols",
288 "new_row_indices",
289 "col_deltas",
290 "values",
291 ];
292 deserializer.deserialize_struct("SparseMatrix", FIELDS, SparseMatrixVisitor)
293 }
294}
295
296#[derive(Debug, Clone, Copy, PartialEq, Eq)]
298pub struct HydratedSparseMatrix<'a> {
299 pub matrix: &'a SparseMatrix,
300 interner: &'a Interner,
301}
302
303impl SparseMatrix {
304 pub fn new(rows: usize, cols: usize) -> Self {
305 Self {
306 num_rows: rows,
307 num_cols: cols,
308 new_row_indices: vec![0; rows],
309 col_indices: Vec::new(),
310 values: Vec::new(),
311 }
312 }
313
314 pub const fn hydrate<'a>(&'a self, interner: &'a Interner) -> HydratedSparseMatrix<'a> {
315 HydratedSparseMatrix {
316 matrix: self,
317 interner,
318 }
319 }
320
321 pub fn num_entries(&self) -> usize {
322 self.values.len()
323 }
324
325 pub fn delta_encoding_stats(&self) -> DeltaEncodingStats {
326 let deltas = encode_col_deltas(&self.col_indices, &self.new_row_indices, self.values.len());
327
328 let absolute_bytes: usize = self.col_indices.iter().map(|&v| varint_size(v)).sum();
329 let delta_bytes: usize = deltas.iter().map(|&v| varint_size(v)).sum();
330
331 DeltaEncodingStats {
332 total_entries: self.col_indices.len(),
333 absolute_bytes,
334 delta_bytes,
335 }
336 }
337
338 pub fn grow(&mut self, rows: usize, cols: usize) {
339 assert!(rows >= self.num_rows);
341 assert!(cols >= self.num_cols);
342 self.num_rows = rows;
343 self.num_cols = cols;
344 self.new_row_indices.resize(rows, self.values.len() as u32);
345 }
346
347 pub fn set(&mut self, row: usize, col: usize, value: InternedFieldElement) {
349 assert!(row < self.num_rows, "row index out of bounds");
350 assert!(col < self.num_cols, "column index out of bounds");
351
352 let row_range = self.row_range(row);
354 let cols = &self.col_indices[row_range.clone()];
355
356 match cols.binary_search(&(col as u32)) {
358 Ok(_) => {
359 unreachable!("Duplicate column {col} in row {row}");
360 }
361 Err(i) => {
362 let i = i + row_range.start;
364 self.col_indices.insert(i, col as u32);
365 self.values.insert(i, value);
366 for index in &mut self.new_row_indices[row + 1..] {
367 *index += 1;
368 }
369 }
370 }
371 }
372
373 pub fn iter_row(
375 &self,
376 row: usize,
377 ) -> impl Iterator<Item = (usize, InternedFieldElement)> + use<'_> {
378 let row_range = self.row_range(row);
379 let cols = self.col_indices[row_range.clone()].iter().copied();
380 let values = self.values[row_range].iter().copied();
381 cols.zip(values).map(|(col, value)| (col as usize, value))
382 }
383
384 pub fn iter(&self) -> impl Iterator<Item = ((usize, usize), InternedFieldElement)> + use<'_> {
386 (0..self.new_row_indices.len()).flat_map(|row| {
387 self.iter_row(row)
388 .map(move |(col, value)| ((row, col), value))
389 })
390 }
391
392 fn row_range(&self, row: usize) -> Range<usize> {
393 let start = *self
394 .new_row_indices
395 .get(row)
396 .expect("Row index out of bounds") as usize;
397 let end = self
398 .new_row_indices
399 .get(row + 1)
400 .map_or(self.values.len(), |&v| v as usize);
401 start..end
402 }
403
404 pub fn transpose(&self) -> SparseMatrix {
410 let nnz = self.values.len();
411
412 let mut entries: Vec<(u32, u32, InternedFieldElement)> = Vec::with_capacity(nnz);
413 for row in 0..self.num_rows {
414 let range = self.row_range(row);
415 for i in range {
416 entries.push((self.col_indices[i], row as u32, self.values[i]));
417 }
418 }
419
420 entries.par_sort_unstable_by_key(|&(new_row, new_col, _)| (new_row, new_col));
421 debug_assert!(
422 entries
423 .windows(2)
424 .all(|w| (w[0].0, w[0].1) != (w[1].0, w[1].1)),
425 "Duplicate (row, col) entries in sparse matrix transpose"
426 );
427
428 let mut new_row_indices = Vec::with_capacity(self.num_cols);
429 let mut col_indices = Vec::with_capacity(nnz);
430 let mut values = Vec::with_capacity(nnz);
431
432 let mut entry_idx = 0;
433 for row in 0..self.num_cols {
434 new_row_indices.push(entry_idx as u32);
435 while entry_idx < entries.len() && entries[entry_idx].0 == row as u32 {
436 col_indices.push(entries[entry_idx].1);
437 values.push(entries[entry_idx].2);
438 entry_idx += 1;
439 }
440 }
441
442 SparseMatrix {
443 num_rows: self.num_cols,
444 num_cols: self.num_rows,
445 new_row_indices,
446 col_indices,
447 values,
448 }
449 }
450
451 pub fn remap_columns<F>(&mut self, remap_fn: F)
454 where
455 F: Fn(usize) -> usize + Send + Sync,
456 {
457 self.col_indices.par_iter_mut().for_each(|col| {
459 *col = remap_fn(*col as usize) as u32;
460 });
461
462 for row in 0..self.num_rows {
464 let start = self.new_row_indices[row] as usize;
465 let end = self
466 .new_row_indices
467 .get(row + 1)
468 .map_or(self.col_indices.len(), |&v| v as usize);
469
470 let row_cols = &mut self.col_indices[start..end];
471 let row_vals = &mut self.values[start..end];
472
473 let mut pairs: Vec<_> = row_cols
474 .iter()
475 .zip(row_vals.iter())
476 .map(|(&c, &v)| (c, v))
477 .collect();
478 pairs.sort_unstable_by_key(|(c, _)| *c);
479
480 for (i, (c, v)) in pairs.into_iter().enumerate() {
481 row_cols[i] = c;
482 row_vals[i] = v;
483 }
484 }
485 }
486}
487
488impl HydratedSparseMatrix<'_> {
489 pub fn iter_row(&self, row: usize) -> impl Iterator<Item = (usize, FieldElement)> + use<'_> {
491 self.matrix.iter_row(row).map(|(col, value)| {
492 (
493 col,
494 self.interner.get(value).expect("Value not in interner."),
495 )
496 })
497 }
498
499 pub fn iter(&self) -> impl Iterator<Item = ((usize, usize), FieldElement)> + use<'_> {
501 self.matrix.iter().map(|((i, j), v)| {
502 (
503 (i, j),
504 self.interner.get(v).expect("Value not in interner."),
505 )
506 })
507 }
508}
509
510impl Mul<&[FieldElement]> for HydratedSparseMatrix<'_> {
512 type Output = Vec<FieldElement>;
513
514 fn mul(self, rhs: &[FieldElement]) -> Self::Output {
515 assert_eq!(
516 self.matrix.num_cols,
517 rhs.len(),
518 "Vector length does not match number of columns."
519 );
520 (0..self.matrix.num_rows)
521 .into_par_iter()
522 .map(|row| {
523 self.iter_row(row)
524 .map(|(col, value)| value * rhs[col])
525 .fold(FieldElement::zero(), |acc, x| acc + x)
526 })
527 .collect()
528 }
529}
530
531impl Mul<HydratedSparseMatrix<'_>> for &[FieldElement] {
536 type Output = Vec<FieldElement>;
537
538 fn mul(self, rhs: HydratedSparseMatrix<'_>) -> Self::Output {
539 assert_eq!(
540 self.len(),
541 rhs.matrix.num_rows,
542 "Vector length does not match number of rows."
543 );
544 let mut result = vec![FieldElement::zero(); rhs.matrix.num_cols];
545 for ((i, j), value) in rhs.iter() {
546 result[j] += value * self[i];
547 }
548 result
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_delta_encoding_roundtrip() {
558 let col_indices = vec![3, 15, 100, 5, 50, 200];
559 let new_row_indices = vec![0, 3];
560 let total_entries = 6;
561
562 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
563 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
564
565 assert_eq!(col_indices, decoded);
566 }
567
568 #[test]
569 fn test_delta_encoding_values() {
570 let col_indices = vec![3, 15, 100];
571 let new_row_indices = vec![0];
572 let total_entries = 3;
573
574 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
575
576 assert_eq!(deltas, vec![3, 12, 85]);
577 }
578
579 #[test]
580 fn test_delta_encoding_multiple_rows() {
581 let col_indices = vec![0, 10, 20, 5, 15];
582 let new_row_indices = vec![0, 3];
583 let total_entries = 5;
584
585 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
586 assert_eq!(deltas, vec![0, 10, 10, 5, 10]);
587
588 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
589 assert_eq!(col_indices, decoded);
590 }
591
592 #[test]
593 fn test_delta_encoding_empty_row() {
594 let col_indices = vec![5, 10];
595 let new_row_indices = vec![0, 0, 2];
596 let total_entries = 2;
597
598 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
599 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
600
601 assert_eq!(col_indices, decoded);
602 }
603
604 #[test]
606 fn test_delta_encoding_single_entry() {
607 let col_indices = vec![42];
608 let new_row_indices = vec![0];
609 let total_entries = 1;
610
611 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
612 assert_eq!(deltas, vec![42]);
613
614 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
615 assert_eq!(col_indices, decoded);
616 }
617
618 #[test]
621 fn test_delta_encoding_single_column_per_row() {
622 let col_indices = vec![0, 5, 100];
623 let new_row_indices = vec![0, 1, 2];
624 let total_entries = 3;
625
626 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
627 assert_eq!(deltas, vec![0, 5, 100]);
628
629 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
630 assert_eq!(col_indices, decoded);
631 }
632
633 #[test]
635 fn test_delta_encoding_consecutive_columns() {
636 let col_indices = vec![10, 11, 12, 13];
637 let new_row_indices = vec![0];
638 let total_entries = 4;
639
640 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
641 assert_eq!(deltas, vec![10, 1, 1, 1]);
642
643 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
644 assert_eq!(col_indices, decoded);
645 }
646
647 #[test]
649 fn test_delta_encoding_all_rows_empty() {
650 let col_indices: Vec<u32> = vec![];
651 let new_row_indices = vec![0, 0, 0];
652 let total_entries = 0;
653
654 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
655 assert!(deltas.is_empty());
656
657 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
658 assert!(decoded.is_empty());
659 }
660
661 #[test]
663 fn test_delta_encoding_last_row_empty() {
664 let col_indices = vec![1, 2, 7];
665 let new_row_indices = vec![0, 2, 3];
666 let total_entries = 3;
667
668 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
669 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
670
671 assert_eq!(col_indices, decoded);
672 }
673
674 #[test]
677 fn test_delta_encoding_only_last_row_non_empty() {
678 let col_indices = vec![3, 8];
679 let new_row_indices = vec![0, 0, 0, 2];
680 let total_entries = 2;
681
682 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
683 assert_eq!(deltas, vec![3, 5]);
684
685 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
686 assert_eq!(col_indices, decoded);
687 }
688
689 #[test]
691 fn test_delta_encoding_large_column_indices() {
692 let col_indices = vec![1_000_000, 1_000_001, 2_000_000];
693 let new_row_indices = vec![0];
694 let total_entries = 3;
695
696 let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
697 assert_eq!(deltas, vec![1_000_000, 1, 999_999]);
698
699 let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
700 assert_eq!(col_indices, decoded);
701 }
702
703 #[test]
704 fn test_sparse_matrix_serde_roundtrip() {
705 let mut interner = Interner::new();
706 let val1 = interner.intern(FieldElement::from(1u64));
707 let val2 = interner.intern(FieldElement::from(2u64));
708 let val3 = interner.intern(FieldElement::from(3u64));
709
710 let mut matrix = SparseMatrix::new(3, 100);
711 matrix.grow(3, 100);
712 matrix.set(0, 5, val1);
713 matrix.set(0, 20, val2);
714 matrix.set(1, 50, val3);
715
716 let serialized = postcard::to_allocvec(&matrix).expect("serialization failed");
717 let deserialized: SparseMatrix =
718 postcard::from_bytes(&serialized).expect("deserialization failed");
719
720 assert_eq!(matrix, deserialized);
721 }
722
723 #[test]
724 fn test_delta_encoding_size_reduction() {
725 let mut interner = Interner::new();
726 let val = interner.intern(FieldElement::from(1u64));
727
728 let mut matrix = SparseMatrix::new(10, 1000);
729 matrix.grow(10, 1000);
730
731 for row in 0..10 {
732 for col_offset in 0..20 {
733 matrix.set(row, row * 50 + col_offset, val);
734 }
735 }
736
737 let serialized = postcard::to_allocvec(&matrix).expect("serialization failed");
738
739 let col_count = matrix.col_indices.len();
740 let naive_col_bytes = col_count * 4;
741 let actual_bytes = serialized.len();
742
743 assert!(
744 actual_bytes < naive_col_bytes,
745 "delta encoding should reduce size: actual {} vs naive col bytes {}",
746 actual_bytes,
747 naive_col_bytes
748 );
749 }
750}