use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub enum ResamplingError {
InvalidInput(String),
InsufficientSamples,
ConfigError(String),
ComputationError(String),
}
impl std::fmt::Display for ResamplingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ResamplingError::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
ResamplingError::InsufficientSamples => write!(f, "Insufficient samples for resampling"),
ResamplingError::ConfigError(msg) => write!(f, "Configuration error: {}", msg),
ResamplingError::ComputationError(msg) => write!(f, "Computation error: {}", msg),
}
}
}
impl std::error::Error for ResamplingError {}
#[derive(Debug, Clone, Copy)]
pub enum PerformanceHint {
CacheFriendly,
Vectorize,
Parallel,
GpuAccelerated,
}
#[derive(Debug, Clone, Default)]
pub struct PerformanceHints {
hints: Vec<PerformanceHint>,
}
impl PerformanceHints {
pub fn new() -> Self {
Self::default()
}
pub fn with_hint(mut self, hint: PerformanceHint) -> Self {
self.hints.push(hint);
self
}
pub fn has_hint(&self, hint: PerformanceHint) -> bool {
self.hints.iter().any(|&h| std::mem::discriminant(&h) == std::mem::discriminant(&hint))
}
}
pub trait ResamplingStrategy: Send + Sync {
type Input;
type Output;
type Config;
fn resample(
&self,
x: ArrayView2<f64>,
y: ArrayView1<i32>,
config: &Self::Config,
) -> Result<(Array2<f64>, Array1<i32>), ResamplingError>;
fn performance_hints(&self) -> PerformanceHints {
PerformanceHints::default()
}
}
pub struct Uninitialized;
pub struct Configured;
pub struct Fitted;
pub struct Resampler<Strategy, State> {
strategy: Strategy,
_state: PhantomData<State>,
}
impl<S: ResamplingStrategy> Resampler<S, Uninitialized> {
pub fn new(strategy: S) -> Self {
Self {
strategy,
_state: PhantomData,
}
}
pub fn configure(self, _config: S::Config) -> Resampler<S, Configured> {
Resampler {
strategy: self.strategy,
_state: PhantomData,
}
}
}