matamorph 0.2.0

Seamless conversions between Rust’s major matrix libraries: ndarray, faer, nalgebra, and mdarray.
Documentation
#[allow(unused_imports)]
use crate::compute_len;

#[cfg(feature = "with-ndarray")]
use ndarray::ShapeBuilder;

#[allow(dead_code)]
pub struct MataRefMut<T> {
    pub data: *mut T,
    pub shape: (usize, usize),
    pub strides: (isize, isize),
    pub trans_strides: bool,
    pub trans_dims: bool,
}

pub trait FromMataRefMut<T>: Sized {
    fn from_mata_ref_mut(mat: MataRefMut<T>) -> Self;
}

pub trait IntoMataRefMut<T> {
    fn into_mata_ref_mut(self) -> MataRefMut<T>;
}

// MATA <-> NDARRAY

#[cfg(feature = "with-ndarray")]
impl<'a, T> FromMataRefMut<T> for ndarray::ArrayViewMut2<'a, T> {
    fn from_mata_ref_mut(mat: MataRefMut<T>) -> ndarray::ArrayViewMut2<'a, T> {
        let (rows, cols) = mat.shape;
        let len = compute_len(mat.shape, mat.strides);
        let strides = (mat.strides.0 as usize, mat.strides.1 as usize);

        unsafe {
            let slice = std::slice::from_raw_parts_mut(mat.data, len);

            ndarray::ArrayViewMut2::from_shape((rows, cols).strides(strides), slice)
                .expect("MataRefMut: invalid shape for contiguous data")
        }
    }
}

#[cfg(feature = "with-ndarray")]
impl<T> IntoMataRefMut<T> for ndarray::ArrayViewMut2<'_, T> {
    fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
        let (rows, cols) = (self.shape()[0], self.shape()[1]);
        let strides = (self.strides()[0], self.strides()[1]);
        let trans_strides = strides.0 == 1;
        let trans_dims = strides.0 != 1;
        MataRefMut {
            data: self.as_mut_ptr(),
            shape: (rows, cols),
            strides: (strides.0, strides.1),
            trans_strides,
            trans_dims,
        }
    }
}

// MATA <-> MDARRAY

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRefMut<T>
    for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided>
{
    fn from_mata_ref_mut(
        mat: MataRefMut<T>,
    ) -> mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided> {
        let (rows, cols) = mat.shape;

        let map = mdarray::StridedMapping::new(
            mdarray::Shape::from_dims(&[rows, cols]),
            &[mat.strides.0, mat.strides.1],
        );

        unsafe { mdarray::ViewMut::<T, (D0, D1), mdarray::Strided>::new_unchecked(mat.data, map) }
    }
}

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRefMut<T>
    for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided>
{
    fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
        let (rows, cols) = *self.shape();

        MataRefMut {
            data: self.as_mut_ptr(),
            shape: (rows.size(), cols.size()),
            strides: (self.stride(0), self.stride(1)),
            trans_strides: false,
            trans_dims: false,
        }
    }
}

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRefMut<T>
    for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Dense>
{
    fn from_mata_ref_mut(mat: MataRefMut<T>) -> mdarray::ViewMut<'a, T, (D0, D1), mdarray::Dense> {
        let (rows, cols) = mat.shape;

        let map = mdarray::DenseMapping::new(mdarray::Shape::from_dims(&[rows, cols]));

        unsafe { mdarray::ViewMut::<T, (D0, D1), mdarray::Dense>::new_unchecked(mat.data, map) }
    }
}

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRefMut<T>
    for mdarray::ViewMut<'a, T, (D0, D1), mdarray::Dense>
{
    fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
        let (rows, cols) = *self.shape();

        MataRefMut {
            data: self.as_mut_ptr(),
            shape: (rows.size(), cols.size()),
            strides: (self.stride(0), self.stride(1)),
            trans_strides: false,
            trans_dims: false,
        }
    }
}

// MATA <-> FAER

#[cfg(feature = "with-faer")]
impl<'a, T> FromMataRefMut<T> for faer::MatMut<'a, T> {
    fn from_mata_ref_mut(mat: MataRefMut<T>) -> faer::MatMut<'a, T> {
        let (rows, cols) = mat.shape;

        unsafe {
            faer::mat::MatMut::from_raw_parts_mut(
                mat.data,
                rows,
                cols,
                mat.strides.0,
                mat.strides.1,
            )
        }
    }
}

#[cfg(feature = "with-faer")]
impl<'a, T> IntoMataRefMut<T> for faer::MatMut<'a, T> {
    fn into_mata_ref_mut(self) -> MataRefMut<T> {
        let (rows, cols) = self.shape();

        MataRefMut {
            data: self.as_ptr_mut(),
            shape: (rows, cols),
            strides: (self.row_stride(), self.col_stride()),
            trans_strides: true,
            trans_dims: false,
        }
    }
}

// MATA <-> NALGEBRA

