use crate::prelude_dev::*;
use num::ToPrimitive;
impl<R, T, B, D> TensorAny<R, T, B, D>
where
R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
D: DimAPI + DimSmallerOneAPI,
D::SmallerOne: DimAPI,
B: DeviceAPI<T> + OpPackTriAPI<T> + DeviceCreationAnyAPI<T>,
{
pub fn pack_tri_f(&self, uplo: FlagUpLo) -> Result<Tensor<T, B, D::SmallerOne>> {
let lb = self.layout().to_dim::<IxD>()?;
let default_order = self.device().default_order();
let la_shape = match default_order {
RowMajor => {
let (lb_rest, lb_inner) = lb.dim_split_at(-2)?;
rstsr_assert_eq!(
lb_inner.shape()[0],
lb_inner.shape()[1],
InvalidLayout,
"Last two dimensions should be the same for pack_tri."
)?;
let n: usize = lb_inner.shape()[0];
let n_tp = n * (n + 1) / 2;
let mut la_shape = lb_rest.shape().to_vec();
la_shape.push(n_tp);
la_shape
},
ColMajor => {
let (lb_inner, lb_rest) = lb.dim_split_at(2)?;
rstsr_assert_eq!(
lb_inner.shape()[0],
lb_inner.shape()[1],
InvalidLayout,
"First two dimensions should be the same for pack_tri."
)?;
let n: usize = lb_inner.shape()[0];
let n_tp = n * (n + 1) / 2;
let mut la_shape = vec![n_tp];
la_shape.append(&mut lb_rest.shape().to_vec());
la_shape
},
};
let la = match (lb.c_prefer(), lb.f_prefer()) {
(true, false) => la_shape.c(),
(false, true) => la_shape.f(),
_ => match self.device().default_order() {
RowMajor => la_shape.c(),
ColMajor => la_shape.f(),
},
};
let device = self.device();
let mut storage_a = device.uninit_impl(la.bounds_index()?.1)?;
device.pack_tri(storage_a.raw_mut(), &la, self.raw(), &lb, uplo)?;
let storage_a = unsafe { B::assume_init_impl(storage_a)? };
Tensor::new_f(storage_a, la.into_dim()?)
}
pub fn pack_tri(&self, uplo: FlagUpLo) -> Tensor<T, B, D::SmallerOne> {
self.pack_tri_f(uplo).rstsr_unwrap()
}
pub fn pack_tril_f(&self) -> Result<Tensor<T, B, D::SmallerOne>> {
self.pack_tri_f(FlagUpLo::L)
}
pub fn pack_tril(&self) -> Tensor<T, B, D::SmallerOne> {
self.pack_tril_f().rstsr_unwrap()
}
pub fn pack_triu_f(&self) -> Result<Tensor<T, B, D::SmallerOne>> {
self.pack_tri_f(FlagUpLo::U)
}
pub fn pack_triu(&self) -> Tensor<T, B, D::SmallerOne> {
self.pack_triu_f().rstsr_unwrap()
}
}
impl<R, T, B, D> TensorAny<R, T, B, D>
where
R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
D: DimAPI + DimLargerOneAPI,
D::LargerOne: DimAPI,
B: DeviceAPI<T> + OpUnpackTriAPI<T> + DeviceCreationAnyAPI<T>,
{
pub fn unpack_tri_f(&self, uplo: FlagUpLo, symm: FlagSymm) -> Result<Tensor<T, B, D::LargerOne>> {
let lb = self.layout().to_dim::<IxD>()?;
let default_order = self.device().default_order();
let la_shape = match default_order {
RowMajor => {
let (lb_rest, lb_inner) = lb.dim_split_at(-1)?;
let n_tp: usize = lb_inner.shape()[0];
let n: usize = (2 * n_tp).to_f64().unwrap().sqrt().floor().to_usize().unwrap();
rstsr_assert_eq!(
n * (n + 1) / 2,
n_tp,
InvalidLayout,
"Last dimension should be triangular number for unpack_tri."
)?;
let mut la_shape = lb_rest.shape().to_vec();
la_shape.append(&mut vec![n, n]);
la_shape
},
ColMajor => {
let (lb_inner, lb_rest) = lb.dim_split_at(1)?;
let n_tp: usize = lb_inner.shape()[0];
let n: usize = (2 * n_tp).to_f64().unwrap().sqrt().floor().to_usize().unwrap();
rstsr_assert_eq!(
n * (n + 1) / 2,
n_tp,
InvalidLayout,
"First dimension should be triangular number for unpack_tri."
)?;
let mut la_shape = vec![n, n];
la_shape.append(&mut lb_rest.shape().to_vec());
la_shape
},
};
let la = match (lb.c_prefer(), lb.f_prefer()) {
(true, false) => la_shape.c(),
(false, true) => la_shape.f(),
_ => match self.device().default_order() {
RowMajor => la_shape.c(),
ColMajor => la_shape.f(),
},
};
let device = self.device();
let mut storage_a = device.uninit_impl(la.bounds_index()?.1)?;
device.unpack_tri(storage_a.raw_mut(), &la, self.raw(), &lb, uplo, symm)?;
let storage_a = unsafe { B::assume_init_impl(storage_a)? };
Tensor::new_f(storage_a, la.into_dim()?)
}
pub fn unpack_tri(&self, uplo: FlagUpLo, symm: FlagSymm) -> Tensor<T, B, D::LargerOne> {
self.unpack_tri_f(uplo, symm).rstsr_unwrap()
}
pub fn unpack_tril(&self, symm: FlagSymm) -> Tensor<T, B, D::LargerOne> {
self.unpack_tri_f(FlagUpLo::L, symm).rstsr_unwrap()
}
pub fn unpack_triu(&self, symm: FlagSymm) -> Tensor<T, B, D::LargerOne> {
self.unpack_tri_f(FlagUpLo::U, symm).rstsr_unwrap()
}
pub fn unpack_tril_f(&self, symm: FlagSymm) -> Result<Tensor<T, B, D::LargerOne>> {
self.unpack_tri_f(FlagUpLo::L, symm)
}
pub fn unpack_triu_f(&self, symm: FlagSymm) -> Result<Tensor<T, B, D::LargerOne>> {
self.unpack_tri_f(FlagUpLo::U, symm)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pack_tri() {
#[cfg(not(feature = "col_major"))]
{
let a = {
let a = arange((48., &DeviceCpuSerial::default()));
let storage_a = a.into_raw_parts().0;
Tensor::new(storage_a, [3, 4, 4].f())
};
let a_triu = a.pack_tril();
println!("{a_triu:?}");
println!("{:?}", a.slice(0));
println!("{:?}", a_triu.slice(0).to_vec());
assert_eq!(a_triu.slice(1).to_vec(), [1., 4., 16., 7., 19., 31., 10., 22., 34., 46.]);
let b = a_triu.unpack_tril(FlagSymm::Sy);
println!("{b:?}");
assert_eq!(b.slice((0, 1)).to_vec(), [3., 15., 18., 21.]);
}
#[cfg(feature = "col_major")]
{
let a = {
let a = arange((48., &DeviceCpuSerial::default()));
let storage_a = a.into_raw_parts().0;
Tensor::new(storage_a, [4, 4, 3].c())
};
let a_triu = a.pack_triu();
println!("{a_triu:?}");
println!("{:?}", a.slice((.., 0)));
println!("{:?}", a_triu.slice((.., 0)).to_vec());
assert_eq!(a_triu.slice((.., 1)).to_vec(), [1., 4., 16., 7., 19., 31., 10., 22., 34., 46.]);
let b = a_triu.unpack_triu(FlagSymm::Sy);
println!("{b:?}");
assert_eq!(b.slice((.., 1, 0)).to_vec(), [3., 15., 18., 21.]);
}
}
#[test]
#[cfg(feature = "rayon")]
fn test_par_pack_tril_compiles() {
#[cfg(not(feature = "col_major"))]
{
use num::complex::c64;
let a = linspace((c64(-2.0, 1.5), c64(1.7, -2.3), 256 * 256 * 256)).into_layout([4, 64, 256, 256].f());
let a_tril = a.pack_tril();
println!("{a_tril:20.5}");
let b = a_tril.unpack_tril(FlagSymm::Ah);
println!("{b:20.5}");
}
#[cfg(feature = "col_major")]
{
use num::complex::c64;
let a = linspace((c64(-2.0, 1.5), c64(1.7, -2.3), 256 * 256 * 256)).into_layout([256, 256, 64, 4].c());
let a_tril = a.pack_tril();
println!("{a_tril:20.5}");
let b = a_tril.unpack_tril(FlagSymm::Ah);
println!("{b:20.5}");
}
}
#[test]
fn test_correctness() {
let a = {
let a = arange((16., &DeviceCpuSerial::default()));
let storage_a = a.into_raw_parts().0;
Tensor::new(storage_a, [4, 4].c())
};
println!("{a:}");
let a_tril1 = a.pack_tril();
println!("{a_tril1:}");
let a = a.to_contig(FlagOrder::F);
let a_tril2 = a.pack_tril();
println!("{a_tril2:}");
assert!((&a_tril1 - &a_tril2).l2_norm_all() < 1e-6);
}
}