use std::fmt::Debug;
use crate::error::{UtilsError, UtilsResult};
pub struct Matrix<T> {
columns: usize,
rows: usize,
data: Vec<T>,
}
impl<T: Default + Clone + Debug> Matrix<T> {
pub fn new(columns: usize, rows: usize) -> Self {
Matrix {
columns,
rows,
data: vec![T::default(); columns * rows],
}
}
pub fn get(&self, column: usize, row: usize) -> UtilsResult<T> {
let pos = self.get_pos(column, row);
if pos.is_err() {
return Err(pos.unwrap_err());
}
match self.data.get(pos.unwrap()) {
None => Err(UtilsError::NotFound),
Some(val) => Ok(val.clone()),
}
}
pub fn set(&mut self, column: usize, row: usize, val: T) -> UtilsResult<()> {
let pos = self.get_pos(column, row);
if pos.is_err() {
return Err(pos.unwrap_err());
}
self.data[pos.unwrap()] = val;
Ok(())
}
pub fn get_shape(&self) -> Vec<usize> {
vec![self.columns, self.rows]
}
pub fn get_data(&self) -> UtilsResult<Vec<T>> {
Ok(self.data.to_vec())
}
pub fn set_data(&mut self, data: Vec<T>) -> UtilsResult<()> {
if data.len() != self.columns * self.rows {
return Err(UtilsError::WrongSize);
}
self.data = data;
Ok(())
}
pub fn transpose(&self) -> UtilsResult<Matrix<T>> {
let mut matrix = Matrix::new(self.rows, self.columns);
for column in 0..self.columns {
for row in 0..self.rows {
let val = self.get(column, row);
if val.is_err() {
return Err(val.unwrap_err());
}
matrix.set(row, column, val.unwrap());
}
}
Ok(matrix)
}
pub fn slice(&self, from: usize, to: usize) -> UtilsResult<Matrix<T>> {
let delta = to - from;
let mut matrix = Matrix::new(self.columns, delta);
for row in 0..delta {
for column in 0..self.columns {
let val = self.get(column, delta + row);
if val.is_err() {
return Err(val.unwrap_err());
}
matrix.set(column, row, val.unwrap());
}
}
Ok(matrix)
}
fn get_pos(&self, column: usize, row: usize) -> UtilsResult<usize> {
if column >= self.columns || row >= self.rows {
return Err(UtilsError::WrongPosition);
}
Ok(row * self.columns + column)
}
}
#[cfg(test)]
mod tests {
use crate::matrix::Matrix;
#[test]
fn it_works() {
let mut matrix = Matrix::new(4, 3);
let result = matrix.set(1, 2, 1.0);
assert_eq!(result.is_ok(), true);
assert_eq!(matrix.get(1, 2).unwrap(), 1.0);
assert_eq!(matrix.get(3, 3).is_ok(), false);
assert_eq!(matrix.get(2, 4).is_ok(), false);
}
}