use crate::error::Result;
use crate::types::FloatBounds;
use futures_core::{Future, Stream};
use std::pin::Pin;
use std::time::Duration;
pub type AsyncPartialFitFuture<'a, T> =
Pin<Box<dyn Future<Output = Result<Option<T>>> + Send + 'a>>;
pub type AsyncPredictConfidenceFuture<'a, Output> =
Pin<Box<dyn Future<Output = Result<(Output, ConfidenceInterval)>> + Send + 'a>>;
pub type AsyncScoreFuture<'a, Score> =
Pin<Box<dyn Future<Output = Result<Vec<Score>>> + Send + 'a>>;
pub type AsyncFitFuture<'a, T> = Pin<Box<dyn Future<Output = Result<T>> + Send + 'a>>;
pub type AsyncTransformFuture<'a, Output> =
Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
pub type AsyncCVStream<'a, Score> = Pin<Box<dyn Stream<Item = Result<(usize, Score)>> + Send + 'a>>;
pub type AsyncEnsembleFitStream<'a, Model> =
Pin<Box<dyn Stream<Item = Result<(usize, Model)>> + Send + 'a>>;
pub type AsyncEnsemblePredictStream<'a, Output> =
Pin<Box<dyn Stream<Item = Result<(usize, Output)>> + Send + 'a>>;
pub type AsyncOptimizationStream<'a, Config, Score> =
Pin<Box<dyn Stream<Item = Result<OptimizationResult<Config, Score>>> + Send + 'a>>;
pub type ConfigFactory<Config> =
Box<dyn Fn(&std::collections::HashMap<String, f64>) -> Config + Send + Sync>;
pub type AsyncUnitFuture<'a> = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
#[derive(Debug, Clone)]
pub struct ProgressInfo {
pub progress: f64,
pub current_step: usize,
pub total_steps: Option<usize>,
pub elapsed: Duration,
pub eta: Option<Duration>,
pub current_metric: Option<f64>,
pub message: String,
}
impl ProgressInfo {
pub fn new(progress: f64, current_step: usize) -> Self {
Self {
progress: progress.clamp(0.0, 1.0),
current_step,
total_steps: None,
elapsed: Duration::from_secs(0),
eta: None,
current_metric: None,
message: String::new(),
}
}
pub fn with_total_steps(mut self, total: usize) -> Self {
self.total_steps = Some(total);
if total > 0 {
self.progress = self.current_step as f64 / total as f64;
}
self
}
pub fn with_elapsed(mut self, elapsed: Duration) -> Self {
self.elapsed = elapsed;
self
}
pub fn with_eta(mut self, eta: Duration) -> Self {
self.eta = Some(eta);
self
}
pub fn with_metric(mut self, metric: f64) -> Self {
self.current_metric = Some(metric);
self
}
pub fn with_message<S: Into<String>>(mut self, message: S) -> Self {
self.message = message.into();
self
}
}
#[derive(Debug, Clone)]
pub struct AsyncConfig {
pub batch_size: usize,
pub operation_timeout: Option<Duration>,
pub enable_progress: bool,
pub progress_interval: Duration,
pub max_concurrency: usize,
}
impl Default for AsyncConfig {
fn default() -> Self {
Self {
batch_size: 1000,
operation_timeout: Some(Duration::from_secs(300)), enable_progress: true,
progress_interval: Duration::from_secs(1),
max_concurrency: num_cpus::get(),
}
}
}
pub trait AsyncFitAdvanced<X, Y, State = crate::traits::Untrained> {
type Fitted;
type Error: std::error::Error + Send + Sync;
fn fit_async_with_progress<'a>(
self,
x: &'a X,
y: &'a Y,
config: &'a AsyncConfig,
) -> AsyncFitFuture<'a, Self::Fitted>
where
Self: 'a;
fn fit_async_with_progress_stream<'a>(
self,
x: &'a X,
y: &'a Y,
config: &'a AsyncConfig,
) -> Pin<Box<dyn Stream<Item = Result<ProgressInfo>> + Send + 'a>>
where
Self: 'a;
fn fit_async_cancellable<'a>(
self,
x: &'a X,
y: &'a Y,
cancel_token: CancellationToken,
) -> AsyncPartialFitFuture<'a, Self::Fitted>
where
Self: 'a;
}
pub trait AsyncPredictAdvanced<X, Output> {
type Error: std::error::Error + Send + Sync;
fn predict_async_batched<'a>(
&'a self,
x: &'a X,
config: &'a AsyncConfig,
) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
fn predict_stream<'a>(
&'a self,
x_stream: Pin<Box<dyn Stream<Item = X> + Send + 'a>>,
config: &'a AsyncConfig,
) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
fn predict_async_with_uncertainty<'a>(
&'a self,
x: &'a X,
confidence_level: f64,
) -> AsyncPredictConfidenceFuture<'a, Output>
where
Self: 'a;
}
pub trait AsyncTransformAdvanced<X, Output = X> {
type Error: std::error::Error + Send + Sync;
fn transform_async_with_progress<'a>(
&'a self,
x: &'a X,
config: &'a AsyncConfig,
) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>;
fn transform_stream<'a>(
&'a self,
x_stream: Pin<Box<dyn Stream<Item = X> + Send + 'a>>,
config: &'a AsyncConfig,
) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
fn transform_async_chunked<'a>(
&'a self,
x: &'a X,
chunk_size: usize,
) -> Pin<Box<dyn Stream<Item = Result<Output>> + Send + 'a>>;
}
pub trait AsyncPartialFit<X, Y> {
type Error: std::error::Error + Send + Sync;
fn partial_fit_async<'a>(
&'a mut self,
x: &'a X,
y: &'a Y,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
fn partial_fit_stream<'a>(
&'a mut self,
data_stream: Pin<Box<dyn Stream<Item = (X, Y)> + Send + 'a>>,
config: &'a AsyncConfig,
) -> Pin<Box<dyn Stream<Item = Result<ProgressInfo>> + Send + 'a>>;
fn adaptive_partial_fit<'a>(
&'a mut self,
data_stream: Pin<Box<dyn Stream<Item = (X, Y)> + Send + 'a>>,
adaptation_config: AdaptationConfig,
) -> Pin<Box<dyn Stream<Item = Result<AdaptationInfo>> + Send + 'a>>;
}
#[derive(Debug, Clone)]
pub struct AdaptationConfig {
pub initial_batch_size: usize,
pub min_batch_size: usize,
pub max_batch_size: usize,
pub adaptation_rate: f64,
pub performance_threshold: f64,
pub memory_threshold: usize,
}
#[derive(Debug, Clone)]
pub struct AdaptationInfo {
pub current_batch_size: usize,
pub current_learning_rate: f64,
pub performance_metric: f64,
pub memory_usage: usize,
pub progress: ProgressInfo,
}
#[derive(Debug, Clone)]
pub struct ConfidenceInterval {
pub lower: f64,
pub upper: f64,
pub confidence_level: f64,
}
#[derive(Debug, Clone)]
pub struct CancellationToken {
inner: std::sync::Arc<std::sync::atomic::AtomicBool>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
inner: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
}
}
pub fn cancel(&self) {
self.inner.store(true, std::sync::atomic::Ordering::Relaxed);
}
pub fn is_cancelled(&self) -> bool {
self.inner.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
pub trait AsyncCrossValidation<X, Y> {
type Score: FloatBounds + Send;
type Model: Clone + Send + Sync;
fn cross_validate_async<'a>(
&'a self,
model: Self::Model,
x: &'a X,
y: &'a Y,
cv_folds: usize,
config: &'a AsyncConfig,
) -> AsyncScoreFuture<'a, Self::Score>
where
X: Clone + Send + Sync,
Y: Clone + Send + Sync;
fn cross_validate_stream<'a>(
&'a self,
model: Self::Model,
x: &'a X,
y: &'a Y,
cv_folds: usize,
config: &'a AsyncConfig,
) -> AsyncCVStream<'a, Self::Score>
where
X: Clone + Send + Sync,
Y: Clone + Send + Sync;
}
pub trait AsyncEnsemble<X, Y, Output> {
type Model: Send + Sync;
type Error: std::error::Error + Send + Sync;
fn fit_ensemble_async<'a>(
models: Vec<Self::Model>,
x: &'a X,
y: &'a Y,
config: &'a AsyncConfig,
) -> AsyncEnsembleFitStream<'a, Self::Model>
where
X: Send + Sync,
Y: Send + Sync,
Self::Model: 'a;
fn predict_ensemble_async<'a>(
models: &'a [Self::Model],
x: &'a X,
config: &'a AsyncConfig,
) -> Pin<Box<dyn Future<Output = Result<Output>> + Send + 'a>>
where
X: Send + Sync;
fn predict_ensemble_stream<'a>(
models: &'a [Self::Model],
x: &'a X,
config: &'a AsyncConfig,
) -> AsyncEnsemblePredictStream<'a, Output>
where
X: Send + Sync;
}
pub trait AsyncHyperparameterOptimization<X, Y, Config> {
type Score: FloatBounds + Send;
type Error: std::error::Error + Send + Sync;
fn optimize_async<'a>(
&'a self,
x: &'a X,
y: &'a Y,
param_space: ParameterSpace<Config>,
optimization_config: OptimizationConfig,
) -> AsyncOptimizationStream<'a, Config, Self::Score>
where
X: Send + Sync,
Y: Send + Sync,
Config: Send + Sync;
}
pub struct ParameterSpace<Config> {
pub parameters: std::collections::HashMap<String, ParameterRange>,
pub dependencies: Vec<ParameterDependency>,
pub config_factory: ConfigFactory<Config>,
}
#[derive(Debug, Clone)]
pub enum ParameterRange {
Continuous { min: f64, max: f64 },
Discrete { values: Vec<f64> },
LogContinuous { min: f64, max: f64 },
Integer { min: i64, max: i64 },
}
pub struct ParameterDependency {
pub dependent: String,
pub parent: String,
pub condition: Box<dyn Fn(f64) -> bool + Send + Sync>,
}
#[derive(Debug, Clone)]
pub struct OptimizationConfig {
pub max_evaluations: usize,
pub algorithm: OptimizationAlgorithm,
pub early_stopping: Option<EarlyStoppingConfig>,
pub parallel_config: AsyncConfig,
}
#[derive(Debug, Clone)]
pub enum OptimizationAlgorithm {
Random,
BayesianOptimization {
acquisition_function: AcquisitionFunction,
n_initial_points: usize,
},
TPE {
n_startup_trials: usize,
n_ei_candidates: usize,
},
Hyperband { max_resource: usize, eta: f64 },
}
#[derive(Debug, Clone)]
pub enum AcquisitionFunction {
ExpectedImprovement,
UpperConfidenceBound { kappa: f64 },
ProbabilityOfImprovement,
}
#[derive(Debug, Clone)]
pub struct EarlyStoppingConfig {
pub patience: usize,
pub min_improvement: f64,
pub maximize: bool,
}
#[derive(Debug, Clone)]
pub struct OptimizationResult<Config, Score> {
pub trial: usize,
pub config: Config,
pub score: Score,
pub evaluation_time: Duration,
pub metrics: std::collections::HashMap<String, f64>,
}
pub trait AsyncModelPersistence {
type Error: std::error::Error + Send + Sync;
fn save_async<'a>(
&'a self,
path: &'a std::path::Path,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
fn load_async<'a>(
path: &'a std::path::Path,
) -> Pin<Box<dyn Future<Output = Result<Self>> + Send + 'a>>
where
Self: Sized;
fn save_compressed_async<'a>(
&'a self,
path: &'a std::path::Path,
compression_level: u32,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
}
#[allow(non_snake_case)]
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_progress_info() {
let progress = ProgressInfo::new(0.5, 50)
.with_total_steps(100)
.with_elapsed(Duration::from_secs(30))
.with_eta(Duration::from_secs(30))
.with_metric(0.85)
.with_message("Training in progress");
assert_eq!(progress.progress, 0.5);
assert_eq!(progress.current_step, 50);
assert_eq!(progress.total_steps, Some(100));
assert_eq!(progress.elapsed, Duration::from_secs(30));
assert_eq!(progress.eta, Some(Duration::from_secs(30)));
assert_eq!(progress.current_metric, Some(0.85));
assert_eq!(progress.message, "Training in progress");
}
#[test]
fn test_cancellation_token() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn test_async_config_default() {
let config = AsyncConfig::default();
assert_eq!(config.batch_size, 1000);
assert!(config.enable_progress);
assert_eq!(config.progress_interval, Duration::from_secs(1));
}
#[test]
fn test_confidence_interval() {
let ci = ConfidenceInterval {
lower: 0.1,
upper: 0.9,
confidence_level: 0.95,
};
assert_eq!(ci.lower, 0.1);
assert_eq!(ci.upper, 0.9);
assert_eq!(ci.confidence_level, 0.95);
}
}