use thiserror::Error;
use crate::{HashableMessage, Integer, elgamal::Ciphertext};
#[derive(Debug, Clone)]
pub struct Matrix<T>
where
T: Clone + Default + std::fmt::Debug,
{
rows: Vec<Vec<T>>,
}
#[derive(Error, Debug)]
pub enum MatrixError {
#[error("The size {0} of the vector must the product of m={1} et n={2}")]
WrongVectorSize(usize, usize, usize),
#[error("The Matrix is malformed")]
MalformedMatrix,
#[error("The Matrices have different size")]
NotSameSize,
}
impl<T: Clone + Default + std::fmt::Debug> Matrix<T> {
pub fn get_matrix_dimensions(upper_n: usize) -> (usize, usize) {
let mut m = 1;
let mut n = upper_n;
let mut i = (upper_n as f64).sqrt() as usize;
while i > 1 {
if upper_n.is_multiple_of(i) {
m = i;
n = upper_n / i;
return (m, n);
}
i -= 1;
}
(m, n)
}
fn new(m: usize, n: usize) -> Self {
Self {
rows: vec![vec![T::default(); n]; m],
}
}
pub fn to_matrix(v: &[T], (m, n): (usize, usize)) -> Result<Self, MatrixError> {
if v.len() != m * n {
return Err(MatrixError::WrongVectorSize(v.len(), m, n));
}
let mut res = Self::new(m, n);
for i in 0..m {
for j in 0..n {
res.set_elt(&v[i * n + j], i, j);
}
}
Ok(res)
}
pub fn from_rows(rows: &[Vec<T>]) -> Result<Self, MatrixError> {
let res = Self {
rows: rows.to_vec(),
};
match res.is_malformed() {
true => Err(MatrixError::MalformedMatrix),
false => Ok(res),
}
}
pub fn transpose(&self) -> Result<Self, MatrixError> {
if self.is_malformed() {
return Err(MatrixError::MalformedMatrix);
}
let m = self.nb_rows();
let n = self.nb_columns();
let mut res = Self::new(n, m);
for i in 0..m {
for j in 0..n {
res.set_elt(self.elt(i, j), j, i);
}
}
Ok(res)
}
pub fn elt(&self, i: usize, j: usize) -> &T {
&self.rows[i][j]
}
pub fn elt_mut(&mut self, i: usize, j: usize) -> &mut T {
&mut self.rows[i][j]
}
pub fn set_elt(&mut self, value: &T, i: usize, j: usize) {
self.elt_mut(i, j).clone_from(value)
}
pub fn nb_rows(&self) -> usize {
self.rows.len()
}
pub fn nb_columns(&self) -> usize {
self.rows[0].len()
}
pub fn columns_iter(&self) -> impl Iterator<Item = Vec<&T>> + '_ {
ColIter {
matrix: self,
index: 0,
}
}
pub fn columns_cloned_iter(&self) -> impl Iterator<Item = Vec<T>> + '_ {
self.columns_iter()
.map(|e| e.into_iter().cloned().collect::<Vec<T>>())
}
pub fn rows_iter(&self) -> impl Iterator<Item = &Vec<T>> + '_ {
self.rows.iter()
}
pub fn rows_cloned_iter(&self) -> impl Iterator<Item = Vec<T>> + '_ {
self.rows_iter().map(|e| e.to_vec())
}
pub fn column(&self, j: usize) -> Vec<&T> {
self.rows_iter().map(|r| &r[j]).collect()
}
pub fn row(&self, i: usize) -> Vec<&T> {
self.rows[i].iter().collect::<Vec<_>>()
}
#[allow(dead_code)]
pub fn row_cloned(&self, i: usize) -> Vec<T> {
self.row(i).into_iter().cloned().collect()
}
#[allow(dead_code)]
pub fn column_cloned(&self, i: usize) -> Vec<T> {
self.column(i).into_iter().cloned().collect()
}
pub fn is_malformed(&self) -> bool {
if self.rows.is_empty() {
return false;
}
let size = self.rows[0].len();
!self.rows_iter().all(|r| r.len() == size)
}
}
impl Matrix<Integer> {
#[allow(dead_code)]
pub fn entrywise_product(&self, other: &Self) -> Result<Self, MatrixError> {
if self.nb_rows() != other.nb_rows() || self.nb_columns() != other.nb_columns() {
return Err(MatrixError::NotSameSize);
}
let mut res = Self::new(self.nb_rows(), self.nb_columns());
for i in 1..self.nb_rows() {
for j in 1..self.nb_columns() {
res.set_elt(&Integer::from(self.elt(i, j) * other.elt(i, j)), j, i);
}
}
Ok(res)
}
}
impl<'a> From<&'a Matrix<Ciphertext>> for HashableMessage<'a> {
fn from(value: &'a Matrix<Ciphertext>) -> Self {
HashableMessage::from(
value
.rows_iter()
.map(HashableMessage::from)
.collect::<Vec<_>>(),
)
}
}
struct ColIter<'a, T>
where
T: Clone + Default + std::fmt::Debug,
{
matrix: &'a Matrix<T>,
index: usize,
}
impl<'a, T> Iterator for ColIter<'a, T>
where
T: Clone + Default + std::fmt::Debug,
{
type Item = Vec<&'a T>;
fn next(&mut self) -> Option<Self::Item> {
if self.index < self.matrix.nb_columns() {
let i = self.index;
self.index += 1;
return Some(self.matrix.column(i));
}
None
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_get_matrix_dimensions() {
assert_eq!(Matrix::<Integer>::get_matrix_dimensions(12), (3, 4));
assert_eq!(Matrix::<Integer>::get_matrix_dimensions(18), (3, 6));
assert_eq!(Matrix::<Integer>::get_matrix_dimensions(23), (1, 23));
}
#[test]
fn test_from_rows() {
let rows = vec![vec![1, 2, 3], vec![4, 5, 6]];
let m_r = Matrix::from_rows(&rows);
assert!(m_r.is_ok());
let m = m_r.unwrap();
assert_eq!(m.nb_rows(), 2);
assert_eq!(m.nb_columns(), 3);
let mut l_iter = m.rows_iter();
assert_eq!(l_iter.next(), Some(&vec![1, 2, 3]));
assert_eq!(l_iter.next(), Some(&vec![4, 5, 6]));
assert!(l_iter.next().is_none())
}
#[test]
fn test_matrix() {
let matrix = Matrix::to_matrix(&[1, 2, 3, 4, 5, 6], (2, 3)).unwrap();
assert!(!matrix.is_malformed());
assert_eq!(matrix.nb_rows(), 2);
assert_eq!(matrix.nb_columns(), 3);
assert_eq!(matrix.column(0), vec![&1, &4]);
assert_eq!(matrix.column(1), vec![&2, &5]);
assert_eq!(matrix.column(2), vec![&3, &6]);
assert_eq!(matrix.row(0), vec![&1, &2, &3]);
assert_eq!(matrix.row(1), vec![&4, &5, &6]);
let m2 = matrix.transpose().unwrap();
assert_eq!(m2.nb_rows(), 3);
assert_eq!(m2.nb_columns(), 2);
assert_eq!(m2.row(0), vec![&1, &4]);
assert_eq!(m2.row(1), vec![&2, &5]);
assert_eq!(m2.row(2), vec![&3, &6]);
assert_eq!(m2.column(0), vec![&1, &2, &3]);
assert_eq!(m2.column(1), vec![&4, &5, &6]);
}
#[test]
fn test_matrix_iter() {
let matrix = Matrix::to_matrix(&[1, 2, 3, 4, 5, 6], (2, 3)).unwrap();
let mut c_iter = matrix.columns_iter();
assert_eq!(c_iter.next(), Some(vec![&1, &4]));
assert_eq!(c_iter.next(), Some(vec![&2, &5]));
assert_eq!(c_iter.next(), Some(vec![&3, &6]));
assert!(c_iter.next().is_none());
let mut l_iter = matrix.rows_iter();
assert_eq!(l_iter.next(), Some(&vec![1, 2, 3]));
assert_eq!(l_iter.next(), Some(&vec![4, 5, 6]));
assert!(l_iter.next().is_none())
}
}