use std::fmt;
use crate::deep_kernel::feature_extractor::NeuralFeatureMap;
use crate::error::Result;
use crate::types::Kernel;
#[derive(Clone, Debug)]
pub struct DeepKernel<F: NeuralFeatureMap, K: Kernel> {
extractor: F,
base: K,
}
impl<F: NeuralFeatureMap, K: Kernel> DeepKernel<F, K> {
pub fn new(extractor: F, base: K) -> Self {
Self { extractor, base }
}
pub fn feature_extractor(&self) -> &F {
&self.extractor
}
pub fn feature_extractor_mut(&mut self) -> &mut F {
&mut self.extractor
}
pub fn base_kernel(&self) -> &K {
&self.base
}
pub fn features(&self, x: &[f64]) -> Result<Vec<f64>> {
self.extractor.forward(x)
}
pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
let fx = self.extractor.forward(x)?;
let fy = self.extractor.forward(y)?;
self.base.compute(&fx, &fy)
}
pub fn compute_gram(&self, xs: &[&[f64]], ys: &[&[f64]]) -> Result<Vec<Vec<f64>>> {
let fx: Vec<Vec<f64>> = xs
.iter()
.map(|x| self.extractor.forward(x))
.collect::<Result<Vec<_>>>()?;
let fy: Vec<Vec<f64>> = ys
.iter()
.map(|y| self.extractor.forward(y))
.collect::<Result<Vec<_>>>()?;
let mut matrix = vec![vec![0.0; fy.len()]; fx.len()];
for i in 0..fx.len() {
for j in 0..fy.len() {
matrix[i][j] = self.base.compute(&fx[i], &fy[j])?;
}
}
Ok(matrix)
}
pub fn compute_symmetric_gram(&self, xs: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
let fx: Vec<Vec<f64>> = xs
.iter()
.map(|x| self.extractor.forward(x))
.collect::<Result<Vec<_>>>()?;
let n = fx.len();
let mut matrix = vec![vec![0.0; n]; n];
for i in 0..n {
for j in i..n {
let v = self.base.compute(&fx[i], &fx[j])?;
matrix[i][j] = v;
matrix[j][i] = v;
}
}
Ok(matrix)
}
}
impl<F: NeuralFeatureMap, K: Kernel> Kernel for DeepKernel<F, K> {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
self.evaluate(x, y)
}
fn name(&self) -> &str {
"DeepKernel"
}
fn is_psd(&self) -> bool {
self.base.is_psd()
}
}
pub trait FeatureMapShape {
fn feature_dim(&self) -> usize;
}
impl<M: NeuralFeatureMap> FeatureMapShape for M {
fn feature_dim(&self) -> usize {
self.output_dim()
}
}
pub struct DeepKernelSummary<'a, F, K>
where
F: NeuralFeatureMap,
K: Kernel,
{
pub kernel: &'a DeepKernel<F, K>,
}
impl<F, K> fmt::Display for DeepKernelSummary<'_, F, K>
where
F: NeuralFeatureMap,
K: Kernel,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"DeepKernel(in={}, features={}, base={})",
self.kernel.extractor.input_dim(),
self.kernel.extractor.output_dim(),
self.kernel.base.name()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
use crate::deep_kernel::layer::{Activation, DenseLayer};
use crate::types::RbfKernelConfig;
use crate::{LinearKernel, RbfKernel};
fn identity_mlp_1x1() -> MLPFeatureExtractor {
let layer =
DenseLayer::new(vec![vec![1.0]], vec![0.0], Activation::Identity).expect("valid");
MLPFeatureExtractor::from_layers(vec![layer]).expect("valid")
}
#[test]
fn deep_kernel_with_identity_equals_base() {
let linear = LinearKernel::new();
let dkl = DeepKernel::new(identity_mlp_1x1(), linear);
let expected = LinearKernel::new().compute(&[3.0], &[4.0]).expect("linear");
let got = dkl.compute(&[3.0], &[4.0]).expect("deep");
assert!((got - expected).abs() < 1e-12);
}
#[test]
fn deep_kernel_propagates_psd_from_base() {
let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
let dkl = DeepKernel::new(identity_mlp_1x1(), rbf);
assert!(dkl.is_psd());
}
}