use super::super::{WgpuClient, WgpuRuntime};
use crate::algorithm::linalg::{
LinearAlgebraAlgorithms, validate_linalg_dtype, validate_matrix_2d,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{
BinaryOps, CompareOps, ConditionalOps, LinalgOps, MatmulOps, ReduceOps, UnaryOps, UtilityOps,
};
use crate::runtime::RuntimeClient;
use crate::tensor::Tensor;
pub fn cov(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
ddof: Option<usize>,
) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let (n_samples, _n_features) = validate_matrix_2d(a.shape())?;
let dtype = a.dtype();
let device = client.device();
let ddof_val = ddof.unwrap_or(1);
if dtype != DType::F32 && dtype != DType::F64 {
return Err(Error::UnsupportedDType { dtype, op: "cov" });
}
if n_samples <= ddof_val {
return Err(Error::Internal(format!(
"cov: need at least {} samples for ddof={}, got {}",
ddof_val + 1,
ddof_val,
n_samples
)));
}
let sum = client.sum(a, &[0], true)?;
let n_samples_tensor = match dtype {
DType::F32 => Tensor::<WgpuRuntime>::from_slice(&[n_samples as f32], &[], device),
DType::F64 => Tensor::<WgpuRuntime>::from_slice(&[n_samples as f64], &[], device),
_ => unreachable!(),
};
let mean = client.div(&sum, &n_samples_tensor)?;
let centered = client.sub(a, &mean)?;
let centered_t = centered.transpose(0, 1)?.contiguous();
let centered_contig = centered.contiguous();
let cov_unnorm = client.matmul(¢ered_t, ¢ered_contig)?;
let divisor_tensor = match dtype {
DType::F32 => {
Tensor::<WgpuRuntime>::from_slice(&[(n_samples - ddof_val) as f32], &[], device)
}
DType::F64 => {
Tensor::<WgpuRuntime>::from_slice(&[(n_samples - ddof_val) as f64], &[], device)
}
_ => unreachable!(),
};
let cov_mat = client.div(&cov_unnorm, &divisor_tensor)?;
Ok(cov_mat)
}
pub fn corrcoef(client: &WgpuClient, a: &Tensor<WgpuRuntime>) -> Result<Tensor<WgpuRuntime>> {
validate_linalg_dtype(a.dtype())?;
let (n_samples, n_features) = validate_matrix_2d(a.shape())?;
let dtype = a.dtype();
let device = client.device();
if dtype != DType::F32 && dtype != DType::F64 {
return Err(Error::UnsupportedDType {
dtype,
op: "corrcoef",
});
}
if n_samples < 2 {
return Err(Error::Internal(format!(
"corrcoef: need at least 2 samples, got {}",
n_samples
)));
}
if n_features == 0 {
return match dtype {
DType::F32 => Ok(Tensor::<WgpuRuntime>::from_slice::<f32>(
&[],
&[0, 0],
device,
)),
DType::F64 => Ok(Tensor::<WgpuRuntime>::from_slice::<f64>(
&[],
&[0, 0],
device,
)),
_ => unreachable!(),
};
}
let cov_mat = LinearAlgebraAlgorithms::cov(client, a, Some(1))?;
let variances = LinalgOps::diag(client, &cov_mat)?;
let std_devs = client.sqrt(&variances)?;
let std_col = std_devs.reshape(&[n_features, 1])?; let std_row = std_devs.reshape(&[1, n_features])?; let std_outer = client.mul(&std_col, &std_row)?;
let eps = match dtype {
DType::F32 => Tensor::<WgpuRuntime>::from_slice(&[f32::EPSILON], &[], device),
DType::F64 => Tensor::<WgpuRuntime>::from_slice(&[f64::EPSILON], &[], device),
_ => unreachable!(),
};
let mask = client.gt(&std_outer, &eps)?;
let std_outer_safe = client.add(&std_outer, &eps)?;
let corr_raw = client.div(&cov_mat, &std_outer_safe)?;
let corr_clamped = client.clamp(&corr_raw, -1.0, 1.0)?;
let zero = match dtype {
DType::F32 => Tensor::<WgpuRuntime>::from_slice(&[0.0f32], &[], device),
DType::F64 => Tensor::<WgpuRuntime>::from_slice(&[0.0f64], &[], device),
_ => unreachable!(),
};
let corr_masked = client.where_cond(&mask, &corr_clamped, &zero)?;
let std_positive = client.gt(&std_devs, &eps)?;
let identity = client.eye(n_features, Some(n_features), dtype)?;
let one = match dtype {
DType::F32 => Tensor::<WgpuRuntime>::from_slice(&[1.0f32], &[], device),
DType::F64 => Tensor::<WgpuRuntime>::from_slice(&[1.0f64], &[], device),
_ => unreachable!(),
};
let diag_ones = client.where_cond(&std_positive, &one, &zero)?; let diag_matrix = LinalgOps::diagflat(client, &diag_ones)?;
let corr_diag_zeroed = client.mul(&corr_masked, &identity)?;
let corr_no_diag = client.sub(&corr_masked, &corr_diag_zeroed)?;
let corr_final = client.add(&corr_no_diag, &diag_matrix)?;
Ok(corr_final)
}