#![allow(clippy::excessive_precision)]
extern crate tblis_src;
use num::complex::ComplexFloat;
use rstest::rstest;
use tblis::prelude::*;
#[allow(non_camel_case_types)]
type c64 = num::Complex<f64>;
fn gen_array(shape: &[isize]) -> Vec<c64> {
let size = shape.iter().product();
(0..size).map(|i| c64::new((i as f64 + 0.2).cos(), -0.5 * (i as f64 + 0.4).sin())).collect()
}
fn fp(vec: &[c64]) -> c64 {
vec.iter().enumerate().map(|(i, &x)| (i as f64).cos() * x).sum()
}
fn build_shape(s: &str) -> Vec<isize> {
let valid_chars = "abcdefghijklmnopqABC";
let sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4, 9, 10, 2, 4, 5, 3, 2, 6];
s.chars().map(|c| valid_chars.find(c).map(|i| sizes[i]).unwrap_or(1)).collect()
}
fn build_strides(shape: &[isize]) -> Vec<isize> {
let mut strides = vec![1isize; shape.len()];
if shape.is_empty() {
return strides;
}
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
fn build_tblis_tensor(s: &str) -> (Vec<c64>, TblisTensor<c64>) {
let shape = build_shape(s);
let strides = build_strides(&shape);
let mut data = gen_array(&shape);
let tensor = TblisTensor::new(data.as_mut_ptr(), &shape, &strides);
(data, tensor)
}
#[rstest]
#[case("a,->a" , vec![2] , ( 1.0626633964750658, -0.6806902927156129))]
#[case("ab,->ab" , vec![2, 3] , ( 2.6791295309972907, -1.1926934244521328))]
#[case(",ab,->ab" , vec![2, 3] , ( 2.3934969629216369, -1.6905700533121117))]
#[case(",,->" , vec![] , ( 0.8299159805637262, -0.5536905426638488))]
#[case("a,ab,abc->abc" , vec![2, 3, 4] , ( 1.5336587277409919, -0.8487366664981919))]
#[case("a,b,ab->ab" , vec![2, 3] , ( 0.5769339284293409, -1.5532266817002940))]
#[case("ea,fb,gc,hd,abcd->efgh" , vec![4, 3, 2, 6] , ( 9.2102305248174847, 6.6918481319956342))]
#[case("ea,fb,abcd,gc,hd->efgh" , vec![4, 3, 2, 6] , ( 9.2102305248174847, 6.6918481319956342))]
#[case("abcd,ea,fb,gc,hd->efgh" , vec![4, 3, 2, 6] , ( 9.2102305248174883, 6.6918481319956307))]
#[case("acdf,jbje,gihb,hfac,gfac,gifabc,hfac" , vec![5, 4] , ( -0.2116097242816026, 0.0454024855951956))]
#[case("acdf,jbje,gihb,hfac,gfac,gifabc,hfac" , vec![5, 4] , ( -0.2116097242816026, 0.0454024855951956))]
#[case("cd,bdhe,aidb,hgca,gc,hgibcd,hgac" , vec![4] , ( -7.7635629821390308, 8.5154775536332945))]
#[case("abhe,hidj,jgba,hiab,gab" , vec![5, 4] , (-21.3091287330690449, -31.1676144095302625))]
#[case("bde,cdh,agdb,hica,ibd,hgicd,hiac" , vec![4] , ( -3.1103334757846168, 0.9019517948241862))]
#[case("chd,bde,agbc,hiad,hgc,hgi,hiad" , vec![4] , ( -3.9154496906531921, -0.3719312577909981))]
#[case("chd,bde,agbc,hiad,bdi,cgh,agdb" , vec![4] , ( 21.8169594386688530, -0.8392270694101933))]
#[case("bdhe,acad,hiab,agac,hibd" , vec![4, 2] , ( 0.0714978291829508, 0.0179331909288974))]
#[case("ab,ab,c->" , vec![] , ( -1.1440117473937714, -1.7203308328855442))]
#[case("ab,ab,c->c" , vec![4] , ( 4.7647955571612179, -2.6302597509298922))]
#[case("ab,ab,cd,cd->" , vec![] , ( 15.2981766805646124, -10.6028689530129476))]
#[case("ab,ab,cd,cd->ac" , vec![2, 4] , ( 2.4818565990824020, -1.6167222863631672))]
#[case("ab,ab,cd,cd->cd" , vec![4, 5] , ( 1.6820837207751058, -1.6635100382781292))]
#[case("ab,ab,cd,cd,ef,ef->" , vec![] , ( 50.6236785641111808, -68.0228858281206357))]
#[case("ab,cd,ef->abcdef" , vec![2, 3, 4, 5, 4, 3] , ( 8.0160220339714989, 9.8106351580935769))]
#[case("ab,cd,ef->acdf" , vec![2, 4, 5, 3] , ( -0.3189002168829093, -0.5698001035414051))]
#[case("ab,cd,de->abcde" , vec![2, 3, 4, 5, 4] , ( 3.3311677876018635, -0.9646556190738309))]
#[case("ab,cd,de->be" , vec![3, 4] , ( 0.0060456268678287, 0.0315858071478905))]
#[case("ab,bcd,cd->abcd" , vec![2, 3, 4, 5] , ( 0.7947390762604263, -1.8313657121740694))]
#[case("ab,bcd,cd->abd" , vec![2, 3, 5] , ( 1.1177831190650014, -0.4536529682239876))]
#[case("eb,cb,fb->cef" , vec![4, 4, 3] , ( -0.3417214715065875, -0.2989636671893672))]
#[case("dd,fb,be,cdb->cef" , vec![4, 4, 3] , ( 0.0614607960345380, -0.0292966395473457))]
#[case("bca,cdb,dbf,afc->" , vec![] , (-11.1976419920901726, 2.6675195489355876))]
#[case("dcc,fce,ea,dbf->ab" , vec![2, 3] , ( -0.0254992770945275, 0.1291919328619402))]
#[case("fdf,cdd,ccd,afe->ae" , vec![2, 4] , ( -2.3807100283394025, -1.7309388790228271))]
#[case("abcd,ad" , vec![3, 4] , ( -4.2814335847278109, -0.2423568704494487))]
#[case("ed,fcd,ff,bcf->be" , vec![3, 4] , ( 0.0845024493295495, -1.8943171230788358))]
#[case("baa,dcf,af,cde->be" , vec![3, 4] , ( 3.8735271235695636, -5.4627841204381227))]
#[case("bd,db,eac->ace" , vec![2, 4, 4] , ( 0.7557248790243568, 0.9129313989409814))]
#[case("fff,fae,bef,def->abd" , vec![2, 3, 5] , ( -3.2545155626795133, 1.4559659254341435))]
#[case("efc,dbc,acf,fd->abe" , vec![2, 3, 4] , ( 0.0046658477715147, -1.0444897081753224))]
#[case("ab,ab" , vec![] , ( 2.1078679480925979, -0.7008159330640431))]
#[case("ab,ba" , vec![] , ( 1.1922381853320942, 0.0341064650010012))]
#[case("abc,abc" , vec![] , ( 9.0550386293676404, -2.9217348330317572))]
#[case("abc,bac" , vec![] , ( 0.5665941613556231, 0.1918424986448617))]
#[case("abc,cba" , vec![] , ( 0.7427667439638136, -8.0266281535986987))]
#[case("ab,bc" , vec![2, 4] , ( 3.0636698684468722, 2.4485748049856646))]
#[case("ab,cb" , vec![2, 4] , ( 1.7712916026623675, -0.5978413636325484))]
#[case("ba,bc" , vec![2, 4] , ( 1.0343214702468668, 0.9034628404673370))]
#[case("ba,cb" , vec![2, 4] , ( 1.0008955702062043, 0.0746914622620528))]
#[case("abcd,cd" , vec![2, 3] , ( 21.7844254710988672, -6.8920102056717578))]
#[case("abcd,ab" , vec![4, 5] , ( 22.3929921840377375, -5.3370424046114540))]
#[case("abcd,cdef" , vec![2, 3, 4, 3] , (-17.5109719823750396, -3.7862458273559318))]
#[case("abcd,cdef->feba" , vec![3, 4, 3, 2] , ( 0.8325862854248018, -2.9856704587480039))]
#[case("abcd,efdc" , vec![2, 3, 4, 3] , ( 4.1316137974856151, -2.1183100696273223))]
#[case("aab,bc->ac" , vec![2, 4] , ( 2.6122198655941657, 2.4070577835099365))]
#[case("ab,bcc->ac" , vec![2, 4] , ( 0.5536064032919279, 0.8165245266829665))]
#[case("aab,bcc->ac" , vec![2, 4] , ( 0.4283036047645465, 0.8328438191931101))]
#[case("baa,bcc->ac" , vec![2, 4] , ( 4.0702573007696854, -4.8086334738516854))]
#[case("aab,ccb->ac" , vec![2, 4] , ( 0.2599457596087409, -0.0972866664359448))]
#[case("aab,fa,df,ecc->bde" , vec![3, 5, 4] , ( -0.6452535857516806, 0.2701805091737071))]
#[case("ecb,fef,bad,ed->ac" , vec![2, 4] , ( -2.0991536924824938, -0.6101631190896243))]
#[case("bcf,bbb,fbf,fc->" , vec![] , ( 0.3689725389646942, -0.4863879440986705))]
#[case("bb,ff,be->e" , vec![4] , ( 0.1216715120397994, -0.5219532921091159))]
#[case("bcb,bb,fc,fff->" , vec![] , ( 0.0496840348532896, 0.0876400640479318))]
#[case("fbb,dfd,fc,fc->" , vec![] , ( -0.2019588423252376, 0.0520794854329894))]
#[case("afd,ba,cc,dc->bf" , vec![3, 3] , ( -0.2164303183208181, 1.7498299453944053))]
#[case("adb,bc,fa,cfc->d" , vec![5] , ( 0.4821755077243144, -0.2166447771592562))]
#[case("bbd,bda,fc,db->acf" , vec![2, 4, 3] , ( 4.4800763340690981, 2.9909850932868034))]
#[case("dba,ead,cad->bce" , vec![3, 4, 4] , ( 1.2455232197722861, 0.6959454118311303))]
#[case("aef,fbc,dca->bde" , vec![3, 5, 4] , ( 2.6115342498937100, 0.1558997986008280))]
#[case("ea->ea" , vec![4, 2] , ( 4.2779003835689338, -1.0430003580983946))]
#[case("fb->fb" , vec![3, 3] , ( 4.3272474273084374, -0.9808282730885702))]
#[case("abcd->dcab" , vec![5, 4, 2, 3] , ( 5.0109395102756977, -0.9966085998460459))]
#[case("gc->cg" , vec![4, 2] , ( 0.2680112599856153, -0.1269641467186604))]
#[case("hd->dh" , vec![5, 6] , ( 4.4819580717183714, -3.3460245221870828))]
#[case("efgh->hfge" , vec![6, 3, 2, 4] , ( 2.2237855229147128, 3.8833637852830840))]
#[case("acdf->afcd" , vec![2, 3, 4, 5] , ( 3.2392076661644174, -0.6296964164425971))]
#[case("gihb->ghib" , vec![2, 6, 5, 3] , ( -2.1102327279825022, 0.4224646504787080))]
#[case("hfac->cfah" , vec![4, 3, 2, 6] , ( 2.2237855229146861, 3.8833637852830756))]
#[case("gfac->cgaf" , vec![4, 2, 2, 3] , ( 4.0858816458932772, -0.5554982976507628))]
#[case("gifabc->abifcg" , vec![2, 3, 5, 3, 4, 2] , ( -0.8209082897628119, 0.1283066691603100))]
#[case("hfac->cfha" , vec![4, 3, 6, 2] , ( -1.3035489268162332, -2.2966898260907733))]
fn test_einsum(#[case] einsum_str: &str, #[case] ref_shape: Vec<isize>, #[case] ref_fp: (f64, f64)) {
let ref_fp = c64::new(ref_fp.0, ref_fp.1);
let einsum_str_inp = einsum_str.split("->").next().unwrap().split(',').collect::<Vec<&str>>();
let tensors = einsum_str_inp.iter().map(|s| build_tblis_tensor(s)).collect::<Vec<_>>();
let tblis_tensors = tensors.iter().map(|(_, t)| t).collect::<Vec<_>>();
let (out_data, out_tensor) = unsafe { tblis_einsum(einsum_str, &tblis_tensors, true, None, true, None).unwrap() };
let out_fp = fp(&out_data);
assert_eq!(out_tensor.shape, ref_shape);
assert!((out_fp - ref_fp).abs() < 1e-10);
}