use crate::Float;
use ndarray::{ArrayBase, Data, Dimension, OwnedRepr, ViewRepr};
pub trait WithLapack<D: Data + WithLapackData, I: Dimension> {
fn with_lapack(self) -> ArrayBase<D::D, I>;
}
impl<F: Float, D, I> WithLapack<D, I> for ArrayBase<D, I>
where
D: Data<Elem = F> + WithLapackData,
I: Dimension,
{
fn with_lapack(self) -> ArrayBase<D::D, I> {
D::with_lapack(self)
}
}
pub trait WithoutLapack<F: Float, D: Data + WithoutLapackData<F>, I: Dimension> {
fn without_lapack(self) -> ArrayBase<D::D, I>;
}
impl<F: Float, D, I> WithoutLapack<F, D, I> for ArrayBase<D, I>
where
D: Data<Elem = F::Lapack> + WithoutLapackData<F>,
I: Dimension,
{
fn without_lapack(self) -> ArrayBase<D::D, I> {
D::without_lapack(self)
}
}
unsafe fn transmute<A, B>(a: A) -> B {
let b = std::ptr::read(&a as *const A as *const B);
std::mem::forget(a);
b
}
pub trait WithLapackData
where
Self: Data,
{
type D: Data;
fn with_lapack<I>(x: ArrayBase<Self, I>) -> ArrayBase<Self::D, I>
where
I: Dimension,
{
unsafe { transmute(x) }
}
}
impl<F: Float> WithLapackData for OwnedRepr<F> {
type D = OwnedRepr<F::Lapack>;
}
impl<'a, F: Float> WithLapackData for ViewRepr<&'a F> {
type D = ViewRepr<&'a F::Lapack>;
}
impl<'a, F: Float> WithLapackData for ViewRepr<&'a mut F> {
type D = ViewRepr<&'a mut F::Lapack>;
}
pub trait WithoutLapackData<F: Float>
where
Self: Data,
{
type D: Data<Elem = F>;
fn without_lapack<I>(x: ArrayBase<Self, I>) -> ArrayBase<Self::D, I>
where
I: Dimension,
{
unsafe { transmute(x) }
}
}
impl<F: Float> WithoutLapackData<F> for OwnedRepr<F::Lapack> {
type D = OwnedRepr<F>;
}
impl<'a, F: Float> WithoutLapackData<F> for ViewRepr<&'a F::Lapack> {
type D = ViewRepr<&'a F>;
}
impl<'a, F: Float> WithoutLapackData<F> for ViewRepr<&'a mut F::Lapack> {
type D = ViewRepr<&'a mut F>;
}
#[cfg(test)]
mod tests {
use ndarray::Array2;
#[cfg(feature = "ndarray-linalg")]
use ndarray_linalg::eig::*;
use super::{WithLapack, WithoutLapack};
#[test]
fn memory_check() {
let a: Array2<f32> = Array2::zeros((20, 20));
let a: Array2<f32> = a.with_lapack();
assert_eq!(a.shape(), &[20, 20]);
let b: Array2<f32> = a.clone().without_lapack();
assert_eq!(b, a);
}
#[cfg(feature = "ndarray-linalg")]
#[test]
fn lapack_exists() {
let a: Array2<f32> = Array2::zeros((4, 4));
let a: Array2<f32> = a.with_lapack();
let (_a, _b) = a.eig().unwrap();
}
}