use {
crate::{FieldElement, InternedFieldElement, Interner},
ark_std::Zero,
rayon::{
iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator},
slice::ParallelSliceMut,
},
serde::{
de::{SeqAccess, Visitor},
ser::SerializeStruct,
Deserialize, Deserializer, Serialize, Serializer,
},
std::{
fmt::{self, Debug},
ops::{Mul, Range},
},
};
#[derive(Debug, Clone, Copy)]
pub struct DeltaEncodingStats {
pub total_entries: usize,
pub absolute_bytes: usize,
pub delta_bytes: usize,
}
impl DeltaEncodingStats {
pub const fn savings_bytes(&self) -> usize {
self.absolute_bytes.saturating_sub(self.delta_bytes)
}
pub fn savings_percent(&self) -> f64 {
if self.absolute_bytes == 0 {
0.0
} else {
self.savings_bytes() as f64 / self.absolute_bytes as f64 * 100.0
}
}
}
const fn varint_size(value: u32) -> usize {
match value {
0..=0x7f => 1,
0x80..=0x3fff => 2,
0x4000..=0x1f_ffff => 3,
0x20_0000..=0xfff_ffff => 4,
_ => 5,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SparseMatrix {
pub num_rows: usize,
pub num_cols: usize,
new_row_indices: Vec<u32>,
col_indices: Vec<u32>,
values: Vec<InternedFieldElement>,
}
fn encode_col_deltas(
col_indices: &[u32],
new_row_indices: &[u32],
total_entries: usize,
) -> Vec<u32> {
let mut deltas = Vec::with_capacity(col_indices.len());
let num_rows = new_row_indices.len();
for row in 0..num_rows {
let start = new_row_indices[row] as usize;
let end = new_row_indices
.get(row + 1)
.map_or(total_entries, |&v| v as usize);
let row_cols = &col_indices[start..end];
if row_cols.is_empty() {
continue;
}
debug_assert!(
row_cols.windows(2).all(|w| w[0] <= w[1]),
"Column indices must be sorted within each row"
);
deltas.push(row_cols[0]);
for i in 1..row_cols.len() {
deltas.push(row_cols[i] - row_cols[i - 1]);
}
}
deltas
}
fn decode_col_deltas(deltas: &[u32], new_row_indices: &[u32], total_entries: usize) -> Vec<u32> {
let mut col_indices = Vec::with_capacity(deltas.len());
let num_rows = new_row_indices.len();
let mut delta_idx = 0;
for row in 0..num_rows {
let start = new_row_indices[row] as usize;
let end = new_row_indices
.get(row + 1)
.map_or(total_entries, |&v| v as usize);
let row_len = end - start;
if row_len == 0 {
continue;
}
let first_col = deltas[delta_idx];
col_indices.push(first_col);
delta_idx += 1;
let mut prev_col = first_col;
for _ in 1..row_len {
let col = prev_col + deltas[delta_idx];
col_indices.push(col);
prev_col = col;
delta_idx += 1;
}
}
col_indices
}
impl Serialize for SparseMatrix {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let col_deltas =
encode_col_deltas(&self.col_indices, &self.new_row_indices, self.values.len());
let mut state = serializer.serialize_struct("SparseMatrix", 5)?;
state.serialize_field("num_rows", &self.num_rows)?;
state.serialize_field("num_cols", &self.num_cols)?;
state.serialize_field("new_row_indices", &self.new_row_indices)?;
state.serialize_field("col_deltas", &col_deltas)?;
state.serialize_field("values", &self.values)?;
state.end()
}
}
impl<'de> Deserialize<'de> for SparseMatrix {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
enum Field {
NumRows,
NumCols,
NewRowIndices,
ColDeltas,
Values,
}
struct SparseMatrixVisitor;
impl<'de> Visitor<'de> for SparseMatrixVisitor {
type Value = SparseMatrix;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("struct SparseMatrix")
}
fn visit_seq<V>(self, mut seq: V) -> Result<SparseMatrix, V::Error>
where
V: SeqAccess<'de>,
{
let num_rows = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(0, &self))?;
let num_cols = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(1, &self))?;
let new_row_indices: Vec<u32> = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(2, &self))?;
let col_deltas: Vec<u32> = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(3, &self))?;
let values: Vec<InternedFieldElement> = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(4, &self))?;
let col_indices = decode_col_deltas(&col_deltas, &new_row_indices, values.len());
Ok(SparseMatrix {
num_rows,
num_cols,
new_row_indices,
col_indices,
values,
})
}
fn visit_map<V>(self, mut map: V) -> Result<SparseMatrix, V::Error>
where
V: serde::de::MapAccess<'de>,
{
let mut num_rows = None;
let mut num_cols = None;
let mut new_row_indices: Option<Vec<u32>> = None;
let mut col_deltas: Option<Vec<u32>> = None;
let mut values: Option<Vec<InternedFieldElement>> = None;
while let Some(key) = map.next_key()? {
match key {
Field::NumRows => {
if num_rows.is_some() {
return Err(serde::de::Error::duplicate_field("num_rows"));
}
num_rows = Some(map.next_value()?);
}
Field::NumCols => {
if num_cols.is_some() {
return Err(serde::de::Error::duplicate_field("num_cols"));
}
num_cols = Some(map.next_value()?);
}
Field::NewRowIndices => {
if new_row_indices.is_some() {
return Err(serde::de::Error::duplicate_field("new_row_indices"));
}
new_row_indices = Some(map.next_value()?);
}
Field::ColDeltas => {
if col_deltas.is_some() {
return Err(serde::de::Error::duplicate_field("col_deltas"));
}
col_deltas = Some(map.next_value()?);
}
Field::Values => {
if values.is_some() {
return Err(serde::de::Error::duplicate_field("values"));
}
values = Some(map.next_value()?);
}
}
}
let num_rows =
num_rows.ok_or_else(|| serde::de::Error::missing_field("num_rows"))?;
let num_cols =
num_cols.ok_or_else(|| serde::de::Error::missing_field("num_cols"))?;
let new_row_indices = new_row_indices
.ok_or_else(|| serde::de::Error::missing_field("new_row_indices"))?;
let col_deltas =
col_deltas.ok_or_else(|| serde::de::Error::missing_field("col_deltas"))?;
let values = values.ok_or_else(|| serde::de::Error::missing_field("values"))?;
let col_indices = decode_col_deltas(&col_deltas, &new_row_indices, values.len());
Ok(SparseMatrix {
num_rows,
num_cols,
new_row_indices,
col_indices,
values,
})
}
}
const FIELDS: &[&str] = &[
"num_rows",
"num_cols",
"new_row_indices",
"col_deltas",
"values",
];
deserializer.deserialize_struct("SparseMatrix", FIELDS, SparseMatrixVisitor)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct HydratedSparseMatrix<'a> {
pub matrix: &'a SparseMatrix,
interner: &'a Interner,
}
impl SparseMatrix {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
num_rows: rows,
num_cols: cols,
new_row_indices: vec![0; rows],
col_indices: Vec::new(),
values: Vec::new(),
}
}
pub const fn hydrate<'a>(&'a self, interner: &'a Interner) -> HydratedSparseMatrix<'a> {
HydratedSparseMatrix {
matrix: self,
interner,
}
}
pub fn num_entries(&self) -> usize {
self.values.len()
}
pub fn delta_encoding_stats(&self) -> DeltaEncodingStats {
let deltas = encode_col_deltas(&self.col_indices, &self.new_row_indices, self.values.len());
let absolute_bytes: usize = self.col_indices.iter().map(|&v| varint_size(v)).sum();
let delta_bytes: usize = deltas.iter().map(|&v| varint_size(v)).sum();
DeltaEncodingStats {
total_entries: self.col_indices.len(),
absolute_bytes,
delta_bytes,
}
}
pub fn grow(&mut self, rows: usize, cols: usize) {
assert!(rows >= self.num_rows);
assert!(cols >= self.num_cols);
self.num_rows = rows;
self.num_cols = cols;
self.new_row_indices.resize(rows, self.values.len() as u32);
}
pub fn set(&mut self, row: usize, col: usize, value: InternedFieldElement) {
assert!(row < self.num_rows, "row index out of bounds");
assert!(col < self.num_cols, "column index out of bounds");
let row_range = self.row_range(row);
let cols = &self.col_indices[row_range.clone()];
match cols.binary_search(&(col as u32)) {
Ok(_) => {
unreachable!("Duplicate column {col} in row {row}");
}
Err(i) => {
let i = i + row_range.start;
self.col_indices.insert(i, col as u32);
self.values.insert(i, value);
for index in &mut self.new_row_indices[row + 1..] {
*index += 1;
}
}
}
}
pub fn iter_row(
&self,
row: usize,
) -> impl Iterator<Item = (usize, InternedFieldElement)> + use<'_> {
let row_range = self.row_range(row);
let cols = self.col_indices[row_range.clone()].iter().copied();
let values = self.values[row_range].iter().copied();
cols.zip(values).map(|(col, value)| (col as usize, value))
}
pub fn iter(&self) -> impl Iterator<Item = ((usize, usize), InternedFieldElement)> + use<'_> {
(0..self.new_row_indices.len()).flat_map(|row| {
self.iter_row(row)
.map(move |(col, value)| ((row, col), value))
})
}
fn row_range(&self, row: usize) -> Range<usize> {
let start = *self
.new_row_indices
.get(row)
.expect("Row index out of bounds") as usize;
let end = self
.new_row_indices
.get(row + 1)
.map_or(self.values.len(), |&v| v as usize);
start..end
}
pub fn transpose(&self) -> SparseMatrix {
let nnz = self.values.len();
let mut entries: Vec<(u32, u32, InternedFieldElement)> = Vec::with_capacity(nnz);
for row in 0..self.num_rows {
let range = self.row_range(row);
for i in range {
entries.push((self.col_indices[i], row as u32, self.values[i]));
}
}
entries.par_sort_unstable_by_key(|&(new_row, new_col, _)| (new_row, new_col));
debug_assert!(
entries
.windows(2)
.all(|w| (w[0].0, w[0].1) != (w[1].0, w[1].1)),
"Duplicate (row, col) entries in sparse matrix transpose"
);
let mut new_row_indices = Vec::with_capacity(self.num_cols);
let mut col_indices = Vec::with_capacity(nnz);
let mut values = Vec::with_capacity(nnz);
let mut entry_idx = 0;
for row in 0..self.num_cols {
new_row_indices.push(entry_idx as u32);
while entry_idx < entries.len() && entries[entry_idx].0 == row as u32 {
col_indices.push(entries[entry_idx].1);
values.push(entries[entry_idx].2);
entry_idx += 1;
}
}
SparseMatrix {
num_rows: self.num_cols,
num_cols: self.num_rows,
new_row_indices,
col_indices,
values,
}
}
pub fn remap_columns<F>(&mut self, remap_fn: F)
where
F: Fn(usize) -> usize + Send + Sync,
{
self.col_indices.par_iter_mut().for_each(|col| {
*col = remap_fn(*col as usize) as u32;
});
for row in 0..self.num_rows {
let start = self.new_row_indices[row] as usize;
let end = self
.new_row_indices
.get(row + 1)
.map_or(self.col_indices.len(), |&v| v as usize);
let row_cols = &mut self.col_indices[start..end];
let row_vals = &mut self.values[start..end];
let mut pairs: Vec<_> = row_cols
.iter()
.zip(row_vals.iter())
.map(|(&c, &v)| (c, v))
.collect();
pairs.sort_unstable_by_key(|(c, _)| *c);
for (i, (c, v)) in pairs.into_iter().enumerate() {
row_cols[i] = c;
row_vals[i] = v;
}
}
}
}
impl HydratedSparseMatrix<'_> {
pub fn iter_row(&self, row: usize) -> impl Iterator<Item = (usize, FieldElement)> + use<'_> {
self.matrix.iter_row(row).map(|(col, value)| {
(
col,
self.interner.get(value).expect("Value not in interner."),
)
})
}
pub fn iter(&self) -> impl Iterator<Item = ((usize, usize), FieldElement)> + use<'_> {
self.matrix.iter().map(|((i, j), v)| {
(
(i, j),
self.interner.get(v).expect("Value not in interner."),
)
})
}
}
impl Mul<&[FieldElement]> for HydratedSparseMatrix<'_> {
type Output = Vec<FieldElement>;
fn mul(self, rhs: &[FieldElement]) -> Self::Output {
assert_eq!(
self.matrix.num_cols,
rhs.len(),
"Vector length does not match number of columns."
);
(0..self.matrix.num_rows)
.into_par_iter()
.map(|row| {
self.iter_row(row)
.map(|(col, value)| value * rhs[col])
.fold(FieldElement::zero(), |acc, x| acc + x)
})
.collect()
}
}
impl Mul<HydratedSparseMatrix<'_>> for &[FieldElement] {
type Output = Vec<FieldElement>;
fn mul(self, rhs: HydratedSparseMatrix<'_>) -> Self::Output {
assert_eq!(
self.len(),
rhs.matrix.num_rows,
"Vector length does not match number of rows."
);
let mut result = vec![FieldElement::zero(); rhs.matrix.num_cols];
for ((i, j), value) in rhs.iter() {
result[j] += value * self[i];
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_delta_encoding_roundtrip() {
let col_indices = vec![3, 15, 100, 5, 50, 200];
let new_row_indices = vec![0, 3];
let total_entries = 6;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_values() {
let col_indices = vec![3, 15, 100];
let new_row_indices = vec![0];
let total_entries = 3;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert_eq!(deltas, vec![3, 12, 85]);
}
#[test]
fn test_delta_encoding_multiple_rows() {
let col_indices = vec![0, 10, 20, 5, 15];
let new_row_indices = vec![0, 3];
let total_entries = 5;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert_eq!(deltas, vec![0, 10, 10, 5, 10]);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_empty_row() {
let col_indices = vec![5, 10];
let new_row_indices = vec![0, 0, 2];
let total_entries = 2;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_single_entry() {
let col_indices = vec![42];
let new_row_indices = vec![0];
let total_entries = 1;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert_eq!(deltas, vec![42]);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_single_column_per_row() {
let col_indices = vec![0, 5, 100];
let new_row_indices = vec![0, 1, 2];
let total_entries = 3;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert_eq!(deltas, vec![0, 5, 100]);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_consecutive_columns() {
let col_indices = vec![10, 11, 12, 13];
let new_row_indices = vec![0];
let total_entries = 4;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert_eq!(deltas, vec![10, 1, 1, 1]);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_all_rows_empty() {
let col_indices: Vec<u32> = vec![];
let new_row_indices = vec![0, 0, 0];
let total_entries = 0;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert!(deltas.is_empty());
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert!(decoded.is_empty());
}
#[test]
fn test_delta_encoding_last_row_empty() {
let col_indices = vec![1, 2, 7];
let new_row_indices = vec![0, 2, 3];
let total_entries = 3;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_only_last_row_non_empty() {
let col_indices = vec![3, 8];
let new_row_indices = vec![0, 0, 0, 2];
let total_entries = 2;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert_eq!(deltas, vec![3, 5]);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_delta_encoding_large_column_indices() {
let col_indices = vec![1_000_000, 1_000_001, 2_000_000];
let new_row_indices = vec![0];
let total_entries = 3;
let deltas = encode_col_deltas(&col_indices, &new_row_indices, total_entries);
assert_eq!(deltas, vec![1_000_000, 1, 999_999]);
let decoded = decode_col_deltas(&deltas, &new_row_indices, total_entries);
assert_eq!(col_indices, decoded);
}
#[test]
fn test_sparse_matrix_serde_roundtrip() {
let mut interner = Interner::new();
let val1 = interner.intern(FieldElement::from(1u64));
let val2 = interner.intern(FieldElement::from(2u64));
let val3 = interner.intern(FieldElement::from(3u64));
let mut matrix = SparseMatrix::new(3, 100);
matrix.grow(3, 100);
matrix.set(0, 5, val1);
matrix.set(0, 20, val2);
matrix.set(1, 50, val3);
let serialized = postcard::to_allocvec(&matrix).expect("serialization failed");
let deserialized: SparseMatrix =
postcard::from_bytes(&serialized).expect("deserialization failed");
assert_eq!(matrix, deserialized);
}
#[test]
fn test_delta_encoding_size_reduction() {
let mut interner = Interner::new();
let val = interner.intern(FieldElement::from(1u64));
let mut matrix = SparseMatrix::new(10, 1000);
matrix.grow(10, 1000);
for row in 0..10 {
for col_offset in 0..20 {
matrix.set(row, row * 50 + col_offset, val);
}
}
let serialized = postcard::to_allocvec(&matrix).expect("serialization failed");
let col_count = matrix.col_indices.len();
let naive_col_bytes = col_count * 4;
let actual_bytes = serialized.len();
assert!(
actual_bytes < naive_col_bytes,
"delta encoding should reduce size: actual {} vs naive col bytes {}",
actual_bytes,
naive_col_bytes
);
}
}