use std::fmt;
use std::sync::Arc;
use crate::error::{KernelError, Result};
use crate::types::Kernel;
#[derive(Clone)]
pub struct LearnedMixtureKernel {
base_kernels: Vec<Arc<dyn Kernel>>,
logits: Vec<f64>,
}
impl LearnedMixtureKernel {
pub fn new(base_kernels: Vec<Arc<dyn Kernel>>, logits: Vec<f64>) -> Result<Self> {
if base_kernels.is_empty() {
return Err(KernelError::InvalidParameter {
parameter: "base_kernels".to_string(),
value: "[]".to_string(),
reason: "learned mixture requires at least one base kernel".to_string(),
});
}
if base_kernels.len() != logits.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![base_kernels.len()],
got: vec![logits.len()],
context: "LearnedMixtureKernel logits length".to_string(),
});
}
for (i, &w) in logits.iter().enumerate() {
if !w.is_finite() {
return Err(KernelError::InvalidParameter {
parameter: format!("logits[{}]", i),
value: w.to_string(),
reason: "logits must be finite".to_string(),
});
}
}
Ok(Self {
base_kernels,
logits,
})
}
pub fn uniform(base_kernels: Vec<Arc<dyn Kernel>>) -> Result<Self> {
let n = base_kernels.len();
Self::new(base_kernels, vec![0.0; n])
}
pub fn num_kernels(&self) -> usize {
self.base_kernels.len()
}
pub fn logits(&self) -> &[f64] {
&self.logits
}
pub fn weights(&self) -> Vec<f64> {
softmax(&self.logits)
}
pub fn set_logits(&mut self, new_logits: Vec<f64>) -> Result<()> {
if new_logits.len() != self.base_kernels.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![self.base_kernels.len()],
got: vec![new_logits.len()],
context: "LearnedMixtureKernel::set_logits".to_string(),
});
}
for (i, &w) in new_logits.iter().enumerate() {
if !w.is_finite() {
return Err(KernelError::InvalidParameter {
parameter: format!("logits[{}]", i),
value: w.to_string(),
reason: "logits must be finite".to_string(),
});
}
}
self.logits = new_logits;
Ok(())
}
pub fn apply_gradient_step(&mut self, gradient: &[f64], learning_rate: f64) -> Result<()> {
if gradient.len() != self.logits.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![self.logits.len()],
got: vec![gradient.len()],
context: "LearnedMixtureKernel::apply_gradient_step".to_string(),
});
}
if !learning_rate.is_finite() {
return Err(KernelError::InvalidParameter {
parameter: "learning_rate".to_string(),
value: learning_rate.to_string(),
reason: "must be finite".to_string(),
});
}
for (w, &g) in self.logits.iter_mut().zip(gradient.iter()) {
if !g.is_finite() {
return Err(KernelError::InvalidParameter {
parameter: "gradient".to_string(),
value: g.to_string(),
reason: "gradient entries must be finite".to_string(),
});
}
*w -= learning_rate * g;
}
Ok(())
}
fn per_kernel_values(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
let mut values = Vec::with_capacity(self.base_kernels.len());
for kernel in &self.base_kernels {
values.push(kernel.compute(x, y)?);
}
Ok(values)
}
pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
let weights = self.weights();
let mut acc = 0.0;
for (kernel, &w) in self.base_kernels.iter().zip(weights.iter()) {
acc += w * kernel.compute(x, y)?;
}
Ok(acc)
}
pub fn gradient_wrt_logits(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
let weights = self.weights();
let k_vals = self.per_kernel_values(x, y)?;
let k_mix: f64 = weights.iter().zip(k_vals.iter()).map(|(p, k)| p * k).sum();
Ok(weights
.iter()
.zip(k_vals.iter())
.map(|(p, k)| p * (k - k_mix))
.collect())
}
pub fn evaluate_with_gradient(&self, x: &[f64], y: &[f64]) -> Result<(f64, Vec<f64>)> {
let weights = self.weights();
let k_vals = self.per_kernel_values(x, y)?;
let k_mix: f64 = weights.iter().zip(k_vals.iter()).map(|(p, k)| p * k).sum();
let grad: Vec<f64> = weights
.iter()
.zip(k_vals.iter())
.map(|(p, k)| p * (k - k_mix))
.collect();
Ok((k_mix, grad))
}
pub fn compute_gram(&self, xs: &[&[f64]], ys: &[&[f64]]) -> Result<Vec<Vec<f64>>> {
let mut matrix = vec![vec![0.0; ys.len()]; xs.len()];
for (i, &xi) in xs.iter().enumerate() {
for (j, &yj) in ys.iter().enumerate() {
matrix[i][j] = self.evaluate(xi, yj)?;
}
}
Ok(matrix)
}
}
impl fmt::Debug for LearnedMixtureKernel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let names: Vec<&str> = self.base_kernels.iter().map(|k| k.name()).collect();
f.debug_struct("LearnedMixtureKernel")
.field("base_kernels", &names)
.field("logits", &self.logits)
.finish()
}
}
impl Kernel for LearnedMixtureKernel {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
self.evaluate(x, y)
}
fn name(&self) -> &str {
"LearnedMixture"
}
fn is_psd(&self) -> bool {
self.base_kernels.iter().all(|k| k.is_psd())
}
}
pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
if logits.is_empty() {
return Vec::new();
}
let max = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let shifted: Vec<f64> = logits.iter().map(|&w| (w - max).exp()).collect();
let denom: f64 = shifted.iter().sum();
if denom <= 0.0 || !denom.is_finite() {
let n = logits.len() as f64;
return vec![1.0 / n; logits.len()];
}
shifted.iter().map(|&e| e / denom).collect()
}