#[cfg(feature = "with-ndarray")]
use ndarray::ShapeBuilder;
#[allow(dead_code)]
pub struct MataOwn<T> {
pub data: *mut T,
pub shape: (usize, usize),
pub strides: (isize, isize),
pub row_major: bool,
}
pub trait FromMataOwn<T>: Sized {
fn from_mata_own(mat: MataOwn<T>) -> Self;
}
pub trait IntoMataOwn<T> {
fn into_mata_own(&self) -> MataOwn<T>;
}
#[cfg(feature = "with-ndarray")]
impl<T: Clone> FromMataOwn<T> for ndarray::Array2<T> {
fn from_mata_own(mat: MataOwn<T>) -> ndarray::Array2<T> {
let (rows, cols) = mat.shape;
let len = rows * cols;
let strides = (mat.strides.0 as usize, mat.strides.1 as usize);
unsafe {
let v = Vec::from_raw_parts(mat.data, len, len);
ndarray::Array2::from_shape_vec((rows, cols).strides(strides), v)
.expect("MataOwn: invalid shape for contiguous data")
}
}
}
#[cfg(feature = "with-ndarray")]
impl<T: Clone> IntoMataOwn<T> for ndarray::Array2<T> {
fn into_mata_own(&self) -> MataOwn<T> {
let (rows, cols) = (self.shape()[0], self.shape()[1]);
let mut data = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
data.push(self[[i, j]].clone());
}
}
let ptr = data.as_mut_ptr();
std::mem::forget(data);
MataOwn {
data: ptr,
shape: (rows, cols),
strides: (self.strides()[0], self.strides()[1]),
row_major: true,
}
}
}
#[cfg(feature = "with-nalgebra")]
impl<T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> FromMataOwn<T>
for nalgebra::DMatrix<T>
{
fn from_mata_own(mat: MataOwn<T>) -> nalgebra::DMatrix<T> {
let (rows, cols) = mat.shape;
let len = rows * cols;
let vec = unsafe { Vec::from_raw_parts(mat.data, len, len) };
let row_major = mat.row_major;
std::mem::forget(mat);
if row_major {
nalgebra::DMatrix::from_fn(rows, cols, |i, j| vec[i * cols + j].clone())
} else {
nalgebra::DMatrix::from_vec(rows, cols, vec)
}
}
}
#[cfg(feature = "with-nalgebra")]
impl<T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static> IntoMataOwn<T>
for nalgebra::DMatrix<T>
{
fn into_mata_own(&self) -> MataOwn<T> {
let (rows, cols) = self.shape();
let mut data = Vec::with_capacity(rows * cols);
for j in 0..cols {
for i in 0..rows {
data.push(self[(i, j)].clone());
}
}
let ptr = data.as_mut_ptr();
std::mem::forget(data);
MataOwn {
data: ptr,
shape: (rows, cols),
strides: (self.strides().0 as isize, self.strides().1 as isize),
row_major: false,
}
}
}
#[cfg(feature = "with-mdarray")]
impl<T: Clone, D0: mdarray::Dim, D1: mdarray::Dim> FromMataOwn<T> for mdarray::Array<T, (D0, D1)> {
fn from_mata_own(mat: MataOwn<T>) -> mdarray::Tensor<T, (D0, D1)> {
let mat = std::mem::ManuallyDrop::new(mat);
let shape = <(D0, D1) as mdarray::Shape>::from_dims(&[mat.shape.0, mat.shape.1]);
let map = mdarray::DenseMapping::new(shape);
unsafe {
mdarray::ViewMut::<T, (D0, D1), mdarray::Dense>::new_unchecked(mat.data, map).to_owned()
}
}
}
#[cfg(feature = "with-mdarray")]
impl<T: Clone, D0: mdarray::Dim, D1: mdarray::Dim> IntoMataOwn<T> for mdarray::Tensor<T, (D0, D1)> {
fn into_mata_own(&self) -> MataOwn<T> {
let shape = *self.shape();
let (rows, cols) = (shape.0.size(), shape.1.size());
let mut data = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
data.push(self[[i, j]].clone());
}
}
let ptr = data.as_mut_ptr();
std::mem::forget(data);
MataOwn {
data: ptr,
shape: (rows, cols),
strides: (cols as isize, 1),
row_major: true,
}
}
}
#[cfg(feature = "with-faer")]
impl<T: Clone> FromMataOwn<T> for faer::Mat<T> {
fn from_mata_own(mat: MataOwn<T>) -> faer::Mat<T> {
let (rows, cols) = mat.shape;
let vec = unsafe { Vec::from_raw_parts(mat.data, rows * cols, rows * cols) };
let row_major = mat.row_major;
std::mem::forget(mat);
if row_major {
faer::Mat::from_fn(rows, cols, |i, j| vec[i * cols + j].clone())
} else {
faer::Mat::from_fn(rows, cols, |i, j| vec[j * rows + i].clone())
}
}
}
#[cfg(feature = "with-faer")]
impl<T: Clone> IntoMataOwn<T> for faer::Mat<T> {
fn into_mata_own(&self) -> MataOwn<T> {
let (rows, cols) = self.shape();
let mut data = Vec::with_capacity(rows * cols);
for i in 0..rows {
for j in 0..cols {
data.push(self[(i, j)].clone());
}
}
let ptr = data.as_mut_ptr();
std::mem::forget(data);
MataOwn {
data: ptr,
shape: (rows, cols),
strides: (cols as isize, 1),
row_major: true,
}
}
}
pub trait MataConvertOwn<T> {
#[cfg(feature = "with-mdarray")]
fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(&self) -> mdarray::Tensor<T, (D0, D1)>;
#[cfg(feature = "with-ndarray")]
fn to_ndarray(&self) -> ndarray::Array2<T>;
#[cfg(feature = "with-nalgebra")]
fn to_nalgebra(&self) -> nalgebra::DMatrix<T>;
#[cfg(feature = "with-faer")]
fn to_faer(&self) -> faer::Mat<T>;
}
macro_rules! impl_owned_conversions {
($source_type:ty, $feature:literal) => {
#[cfg(feature = $feature)]
impl<T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static + num_traits::Zero>
MataConvertOwn<T> for $source_type
{
fn to_mdarray<D0: mdarray::Dim, D1: mdarray::Dim>(
&self,
) -> mdarray::Tensor<T, (D0, D1)> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
fn to_ndarray(&self) -> ndarray::Array2<T> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
fn to_nalgebra(&self) -> nalgebra::DMatrix<T> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
fn to_faer(&self) -> faer::Mat<T> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
}
};
}
#[cfg(feature = "with-mdarray")]
impl<
T: Clone + std::fmt::Debug + std::cmp::PartialEq + 'static + num_traits::Zero,
D0: mdarray::Dim,
D1: mdarray::Dim,
> MataConvertOwn<T> for mdarray::Tensor<T, (D0, D1)>
{
fn to_mdarray<D0out: mdarray::Dim, D1out: mdarray::Dim>(
&self,
) -> mdarray::Tensor<T, (D0out, D1out)> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
fn to_ndarray(&self) -> ndarray::Array2<T> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
fn to_nalgebra(&self) -> nalgebra::DMatrix<T> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
fn to_faer(&self) -> faer::Mat<T> {
let mat_own: MataOwn<T> = self.into_mata_own();
FromMataOwn::from_mata_own(mat_own)
}
}
impl_owned_conversions!(ndarray::Array2<T>, "with-ndarray");
impl_owned_conversions!(nalgebra::DMatrix<T>, "with-nalgebra");
impl_owned_conversions!(faer::Mat<T>, "with-faer");