use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn tucker_decomposition(
input: &Tensor<f32>,
ranks: Option<&[usize]>,
) -> TorshResult<(Tensor<f32>, Vec<Tensor<f32>>)> {
let input_shape = input.shape();
let shape = input_shape.dims();
let ndim = shape.len();
if ndim < 2 {
return Err(TorshError::invalid_argument_with_context(
"Tucker decomposition requires at least 2D tensor",
"tucker_decomposition",
));
}
let target_ranks = if let Some(r) = ranks {
if r.len() != ndim {
return Err(TorshError::invalid_argument_with_context(
&format!(
"ranks length {} must match tensor dimensions {}",
r.len(),
ndim
),
"tucker_decomposition",
));
}
r.to_vec()
} else {
shape.iter().map(|&s| (s / 2).max(1)).collect()
};
let mut factor_matrices = Vec::with_capacity(ndim);
for (mode, &rank) in target_ranks.iter().enumerate() {
let unfolded = unfold_tensor(input, mode)?;
let (u, _s, _vt) = crate::linalg::svd(&unfolded, false)?;
let factor = truncate_matrix(&u, rank)?;
factor_matrices.push(factor);
}
let mut core = input.clone();
for (mode, factor) in factor_matrices.iter().enumerate() {
core = mode_product(&core, factor, mode, true)?;
}
Ok((core, factor_matrices))
}
fn unfold_tensor(tensor: &Tensor<f32>, mode: usize) -> TorshResult<Tensor<f32>> {
let tensor_shape = tensor.shape();
let shape = tensor_shape.dims();
let ndim = shape.len();
if mode >= ndim {
return Err(TorshError::invalid_argument_with_context(
&format!(
"mode {} out of range for tensor with {} dimensions",
mode, ndim
),
"unfold_tensor",
));
}
let mode_size = shape[mode];
let other_size: usize = shape
.iter()
.enumerate()
.filter(|(i, _)| *i != mode)
.map(|(_, &s)| s)
.product();
let mut perm: Vec<i32> = vec![mode as i32];
for i in 0..ndim {
if i != mode {
perm.push(i as i32);
}
}
let permuted = tensor.permute(&perm)?;
permuted.reshape(&[mode_size as i32, other_size as i32])
}
fn truncate_matrix(matrix: &Tensor<f32>, rank: usize) -> TorshResult<Tensor<f32>> {
let matrix_shape = matrix.shape();
let shape = matrix_shape.dims();
if shape.len() != 2 {
return Err(TorshError::invalid_argument_with_context(
"truncate_matrix requires 2D tensor",
"truncate_matrix",
));
}
let rows = shape[0];
let cols = shape[1];
if rank > cols {
return Err(TorshError::invalid_argument_with_context(
&format!("rank {} exceeds matrix columns {}", rank, cols),
"truncate_matrix",
));
}
let data = matrix.data()?;
let mut truncated_data = Vec::with_capacity(rows * rank);
for i in 0..rows {
for j in 0..rank {
truncated_data.push(data[i * cols + j]);
}
}
Tensor::from_data(truncated_data, vec![rows, rank], matrix.device())
}
fn mode_product(
tensor: &Tensor<f32>,
matrix: &Tensor<f32>,
mode: usize,
transpose: bool,
) -> TorshResult<Tensor<f32>> {
let tensor_shape_binding = tensor.shape();
let tensor_shape = tensor_shape_binding.dims();
let matrix_shape_binding = matrix.shape();
let matrix_shape = matrix_shape_binding.dims();
if matrix_shape.len() != 2 {
return Err(TorshError::invalid_argument_with_context(
"matrix must be 2D for mode product",
"mode_product",
));
}
let (mat_rows, mat_cols) = (matrix_shape[0], matrix_shape[1]);
let (expected_size, output_size) = if transpose {
(mat_rows, mat_cols)
} else {
(mat_cols, mat_rows)
};
if tensor_shape[mode] != expected_size {
return Err(TorshError::invalid_argument_with_context(
&format!(
"tensor mode {} size {} doesn't match matrix dimension {}",
mode, tensor_shape[mode], expected_size
),
"mode_product",
));
}
let unfolded = unfold_tensor(tensor, mode)?;
let result_matrix = if transpose {
let matrix_t = matrix.transpose(0, 1)?;
let matrix_t_unsq = matrix_t.unsqueeze(0)?;
let unfolded_unsq = unfolded.unsqueeze(0)?;
let bmm_result = crate::linalg::bmm(&matrix_t_unsq, &unfolded_unsq)?;
bmm_result.squeeze(0)?
} else {
let matrix_unsq = matrix.unsqueeze(0)?;
let unfolded_unsq = unfolded.unsqueeze(0)?;
let bmm_result = crate::linalg::bmm(&matrix_unsq, &unfolded_unsq)?;
bmm_result.squeeze(0)?
};
let mut result_shape = tensor_shape.to_vec();
result_shape[mode] = output_size;
fold_tensor(&result_matrix, mode, &result_shape)
}
fn fold_tensor(
matrix: &Tensor<f32>,
mode: usize,
target_shape: &[usize],
) -> TorshResult<Tensor<f32>> {
let ndim = target_shape.len();
if mode >= ndim {
return Err(TorshError::invalid_argument_with_context(
&format!(
"mode {} out of range for target shape with {} dimensions",
mode, ndim
),
"fold_tensor",
));
}
let mut intermediate_shape = vec![target_shape[mode]];
for (i, &size) in target_shape.iter().enumerate() {
if i != mode {
intermediate_shape.push(size);
}
}
let reshaped = matrix.reshape(
&intermediate_shape
.iter()
.map(|&x| x as i32)
.collect::<Vec<_>>(),
)?;
let mut inv_perm = vec![0i32; ndim];
inv_perm[mode] = 0;
let mut idx = 1;
for i in 0..ndim {
if i != mode {
inv_perm[i] = idx;
idx += 1;
}
}
let mut perm = vec![0i32; ndim];
for (i, &p) in inv_perm.iter().enumerate() {
perm[p as usize] = i as i32;
}
reshaped.permute(&perm)
}
pub fn cp_decomposition(
input: &Tensor<f32>,
rank: usize,
max_iter: usize,
) -> TorshResult<(Tensor<f32>, Vec<Tensor<f32>>)> {
let input_shape = input.shape();
let shape = input_shape.dims();
let ndim = shape.len();
if ndim < 2 {
return Err(TorshError::invalid_argument_with_context(
"CP decomposition requires at least 2D tensor",
"cp_decomposition",
));
}
if rank == 0 {
return Err(TorshError::invalid_argument_with_context(
"rank must be positive",
"cp_decomposition",
));
}
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let mut factors: Vec<Tensor<f32>> = Vec::with_capacity(ndim);
for &size in shape.iter() {
let factor_data: Vec<f32> = (0..size * rank).map(|_| rng.gen_range(-0.5..0.5)).collect();
let factor = Tensor::from_data(factor_data, vec![size, rank], input.device())?;
factors.push(factor);
}
for _iter in 0..max_iter {
for mode in 0..ndim {
let kr = khatri_rao_product_except(&factors, mode)?;
let unfolded = unfold_tensor(input, mode)?;
let kr_t = kr.transpose(0, 1)?; let kr_t_unsq = kr_t.unsqueeze(0)?;
let kr_unsq = kr.unsqueeze(0)?;
let gram = crate::linalg::bmm(&kr_t_unsq, &kr_unsq)?;
let gram_squeezed = gram.squeeze(0)?;
let mut gram_data = gram_squeezed.data()?.to_vec();
let gram_shape = gram_squeezed.shape();
let rank_val = gram_shape.dims()[0];
for i in 0..rank_val {
gram_data[i * rank_val + i] += 1e-6; }
let gram_reg = Tensor::from_data(gram_data, vec![rank_val, rank_val], input.device())?;
let gram_inv = crate::linalg::inv(&gram_reg)?;
let unfolded_unsq = unfolded.unsqueeze(0)?;
let unfolded_kr = crate::linalg::bmm(&unfolded_unsq, &kr_unsq)?;
let unfolded_kr_squeezed = unfolded_kr.squeeze(0)?;
let unfolded_kr_unsq = unfolded_kr_squeezed.unsqueeze(0)?;
let gram_inv_unsq = gram_inv.unsqueeze(0)?;
let new_factor_result = crate::linalg::bmm(&unfolded_kr_unsq, &gram_inv_unsq)?;
let new_factor = new_factor_result.squeeze(0)?;
factors[mode] = new_factor;
}
}
let weights = normalize_factors(&mut factors)?;
Ok((weights, factors))
}
fn khatri_rao_product_except(
factors: &[Tensor<f32>],
except_mode: usize,
) -> TorshResult<Tensor<f32>> {
let factor0_shape = factors[0].shape();
let rank = factor0_shape.dims()[1];
let other_factors: Vec<&Tensor<f32>> = factors
.iter()
.enumerate()
.filter(|(i, _)| *i != except_mode)
.map(|(_, f)| f)
.collect();
if other_factors.len() == 1 {
return Ok(other_factors[0].clone());
}
let mut result = other_factors[0].clone();
for factor in other_factors.iter().skip(1) {
result = khatri_rao_product(&result, factor)?;
}
let result_shape = result.shape();
if result_shape.dims()[1] != rank {
return Err(TorshError::InvalidOperation(
"Khatri-Rao product rank mismatch (khatri_rao_product_except)".to_string(),
));
}
Ok(result)
}
fn khatri_rao_product(a: &Tensor<f32>, b: &Tensor<f32>) -> TorshResult<Tensor<f32>> {
let a_shape_obj = a.shape();
let shape_a = a_shape_obj.dims();
let b_shape_obj = b.shape();
let shape_b = b_shape_obj.dims();
if shape_a.len() != 2 || shape_b.len() != 2 {
return Err(TorshError::invalid_argument_with_context(
"Khatri-Rao product requires 2D tensors",
"khatri_rao_product",
));
}
let (rows_a, cols_a) = (shape_a[0], shape_a[1]);
let (rows_b, cols_b) = (shape_b[0], shape_b[1]);
if cols_a != cols_b {
return Err(TorshError::invalid_argument_with_context(
&format!("column dimensions must match: {} vs {}", cols_a, cols_b),
"khatri_rao_product",
));
}
let data_a = a.data()?;
let data_b = b.data()?;
let mut result_data = Vec::with_capacity(rows_a * rows_b * cols_a);
for col in 0..cols_a {
for i in 0..rows_a {
for j in 0..rows_b {
let val_a = data_a[i * cols_a + col];
let val_b = data_b[j * cols_b + col];
result_data.push(val_a * val_b);
}
}
}
Tensor::from_data(result_data, vec![rows_a * rows_b, cols_a], a.device())
}
fn normalize_factors(factors: &mut [Tensor<f32>]) -> TorshResult<Tensor<f32>> {
let factor0_shape = factors[0].shape();
let rank = factor0_shape.dims()[1];
let mut weights = vec![1.0f32; rank];
for factor in factors.iter_mut() {
let factor_shape = factor.shape();
let shape = factor_shape.dims();
let (rows, cols) = (shape[0], shape[1]);
let data = factor.data()?;
let mut new_data = data.to_vec();
for col in 0..cols {
let mut norm = 0.0f32;
for row in 0..rows {
let idx = row * cols + col;
norm += new_data[idx] * new_data[idx];
}
norm = norm.sqrt();
if norm > 1e-10 {
weights[col] *= norm;
for row in 0..rows {
let idx = row * cols + col;
new_data[idx] /= norm;
}
}
}
*factor = Tensor::from_data(new_data, vec![rows, cols], factor.device())?;
}
Tensor::from_data(weights, vec![rank], factors[0].device())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_unfold_tensor() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = Tensor::from_data(data, vec![2, 2, 2], torsh_core::device::DeviceType::Cpu)
.expect("failed to create tensor");
let unfolded = unfold_tensor(&tensor, 0).expect("unfold failed");
assert_eq!(unfolded.shape().dims(), &[2, 4]);
}
#[test]
fn test_tucker_decomposition() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = Tensor::from_data(data, vec![2, 2, 2], torsh_core::device::DeviceType::Cpu)
.expect("failed to create tensor");
let (core, factors) =
tucker_decomposition(&tensor, Some(&[1, 1, 1])).expect("tucker decomposition failed");
assert_eq!(core.shape().dims(), &[1, 1, 1]);
assert_eq!(factors.len(), 3);
assert_eq!(factors[0].shape().dims(), &[2, 1]);
assert_eq!(factors[1].shape().dims(), &[2, 1]);
assert_eq!(factors[2].shape().dims(), &[2, 1]);
}
#[test]
fn test_khatri_rao_product() {
let a = Tensor::from_data(
vec![1.0, 2.0, 3.0, 4.0],
vec![2, 2],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let b = Tensor::from_data(
vec![5.0, 6.0, 7.0, 8.0],
vec![2, 2],
torsh_core::device::DeviceType::Cpu,
)
.expect("failed to create tensor");
let result = khatri_rao_product(&a, &b).expect("khatri-rao failed");
assert_eq!(result.shape().dims(), &[4, 2]);
let result_data = result.data().expect("failed to get data");
assert_relative_eq!(result_data[0], 5.0, epsilon = 1e-6);
assert_relative_eq!(result_data[1], 7.0, epsilon = 1e-6);
assert_relative_eq!(result_data[2], 15.0, epsilon = 1e-6);
assert_relative_eq!(result_data[3], 21.0, epsilon = 1e-6);
}
#[test]
fn test_cp_decomposition() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::from_data(data, vec![2, 2], torsh_core::device::DeviceType::Cpu)
.expect("failed to create tensor");
let (weights, factors) = cp_decomposition(&tensor, 1, 10).expect("cp decomposition failed");
assert_eq!(weights.shape().dims(), &[1]);
assert_eq!(factors.len(), 2);
assert_eq!(factors[0].shape().dims(), &[2, 1]);
assert_eq!(factors[1].shape().dims(), &[2, 1]);
}
}