use crate::error::{KernelError, Result};
use crate::types::Kernel;
pub fn kernel_target_alignment(kernel_matrix: &[Vec<f64>], labels: &[f64]) -> Result<f64> {
let n = kernel_matrix.len();
if n == 0 {
return Err(KernelError::ComputationError(
"Kernel matrix cannot be empty".to_string(),
));
}
if labels.len() != n {
return Err(KernelError::DimensionMismatch {
expected: vec![n],
got: vec![labels.len()],
context: "kernel-target alignment".to_string(),
});
}
for row in kernel_matrix {
if row.len() != n {
return Err(KernelError::ComputationError(
"Kernel matrix must be square".to_string(),
));
}
}
let mut ideal_kernel = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
ideal_kernel[i][j] = labels[i] * labels[j];
}
}
let mut inner_product = 0.0;
for i in 0..n {
for j in 0..n {
inner_product += kernel_matrix[i][j] * ideal_kernel[i][j];
}
}
let k_norm = frobenius_norm(kernel_matrix);
let y_norm = frobenius_norm(&ideal_kernel);
if k_norm == 0.0 || y_norm == 0.0 {
return Ok(0.0);
}
Ok(inner_product / (k_norm * y_norm))
}
fn frobenius_norm(matrix: &[Vec<f64>]) -> f64 {
matrix
.iter()
.flat_map(|row| row.iter())
.map(|&x| x * x)
.sum::<f64>()
.sqrt()
}
pub fn distances_from_kernel(kernel_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let n = kernel_matrix.len();
if n == 0 {
return Ok(Vec::new());
}
for row in kernel_matrix {
if row.len() != n {
return Err(KernelError::ComputationError(
"Kernel matrix must be square".to_string(),
));
}
}
let diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
let mut distances = vec![vec![0.0; n]; n];
for i in 0..n {
for j in 0..n {
let sq_dist = diagonal[i] + diagonal[j] - 2.0 * kernel_matrix[i][j];
distances[i][j] = sq_dist.max(0.0).sqrt();
}
}
Ok(distances)
}
#[allow(clippy::needless_range_loop)]
pub fn is_valid_kernel_matrix(kernel_matrix: &[Vec<f64>], tolerance: f64) -> Result<bool> {
let n = kernel_matrix.len();
if n == 0 {
return Ok(true);
}
for row in kernel_matrix {
if row.len() != n {
return Ok(false);
}
}
for i in 0..n {
for j in (i + 1)..n {
if (kernel_matrix[i][j] - kernel_matrix[j][i]).abs() > tolerance {
return Ok(false);
}
}
}
Ok(true)
}
pub fn estimate_kernel_rank(kernel_matrix: &[Vec<f64>], variance_threshold: f64) -> Result<usize> {
let n = kernel_matrix.len();
if n == 0 {
return Ok(0);
}
if !(0.0..=1.0).contains(&variance_threshold) {
return Err(KernelError::InvalidParameter {
parameter: "variance_threshold".to_string(),
value: variance_threshold.to_string(),
reason: "must be in range [0, 1]".to_string(),
});
}
for row in kernel_matrix {
if row.len() != n {
return Err(KernelError::ComputationError(
"Kernel matrix must be square".to_string(),
));
}
}
let mut diagonal: Vec<f64> = (0..n).map(|i| kernel_matrix[i][i]).collect();
diagonal.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let total: f64 = diagonal.iter().sum();
if total == 0.0 {
return Ok(0);
}
let mut cumsum = 0.0;
for (rank, &val) in diagonal.iter().enumerate() {
cumsum += val;
if cumsum / total >= variance_threshold {
return Ok(rank + 1);
}
}
Ok(n)
}
pub fn compute_gram_matrix(data: &[Vec<f64>], kernel: &dyn Kernel) -> Result<Vec<Vec<f64>>> {
kernel.compute_matrix(data)
}
pub fn normalize_rows(data: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
if data.is_empty() {
return Ok(Vec::new());
}
let mut normalized = Vec::with_capacity(data.len());
for row in data {
let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
if norm == 0.0 {
normalized.push(row.clone());
} else {
let normalized_row: Vec<f64> = row.iter().map(|&x| x / norm).collect();
normalized.push(normalized_row);
}
}
Ok(normalized)
}
pub fn median_heuristic_bandwidth(
data: &[Vec<f64>],
kernel: &dyn Kernel,
sample_size: Option<usize>,
) -> Result<f64> {
let n = data.len();
if n < 2 {
return Err(KernelError::ComputationError(
"Need at least 2 samples for bandwidth estimation".to_string(),
));
}
let gram_matrix = kernel.compute_matrix(data)?;
let diagonal: Vec<f64> = (0..n).map(|i| gram_matrix[i][i]).collect();
let mut distances = Vec::new();
let sample_size = sample_size.unwrap_or(n * (n - 1) / 2);
for i in 0..n {
for j in (i + 1)..n {
let sq_dist = diagonal[i] + diagonal[j] - 2.0 * gram_matrix[i][j];
let dist = sq_dist.max(0.0).sqrt();
if dist > 0.0 {
distances.push(dist);
}
if distances.len() >= sample_size {
break;
}
}
if distances.len() >= sample_size {
break;
}
}
if distances.is_empty() {
return Err(KernelError::ComputationError(
"All pairwise distances are zero".to_string(),
));
}
distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if distances.len() % 2 == 0 {
let mid = distances.len() / 2;
(distances[mid - 1] + distances[mid]) / 2.0
} else {
distances[distances.len() / 2]
};
let gamma = 1.0 / (2.0 * median * median);
Ok(gamma)
}
#[cfg(test)]
#[allow(non_snake_case, clippy::needless_range_loop)] mod tests {
use super::*;
use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
#[test]
fn test_kernel_target_alignment_good() {
let K = vec![
vec![1.0, 0.9, 0.1],
vec![0.9, 1.0, 0.1],
vec![0.1, 0.1, 1.0],
];
let labels = vec![1.0, 1.0, -1.0];
let alignment = kernel_target_alignment(&K, &labels).expect("unwrap");
assert!((0.5..=1.0).contains(&alignment));
}
#[test]
fn test_kernel_target_alignment_poor() {
let K = vec![
vec![1.0, 0.5, 0.5],
vec![0.5, 1.0, 0.5],
vec![0.5, 0.5, 1.0],
];
let labels = vec![1.0, 1.0, -1.0];
let alignment = kernel_target_alignment(&K, &labels).expect("unwrap");
assert!(alignment < 0.5); }
#[test]
fn test_kernel_target_alignment_dimension_mismatch() {
let K = vec![vec![1.0, 0.5], vec![0.5, 1.0]];
let labels = vec![1.0, -1.0, 1.0];
let result = kernel_target_alignment(&K, &labels);
assert!(result.is_err());
}
#[test]
fn test_distances_from_kernel() {
let K = vec![
vec![1.0, 0.8, 0.6],
vec![0.8, 1.0, 0.7],
vec![0.6, 0.7, 1.0],
];
let distances = distances_from_kernel(&K).expect("unwrap");
assert!(distances[0][0].abs() < 1e-10);
assert!(distances[1][1].abs() < 1e-10);
assert!(distances[2][2].abs() < 1e-10);
for i in 0..3 {
for j in 0..3 {
assert!((distances[i][j] - distances[j][i]).abs() < 1e-10);
}
}
}
#[test]
fn test_is_valid_kernel_matrix() {
let K = vec![
vec![1.0, 0.8, 0.6],
vec![0.8, 1.0, 0.7],
vec![0.6, 0.7, 1.0],
];
assert!(is_valid_kernel_matrix(&K, 1e-10).expect("unwrap"));
let K_bad = vec![
vec![1.0, 0.8, 0.6],
vec![0.7, 1.0, 0.7], vec![0.6, 0.7, 1.0],
];
assert!(!is_valid_kernel_matrix(&K_bad, 1e-10).expect("unwrap"));
}
#[test]
fn test_estimate_kernel_rank() {
let K = vec![
vec![1.0, 0.1, 0.1],
vec![0.1, 0.5, 0.1],
vec![0.1, 0.1, 0.2],
];
let rank = estimate_kernel_rank(&K, 0.9).expect("unwrap");
assert!((1..=3).contains(&rank));
}
#[test]
fn test_normalize_rows() {
let data = vec![vec![3.0, 4.0], vec![5.0, 12.0]];
let normalized = normalize_rows(&data).expect("unwrap");
for row in &normalized {
let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!((norm - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_normalize_rows_zero_vector() {
let data = vec![vec![0.0, 0.0], vec![3.0, 4.0]];
let normalized = normalize_rows(&data).expect("unwrap");
assert!(normalized[0][0].abs() < 1e-10);
assert!(normalized[0][1].abs() < 1e-10);
let norm: f64 = normalized[1].iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!((norm - 1.0).abs() < 1e-10);
}
#[test]
fn test_median_heuristic_bandwidth() {
let data = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
];
let kernel = LinearKernel::new();
let gamma = median_heuristic_bandwidth(&data, &kernel, None).expect("unwrap");
assert!(gamma > 0.0);
}
#[test]
fn test_compute_gram_matrix() {
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let kernel = LinearKernel::new();
let K = compute_gram_matrix(&data, &kernel).expect("unwrap");
assert_eq!(K.len(), 3);
assert_eq!(K[0].len(), 3);
for i in 0..3 {
for j in 0..3 {
assert!((K[i][j] - K[j][i]).abs() < 1e-10);
}
}
}
#[test]
fn test_frobenius_norm() {
let matrix = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let norm = frobenius_norm(&matrix);
assert!((norm - 30.0_f64.sqrt()).abs() < 1e-10);
}
#[test]
fn test_kernel_target_alignment_binary_classification() {
let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("unwrap");
let data = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![0.2, 0.2],
vec![5.0, 5.0], vec![5.1, 5.1],
vec![5.2, 5.2],
];
let labels = vec![1.0, 1.0, 1.0, -1.0, -1.0, -1.0];
let K = kernel.compute_matrix(&data).expect("unwrap");
let alignment = kernel_target_alignment(&K, &labels).expect("unwrap");
assert!((0.0..=1.0).contains(&alignment));
}
}