#[allow(unused_imports)]
use crate::compute_len;
#[cfg(feature = "with-ndarray")]
use ndarray::ShapeBuilder;
#[allow(dead_code)]
pub struct MataRefMut<T> {
pub data: *mut T,
pub shape: (usize, usize),
pub strides: (isize, isize),
pub trans_strides: bool,
pub trans_dims: bool,
}
pub trait FromMataRefMut<T>: Sized {
fn from_mata_ref_mut(mat: MataRefMut<T>) -> Self;
}
pub trait IntoMataRefMut<T> {
fn into_mata_ref_mut(self) -> MataRefMut<T>;
}
#[cfg(feature = "with-ndarray")]
impl<'a, T> FromMataRefMut<T> for ndarray::ArrayViewMut2<'a, T> {
fn from_mata_ref_mut(mat: MataRefMut<T>) -> ndarray::ArrayViewMut2<'a, T> {
let (rows, cols) = mat.shape;
let len = compute_len(mat.shape, mat.strides);
let strides = (mat.strides.0 as usize, mat.strides.1 as usize);
unsafe {
let slice = std::slice::from_raw_parts_mut(mat.data, len);
ndarray::ArrayViewMut2::from_shape((rows, cols).strides(strides), slice)
.expect("MataRefMut: invalid shape for contiguous data")
}
}
}
#[cfg(feature = "with-ndarray")]
impl<T> IntoMataRefMut<T> for ndarray::ArrayViewMut2<'_, T> {
fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
let (rows, cols) = (self.shape()[0], self.shape()[1]);
let strides = (self.strides()[0], self.strides()[1]);
let trans_strides = strides.0 == 1;
let trans_dims = strides.0 != 1;
MataRefMut {
data: self.as_mut_ptr(),
shape: (rows, cols),
strides: (strides.0, strides.1),
trans_strides,
trans_dims,
}
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRefMut<T>
for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided>
{
fn from_mata_ref_mut(
mat: MataRefMut<T>,
) -> mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided> {
let (rows, cols) = mat.shape;
let map = mdarray::StridedMapping::new(
mdarray::Shape::from_dims(&[rows, cols]),
&[mat.strides.0, mat.strides.1],
);
unsafe { mdarray::ViewMut::<T, (D0, D1), mdarray::Strided>::new_unchecked(mat.data, map) }
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRefMut<T>
for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided>
{
fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
let (rows, cols) = *self.shape();
MataRefMut {
data: self.as_mut_ptr(),
shape: (rows.size(), cols.size()),
strides: (self.stride(0), self.stride(1)),
trans_strides: false,
trans_dims: false,
}
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRefMut<T>
for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Dense>
{
fn from_mata_ref_mut(mat: MataRefMut<T>) -> mdarray::ViewMut<'a, T, (D0, D1), mdarray::Dense> {
let (rows, cols) = mat.shape;
let map = mdarray::DenseMapping::new(mdarray::Shape::from_dims(&[rows, cols]));
unsafe { mdarray::ViewMut::<T, (D0, D1), mdarray::Dense>::new_unchecked(mat.data, map) }
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRefMut<T>
for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Dense>
{
fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
let (rows, cols) = *self.shape();
MataRefMut {
data: self.as_mut_ptr(),
shape: (rows.size(), cols.size()),
strides: (self.stride(0), self.stride(1)),
trans_strides: false,
trans_dims: false,
}
}
}
#[cfg(feature = "with-faer")]
impl<'a, T> FromMataRefMut<T> for faer::MatMut<'a, T> {
fn from_mata_ref_mut(mat: MataRefMut<T>) -> faer::MatMut<'a, T> {
let (rows, cols) = mat.shape;
unsafe {
faer::mat::MatMut::from_raw_parts_mut(
mat.data,
rows,
cols,
mat.strides.0,
mat.strides.1,
)
}
}
}
#[cfg(feature = "with-faer")]
impl<'a, T> IntoMataRefMut<T> for faer::MatMut<'a, T> {
fn into_mata_ref_mut(self) -> MataRefMut<T> {
let (rows, cols) = self.shape();
MataRefMut {
data: self.as_ptr_mut(),
shape: (rows, cols),
strides: (self.row_stride(), self.col_stride()),
trans_strides: true,
trans_dims: false,
}
}
}
#[cfg(feature = "with-nalgebra")]
impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> FromMataRefMut<T>
for nalgebra::DMatrixViewMut<'a, T>
{
fn from_mata_ref_mut(mat: MataRefMut<T>) -> nalgebra::DMatrixViewMut<'a, T> {
if mat.shape.1 != mat.shape.0 {
panic!("Rectangular matrices conversions are not handled with nalgebra");
}
let (na_rows, na_cols) = if mat.trans_dims {
(
nalgebra::Dim::from_usize(mat.shape.1),
nalgebra::Dim::from_usize(mat.shape.0),
)
} else {
(
nalgebra::Dim::from_usize(mat.shape.0),
nalgebra::Dim::from_usize(mat.shape.1),
)
};
let (na_stride0, na_stride1) = if mat.trans_strides && mat.strides.0 == 1 {
(
nalgebra::Dim::from_usize(mat.strides.0 as usize),
nalgebra::Dim::from_usize(mat.strides.1 as usize),
)
} else {
(
nalgebra::Dim::from_usize(mat.strides.1 as usize),
nalgebra::Dim::from_usize(mat.strides.0 as usize),
)
};
let len = compute_len(mat.shape, mat.strides);
let s = unsafe { core::slice::from_raw_parts_mut(mat.data, len) };
nalgebra::base::DMatrixViewMut::from_slice_with_strides_generic(
s, na_rows, na_cols, na_stride0, na_stride1,
)
}
}
#[cfg(feature = "with-nalgebra")]
impl<'a, T> IntoMataRefMut<T> for nalgebra::DMatrixViewMut<'a, T> {
fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
let (stride0, stride1) = self.strides();
MataRefMut {
data: self.as_mut_ptr(),
shape: self.shape(),
strides: (stride0 as isize, stride1 as isize),
trans_strides: false,
trans_dims: false,
}
}
}
pub trait MataConvertMut<'a, T> {
#[cfg(feature = "with-mdarray")]
type MdarrayLayout: mdarray::Layout;
#[cfg(feature = "with-ndarray")]
fn to_ndarray(self) -> ndarray::ArrayViewMut2<'a, T>;
#[cfg(feature = "with-mdarray")]
fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
self,
) -> mdarray::ViewMut<'a, T, (D0, D1), Self::MdarrayLayout>;
#[cfg(feature = "with-faer")]
fn to_faer(self) -> faer::MatMut<'a, T>;
#[cfg(feature = "with-nalgebra")]
fn to_nalgebra(self) -> nalgebra::DMatrixViewMut<'a, T>;
}
macro_rules! impl_mut_conversions {
($source_type:ty, $feature:literal, $mata_type:ty) => {
#[cfg(feature = $feature)]
impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> MataConvertMut<'a, T>
for $source_type
{
type MdarrayLayout = mdarray::Strided;
fn to_ndarray(self) -> ndarray::ArrayViewMut2<'a, T> {
let mata: $mata_type = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
self,
) -> mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided> {
let mata: $mata_type = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
fn to_faer(self) -> faer::MatMut<'a, T> {
let mata: $mata_type = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
fn to_nalgebra(self) -> nalgebra::DMatrixViewMut<'a, T> {
let mata: $mata_type = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
}
};
}
#[cfg(feature = "with-mdarray")]
impl<
'a,
T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static,
D0: mdarray::Dim,
D1: mdarray::Dim,
L: mdarray::Layout,
> MataConvertMut<'a, T> for mdarray::ViewMut<'a, T, (D0, D1), L>
where
Self: IntoMataRefMut<T>,
{
type MdarrayLayout = mdarray::Strided;
fn to_ndarray(self) -> ndarray::ArrayViewMut2<'a, T> {
let mata: MataRefMut<T> = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
fn to_mdarray<D0out: mdarray::Dim, D1out: mdarray::Dim>(
self,
) -> mdarray::ViewMut<'a, T, (D0out, D1out), mdarray::Strided> {
let mata: MataRefMut<T> = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
fn to_faer(self) -> faer::MatMut<'a, T> {
let mata: MataRefMut<T> = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
fn to_nalgebra(self) -> nalgebra::DMatrixViewMut<'a, T> {
let mata: MataRefMut<T> = self.into_mata_ref_mut();
FromMataRefMut::from_mata_ref_mut(mata)
}
}
impl_mut_conversions!(ndarray::ArrayViewMut2<'a, T>, "with-ndarray", MataRefMut<T>);
impl_mut_conversions!(faer::MatMut<'a, T>, "with-faer", MataRefMut<T>);
impl_mut_conversions!(
nalgebra::DMatrixViewMut<'a, T>,
"with-nalgebra",
MataRefMut<T>
);