matamorph 0.2.0

Seamless conversions between Rust’s major matrix libraries: ndarray, faer, nalgebra, and mdarray.
Documentation
#[allow(unused_imports)]
use crate::compute_len;

#[cfg(feature = "with-ndarray")]
use ndarray::ShapeBuilder;

#[allow(dead_code)]
pub struct MataRef<T> {
    pub data: *const T,
    pub shape: (usize, usize),
    pub strides: (isize, isize),
    pub trans_strides: bool,
    pub trans_dims: bool,
}

pub trait FromMataRef<T>: Sized {
    fn from_mata_ref(mat: MataRef<T>) -> Self;
}

pub trait IntoMataRef<T> {
    fn into_mata_ref(self) -> MataRef<T>;
}

// MATA <-> NDARRAY

#[cfg(feature = "with-ndarray")]
impl<'a, T> FromMataRef<T> for ndarray::ArrayView2<'a, T> {
    fn from_mata_ref(mat: MataRef<T>) -> ndarray::ArrayView2<'a, T> {
        let (rows, cols) = mat.shape;
        let len = compute_len(mat.shape, mat.strides);
        let strides = (mat.strides.0 as usize, mat.strides.1 as usize);

        unsafe {
            let slice = std::slice::from_raw_parts(mat.data, len);

            ndarray::ArrayView2::from_shape((rows, cols).strides(strides), slice)
                .expect("MataRef: invalid shape for contiguous data")
        }
    }
}

#[cfg(feature = "with-ndarray")]
impl<T> IntoMataRef<T> for ndarray::ArrayView2<'_, T> {
    fn into_mata_ref(self) -> MataRef<T> {
        let (rows, cols) = (self.shape()[0], self.shape()[1]);
        let strides = self.strides();
        let trans_strides = strides[0] == 1;
        let trans_dims = strides[0] != 1;
        MataRef {
            data: self.as_ptr(),
            shape: (rows, cols),
            strides: (strides[0], strides[1]),
            trans_strides,
            trans_dims,
        }
    }
}

// MATA <-> MDARRAY

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRef<T>
    for mdarray::View<'a, T, (D0, D1), mdarray::Strided>
{
    fn from_mata_ref(mat: MataRef<T>) -> mdarray::View<'a, T, (D0, D1), mdarray::Strided> {
        let (rows, cols) = mat.shape;

        let map = mdarray::StridedMapping::new(
            mdarray::Shape::from_dims(&[rows, cols]),
            &[mat.strides.0, mat.strides.1],
        );

        unsafe { mdarray::View::<T, (D0, D1), mdarray::Strided>::new_unchecked(mat.data, map) }
    }
}

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRef<T>
    for mdarray::View<'a, T, (D0, D1), mdarray::Strided>
{
    fn into_mata_ref(self) -> MataRef<T> {
        let (rows, cols) = *self.shape();

        MataRef {
            data: self.as_ptr(),
            shape: (rows.size(), cols.size()),
            strides: (self.stride(0), self.stride(1)),
            trans_strides: false,
            trans_dims: false,
        }
    }
}

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRef<T>
    for mdarray::View<'a, T, (D0, D1), mdarray::Dense>
{
    fn from_mata_ref(mat: MataRef<T>) -> mdarray::View<'a, T, (D0, D1), mdarray::Dense> {
        let (rows, cols) = mat.shape;

        let map = mdarray::DenseMapping::new(mdarray::Shape::from_dims(&[rows, cols]));

        unsafe { mdarray::View::<T, (D0, D1), mdarray::Dense>::new_unchecked(mat.data, map) }
    }
}

#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRef<T>
    for mdarray::View<'a, T, (D0, D1), mdarray::Dense>
{
    fn into_mata_ref(self) -> MataRef<T> {
        let (rows, cols) = *self.shape();

        MataRef {
            data: self.as_ptr(),
            shape: (rows.size(), cols.size()),
            strides: (self.stride(0), self.stride(1)),
            trans_strides: false,
            trans_dims: false,
        }
    }
}

// MATA <-> FAER

#[cfg(feature = "with-faer")]
impl<'a, T> FromMataRef<T> for faer::MatRef<'a, T> {
    fn from_mata_ref(mat: MataRef<T>) -> faer::MatRef<'a, T> {
        let (rows, cols) = mat.shape;

        unsafe {
            faer::mat::MatRef::from_raw_parts(mat.data, rows, cols, mat.strides.0, mat.strides.1)
        }
    }
}

#[cfg(feature = "with-faer")]
impl<'a, T> IntoMataRef<T> for faer::MatRef<'a, T> {
    fn into_mata_ref(self) -> MataRef<T> {
        let (rows, cols) = self.shape();

        MataRef {
            data: self.as_ptr(),
            shape: (rows, cols),
            strides: (self.row_stride(), self.col_stride()),
            trans_strides: true,
            trans_dims: false,
        }
    }
}

// MATA <-> NALGEBRA

