use crate::error::{AprenderError, Result};
use crate::primitives::{Matrix, Vector};
pub fn cov(x: &Vector<f32>, y: &Vector<f32>) -> Result<f32> {
let n = x.len();
if n != y.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{n} values in x"),
actual: format!("{} values in y", y.len()),
});
}
if n == 0 {
return Err(AprenderError::Other(
"Cannot compute covariance of empty vectors".into(),
));
}
let x_mean = x.as_slice().iter().sum::<f32>() / n as f32;
let y_mean = y.as_slice().iter().sum::<f32>() / n as f32;
let cov_sum: f32 = x
.as_slice()
.iter()
.zip(y.as_slice().iter())
.map(|(&xi, &yi)| (xi - x_mean) * (yi - y_mean))
.sum();
Ok(cov_sum / n as f32)
}
pub fn cov_matrix(data: &Matrix<f32>) -> Result<Matrix<f32>> {
let n = data.n_rows(); let p = data.n_cols();
if n == 0 || p == 0 {
return Err(AprenderError::Other(
"Cannot compute covariance matrix for empty data".into(),
));
}
let mut means = vec![0.0_f32; p];
#[allow(clippy::needless_range_loop)]
for j in 0..p {
let mut sum = 0.0;
for i in 0..n {
sum += data.get(i, j);
}
means[j] = sum / n as f32;
}
let mut cov_data = vec![0.0_f32; p * p];
for i in 0..p {
for j in 0..=i {
let mut cov_sum = 0.0;
for k in 0..n {
cov_sum += (data.get(k, i) - means[i]) * (data.get(k, j) - means[j]);
}
let cov_val = cov_sum / n as f32;
cov_data[i * p + j] = cov_val;
cov_data[j * p + i] = cov_val;
}
}
Matrix::from_vec(p, p, cov_data)
.map_err(|e| AprenderError::Other(format!("Failed to create covariance matrix: {e}")))
}
pub fn corr(x: &Vector<f32>, y: &Vector<f32>) -> Result<f32> {
let n = x.len();
if n != y.len() {
return Err(AprenderError::DimensionMismatch {
expected: format!("{n} values in x"),
actual: format!("{} values in y", y.len()),
});
}
if n == 0 {
return Err(AprenderError::Other(
"Cannot compute correlation of empty vectors".into(),
));
}
let x_mean = x.as_slice().iter().sum::<f32>() / n as f32;
let y_mean = y.as_slice().iter().sum::<f32>() / n as f32;
let mut cov_sum = 0.0;
let mut x_var_sum = 0.0;
let mut y_var_sum = 0.0;
for (&xi, &yi) in x.as_slice().iter().zip(y.as_slice().iter()) {
let x_diff = xi - x_mean;
let y_diff = yi - y_mean;
cov_sum += x_diff * y_diff;
x_var_sum += x_diff * x_diff;
y_var_sum += y_diff * y_diff;
}
let x_std = (x_var_sum / n as f32).sqrt();
let y_std = (y_var_sum / n as f32).sqrt();
if x_std < 1e-10 || y_std < 1e-10 {
return Err(AprenderError::Other(
"Cannot compute correlation when variance is zero".into(),
));
}
let covariance = cov_sum / n as f32;
Ok(covariance / (x_std * y_std))
}
pub fn corr_matrix(data: &Matrix<f32>) -> Result<Matrix<f32>> {
let n = data.n_rows();
let p = data.n_cols();
if n == 0 || p == 0 {
return Err(AprenderError::Other(
"Cannot compute correlation matrix for empty data".into(),
));
}
let (means, stds) = compute_feature_stats(data, n, p)?;
let corr_data = compute_correlation_values(data, &means, &stds, n, p);
Matrix::from_vec(p, p, corr_data)
.map_err(|e| AprenderError::Other(format!("Failed to create correlation matrix: {e}")))
}
fn compute_feature_stats(data: &Matrix<f32>, n: usize, p: usize) -> Result<(Vec<f32>, Vec<f32>)> {
let mut means = vec![0.0_f32; p];
let mut stds = vec![0.0_f32; p];
for j in 0..p {
let sum: f32 = (0..n).map(|i| data.get(i, j)).sum();
means[j] = sum / n as f32;
let var_sum: f32 = (0..n).map(|i| (data.get(i, j) - means[j]).powi(2)).sum();
stds[j] = (var_sum / n as f32).sqrt();
if stds[j] < 1e-10 {
return Err(AprenderError::Other(format!(
"Feature {j} has zero variance"
)));
}
}
Ok((means, stds))
}
fn compute_correlation_values(
data: &Matrix<f32>,
means: &[f32],
stds: &[f32],
n: usize,
p: usize,
) -> Vec<f32> {
let mut corr_data = vec![0.0_f32; p * p];
for i in 0..p {
corr_data[i * p + i] = 1.0; for j in 0..i {
let cov_sum: f32 = (0..n)
.map(|k| (data.get(k, i) - means[i]) * (data.get(k, j) - means[j]))
.sum();
let corr_val = cov_sum / (n as f32 * stds[i] * stds[j]);
corr_data[i * p + j] = corr_val;
corr_data[j * p + i] = corr_val;
}
}
corr_data
}
#[cfg(test)]
#[path = "covariance_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_covariance_contract.rs"]
mod tests_covariance_contract;