use nalgebra::DMatrix;
use serde::{Deserialize, Serialize};
use super::error::NetworkError;
#[derive(Clone, Serialize, Deserialize)]
pub struct DMat {
data: DMatrix<f32>,
}
impl DMat {
pub(crate) fn new(rows: usize, cols: usize, data: &[f32]) -> Self {
Self {
data: DMatrix::from_row_slice(rows, cols, data),
}
}
pub(crate) fn mul_new(other: &DMat, another: &DMat) -> DMat {
DMat {
data: &other.data * &another.data,
}
}
pub(crate) fn zeros(rows: usize, cols: usize) -> Self {
Self {
data: DMatrix::zeros(rows, cols),
}
}
pub(crate) fn transpose(&self) -> DMat {
Self {
data: self.data.transpose(),
}
}
#[inline]
pub fn rows(&self) -> usize {
self.data.nrows()
}
#[inline]
pub fn cols(&self) -> usize {
self.data.ncols()
}
#[inline]
pub(crate) fn at(&self, i: usize, j: usize) -> f32 {
self.data[(i, j)]
}
#[inline]
pub(crate) fn set(&mut self, i: usize, j: usize, value: f32) {
self.data[(i, j)] = value;
}
#[inline]
pub(crate) fn add(&mut self, other: &DMat) {
self.data += &other.data;
}
#[inline]
pub(crate) fn scale(&mut self, factor: f32) {
self.data *= factor;
}
#[inline]
pub(crate) fn mul_elem(&mut self, other: &DMat) {
self.data.component_mul_assign(&other.data);
}
pub(crate) fn slice(&self, i: usize, k: usize, j: usize, l: usize) -> DMat {
let rows = k - i;
let cols = l - j;
DMat {
data: DMatrix::from_fn(rows, cols, |row, col| self.data[(i + row, j + col)]),
}
}
pub(crate) fn get_row(&self, i: usize) -> Vec<f32> {
self.data.row(i).iter().cloned().collect()
}
pub(crate) fn set_row(&mut self, i: usize, src: &[f32]) {
for (j, &value) in src.iter().enumerate() {
self.set(i, j, value);
}
}
pub(crate) fn clip(&mut self, threshold: f32) {
if threshold > 0.0 {
let norm = self.norm(2.0);
if norm > threshold {
let scale = threshold / norm;
self.scale(scale);
}
}
}
pub(crate) fn norm(&self, norm_type: f32) -> f32 {
if norm_type == 1.0 {
self.data.iter().map(|x| x.abs()).sum::<f32>()
} else if norm_type == 2.0 {
(self.data.iter().map(|x| x * x).sum::<f32>()).sqrt()
} else {
panic!("Unsupported norm type");
}
}
pub(crate) fn set_column_sum(&mut self, other: &DMat) {
for i in 0..self.rows() {
self.data[(i, 0)] = other.data.column(i).iter().sum();
}
}
pub(crate) fn apply<F>(&mut self, func: F)
where
F: Fn(f32) -> f32,
{
self.data.apply(|x| *x = func(*x));
}
pub(crate) fn apply_with_indices<F>(&mut self, mut f: F)
where
F: FnMut(usize, usize, &mut f32),
{
for i in 0..self.rows() {
for j in 0..self.cols() {
f(i, j, &mut self.data[(i, j)]);
}
}
}
}
pub struct DenseMatrix {
rows: usize,
cols: usize,
data: Vec<f32>,
}
impl DenseMatrix {
pub fn new(rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
data: Vec::new(),
}
}
pub fn data(mut self, data: &[f32]) -> Self {
self.data = data.to_vec();
self
}
pub fn random(mut self) -> Self {
self.data = vec![0.0; self.rows * self.cols];
for i in 0..self.rows * self.cols {
self.data[i] = rand::random();
}
self
}
fn validate(&self) -> Result<(), NetworkError> {
if self.rows == 0 || self.cols == 0 {
return Err(NetworkError::ConfigError(format!(
"Rows:{} and columns:{} must be greater than zero",
self.rows, self.cols
)));
}
if self.data.len() != self.rows * self.cols {
return Err(NetworkError::ConfigError(format!(
"Data length:{} does not match matrix dimensions:{}",
self.data.len(),
self.rows * self.cols
)));
}
Ok(())
}
pub fn build(self) -> Result<DMat, NetworkError> {
self.validate()?;
Ok(DMat::new(self.rows, self.cols, &self.data))
}
}
#[cfg(test)]
mod tests {
use crate::util;
use super::*;
#[test]
fn test_at() {
let matrix = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(matrix.at(1, 1), 4.0);
}
#[test]
fn test_add() {
let mut matrix = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let other = DMat::new(2, 2, &[4.0, 3.0, 2.0, 1.0]);
matrix.add(&other);
assert_eq!(util::flatten(&matrix), &[5.0, 5.0, 5.0, 5.0]);
}
#[test]
fn test_scale() {
let mut matrix = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
matrix.scale(2.0);
assert_eq!(util::flatten(&matrix), &[2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_transpose() {
let matrix = DMat::new(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let transposed = matrix.transpose();
assert_eq!((transposed.rows(), transposed.cols()), (3, 2));
assert_eq!(transposed.at(0, 1), 4.0);
}
#[test]
fn test_norm() {
let matrix = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
assert_eq!(matrix.norm(1.0), 10.0); assert!((matrix.norm(2.0) - 5.477).abs() < 0.001); }
#[test]
fn test_slice() {
let matrix = DMat::new(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
let submatrix = matrix.slice(1, 2, 1, 2);
assert_eq!(util::flatten(&submatrix), &[5.0]);
}
#[test]
fn test_mul_elem() {
let mut matrix_a = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let matrix_b = DMat::new(2, 2, &[5.0, 6.0, 7.0, 8.0]);
matrix_a.mul_elem(&matrix_b);
assert_eq!(util::flatten(&matrix_a), &[5.0, 12.0, 21.0, 32.0]);
let mut matrix_c = DMat::new(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let matrix_d = DMat::new(2, 3, &[6.0, 5.0, 4.0, 3.0, 2.0, 1.0]);
matrix_c.mul_elem(&matrix_d);
assert_eq!(util::flatten(&matrix_c), &[6.0, 10.0, 12.0, 12.0, 10.0, 6.0]);
let mut matrix_e = DMat::new(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
let matrix_ones = DMat::new(3, 3, &[1.0; 9]); matrix_e.mul_elem(&matrix_ones);
assert_eq!(util::flatten(&matrix_e), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
let mut matrix_f = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let zero_matrix = DMat::zeros(2, 2);
matrix_f.mul_elem(&zero_matrix);
assert_eq!(util::flatten(&matrix_f), &[0.0, 0.0, 0.0, 0.0]);
let mut matrix_g = DMat::new(1, 4, &[1.0, 2.0, 3.0, 4.0]);
let matrix_h = DMat::new(1, 4, &[4.0, 3.0, 2.0, 1.0]);
matrix_g.mul_elem(&matrix_h);
assert_eq!(util::flatten(&matrix_g), &[4.0, 6.0, 6.0, 4.0]);
let mut matrix_i = DMat::new(4, 1, &[1.0, 2.0, 3.0, 4.0]);
let matrix_j = DMat::new(4, 1, &[4.0, 3.0, 2.0, 1.0]);
matrix_i.mul_elem(&matrix_j);
assert_eq!(util::flatten(&matrix_i), &[4.0, 6.0, 6.0, 4.0]);
}
#[test]
fn test_clone() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let mut matrix = DMat::new(2, 2, &data);
let mut clone = matrix.clone();
assert_eq!((matrix.rows(), matrix.cols()), (clone.rows(), clone.cols()));
assert_eq!(matrix.at(0, 0), clone.at(0, 0));
clone.data[(0, 0)] = 5.0;
assert_eq!(clone.at(0, 0), 5.0); assert_eq!(matrix.at(0, 0), 1.0);
matrix.data[(1, 1)] = 6.0;
assert_eq!(matrix.at(1, 1), 6.0); assert_eq!(clone.at(1, 1), 4.0); }
#[test]
fn test_set_column_sum() {
let mut result_matrix = DMat::new(3, 1, &[0.0, 0.0, 0.0]);
let source_matrix = DMat::new(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
result_matrix.set_column_sum(&source_matrix);
assert_eq!(util::flatten(&result_matrix), &[12.0, 15.0, 18.0]);
}
#[test]
fn test_zeros() {
let matrix = DMat::zeros(2, 2);
assert_eq!((matrix.rows(), matrix.cols()), (2, 2));
for i in 0..2 {
for j in 0..2 {
assert_eq!(matrix.at(i, j), 0.0);
}
}
}
#[test]
fn test_clip() {
let mut matrix = DMat::new(2, 2, &[3.0, 4.0, 0.0, 0.0]);
matrix.clip(5.0);
let clipped = DMat::new(2, 2, &[3.0, 4.0, 0.0, 0.0]);
assert_eq!(util::flatten(&matrix), util::flatten(&clipped));
let mut matrix = DMat::new(2, 2, &[6.0, 8.0, 0.0, 0.0]);
matrix.clip(5.0);
let clipped = DMat::new(2, 2, &[3.0, 4.0, 0.0, 0.0]);
assert_eq!(util::flatten(&matrix), util::flatten(&clipped)); }
#[test]
fn test_apply() {
let mut matrix = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
matrix.apply(|x| x * 2.0);
assert_eq!(util::flatten(&matrix), &[2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_apply_with_indices() {
let mut matrix = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
matrix.apply_with_indices(|i, j, v| {
*v = (i + j) as f32;
});
assert_eq!(util::flatten(&matrix), &[0.0, 1.0, 1.0, 2.0]);
}
#[test]
fn test_set_row() {
let mut matrix = DMat::new(3, 3, &[0.0; 9]);
matrix.set_row(1, &[1.0, 2.0, 3.0]);
assert_eq!(matrix.get_row(1), vec![1.0, 2.0, 3.0]);
}
#[test]
fn test_get_row() {
let matrix = DMat::new(3, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
assert_eq!(matrix.get_row(1), vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_mul_new() {
let matrix_a = DMat::new(2, 2, &[1.0, 2.0, 3.0, 4.0]);
let matrix_b = DMat::new(2, 2, &[5.0, 6.0, 7.0, 8.0]);
let result = DMat::mul_new(&matrix_a, &matrix_b);
assert_eq!(util::flatten(&result), &[19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_dense_matrix() {
let matrix = DenseMatrix::new(2, 2).data(&[1.0, 2.0, 3.0, 4.0]);
assert_eq!(matrix.rows, 2);
assert_eq!(matrix.cols, 2);
assert_eq!(matrix.data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_dense_matrix_random() {
let matrix = DenseMatrix::new(2, 2).random().build().unwrap();
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 2);
assert_eq!(matrix.data.len(), 4);
}
#[test]
fn test_dense_matrix_validate() {
let matrix = DenseMatrix::new(2, 2).data(&[1.0, 2.0, 3.0, 4.0]).build();
assert!(matrix.is_ok());
let invalid_matrix = DenseMatrix::new(2, 2).data(&[1.0, 2.0]).build();
assert!(invalid_matrix.is_err());
if let Err(err) = invalid_matrix {
assert_eq!(err.to_string(), "Configuration error: Data length:2 does not match matrix dimensions:4");
}
}
}