#![allow(clippy::ptr_arg)]
use ff::PrimeField;
use serde::{Deserialize, Serialize};
use super::{
matrix,
matrix::{
invert, is_identity, is_invertible, is_square, left_apply_matrix, mat_mul, minor, transpose,
Matrix,
},
};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct MdsMatrices<F: PrimeField> {
pub m: Matrix<F>,
pub m_inv: Matrix<F>,
pub m_hat: Matrix<F>,
pub m_hat_inv: Matrix<F>,
pub m_prime: Matrix<F>,
pub m_double_prime: Matrix<F>,
}
pub(crate) fn derive_mds_matrices<F: PrimeField>(m: Matrix<F>) -> MdsMatrices<F> {
let m_inv = invert(&m).unwrap(); let m_hat = minor(&m, 0, 0);
let m_hat_inv = invert(&m_hat).unwrap(); let m_prime = make_prime(&m);
let m_double_prime = make_double_prime(&m, &m_hat_inv);
MdsMatrices {
m,
m_inv,
m_hat,
m_hat_inv,
m_prime,
m_double_prime,
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SparseMatrix<F: PrimeField> {
pub w_hat: Vec<F>,
pub v_rest: Vec<F>,
}
impl<F: PrimeField> SparseMatrix<F> {
pub fn new_from_ref(m_double_prime: &Matrix<F>) -> Self {
assert!(Self::is_sparse_matrix(m_double_prime));
let size = matrix::rows(m_double_prime);
let w_hat = (0..size).map(|i| m_double_prime[i][0]).collect::<Vec<_>>();
let v_rest = m_double_prime[0][1..].to_vec();
Self { w_hat, v_rest }
}
pub fn is_sparse_matrix(m: &Matrix<F>) -> bool {
is_square(m) && is_identity(&minor(m, 0, 0))
}
}
pub(crate) fn factor_to_sparse_matrixes<F: PrimeField>(
base_matrix: &Matrix<F>,
n: usize,
) -> (Matrix<F>, Vec<SparseMatrix<F>>) {
let (pre_sparse, sparse_matrices) = factor_to_sparse_matrices(base_matrix, n);
let sparse_matrixes = sparse_matrices
.iter()
.map(|m| SparseMatrix::<F>::new_from_ref(m))
.collect::<Vec<_>>();
(pre_sparse, sparse_matrixes)
}
pub(crate) fn factor_to_sparse_matrices<F: PrimeField>(
base_matrix: &Matrix<F>,
n: usize,
) -> (Matrix<F>, Vec<Matrix<F>>) {
let (pre_sparse, mut all) =
(0..n).fold((base_matrix.clone(), Vec::new()), |(curr, mut acc), _| {
let derived = derive_mds_matrices(curr);
acc.push(derived.m_double_prime);
let new = mat_mul(base_matrix, &derived.m_prime).unwrap();
(new, acc)
});
all.reverse();
(pre_sparse, all)
}
pub(crate) fn generate_mds<F: PrimeField>(t: usize) -> Matrix<F> {
let xs: Vec<F> = (0..t as u64).map(F::from).collect();
let ys: Vec<F> = (t as u64..2 * t as u64).map(F::from).collect();
let matrix = xs
.iter()
.map(|xs_item| {
ys.iter()
.map(|ys_item| {
let mut tmp = *xs_item;
tmp.add_assign(ys_item);
tmp.invert().unwrap()
})
.collect()
})
.collect();
assert!(is_invertible(&matrix));
assert_eq!(matrix, transpose(&matrix));
matrix
}
fn make_prime<F: PrimeField>(m: &Matrix<F>) -> Matrix<F> {
m.iter()
.enumerate()
.map(|(i, row)| match i {
0 => {
let mut new_row = vec![F::ZERO; row.len()];
new_row[0] = F::ONE;
new_row
}
_ => {
let mut new_row = vec![F::ZERO; row.len()];
new_row[1..].copy_from_slice(&row[1..]);
new_row
}
})
.collect()
}
fn make_double_prime<F: PrimeField>(m: &Matrix<F>, m_hat_inv: &Matrix<F>) -> Matrix<F> {
let (v, w) = make_v_w(m);
let w_hat = left_apply_matrix(m_hat_inv, &w);
m.iter()
.enumerate()
.map(|(i, row)| match i {
0 => {
let mut new_row = Vec::with_capacity(row.len());
new_row.push(row[0]);
new_row.extend(&v);
new_row
}
_ => {
let mut new_row = vec![F::ZERO; row.len()];
new_row[0] = w_hat[i - 1];
new_row[i] = F::ONE;
new_row
}
})
.collect()
}
fn make_v_w<F: PrimeField>(m: &Matrix<F>) -> (Vec<F>, Vec<F>) {
let v = m[0][1..].to_vec();
let w = m.iter().skip(1).map(|column| column[0]).collect();
(v, w)
}