use std::{fmt::Debug, marker::PhantomData};
use linfa::ParamGuard;
use rand::Rng;
use crate::ReductionError;
use super::methods::ProjectionMethod;
pub struct RandomProjectionParams<Proj: ProjectionMethod, R: Rng + Clone>(
pub(crate) RandomProjectionValidParams<Proj, R>,
);
impl<Proj: ProjectionMethod, R: Rng + Clone> RandomProjectionParams<Proj, R> {
pub fn target_dim(mut self, dim: usize) -> Self {
self.0.params = RandomProjectionParamsInner::Dimension { target_dim: dim };
self
}
pub fn eps(mut self, eps: f64) -> Self {
self.0.params = RandomProjectionParamsInner::Epsilon { eps };
self
}
pub fn with_rng<R2: Rng + Clone>(self, rng: R2) -> RandomProjectionParams<Proj, R2> {
RandomProjectionParams(RandomProjectionValidParams {
params: self.0.params,
rng,
marker: PhantomData,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RandomProjectionValidParams<Proj: ProjectionMethod, R: Rng + Clone> {
pub(super) params: RandomProjectionParamsInner,
pub(super) rng: R,
pub(crate) marker: PhantomData<Proj>,
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum RandomProjectionParamsInner {
Dimension { target_dim: usize },
Epsilon { eps: f64 },
}
impl RandomProjectionParamsInner {
fn target_dim(&self) -> Option<usize> {
use RandomProjectionParamsInner::*;
match self {
Dimension { target_dim } => Some(*target_dim),
Epsilon { .. } => None,
}
}
fn eps(&self) -> Option<f64> {
use RandomProjectionParamsInner::*;
match self {
Dimension { .. } => None,
Epsilon { eps } => Some(*eps),
}
}
}
impl<Proj: ProjectionMethod, R: Rng + Clone> RandomProjectionValidParams<Proj, R> {
pub fn target_dim(&self) -> Option<usize> {
self.params.target_dim()
}
pub fn eps(&self) -> Option<f64> {
self.params.eps()
}
pub fn rng(&self) -> &R {
&self.rng
}
}
impl<Proj: ProjectionMethod, R: Rng + Clone> ParamGuard for RandomProjectionParams<Proj, R> {
type Checked = RandomProjectionValidParams<Proj, R>;
type Error = ReductionError;
fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
match self.0.params {
RandomProjectionParamsInner::Dimension { target_dim } => {
if target_dim == 0 {
return Err(ReductionError::NonPositiveEmbeddingSize);
}
}
RandomProjectionParamsInner::Epsilon { eps } => {
if eps <= 0. || eps >= 1. {
return Err(ReductionError::InvalidPrecision);
}
}
};
Ok(&self.0)
}
fn check(self) -> Result<Self::Checked, Self::Error> {
self.check_ref()?;
Ok(self.0)
}
}