use crate::prelude::*;
use duplicate::duplicate_item;
use ndarray::prelude::*;
use ndarray::Data;
use opt_einsum_path::typing::SizeLimitType;
use opt_einsum_path::PathOptimizer;
#[duplicate_item(ArrayItem; [ArrayBase<S, D>]; [&ArrayBase<S, D>])]
impl<S, T, D> ToTblisTensor<T> for ArrayItem
where
S: Data<Elem = T>,
T: TblisFloatAPI,
D: Dimension,
{
fn to_tblis_tensor(&self) -> TblisTensor<T> {
let view = self.view();
let shape = view.shape().iter().map(|&s| s as isize).collect::<Vec<isize>>();
let stride = ArrayBase::strides(&view).to_vec();
let data_ptr = view.as_ptr() as *mut T;
TblisTensor::new(data_ptr, &shape, &stride)
}
}
pub trait ArrayFromTblisTensor {
type Out;
fn into_array(self) -> Self::Out;
}
impl<T> ArrayFromTblisTensor for (Vec<T>, TblisTensor<T>)
where
T: TblisFloatAPI,
{
type Out = ArrayD<T>;
fn into_array(self) -> ArrayD<T> {
let (vec, tsr) = self;
let shape = tsr.shape.iter().map(|&s| s as usize).collect::<Vec<usize>>();
ArrayD::from_shape_vec(IxDyn(&shape), vec).unwrap()
}
}
pub fn array_from_tblis_tensor<T>(dat: (Vec<T>, TblisTensor<T>)) -> ArrayD<T>
where
T: TblisFloatAPI,
{
dat.into_array()
}
pub fn tblis_einsum_ndarray<T, A>(
subscripts: &str,
operands: &[&A],
optimize: impl PathOptimizer,
memory_limit: impl Into<SizeLimitType>,
row_major: bool,
out: Option<ArrayViewMutD<T>>,
) -> Option<ArrayD<T>>
where
T: TblisFloatAPI,
A: ToTblisTensor<T>,
{
tblis_einsum_ndarray_f(subscripts, operands, optimize, memory_limit, row_major, out).unwrap()
}
pub fn tblis_einsum_ndarray_f<T, A>(
subscripts: &str,
operands: &[&A],
optimize: impl PathOptimizer,
memory_limit: impl Into<SizeLimitType>,
row_major: bool,
out: Option<ArrayViewMutD<T>>,
) -> Result<Option<ArrayD<T>>, String>
where
T: TblisFloatAPI,
A: ToTblisTensor<T>,
{
let tblis_operands: Vec<TblisTensor<T>> = operands.iter().map(|x| x.to_tblis_tensor()).collect();
let tblis_operands_ref: Vec<&TblisTensor<T>> = tblis_operands.iter().collect();
let mut out_tblis_tensor = out.map(|x| x.to_tblis_tensor());
let res = unsafe {
tblis_einsum_f(subscripts, &tblis_operands_ref, optimize, memory_limit, row_major, out_tblis_tensor.as_mut())
};
match res {
Ok(Some(out)) => Ok(Some(out.into_array())),
Ok(None) => Ok(None),
Err(e) => Err(e),
}
}
#[cfg(test)]
#[cfg(feature = "ndarray")]
mod test_ndarray_native {
#[test]
#[allow(clippy::let_and_return)]
fn test_ndarray_native() {
use crate::prelude::*;
use ndarray::prelude::*;
let (nao, nmo): (usize, usize) = (3, 2);
let vec_c: Vec<f64> = (0..nao * nmo).map(|x| x as f64).collect();
let vec_e: Vec<f64> = (0..nao * nao * nao * nao).map(|x| x as f64).collect();
let arr_c = ArrayView2::from_shape((nao, nmo), &vec_c).unwrap();
let arr_e = ArrayView4::from_shape((nao, nao, nao, nao), &vec_e).unwrap();
fn ao2mo(arr_c: ArrayView2<f64>, arr_e: ArrayView4<f64>) -> Array4<f64> {
let view_c = arr_c.view().into_dyn();
let view_e = arr_e.view().into_dyn();
let operands = [&view_c, &view_c, &view_e, &view_c, &view_c];
let arr_g = tblis_einsum_ndarray(
"μi,νa,μνκλ,κj,λb->iajb", &operands, "optimal", None, true, None, )
.unwrap();
arr_g.into_dimensionality().unwrap()
}
let arr_g = ao2mo(arr_c, arr_e);
println!("{:?}", arr_g);
}
#[test]
#[allow(clippy::let_and_return)]
fn test_ndarray_workable() {
use crate::prelude::*;
use ndarray::prelude::*;
let (nao, nmo): (usize, usize) = (3, 2);
let vec_c: Vec<f64> = (0..nao * nmo).map(|x| x as f64).collect();
let vec_e: Vec<f64> = (0..nao * nao * nao * nao).map(|x| x as f64).collect();
let arr_c = ArrayView2::from_shape((nao, nmo), &vec_c).unwrap();
let arr_e = ArrayView4::from_shape((nao, nao, nao, nao), &vec_e).unwrap();
fn ao2mo(arr_c: ArrayView2<f64>, arr_e: ArrayView4<f64>) -> Array4<f64> {
let tsr_c = arr_c.to_tblis_tensor();
let tsr_e = arr_e.to_tblis_tensor();
let operands = [&tsr_c, &tsr_c, &tsr_e, &tsr_c, &tsr_c];
let out_g = unsafe {
tblis_einsum(
"μi,νa,μνκλ,κj,λb->iajb", &operands, "optimal", None, true, None, )
};
let (vec_g, tsr_g) = out_g.unwrap();
let arr_g = (vec_g, tsr_g).into_array().into_dimensionality().unwrap();
arr_g
}
let arr_g = ao2mo(arr_c, arr_e);
println!("{:?}", arr_g);
}
}