use crate::{Backend, TruenoError};
use super::super::Matrix;
impl std::ops::Index<(usize, usize)> for Matrix<f32> {
type Output = f32;
fn index(&self, (row, col): (usize, usize)) -> &Self::Output {
&self.data[row * self.cols + col]
}
}
impl Matrix<f32> {
pub fn new(rows: usize, cols: usize) -> Self {
let backend = Backend::select_best();
Matrix { rows, cols, data: vec![0.0; rows * cols], backend }
}
pub fn from_vec(rows: usize, cols: usize, data: Vec<f32>) -> Result<Self, TruenoError> {
if data.len() != rows * cols {
return Err(TruenoError::InvalidInput(format!(
"Data length {} does not match matrix dimensions {}x{} (expected {})",
data.len(),
rows,
cols,
rows * cols
)));
}
let backend = Backend::select_best();
Ok(Matrix { rows, cols, data, backend })
}
pub fn from_vec_with_backend(
rows: usize,
cols: usize,
data: Vec<f32>,
backend: Backend,
) -> Self {
assert_eq!(
data.len(),
rows * cols,
"Data length {} does not match matrix dimensions {}x{}",
data.len(),
rows,
cols
);
Matrix { rows, cols, data, backend }
}
pub fn from_slice(rows: usize, cols: usize, data: &[f32]) -> Result<Self, TruenoError> {
Self::from_vec(rows, cols, data.to_vec())
}
pub fn zeros(rows: usize, cols: usize) -> Self {
Matrix::new(rows, cols)
}
pub(crate) fn zeros_with_backend(rows: usize, cols: usize, backend: Backend) -> Self {
Matrix { rows, cols, data: vec![0.0; rows * cols], backend }
}
pub fn identity(n: usize) -> Self {
let mut data = vec![0.0; n * n];
for i in 0..n {
data[i * n + i] = 1.0;
}
let backend = Backend::select_best();
Matrix { rows: n, cols: n, data, backend }
}
pub fn rows(&self) -> usize {
self.rows
}
pub fn cols(&self) -> usize {
self.cols
}
pub fn shape(&self) -> (usize, usize) {
(self.rows, self.cols)
}
pub fn get(&self, row: usize, col: usize) -> Option<&f32> {
if row >= self.rows || col >= self.cols {
None
} else {
self.data.get(row * self.cols + col)
}
}
pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut f32> {
if row >= self.rows || col >= self.cols {
None
} else {
let idx = row * self.cols + col;
self.data.get_mut(idx)
}
}
pub fn as_slice(&self) -> &[f32] {
&self.data
}
pub fn backend(&self) -> Backend {
self.backend
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_creates_zero_matrix() {
let m = Matrix::new(3, 4);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 4);
assert!(m.as_slice().iter().all(|&x| x == 0.0));
}
#[test]
fn test_from_vec_success() {
let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
assert_eq!(m.get(0, 0), Some(&1.0));
assert_eq!(m.get(1, 1), Some(&4.0));
}
#[test]
fn test_from_vec_dimension_mismatch() {
let result = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0]);
assert!(result.is_err());
}
#[test]
fn test_identity() {
let m = Matrix::identity(3);
assert_eq!(m.get(0, 0), Some(&1.0));
assert_eq!(m.get(1, 1), Some(&1.0));
assert_eq!(m.get(2, 2), Some(&1.0));
assert_eq!(m.get(0, 1), Some(&0.0));
}
#[test]
fn test_index_operator() {
let m = Matrix::from_vec(2, 2, vec![1.0, 2.0, 3.0, 4.0]).unwrap();
assert_eq!(m[(0, 0)], 1.0);
assert_eq!(m[(1, 1)], 4.0);
}
#[test]
fn test_get_out_of_bounds() {
let m = Matrix::new(2, 2);
assert_eq!(m.get(2, 0), None);
assert_eq!(m.get(0, 2), None);
}
#[test]
fn test_get_mut() {
let mut m = Matrix::new(2, 2);
if let Some(val) = m.get_mut(1, 1) {
*val = 42.0;
}
assert_eq!(m.get(1, 1), Some(&42.0));
}
}