use std::{
marker::PhantomData,
ops::{Index, IndexMut},
};
use crate::traits::Simd;
use super::{
conversion::{simd_container_flat_slice, simd_container_flat_slice_mut},
packed::PackedMxN,
};
#[doc(hidden)]
pub trait AccessStrategy {
fn flat_to_packed(x: usize, y: usize) -> (usize, usize);
}
#[derive(Clone, Debug)]
#[doc(hidden)]
pub struct Rows;
#[derive(Clone, Debug)]
#[doc(hidden)]
pub struct Columns;
impl AccessStrategy for Rows {
#[inline]
fn flat_to_packed(x: usize, y: usize) -> (usize, usize) { (x, y) }
}
impl AccessStrategy for Columns {
#[inline]
fn flat_to_packed(x: usize, y: usize) -> (usize, usize) { (y, x) }
}
#[derive(Clone, Debug)]
pub struct MatrixD<T, A>
where
T: Simd + Default + Clone,
A: AccessStrategy,
{
pub(crate) simd_rows: PackedMxN<T>,
phantom: PhantomData<A>,
}
impl<T, O> MatrixD<T, O>
where
T: Simd + Default + Clone,
O: AccessStrategy,
{
#[inline]
pub fn with_dimension(width: usize, height: usize) -> Self {
let (x, y) = O::flat_to_packed(width, height);
Self {
simd_rows: PackedMxN::with(T::default(), x, y),
phantom: PhantomData,
}
}
pub fn dimension(&self) -> (usize, usize) { O::flat_to_packed(self.simd_rows.rows, self.simd_rows.row_length) }
#[inline]
pub fn flat(&self) -> MatrixFlat<'_, T, O> {
MatrixFlat {
matrix: self,
phantom: PhantomData,
}
}
#[inline]
pub fn flat_mut(&mut self) -> MatrixFlatMut<'_, T, O> {
MatrixFlatMut {
matrix: self,
phantom: PhantomData,
}
}
}
impl<T> MatrixD<T, Rows>
where
T: Simd + Default + Clone,
{
#[inline]
pub fn row(&self, i: usize) -> &[T] {
let range = self.simd_rows.range_for_row(i);
&self.simd_rows.data[range]
}
#[inline]
pub fn row_iter(&self) -> Matrix2DIter<'_, T, Rows> { Matrix2DIter { matrix: self, index: 0 } }
#[inline]
pub fn row_mut(&mut self, i: usize) -> &mut [T] {
let range = self.simd_rows.range_for_row(i);
&mut self.simd_rows.data[range]
}
#[inline]
pub fn row_as_flat(&self, i: usize) -> &[T::Element] {
let row = self.row(i);
simd_container_flat_slice(row, self.simd_rows.row_length)
}
#[inline]
pub fn row_as_flat_mut(&mut self, i: usize) -> &mut [T::Element] {
let length = self.simd_rows.row_length;
let row = self.row_mut(i);
simd_container_flat_slice_mut(row, length)
}
}
impl<T> MatrixD<T, Columns>
where
T: Simd + Default + Clone,
{
#[inline]
pub fn column(&self, i: usize) -> &[T] {
let range = self.simd_rows.range_for_row(i);
&self.simd_rows.data[range]
}
#[inline]
pub fn column_iter(&self) -> Matrix2DIter<'_, T, Columns> { Matrix2DIter { matrix: self, index: 0 } }
#[inline]
pub fn column_mut(&mut self, i: usize) -> &mut [T] {
let range = self.simd_rows.range_for_row(i);
&mut self.simd_rows.data[range]
}
#[inline]
pub fn column_as_flat(&self, i: usize) -> &[T::Element] {
let column = self.column(i);
simd_container_flat_slice(column, self.simd_rows.row_length)
}
#[inline]
pub fn column_as_flat_mut(&mut self, i: usize) -> &mut [T::Element] {
let length = self.simd_rows.row_length;
let column = self.column_mut(i);
simd_container_flat_slice_mut(column, length)
}
}
pub struct MatrixFlat<'a, T: 'a, A: 'a>
where
T: Simd + Default + Clone,
A: AccessStrategy,
{
matrix: &'a MatrixD<T, A>,
phantom: PhantomData<A>,
}
pub struct MatrixFlatMut<'a, T: 'a, A: 'a>
where
T: Simd + Default + Clone,
A: AccessStrategy,
{
matrix: &'a mut MatrixD<T, A>,
phantom: PhantomData<A>,
}
impl<'a, T, A> Index<(usize, usize)> for MatrixFlat<'a, T, A>
where
T: Simd + Default + Clone,
A: AccessStrategy,
{
type Output = T::Element;
#[inline]
fn index(&self, index: (usize, usize)) -> &Self::Output {
let (row, x) = A::flat_to_packed(index.0, index.1);
let row_slice = self.matrix.simd_rows.row_as_flat(row);
&row_slice[x]
}
}
impl<'a, T, A> Index<(usize, usize)> for MatrixFlatMut<'a, T, A>
where
T: Simd + Default + Clone,
A: AccessStrategy,
{
type Output = T::Element;
#[inline]
fn index(&self, index: (usize, usize)) -> &Self::Output {
let (row, x) = A::flat_to_packed(index.0, index.1);
let row_slice = self.matrix.simd_rows.row_as_flat(row);
&row_slice[x]
}
}
impl<'a, T, A> IndexMut<(usize, usize)> for MatrixFlatMut<'a, T, A>
where
T: Simd + Default + Clone,
A: AccessStrategy,
{
#[inline]
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
let (row, x) = A::flat_to_packed(index.0, index.1);
let row_slice = self.matrix.simd_rows.row_as_flat_mut(row);
&mut row_slice[x]
}
}
#[derive(Clone, Debug)]
pub struct Matrix2DIter<'a, T: 'a, O: 'a>
where
T: Simd + Default + Clone,
O: AccessStrategy,
{
pub(crate) matrix: &'a MatrixD<T, O>,
pub(crate) index: usize,
}
impl<'a, T, O> Iterator for Matrix2DIter<'a, T, O>
where
T: Simd + Default + Clone,
O: AccessStrategy,
{
type Item = &'a [T];
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.matrix.simd_rows.rows {
None
} else {
let range = self.matrix.simd_rows.range_for_row(self.index);
self.index += 1;
Some(&self.matrix.simd_rows.data[range])
}
}
}
#[cfg(test)]
mod test {
use super::{Columns, MatrixD, Rows};
use crate::*;
#[test]
fn allocation_size() {
let m_1_1_r = MatrixD::<f32x4, Rows>::with_dimension(1, 1);
let m_1_1_c = MatrixD::<f32x4, Columns>::with_dimension(1, 1);
let m_5_5_r = MatrixD::<f32x4, Rows>::with_dimension(5, 5);
let m_5_5_c = MatrixD::<f32x4, Columns>::with_dimension(5, 5);
let m_1_4_r = MatrixD::<f32x4, Rows>::with_dimension(1, 4);
let m_4_1_c = MatrixD::<f32x4, Columns>::with_dimension(4, 1);
let m_4_1_r = MatrixD::<f32x4, Rows>::with_dimension(4, 1);
let m_1_4_c = MatrixD::<f32x4, Columns>::with_dimension(1, 4);
assert_eq!(m_1_1_r.simd_rows.data.len(), 1);
assert_eq!(m_1_1_c.simd_rows.data.len(), 1);
assert_eq!(m_5_5_r.simd_rows.data.len(), 10);
assert_eq!(m_5_5_c.simd_rows.data.len(), 10);
assert_eq!(m_1_4_r.simd_rows.data.len(), 1);
assert_eq!(m_4_1_c.simd_rows.data.len(), 1);
assert_eq!(m_4_1_r.simd_rows.data.len(), 4);
assert_eq!(m_1_4_c.simd_rows.data.len(), 4);
}
#[test]
fn access() {
let mut m_5_5_r = MatrixD::<f32x4, Rows>::with_dimension(5, 5);
let mut m_5_5_c = MatrixD::<f32x4, Columns>::with_dimension(5, 5);
assert_eq!(m_5_5_c.column(0).len(), 2);
assert_eq!(m_5_5_c.column_mut(0).len(), 2);
assert_eq!(m_5_5_c.column_as_flat(0).len(), 5);
assert_eq!(m_5_5_c.column_as_flat_mut(0).len(), 5);
assert_eq!(m_5_5_r.row(0).len(), 2);
assert_eq!(m_5_5_r.row_mut(0).len(), 2);
assert_eq!(m_5_5_r.row_as_flat(0).len(), 5);
assert_eq!(m_5_5_r.row_as_flat_mut(0).len(), 5);
let mut count = 0;
for _ in m_5_5_c.column_iter() {
count += 1
}
for _ in m_5_5_r.row_iter() {
count += 1
}
let r1 = m_5_5_r.row(3);
let r2 = m_5_5_r.row(4);
let mut sum = f32x4::splat(0_f32);
for (x, y) in r1.iter().zip(r2) {
sum += *x + *y;
}
assert_eq!(count, 10);
}
#[test]
fn flattened() {
let mut m_1_5_r = MatrixD::<f32x4, Rows>::with_dimension(1, 5);
let mut m_1_5_c = MatrixD::<f32x4, Columns>::with_dimension(1, 5);
let mut m_5_1_c = MatrixD::<f32x4, Columns>::with_dimension(5, 1);
let mut m_5_1_r = MatrixD::<f32x4, Rows>::with_dimension(5, 1);
let mut m_1_5_r_flat = m_1_5_r.flat_mut();
let mut m_1_5_c_flat = m_1_5_c.flat_mut();
let mut m_5_1_c_flat = m_5_1_c.flat_mut();
let mut m_5_1_r_flat = m_5_1_r.flat_mut();
m_1_5_r_flat[(0, 4)] = 1.0;
m_1_5_c_flat[(0, 4)] = 1.0;
m_5_1_r_flat[(4, 0)] = 1.0;
m_5_1_c_flat[(4, 0)] = 1.0;
}
}