use crate::{DiffusionMap, ReductionError};
use linfa::{param_guard::TransformGuard, ParamGuard};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DiffusionMapValidParams {
steps: usize,
embedding_size: usize,
}
impl DiffusionMapValidParams {
pub fn steps(&self) -> usize {
self.steps
}
pub fn embedding_size(&self) -> usize {
self.embedding_size
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DiffusionMapParams(DiffusionMapValidParams);
impl DiffusionMapParams {
pub fn steps(mut self, steps: usize) -> Self {
self.0.steps = steps;
self
}
pub fn embedding_size(mut self, embedding_size: usize) -> Self {
self.0.embedding_size = embedding_size;
self
}
pub fn new(embedding_size: usize) -> DiffusionMapParams {
Self(DiffusionMapValidParams {
steps: 1,
embedding_size,
})
}
}
impl Default for DiffusionMapParams {
fn default() -> Self {
Self::new(2)
}
}
impl<F> DiffusionMap<F> {
pub fn params(embedding_size: usize) -> DiffusionMapParams {
DiffusionMapParams::new(embedding_size)
}
}
impl ParamGuard for DiffusionMapParams {
type Checked = DiffusionMapValidParams;
type Error = ReductionError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if self.0.steps == 0 {
Err(ReductionError::StepsZero)
} else if self.0.embedding_size == 0 {
Err(ReductionError::EmbeddingTooSmall(self.0.embedding_size))
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}
impl TransformGuard for DiffusionMapParams {}