use crate::deep_kernel::feature_extractor::MLPFeatureExtractor;
use crate::deep_kernel::kernel::DeepKernel;
use crate::deep_kernel::layer::Activation;
use crate::error::{KernelError, Result};
use crate::types::Kernel;
#[derive(Clone, Debug, Default)]
pub struct DeepKernelBuilder {
widths: Vec<usize>,
activations: Vec<Activation>,
seed: Option<u64>,
has_output: bool,
}
impl DeepKernelBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn input_dim(mut self, dim: usize) -> Self {
if self.widths.is_empty() {
self.widths.push(dim);
} else {
self.widths[0] = dim;
}
self
}
pub fn hidden_layer(mut self, width: usize, activation: Activation) -> Self {
if !self.has_output {
self.widths.push(width);
self.activations.push(activation);
}
self
}
pub fn output_dim(mut self, width: usize, activation: Activation) -> Self {
if !self.has_output {
self.widths.push(width);
self.activations.push(activation);
self.has_output = true;
}
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn build_extractor(&self) -> Result<MLPFeatureExtractor> {
if self.widths.len() < 2 {
return Err(KernelError::InvalidParameter {
parameter: "widths".to_string(),
value: format!("{:?}", self.widths),
reason: "builder needs at least input_dim + output_dim".to_string(),
});
}
if !self.has_output {
return Err(KernelError::InvalidParameter {
parameter: "output_dim".to_string(),
value: "unset".to_string(),
reason: "call output_dim before build".to_string(),
});
}
let seed = self.seed.unwrap_or(0);
MLPFeatureExtractor::xavier_init(&self.widths, &self.activations, seed)
}
pub fn build<K: Kernel>(self, base: K) -> Result<DeepKernel<MLPFeatureExtractor, K>> {
let extractor = self.build_extractor()?;
Ok(DeepKernel::new(extractor, base))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::RbfKernelConfig;
use crate::RbfKernel;
#[test]
fn builder_assembles_three_layer_mlp() {
let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
let dkl = DeepKernelBuilder::new()
.input_dim(3)
.hidden_layer(5, Activation::ReLU)
.output_dim(2, Activation::Identity)
.seed(123)
.build(rbf)
.expect("valid build");
assert_eq!(dkl.feature_extractor().num_layers(), 2);
}
#[test]
fn builder_fails_without_output_dim() {
let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
let result = DeepKernelBuilder::new()
.input_dim(3)
.hidden_layer(5, Activation::ReLU)
.build(rbf);
match result {
Ok(_) => panic!("missing output_dim must fail"),
Err(KernelError::InvalidParameter { .. }) => {}
Err(other) => panic!("unexpected error variant: {}", other),
}
}
#[test]
fn builder_fails_when_only_input_set() {
let rbf = RbfKernel::new(RbfKernelConfig::new(0.5)).expect("valid");
let result = DeepKernelBuilder::new().input_dim(3).build(rbf);
match result {
Ok(_) => panic!("only input_dim set must fail"),
Err(KernelError::InvalidParameter { .. }) => {}
Err(other) => panic!("unexpected error variant: {}", other),
}
}
}