#[cfg(feature = "with-nalgebra")]
impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> FromMataRef<T>
    for nalgebra::DMatrixView<'a, T>
{
    fn from_mata_ref(mat: MataRef<T>) -> nalgebra::DMatrixView<'a, T> {
        if mat.shape.1 != mat.shape.0 {
            panic!("Rectangular matrices conversions are not handled with nalgebra");
        }

        let (na_rows, na_cols): (_, _) = if mat.trans_dims {
            (
                nalgebra::Dim::from_usize(mat.shape.1),
                nalgebra::Dim::from_usize(mat.shape.0),
            )
        } else {
            (
                nalgebra::Dim::from_usize(mat.shape.0),
                nalgebra::Dim::from_usize(mat.shape.1),
            )
        };

        let (na_stride0, na_stride1): (_, _) = if mat.trans_strides && mat.strides.0 == 1 {
            (
                nalgebra::Dim::from_usize(mat.strides.0 as usize),
                nalgebra::Dim::from_usize(mat.strides.1 as usize),
            )
        } else {
            (
                nalgebra::Dim::from_usize(mat.strides.1 as usize),
                nalgebra::Dim::from_usize(mat.strides.0 as usize),
            )
        };

        let len = compute_len(mat.shape, mat.strides);

        let s = unsafe { core::slice::from_raw_parts(mat.data, len) };

        nalgebra::base::MatrixView::from_slice_with_strides_generic(
            s, na_rows, na_cols, na_stride0, na_stride1,
        )
    }
}

#[cfg(feature = "with-nalgebra")]
impl<'a, T> IntoMataRef<T> for nalgebra::DMatrixView<'a, T> {
    fn into_mata_ref(self) -> MataRef<T> {
        let (stride0, stride1) = self.strides();

        MataRef {
            data: self.as_ptr(),
            shape: self.shape(),
            strides: (stride0 as isize, stride1 as isize),
            trans_strides: false,
            trans_dims: false,
        }
    }
}

pub trait MataConvertRef<'a, T> {
    #[cfg(feature = "with-mdarray")]
    type MdarrayLayout: mdarray::Layout;
    #[cfg(feature = "with-ndarray")]
    fn to_ndarray(self) -> ndarray::ArrayView2<'a, T>;
    #[cfg(feature = "with-mdarray")]
    fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
        self,
    ) -> mdarray::View<'a, T, (D0, D1), Self::MdarrayLayout>;
    #[cfg(feature = "with-faer")]
    fn to_faer(self) -> faer::MatRef<'a, T>;
    #[cfg(feature = "with-nalgebra")]
    fn to_nalgebra(self) -> nalgebra::DMatrixView<'a, T>;
}

macro_rules! impl_ref_conversions {
    ($source_type:ty, $feature:literal, $mata_type:ty) => {
        #[cfg(feature = $feature)]
        impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> MataConvertRef<'a, T>
            for $source_type
        {
            type MdarrayLayout = mdarray::Strided;

            fn to_ndarray(self) -> ndarray::ArrayView2<'a, T> {
                let mata: $mata_type = self.into_mata_ref();
                FromMataRef::from_mata_ref(mata)
            }

            fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
                self,
            ) -> mdarray::View<'a, T, (D0, D1), mdarray::Strided> {
                let mata: $mata_type = self.into_mata_ref();
                FromMataRef::from_mata_ref(mata)
            }

            fn to_faer(self) -> faer::MatRef<'a, T> {
                let mata: $mata_type = self.into_mata_ref();
                FromMataRef::from_mata_ref(mata)
            }

            fn to_nalgebra(self) -> nalgebra::DMatrixView<'a, T> {
                let mata: $mata_type = self.into_mata_ref();
                FromMataRef::from_mata_ref(mata)
            }
        }
    };
}

#[cfg(feature = "with-mdarray")]
impl<
    'a,
    T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static,
    D0: mdarray::Dim,
    D1: mdarray::Dim,
    L: mdarray::Layout,
> MataConvertRef<'a, T> for mdarray::View<'a, T, (D0, D1), L>
where
    Self: IntoMataRef<T>,
{
    type MdarrayLayout = mdarray::Strided;

    fn to_ndarray(self) -> ndarray::ArrayView2<'a, T> {
        let mata: MataRef<T> = self.into_mata_ref();
        FromMataRef::from_mata_ref(mata)
    }

    fn to_mdarray<D0out: mdarray::Dim, D1out: mdarray::Dim>(
        self,
    ) -> mdarray::View<'a, T, (D0out, D1out), mdarray::Strided> {
        let mata: MataRef<T> = self.into_mata_ref();
        FromMataRef::from_mata_ref(mata)
    }

    fn to_faer(self) -> faer::MatRef<'a, T> {
        let mata: MataRef<T> = self.into_mata_ref();
        FromMataRef::from_mata_ref(mata)
    }

    fn to_nalgebra(self) -> nalgebra::DMatrixView<'a, T> {
        let mata: MataRef<T> = self.into_mata_ref();
        FromMataRef::from_mata_ref(mata)
    }
}

impl_ref_conversions!(ndarray::ArrayView2<'a, T>, "with-ndarray", MataRef<T>);
impl_ref_conversions!(faer::MatRef<'a, T>, "with-faer", MataRef<T>);
impl_ref_conversions!(nalgebra::DMatrixView<'a, T>, "with-nalgebra", MataRef<T>);