#![allow(unsafe_code)]
use std::{
marker::PhantomData,
ops::{Index, IndexMut},
ptr, slice,
};
pub struct MatrixMut<'a, T> {
data: *mut T,
rows: usize,
cols: usize,
row_stride: usize,
_lifetime: PhantomData<&'a mut T>,
}
unsafe impl<T: Send> Send for MatrixMut<'_, T> {}
unsafe impl<T: Sync> Sync for MatrixMut<'_, T> {}
impl<'a, T> MatrixMut<'a, T> {
pub fn from_mut_slice(slice: &'a mut [T], rows: usize, cols: usize) -> Self {
assert_eq!(slice.len(), rows * cols);
Self {
data: slice.as_mut_ptr(),
rows,
cols,
row_stride: cols,
_lifetime: PhantomData,
}
}
pub const fn rows(&self) -> usize {
self.rows
}
pub const fn cols(&self) -> usize {
self.cols
}
pub const fn is_square(&self) -> bool {
self.rows == self.cols
}
#[allow(dead_code)]
pub fn row(&mut self, row: usize) -> &mut [T] {
assert!(row < self.rows);
unsafe { slice::from_raw_parts_mut(self.data.add(row * self.row_stride), self.cols) }
}
pub fn split_vertical(self, row: usize) -> (Self, Self) {
assert!(row <= self.rows);
(
Self {
data: self.data,
rows: row,
cols: self.cols,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
Self {
data: unsafe { self.data.add(row * self.row_stride) },
rows: self.rows - row,
cols: self.cols,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
)
}
pub fn split_horizontal(self, col: usize) -> (Self, Self) {
assert!(col <= self.cols);
(
Self {
data: self.data,
rows: self.rows,
cols: col,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
Self {
data: unsafe { self.data.add(col) },
rows: self.rows,
cols: self.cols - col,
row_stride: self.row_stride,
_lifetime: PhantomData,
},
)
}
pub fn split_quadrants(self, row: usize, col: usize) -> (Self, Self, Self, Self) {
let (u, l) = self.split_vertical(row); let (a, b) = u.split_horizontal(col);
let (c, d) = l.split_horizontal(col);
(a, b, c, d)
}
pub unsafe fn swap(&mut self, a: (usize, usize), b: (usize, usize)) {
if a != b {
unsafe {
let a = self.ptr_at_mut(a.0, a.1);
let b = self.ptr_at_mut(b.0, b.1);
ptr::swap_nonoverlapping(a, b, 1);
}
}
}
pub(crate) const unsafe fn ptr_at(&self, row: usize, col: usize) -> *const T {
self.data.add(row * self.row_stride + col)
}
pub(crate) const unsafe fn ptr_at_mut(&mut self, row: usize, col: usize) -> *mut T {
self.data.add(row * self.row_stride + col)
}
}
impl<T> Index<(usize, usize)> for MatrixMut<'_, T> {
type Output = T;
fn index(&self, (row, col): (usize, usize)) -> &T {
assert!(row < self.rows);
assert!(col < self.cols);
unsafe { &*self.ptr_at(row, col) }
}
}
impl<T> IndexMut<(usize, usize)> for MatrixMut<'_, T> {
fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut T {
assert!(row < self.rows);
assert!(col < self.cols);
unsafe { &mut *self.ptr_at_mut(row, col) }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matrix_creation() {
let mut data = vec![1, 2, 3, 4, 5, 6];
let matrix = MatrixMut::from_mut_slice(&mut data, 2, 3);
assert_eq!(matrix.rows(), 2);
assert_eq!(matrix.cols(), 3);
assert!(!matrix.is_square());
assert_eq!(matrix[(0, 0)], 1);
assert_eq!(matrix[(1, 0)], 4);
assert_eq!(matrix[(0, 1)], 2);
assert_eq!(matrix[(1, 1)], 5);
assert_eq!(matrix[(0, 2)], 3);
assert_eq!(matrix[(1, 2)], 6);
}
#[test]
fn test_row_access() {
let mut data = vec![1, 2, 3, 4, 5, 6];
let mut matrix = MatrixMut::from_mut_slice(&mut data, 2, 3);
assert_eq!(matrix.row(0), &[1, 2, 3]);
assert_eq!(matrix.row(1), &[4, 5, 6]);
}
#[test]
fn test_split_vertical() {
let mut data = vec![1, 2, 3, 4, 5, 6];
let matrix = MatrixMut::from_mut_slice(&mut data, 2, 3);
let (top, bottom) = matrix.split_vertical(1);
assert_eq!(top.rows(), 1);
assert_eq!(top[(0, 0)], 1);
assert_eq!(top[(0, 1)], 2);
assert_eq!(top[(0, 2)], 3);
assert_eq!(bottom.rows(), 1);
assert_eq!(bottom[(0, 0)], 4);
assert_eq!(bottom[(0, 1)], 5);
assert_eq!(bottom[(0, 2)], 6);
}
#[test]
fn test_split_horizontal() {
let mut data = vec![1, 2, 3, 4, 5, 6];
let matrix = MatrixMut::from_mut_slice(&mut data, 2, 3);
let (left, right) = matrix.split_horizontal(1);
assert_eq!(left.cols(), 1);
assert_eq!(left[(0, 0)], 1);
assert_eq!(left[(1, 0)], 4);
assert_eq!(right.cols(), 2);
assert_eq!(right[(0, 0)], 2);
assert_eq!(right[(0, 1)], 3);
assert_eq!(right[(1, 0)], 5);
assert_eq!(right[(1, 1)], 6);
}
#[test]
fn test_element_access() {
let mut data = vec![1, 2, 3, 4, 5, 6];
let matrix = MatrixMut::from_mut_slice(&mut data, 2, 3);
assert_eq!(matrix[(0, 1)], 2);
}
#[test]
fn test_swap() {
let mut data = vec![1, 2, 3, 4, 5, 6];
let mut matrix = MatrixMut::from_mut_slice(&mut data, 2, 3);
unsafe {
matrix.swap((0, 0), (1, 1));
}
assert_eq!(matrix[(0, 0)], 5);
assert_eq!(matrix[(1, 1)], 1);
}
#[test]
fn test_split_quadrants_even() {
let mut data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let matrix = MatrixMut::from_mut_slice(&mut data, 4, 4);
let (a, b, c, d) = matrix.split_quadrants(2, 2);
assert_eq!(a.rows(), 2);
assert_eq!(a.cols(), 2);
assert_eq!(b.rows(), 2);
assert_eq!(b.cols(), 2);
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 2);
assert_eq!(d.rows(), 2);
assert_eq!(d.cols(), 2);
assert_eq!(a[(0, 0)], 1);
assert_eq!(a[(0, 1)], 2);
assert_eq!(a[(1, 0)], 5);
assert_eq!(a[(1, 1)], 6);
assert_eq!(b[(0, 0)], 3);
assert_eq!(b[(0, 1)], 4);
assert_eq!(b[(1, 0)], 7);
assert_eq!(b[(1, 1)], 8);
assert_eq!(c[(0, 0)], 9);
assert_eq!(c[(0, 1)], 10);
assert_eq!(c[(1, 0)], 13);
assert_eq!(c[(1, 1)], 14);
assert_eq!(d[(0, 0)], 11);
assert_eq!(d[(0, 1)], 12);
assert_eq!(d[(1, 0)], 15);
assert_eq!(d[(1, 1)], 16);
}
#[test]
fn test_split_quadrants_odd_rows() {
let mut data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let matrix = MatrixMut::from_mut_slice(&mut data, 3, 3);
let (a, b, c, d) = matrix.split_quadrants(1, 1);
assert_eq!(a.rows(), 1);
assert_eq!(a.cols(), 1);
assert_eq!(b.rows(), 1);
assert_eq!(b.cols(), 2);
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 1);
assert_eq!(d.rows(), 2);
assert_eq!(d.cols(), 2);
assert_eq!(a[(0, 0)], 1);
assert_eq!(b[(0, 0)], 2);
assert_eq!(b[(0, 1)], 3);
assert_eq!(c[(0, 0)], 4);
assert_eq!(c[(1, 0)], 7);
assert_eq!(d[(0, 0)], 5);
assert_eq!(d[(0, 1)], 6);
assert_eq!(d[(1, 0)], 8);
assert_eq!(d[(1, 1)], 9);
}
#[test]
fn test_split_quadrants_odd_cols() {
let mut data = vec![1, 2, 3, 4, 5, 6];
let matrix = MatrixMut::from_mut_slice(&mut data, 3, 2);
let (a, b, c, d) = matrix.split_quadrants(1, 1);
assert_eq!(a.rows(), 1);
assert_eq!(a.cols(), 1);
assert_eq!(b.rows(), 1);
assert_eq!(b.cols(), 1);
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 1);
assert_eq!(d.rows(), 2);
assert_eq!(d.cols(), 1);
assert_eq!(a[(0, 0)], 1);
assert_eq!(b[(0, 0)], 2);
assert_eq!(c[(0, 0)], 3);
assert_eq!(c[(1, 0)], 5);
assert_eq!(d[(0, 0)], 4);
assert_eq!(d[(1, 0)], 6);
}
}