use crate::DType;
use crate::spatial::impl_generic::rotation::rotation_from_matrix_impl;
use crate::spatial::traits::procrustes::ProcrustesResult;
use crate::spatial::{validate_matching_dims, validate_points_2d, validate_points_dtype};
use numr::algorithm::linalg::LinearAlgebraAlgorithms;
use numr::error::{Error, Result};
use numr::ops::{ReduceOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn procrustes_impl<R, C>(
client: &C,
source: &Tensor<R>,
target: &Tensor<R>,
scaling: bool,
reflection: bool,
) -> Result<ProcrustesResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + ReduceOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
validate_points_dtype(source.dtype(), "procrustes")?;
validate_points_dtype(target.dtype(), "procrustes")?;
validate_points_2d(source.shape(), "procrustes")?;
validate_points_2d(target.shape(), "procrustes")?;
validate_matching_dims(source.shape(), target.shape(), "procrustes")?;
let _n = source.shape()[0];
let d = source.shape()[1];
if source.shape() != target.shape() {
return Err(Error::InvalidArgument {
arg: "source/target",
reason: format!(
"Source and target must have same shape. Got {:?} and {:?}",
source.shape(),
target.shape()
),
});
}
let device = source.device();
let _dtype = source.dtype();
let source_mean = client.mean(source, &[0], true)?;
let target_mean = client.mean(target, &[0], true)?;
let source_centered = client.sub(source, &source_mean.broadcast_to(source.shape())?)?;
let target_centered = client.sub(target, &target_mean.broadcast_to(target.shape())?)?;
let source_t = source_centered.transpose(0, 1)?;
let h = client.matmul(&source_t, &target_centered)?;
let svd = client.svd_decompose(&h)?;
let u = svd.u;
let s = svd.s;
let vt = svd.vt;
let v = vt.transpose(0, 1)?;
let ut = u.transpose(0, 1)?;
let mut r = client.matmul(&v, &ut)?;
let det = LinearAlgebraAlgorithms::det(client, &r)?;
let det_val: Vec<f64> = det.to_vec();
if det_val[0] < 0.0 && !reflection {
let v_data: Vec<f64> = v.to_vec();
let mut v_corrected = v_data.clone();
for i in 0..d {
v_corrected[i * d + (d - 1)] = -v_corrected[i * d + (d - 1)];
}
let v_new = Tensor::<R>::from_slice(&v_corrected, &[d, d], device);
r = client.matmul(&v_new, &ut)?;
}
let rotation = rotation_from_matrix_impl(client, &r)?;
let scale = if scaling {
let s_data: Vec<f64> = s.to_vec();
let trace_s: f64 = s_data.iter().sum();
let source_sq = client.mul(&source_centered, &source_centered)?;
let source_norm_sq = client.sum(&source_sq, &[0, 1], false)?;
let source_norm_sq_val: Vec<f64> = source_norm_sq.to_vec();
trace_s / source_norm_sq_val[0]
} else {
1.0
};
let source_mean_flat = source_mean.reshape(&[d])?;
let rotated_mean = client.matmul(&r, &source_mean_flat.reshape(&[d, 1])?)?;
let rotated_mean = rotated_mean.reshape(&[d])?;
let scaled_rotated_mean = client.mul_scalar(&rotated_mean, scale)?;
let target_mean_flat = target_mean.reshape(&[d])?;
let translation = client.sub(&target_mean_flat, &scaled_rotated_mean)?;
let r_t = r.transpose(0, 1)?;
let rotated = client.matmul(source, &r_t)?;
let scaled = client.mul_scalar(&rotated, scale)?;
let transformed = client.add(&scaled, &translation.broadcast_to(scaled.shape())?)?;
let diff = client.sub(&transformed, target)?;
let diff_sq = client.mul(&diff, &diff)?;
let disparity_tensor = client.sum(&diff_sq, &[0, 1], false)?;
let disparity_val: Vec<f64> = disparity_tensor.to_vec();
let disparity = disparity_val[0];
Ok(ProcrustesResult {
rotation,
translation,
scale,
transformed,
disparity,
})
}
pub fn orthogonal_procrustes_impl<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
) -> Result<(Tensor<R>, f64)>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + ReduceOps<R> + LinearAlgebraAlgorithms<R> + RuntimeClient<R>,
{
if a.shape() != b.shape() {
return Err(Error::InvalidArgument {
arg: "a/b",
reason: format!(
"A and B must have same shape. Got {:?} and {:?}",
a.shape(),
b.shape()
),
});
}
let at = a.transpose(0, 1)?;
let m = client.matmul(&at, b)?;
let svd = client.svd_decompose(&m)?;
let u = svd.u;
let vt = svd.vt;
let v = vt.transpose(0, 1)?;
let ut = u.transpose(0, 1)?;
let r = client.matmul(&v, &ut)?;
let ar = client.matmul(a, &r)?;
let diff = client.sub(&ar, b)?;
let diff_sq = client.mul(&diff, &diff)?;
let residual_tensor = client.sum(&diff_sq, &[0, 1], false)?;
let residual_val: Vec<f64> = residual_tensor.to_vec();
let residual = residual_val[0].sqrt();
Ok((r, residual))
}