use std::sync::Arc;
use crate::error::{KernelError, Result};
use crate::learned_composition::mixture::LearnedMixtureKernel;
use crate::types::Kernel;
#[derive(Default)]
pub struct LearnedMixtureBuilder {
base_kernels: Vec<Arc<dyn Kernel>>,
logits: Vec<f64>,
}
impl LearnedMixtureBuilder {
pub fn new() -> Self {
Self {
base_kernels: Vec::new(),
logits: Vec::new(),
}
}
pub fn push_kernel(mut self, kernel: Arc<dyn Kernel>) -> Self {
self.base_kernels.push(kernel);
self.logits.push(0.0);
self
}
pub fn push_kernel_with_logit(mut self, kernel: Arc<dyn Kernel>, logit: f64) -> Self {
self.base_kernels.push(kernel);
self.logits.push(logit);
self
}
pub fn extend_kernels<I>(mut self, kernels: I) -> Self
where
I: IntoIterator<Item = Arc<dyn Kernel>>,
{
for kernel in kernels {
self.base_kernels.push(kernel);
self.logits.push(0.0);
}
self
}
pub fn with_logits(mut self, logits: Vec<f64>) -> Self {
self.logits = logits;
self
}
pub fn len(&self) -> usize {
self.base_kernels.len()
}
pub fn is_empty(&self) -> bool {
self.base_kernels.is_empty()
}
pub fn build(self) -> Result<LearnedMixtureKernel> {
if self.base_kernels.is_empty() {
return Err(KernelError::InvalidParameter {
parameter: "base_kernels".to_string(),
value: "[]".to_string(),
reason: "LearnedMixtureBuilder requires at least one kernel".to_string(),
});
}
if self.logits.len() != self.base_kernels.len() {
return Err(KernelError::DimensionMismatch {
expected: vec![self.base_kernels.len()],
got: vec![self.logits.len()],
context: "LearnedMixtureBuilder::with_logits length".to_string(),
});
}
LearnedMixtureKernel::new(self.base_kernels, self.logits)
}
}