use super::super::{CudaClient, CudaRuntime};
use crate::algorithm::linalg::tensor_decompose_core::{
TensorDecomposeDTypeSupport, cp_decompose_impl, cp_reconstruct_impl, fold_impl, hosvd_impl,
mode_n_product_impl, tensor_train_impl, tt_reconstruct_impl, tucker_impl,
tucker_reconstruct_impl, unfold_impl,
};
use crate::algorithm::linalg::{
CpDecomposition, CpOptions, TensorDecomposeAlgorithms, TensorTrainDecomposition,
TuckerDecomposition, TuckerOptions,
};
use crate::error::Result;
use crate::tensor::Tensor;
impl TensorDecomposeAlgorithms<CudaRuntime> for CudaClient {
fn unfold(&self, tensor: &Tensor<CudaRuntime>, mode: usize) -> Result<Tensor<CudaRuntime>> {
unfold_impl(tensor, mode, TensorDecomposeDTypeSupport::FULL)
}
fn fold(
&self,
matrix: &Tensor<CudaRuntime>,
mode: usize,
shape: &[usize],
) -> Result<Tensor<CudaRuntime>> {
fold_impl(matrix, mode, shape)
}
fn mode_n_product(
&self,
tensor: &Tensor<CudaRuntime>,
matrix: &Tensor<CudaRuntime>,
mode: usize,
) -> Result<Tensor<CudaRuntime>> {
mode_n_product_impl(
self,
tensor,
matrix,
mode,
TensorDecomposeDTypeSupport::FULL,
)
}
fn hosvd(
&self,
tensor: &Tensor<CudaRuntime>,
ranks: &[usize],
) -> Result<TuckerDecomposition<CudaRuntime>> {
hosvd_impl(self, tensor, ranks, TensorDecomposeDTypeSupport::FULL)
}
fn tucker(
&self,
tensor: &Tensor<CudaRuntime>,
ranks: &[usize],
options: TuckerOptions,
) -> Result<TuckerDecomposition<CudaRuntime>> {
tucker_impl(
self,
tensor,
ranks,
options,
TensorDecomposeDTypeSupport::FULL,
)
}
fn cp_decompose(
&self,
tensor: &Tensor<CudaRuntime>,
rank: usize,
options: CpOptions,
) -> Result<CpDecomposition<CudaRuntime>> {
cp_decompose_impl(
self,
tensor,
rank,
options,
TensorDecomposeDTypeSupport::FULL,
)
}
fn tensor_train(
&self,
tensor: &Tensor<CudaRuntime>,
max_rank: usize,
tolerance: f64,
) -> Result<TensorTrainDecomposition<CudaRuntime>> {
tensor_train_impl(
self,
tensor,
max_rank,
tolerance,
TensorDecomposeDTypeSupport::FULL,
)
}
fn tucker_reconstruct(
&self,
decomp: &TuckerDecomposition<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
tucker_reconstruct_impl(self, decomp, TensorDecomposeDTypeSupport::FULL)
}
fn cp_reconstruct(
&self,
decomp: &CpDecomposition<CudaRuntime>,
shape: &[usize],
) -> Result<Tensor<CudaRuntime>> {
cp_reconstruct_impl(self, decomp, shape, TensorDecomposeDTypeSupport::FULL)
}
fn tt_reconstruct(
&self,
decomp: &TensorTrainDecomposition<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
tt_reconstruct_impl(self, decomp)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::Runtime;
fn get_client() -> CudaClient {
let device = CudaRuntime::default_device();
CudaRuntime::default_client(&device)
}
#[test]
fn test_unfold_fold_roundtrip() {
let client = get_client();
let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
let tensor = Tensor::<CudaRuntime>::from_slice(&data, &[2, 3, 4], &client.device);
for mode in 0..3 {
let unfolded = client.unfold(&tensor, mode).expect("unfold should succeed");
let refolded = client
.fold(&unfolded, mode, &[2, 3, 4])
.expect("fold should succeed");
let orig: Vec<f32> = tensor.to_vec();
let result: Vec<f32> = refolded.to_vec();
assert_eq!(orig.len(), result.len());
for (a, b) in orig.iter().zip(result.iter()) {
assert!((a - b).abs() < 1e-6, "Mode {} roundtrip failed", mode);
}
}
}
#[test]
fn test_mode_n_product() {
let client = get_client();
let tensor = Tensor::<CudaRuntime>::from_slice(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
&[2, 3],
&client.device,
);
let matrix = Tensor::<CudaRuntime>::from_slice(
&[1.0f32, 0.0, 0.0, 1.0, 1.0, 1.0, -1.0, 1.0],
&[4, 2],
&client.device,
);
let result = client
.mode_n_product(&tensor, &matrix, 0)
.expect("mode_n_product should succeed");
assert_eq!(result.shape(), &[4, 3]);
}
#[test]
fn test_hosvd_basic() {
let client = get_client();
let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
let tensor = Tensor::<CudaRuntime>::from_slice(&data, &[2, 3, 4], &client.device);
let decomp = client
.hosvd(&tensor, &[2, 3, 4])
.expect("hosvd should succeed");
assert_eq!(decomp.core.shape(), &[2, 3, 4]);
assert_eq!(decomp.factors.len(), 3);
assert_eq!(decomp.factors[0].shape(), &[2, 2]);
assert_eq!(decomp.factors[1].shape(), &[3, 3]);
assert_eq!(decomp.factors[2].shape(), &[4, 4]);
let reconstructed = client
.tucker_reconstruct(&decomp)
.expect("reconstruction should succeed");
let orig: Vec<f32> = tensor.to_vec();
let recon: Vec<f32> = reconstructed.to_vec();
for (a, b) in orig.iter().zip(recon.iter()) {
assert!(
(a - b).abs() < 1e-4,
"HOSVD reconstruction error: {} vs {}",
a,
b
);
}
}
#[test]
fn test_cp_decompose_basic() {
let client = get_client();
let tensor = Tensor::<CudaRuntime>::from_slice(
&[1.0f32, 1.0, 1.0, 2.0, 2.0, 2.0],
&[2, 3],
&client.device,
);
let decomp = client
.cp_decompose(&tensor, 1, CpOptions::default())
.expect("cp_decompose should succeed");
assert_eq!(decomp.factors.len(), 2);
assert_eq!(decomp.factors[0].shape(), &[2, 1]);
assert_eq!(decomp.factors[1].shape(), &[3, 1]);
assert_eq!(decomp.weights.shape(), &[1]);
let reconstructed = client
.cp_reconstruct(&decomp, &[2, 3])
.expect("reconstruction should succeed");
let orig: Vec<f32> = tensor.to_vec();
let recon: Vec<f32> = reconstructed.to_vec();
for (a, b) in orig.iter().zip(recon.iter()) {
assert!(
(a - b).abs() < 1e-2,
"CP reconstruction error: {} vs {}",
a,
b
);
}
}
#[test]
fn test_tensor_train_basic() {
let client = get_client();
let data: Vec<f32> = (0..24).map(|x| x as f32).collect();
let tensor = Tensor::<CudaRuntime>::from_slice(&data, &[2, 3, 4], &client.device);
let decomp = client
.tensor_train(&tensor, 0, 0.0)
.expect("tensor_train should succeed");
assert_eq!(decomp.cores.len(), 3);
assert_eq!(decomp.cores[0].shape()[0], 1);
assert_eq!(decomp.cores[0].shape()[1], 2);
assert_eq!(decomp.cores[2].shape()[1], 4);
assert_eq!(decomp.cores[2].shape()[2], 1);
let reconstructed = client
.tt_reconstruct(&decomp)
.expect("reconstruction should succeed");
let orig: Vec<f32> = tensor.to_vec();
let recon: Vec<f32> = reconstructed.to_vec();
for (a, b) in orig.iter().zip(recon.iter()) {
assert!(
(a - b).abs() < 1e-4,
"TT reconstruction error: {} vs {}",
a,
b
);
}
}
}