use linfa::{Float, ParamGuard};
use ndarray_rand::rand::{rngs::SmallRng, Rng, SeedableRng};
use crate::TSneError;
#[derive(Debug, Clone, PartialEq)]
pub struct TSneValidParams<F, R> {
embedding_size: usize,
approx_threshold: F,
perplexity: F,
max_iter: usize,
preliminary_iter: Option<usize>,
rng: R,
}
impl<F: Float, R> TSneValidParams<F, R> {
pub fn embedding_size(&self) -> usize {
self.embedding_size
}
pub fn approx_threshold(&self) -> F {
self.approx_threshold
}
pub fn perplexity(&self) -> F {
self.perplexity
}
pub fn max_iter(&self) -> usize {
self.max_iter
}
pub fn preliminary_iter(&self) -> &Option<usize> {
&self.preliminary_iter
}
pub fn rng(&self) -> &R {
&self.rng
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TSneParams<F, R>(TSneValidParams<F, R>);
impl<F: Float> TSneParams<F, SmallRng> {
pub fn embedding_size(embedding_size: usize) -> TSneParams<F, SmallRng> {
Self::embedding_size_with_rng(embedding_size, SmallRng::seed_from_u64(42))
}
}
impl<F: Float, R: Rng + Clone> TSneParams<F, R> {
pub fn embedding_size_with_rng(embedding_size: usize, rng: R) -> TSneParams<F, R> {
Self(TSneValidParams {
embedding_size,
rng,
approx_threshold: F::cast(0.5),
perplexity: F::cast(5.0),
max_iter: 2000,
preliminary_iter: None,
})
}
pub fn approx_threshold(mut self, threshold: F) -> Self {
self.0.approx_threshold = threshold;
self
}
pub fn perplexity(mut self, perplexity: F) -> Self {
self.0.perplexity = perplexity;
self
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.0.max_iter = max_iter;
self
}
pub fn preliminary_iter(mut self, num_iter: usize) -> Self {
self.0.preliminary_iter = Some(num_iter);
self
}
}
impl<F: Float, R> ParamGuard for TSneParams<F, R> {
type Checked = TSneValidParams<F, R>;
type Error = TSneError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
if self.0.perplexity.is_negative() {
Err(TSneError::NegativePerplexity)
} else if self.0.approx_threshold.is_negative() {
Err(TSneError::NegativeApproximationThreshold)
} else {
Ok(&self.0)
}
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}