use crate::error::Result;
use crate::learned_composition::mixture::LearnedMixtureKernel;
#[derive(Clone, Debug)]
pub struct TrainableKernelMixture {
inner: LearnedMixtureKernel,
}
impl TrainableKernelMixture {
pub fn new(inner: LearnedMixtureKernel) -> Self {
Self { inner }
}
pub fn num_parameters(&self) -> usize {
self.inner.num_kernels()
}
pub fn parameters(&self) -> &[f64] {
self.inner.logits()
}
pub fn weights(&self) -> Vec<f64> {
self.inner.weights()
}
pub fn evaluate(&self, x: &[f64], y: &[f64]) -> Result<f64> {
self.inner.evaluate(x, y)
}
pub fn evaluate_with_gradient(&self, x: &[f64], y: &[f64]) -> Result<(f64, Vec<f64>)> {
self.inner.evaluate_with_gradient(x, y)
}
pub fn gradient(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>> {
self.inner.gradient_wrt_logits(x, y)
}
pub fn step(&mut self, gradient: &[f64], learning_rate: f64) -> Result<()> {
self.inner.apply_gradient_step(gradient, learning_rate)
}
pub fn set_parameters(&mut self, new_logits: Vec<f64>) -> Result<()> {
self.inner.set_logits(new_logits)
}
pub fn inner(&self) -> &LearnedMixtureKernel {
&self.inner
}
pub fn into_inner(self) -> LearnedMixtureKernel {
self.inner
}
}
impl From<LearnedMixtureKernel> for TrainableKernelMixture {
fn from(inner: LearnedMixtureKernel) -> Self {
Self::new(inner)
}
}