use crate::gf;
#[derive(Debug, Clone)]
pub struct GfMatrix {
pub rows: usize,
pub cols: usize,
pub data: Vec<u16>,
}
impl GfMatrix {
pub fn zeros(rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
data: vec![0u16; rows * cols],
}
}
pub fn identity(n: usize) -> Self {
let mut m = Self::zeros(n, n);
for i in 0..n {
m.set(i, i, 1);
}
m
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> u16 {
self.data[row * self.cols + col]
}
#[inline]
pub fn set(&mut self, row: usize, col: usize, val: u16) {
self.data[row * self.cols + col] = val;
}
pub fn par2_encoding_matrix(input_count: usize, recovery_exponents: &[u32]) -> Self {
let total_rows = input_count + recovery_exponents.len();
let mut m = Self::zeros(total_rows, input_count);
let constants = par2_input_constants(input_count);
for i in 0..input_count {
m.set(i, i, 1);
}
for (r, &exp) in recovery_exponents.iter().enumerate() {
for (c, &constant) in constants.iter().enumerate() {
let val = gf::pow(constant, exp);
m.set(input_count + r, c, val);
}
}
m
}
pub fn select_rows(&self, row_indices: &[usize]) -> Self {
let mut result = Self::zeros(row_indices.len(), self.cols);
for (new_row, &old_row) in row_indices.iter().enumerate() {
let src_start = old_row * self.cols;
let dst_start = new_row * self.cols;
result.data[dst_start..dst_start + self.cols]
.copy_from_slice(&self.data[src_start..src_start + self.cols]);
}
result
}
pub fn invert(&self) -> Option<Self> {
assert_eq!(self.rows, self.cols, "Can only invert square matrices");
let n = self.rows;
let mut aug = Self::zeros(n, 2 * n);
for r in 0..n {
for c in 0..n {
aug.set(r, c, self.get(r, c));
}
aug.set(r, n + r, 1); }
for col in 0..n {
let mut pivot_row = None;
for r in col..n {
if aug.get(r, col) != 0 {
pivot_row = Some(r);
break;
}
}
let pivot_row = pivot_row?;
if pivot_row != col {
for c in 0..2 * n {
let tmp = aug.get(col, c);
aug.set(col, c, aug.get(pivot_row, c));
aug.set(pivot_row, c, tmp);
}
}
let pivot_val = aug.get(col, col);
let pivot_inv = gf::inv(pivot_val);
for c in 0..2 * n {
aug.set(col, c, gf::mul(aug.get(col, c), pivot_inv));
}
for r in 0..n {
if r == col {
continue;
}
let factor = aug.get(r, col);
if factor == 0 {
continue;
}
for c in 0..2 * n {
let val = gf::add(aug.get(r, c), gf::mul(factor, aug.get(col, c)));
aug.set(r, c, val);
}
}
}
let mut result = Self::zeros(n, n);
for r in 0..n {
for c in 0..n {
result.set(r, c, aug.get(r, n + c));
}
}
Some(result)
}
}
pub fn par2_input_constants(count: usize) -> Vec<u16> {
let mut constants = Vec::with_capacity(count);
let mut n: u32 = 0;
while constants.len() < count {
n += 1;
if n % 3 != 0 && n % 5 != 0 && n % 17 != 0 && n % 257 != 0 {
constants.push(gf::exp2(n));
}
}
constants
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_identity_inverse() {
let id = GfMatrix::identity(4);
let inv = id.invert().unwrap();
for r in 0..4 {
for c in 0..4 {
let expected = if r == c { 1 } else { 0 };
assert_eq!(inv.get(r, c), expected);
}
}
}
#[test]
fn test_inverse_roundtrip() {
let mut m = GfMatrix::zeros(3, 3);
m.set(0, 0, 1);
m.set(0, 1, 2);
m.set(0, 2, 3);
m.set(1, 0, 4);
m.set(1, 1, 5);
m.set(1, 2, 6);
m.set(2, 0, 7);
m.set(2, 1, 8);
m.set(2, 2, 10);
let inv = m.invert().unwrap();
for r in 0..3 {
for c in 0..3 {
let mut sum = 0u16;
for k in 0..3 {
sum = gf::add(sum, gf::mul(m.get(r, k), inv.get(k, c)));
}
let expected = if r == c { 1 } else { 0 };
assert_eq!(sum, expected, "M*M^-1 [{r},{c}] should be {expected}");
}
}
}
#[test]
fn test_vandermonde_invertible() {
let exponents = vec![0, 1, 2];
let m = GfMatrix::par2_encoding_matrix(3, &exponents);
let recovery = m.select_rows(&[3, 4, 5]);
assert!(
recovery.invert().is_some(),
"Vandermonde submatrix should be invertible"
);
}
#[test]
fn test_select_rows() {
let mut m = GfMatrix::zeros(4, 3);
for r in 0..4 {
for c in 0..3 {
m.set(r, c, (r * 10 + c) as u16);
}
}
let sub = m.select_rows(&[1, 3]);
assert_eq!(sub.rows, 2);
assert_eq!(sub.cols, 3);
assert_eq!(sub.get(0, 0), 10);
assert_eq!(sub.get(1, 2), 32);
}
#[test]
fn test_singular_matrix() {
let m = GfMatrix::zeros(3, 3);
assert!(m.invert().is_none());
let mut m = GfMatrix::zeros(2, 2);
m.set(0, 0, 1);
m.set(0, 1, 2);
m.set(1, 0, 1);
m.set(1, 1, 2);
assert!(m.invert().is_none());
}
}