#[cfg(feature = "with-nalgebra")]
impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> FromMataRefMut<T>
    for nalgebra::DMatrixViewMut<'a, T>
{
    fn from_mata_ref_mut(mat: MataRefMut<T>) -> nalgebra::DMatrixViewMut<'a, T> {
        if mat.shape.1 != mat.shape.0 {
            panic!("Rectangular matrices conversions are not handled with nalgebra");
        }

        let (na_rows, na_cols) = if mat.trans_dims {
            (
                nalgebra::Dim::from_usize(mat.shape.1),
                nalgebra::Dim::from_usize(mat.shape.0),
            )
        } else {
            (
                nalgebra::Dim::from_usize(mat.shape.0),
                nalgebra::Dim::from_usize(mat.shape.1),
            )
        };

        let (na_stride0, na_stride1) = if mat.trans_strides && mat.strides.0 == 1 {
            (
                nalgebra::Dim::from_usize(mat.strides.0 as usize),
                nalgebra::Dim::from_usize(mat.strides.1 as usize),
            )
        } else {
            (
                nalgebra::Dim::from_usize(mat.strides.1 as usize),
                nalgebra::Dim::from_usize(mat.strides.0 as usize),
            )
        };

        let len = compute_len(mat.shape, mat.strides);

        let s = unsafe { core::slice::from_raw_parts_mut(mat.data, len) };

        nalgebra::base::DMatrixViewMut::from_slice_with_strides_generic(
            s, na_rows, na_cols, na_stride0, na_stride1,
        )
    }
}

#[cfg(feature = "with-nalgebra")]
impl<'a, T> IntoMataRefMut<T> for nalgebra::DMatrixViewMut<'a, T> {
    fn into_mata_ref_mut(mut self) -> MataRefMut<T> {
        let (stride0, stride1) = self.strides();

        MataRefMut {
            data: self.as_mut_ptr(),
            shape: self.shape(),
            strides: (stride0 as isize, stride1 as isize),
            trans_strides: false,
            trans_dims: false,
        }
    }
}

pub trait MataConvertMut<'a, T> {
    #[cfg(feature = "with-mdarray")]
    type MdarrayLayout: mdarray::Layout;
    #[cfg(feature = "with-ndarray")]
    fn to_ndarray(self) -> ndarray::ArrayViewMut2<'a, T>;
    #[cfg(feature = "with-mdarray")]
    fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
        self,
    ) -> mdarray::ViewMut<'a, T, (D0, D1), Self::MdarrayLayout>;
    #[cfg(feature = "with-faer")]
    fn to_faer(self) -> faer::MatMut<'a, T>;
    #[cfg(feature = "with-nalgebra")]
    fn to_nalgebra(self) -> nalgebra::DMatrixViewMut<'a, T>;
}

macro_rules! impl_mut_conversions {
    ($source_type:ty, $feature:literal, $mata_type:ty) => {
        #[cfg(feature = $feature)]
        impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> MataConvertMut<'a, T>
            for $source_type
        {
            type MdarrayLayout = mdarray::Strided;

            fn to_ndarray(self) -> ndarray::ArrayViewMut2<'a, T> {
                let mata: $mata_type = self.into_mata_ref_mut();
                FromMataRefMut::from_mata_ref_mut(mata)
            }

            fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
                self,
            ) -> mdarray::ViewMut<'a, T, (D0, D1), mdarray::Strided> {
                let mata: $mata_type = self.into_mata_ref_mut();
                FromMataRefMut::from_mata_ref_mut(mata)
            }

            fn to_faer(self) -> faer::MatMut<'a, T> {
                let mata: $mata_type = self.into_mata_ref_mut();
                FromMataRefMut::from_mata_ref_mut(mata)
            }

            fn to_nalgebra(self) -> nalgebra::DMatrixViewMut<'a, T> {
                let mata: $mata_type = self.into_mata_ref_mut();
                FromMataRefMut::from_mata_ref_mut(mata)
            }
        }
    };
}

#[cfg(feature = "with-mdarray")]
impl<
    'a,
    T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static,
    D0: mdarray::Dim,
    D1: mdarray::Dim,
    L: mdarray::Layout,
> MataConvertMut<'a, T> for mdarray::ViewMut<'a, T, (D0, D1), L>
where
    Self: IntoMataRefMut<T>,
{
    type MdarrayLayout = mdarray::Strided;

    fn to_ndarray(self) -> ndarray::ArrayViewMut2<'a, T> {
        let mata: MataRefMut<T> = self.into_mata_ref_mut();
        FromMataRefMut::from_mata_ref_mut(mata)
    }

    fn to_mdarray<D0out: mdarray::Dim, D1out: mdarray::Dim>(
        self,
    ) -> mdarray::ViewMut<'a, T, (D0out, D1out), mdarray::Strided> {
        let mata: MataRefMut<T> = self.into_mata_ref_mut();
        FromMataRefMut::from_mata_ref_mut(mata)
    }

    fn to_faer(self) -> faer::MatMut<'a, T> {
        let mata: MataRefMut<T> = self.into_mata_ref_mut();
        FromMataRefMut::from_mata_ref_mut(mata)
    }

    fn to_nalgebra(self) -> nalgebra::DMatrixViewMut<'a, T> {
        let mata: MataRefMut<T> = self.into_mata_ref_mut();
        FromMataRefMut::from_mata_ref_mut(mata)
    }
}

impl_mut_conversions!(ndarray::ArrayViewMut2<'a, T>, "with-ndarray", MataRefMut<T>);
impl_mut_conversions!(faer::MatMut<'a, T>, "with-faer", MataRefMut<T>);
impl_mut_conversions!(
    nalgebra::DMatrixViewMut<'a, T>,
    "with-nalgebra",
    MataRefMut<T>
);