#[allow(unused_imports)]
use crate::compute_len;
#[cfg(feature = "with-ndarray")]
use ndarray::ShapeBuilder;
#[allow(dead_code)]
pub struct MataRef<T> {
pub data: *const T,
pub shape: (usize, usize),
pub strides: (isize, isize),
pub trans_strides: bool,
pub trans_dims: bool,
}
pub trait FromMataRef<T>: Sized {
fn from_mata_ref(mat: MataRef<T>) -> Self;
}
pub trait IntoMataRef<T> {
fn into_mata_ref(self) -> MataRef<T>;
}
#[cfg(feature = "with-ndarray")]
impl<'a, T> FromMataRef<T> for ndarray::ArrayView2<'a, T> {
fn from_mata_ref(mat: MataRef<T>) -> ndarray::ArrayView2<'a, T> {
let (rows, cols) = mat.shape;
let len = compute_len(mat.shape, mat.strides);
let strides = (mat.strides.0 as usize, mat.strides.1 as usize);
unsafe {
let slice = std::slice::from_raw_parts(mat.data, len);
ndarray::ArrayView2::from_shape((rows, cols).strides(strides), slice)
.expect("MataRef: invalid shape for contiguous data")
}
}
}
#[cfg(feature = "with-ndarray")]
impl<T> IntoMataRef<T> for ndarray::ArrayView2<'_, T> {
fn into_mata_ref(self) -> MataRef<T> {
let (rows, cols) = (self.shape()[0], self.shape()[1]);
let strides = self.strides();
let trans_strides = strides[0] == 1;
let trans_dims = strides[0] != 1;
MataRef {
data: self.as_ptr(),
shape: (rows, cols),
strides: (strides[0], strides[1]),
trans_strides,
trans_dims,
}
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRef<T>
for mdarray::View<'a, T, (D0, D1), mdarray::Strided>
{
fn from_mata_ref(mat: MataRef<T>) -> mdarray::View<'a, T, (D0, D1), mdarray::Strided> {
let (rows, cols) = mat.shape;
let map = mdarray::StridedMapping::new(
mdarray::Shape::from_dims(&[rows, cols]),
&[mat.strides.0, mat.strides.1],
);
unsafe { mdarray::View::<T, (D0, D1), mdarray::Strided>::new_unchecked(mat.data, map) }
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRef<T>
for mdarray::View<'a, T, (D0, D1), mdarray::Strided>
{
fn into_mata_ref(self) -> MataRef<T> {
let (rows, cols) = *self.shape();
MataRef {
data: self.as_ptr(),
shape: (rows.size(), cols.size()),
strides: (self.stride(0), self.stride(1)),
trans_strides: false,
trans_dims: false,
}
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> FromMataRef<T>
for mdarray::View<'a, T, (D0, D1), mdarray::Dense>
{
fn from_mata_ref(mat: MataRef<T>) -> mdarray::View<'a, T, (D0, D1), mdarray::Dense> {
let (rows, cols) = mat.shape;
let map = mdarray::DenseMapping::new(mdarray::Shape::from_dims(&[rows, cols]));
unsafe { mdarray::View::<T, (D0, D1), mdarray::Dense>::new_unchecked(mat.data, map) }
}
}
#[cfg(feature = "with-mdarray")]
impl<'a, T, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataRef<T>
for mdarray::View<'a, T, (D0, D1), mdarray::Dense>
{
fn into_mata_ref(self) -> MataRef<T> {
let (rows, cols) = *self.shape();
MataRef {
data: self.as_ptr(),
shape: (rows.size(), cols.size()),
strides: (self.stride(0), self.stride(1)),
trans_strides: false,
trans_dims: false,
}
}
}
#[cfg(feature = "with-faer")]
impl<'a, T> FromMataRef<T> for faer::MatRef<'a, T> {
fn from_mata_ref(mat: MataRef<T>) -> faer::MatRef<'a, T> {
let (rows, cols) = mat.shape;
unsafe {
faer::mat::MatRef::from_raw_parts(mat.data, rows, cols, mat.strides.0, mat.strides.1)
}
}
}
#[cfg(feature = "with-faer")]
impl<'a, T> IntoMataRef<T> for faer::MatRef<'a, T> {
fn into_mata_ref(self) -> MataRef<T> {
let (rows, cols) = self.shape();
MataRef {
data: self.as_ptr(),
shape: (rows, cols),
strides: (self.row_stride(), self.col_stride()),
trans_strides: true,
trans_dims: false,
}
}
}
#[cfg(feature = "with-nalgebra")]
impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> FromMataRef<T>
for nalgebra::DMatrixView<'a, T>
{
fn from_mata_ref(mat: MataRef<T>) -> nalgebra::DMatrixView<'a, T> {
if mat.shape.1 != mat.shape.0 {
panic!("Rectangular matrices conversions are not handled with nalgebra");
}
let (na_rows, na_cols): (_, _) = if mat.trans_dims {
(
nalgebra::Dim::from_usize(mat.shape.1),
nalgebra::Dim::from_usize(mat.shape.0),
)
} else {
(
nalgebra::Dim::from_usize(mat.shape.0),
nalgebra::Dim::from_usize(mat.shape.1),
)
};
let (na_stride0, na_stride1): (_, _) = if mat.trans_strides && mat.strides.0 == 1 {
(
nalgebra::Dim::from_usize(mat.strides.0 as usize),
nalgebra::Dim::from_usize(mat.strides.1 as usize),
)
} else {
(
nalgebra::Dim::from_usize(mat.strides.1 as usize),
nalgebra::Dim::from_usize(mat.strides.0 as usize),
)
};
let len = compute_len(mat.shape, mat.strides);
let s = unsafe { core::slice::from_raw_parts(mat.data, len) };
nalgebra::base::MatrixView::from_slice_with_strides_generic(
s, na_rows, na_cols, na_stride0, na_stride1,
)
}
}
#[cfg(feature = "with-nalgebra")]
impl<'a, T> IntoMataRef<T> for nalgebra::DMatrixView<'a, T> {
fn into_mata_ref(self) -> MataRef<T> {
let (stride0, stride1) = self.strides();
MataRef {
data: self.as_ptr(),
shape: self.shape(),
strides: (stride0 as isize, stride1 as isize),
trans_strides: false,
trans_dims: false,
}
}
}
pub trait MataConvertRef<'a, T> {
#[cfg(feature = "with-mdarray")]
type MdarrayLayout: mdarray::Layout;
#[cfg(feature = "with-ndarray")]
fn to_ndarray(self) -> ndarray::ArrayView2<'a, T>;
#[cfg(feature = "with-mdarray")]
fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
self,
) -> mdarray::View<'a, T, (D0, D1), Self::MdarrayLayout>;
#[cfg(feature = "with-faer")]
fn to_faer(self) -> faer::MatRef<'a, T>;
#[cfg(feature = "with-nalgebra")]
fn to_nalgebra(self) -> nalgebra::DMatrixView<'a, T>;
}
macro_rules! impl_ref_conversions {
($source_type:ty, $feature:literal, $mata_type:ty) => {
#[cfg(feature = $feature)]
impl<'a, T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> MataConvertRef<'a, T>
for $source_type
{
type MdarrayLayout = mdarray::Strided;
fn to_ndarray(self) -> ndarray::ArrayView2<'a, T> {
let mata: $mata_type = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
self,
) -> mdarray::View<'a, T, (D0, D1), mdarray::Strided> {
let mata: $mata_type = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
fn to_faer(self) -> faer::MatRef<'a, T> {
let mata: $mata_type = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
fn to_nalgebra(self) -> nalgebra::DMatrixView<'a, T> {
let mata: $mata_type = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
}
};
}
#[cfg(feature = "with-mdarray")]
impl<
'a,
T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static,
D0: mdarray::Dim,
D1: mdarray::Dim,
L: mdarray::Layout,
> MataConvertRef<'a, T> for mdarray::View<'a, T, (D0, D1), L>
where
Self: IntoMataRef<T>,
{
type MdarrayLayout = mdarray::Strided;
fn to_ndarray(self) -> ndarray::ArrayView2<'a, T> {
let mata: MataRef<T> = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
fn to_mdarray<D0out: mdarray::Dim, D1out: mdarray::Dim>(
self,
) -> mdarray::View<'a, T, (D0out, D1out), mdarray::Strided> {
let mata: MataRef<T> = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
fn to_faer(self) -> faer::MatRef<'a, T> {
let mata: MataRef<T> = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
fn to_nalgebra(self) -> nalgebra::DMatrixView<'a, T> {
let mata: MataRef<T> = self.into_mata_ref();
FromMataRef::from_mata_ref(mata)
}
}
impl_ref_conversions!(ndarray::ArrayView2<'a, T>, "with-ndarray", MataRef<T>);
impl_ref_conversions!(faer::MatRef<'a, T>, "with-faer", MataRef<T>);
impl_ref_conversions!(nalgebra::DMatrixView<'a, T>, "with-nalgebra", MataRef<T>);