use crate::error::FFTResult;
use crate::planning::{FftPlan, PlannerBackend, PlanningStrategy};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct AdaptivePlanningConfig {
pub enabled: bool,
pub min_samples: usize,
pub evaluation_interval: Duration,
pub max_strategy_switches: usize,
pub enable_backend_switching: bool,
pub improvement_threshold: f64,
}
impl Default for AdaptivePlanningConfig {
fn default() -> Self {
Self {
enabled: true,
min_samples: 5,
evaluation_interval: Duration::from_secs(10),
max_strategy_switches: 3,
enable_backend_switching: true,
improvement_threshold: 1.1, }
}
}
#[derive(Debug, Clone)]
struct StrategyMetrics {
total_time: Duration,
count: usize,
avg_time: Duration,
#[allow(dead_code)]
last_evaluated: Instant,
}
impl StrategyMetrics {
fn new() -> Self {
Self {
total_time: Duration::from_nanos(0),
count: 0,
avg_time: Duration::from_nanos(0),
last_evaluated: Instant::now(),
}
}
fn record(&mut self, time: Duration) {
self.total_time += time;
self.count += 1;
self.avg_time =
Duration::from_nanos((self.total_time.as_nanos() / self.count as u128) as u64);
}
}
pub struct AdaptivePlanner {
size: Vec<usize>,
forward: bool,
current_strategy: PlanningStrategy,
current_backend: PlannerBackend,
metrics: HashMap<PlanningStrategy, StrategyMetrics>,
last_strategy_switch: Instant,
strategy_switches: usize,
config: AdaptivePlanningConfig,
current_plan: Option<Arc<FftPlan>>,
}
use std::collections::HashMap;
impl AdaptivePlanner {
pub fn new(size: &[usize], forward: bool, config: Option<AdaptivePlanningConfig>) -> Self {
let config = config.unwrap_or_default();
let mut metrics = HashMap::new();
metrics.insert(PlanningStrategy::AlwaysNew, StrategyMetrics::new());
metrics.insert(PlanningStrategy::CacheFirst, StrategyMetrics::new());
metrics.insert(PlanningStrategy::SerializedFirst, StrategyMetrics::new());
metrics.insert(PlanningStrategy::AutoTuned, StrategyMetrics::new());
Self {
size: size.to_vec(),
forward,
current_strategy: PlanningStrategy::CacheFirst, current_backend: PlannerBackend::default(),
metrics,
last_strategy_switch: Instant::now(),
strategy_switches: 0,
config,
current_plan: None,
}
}
pub fn current_strategy(&self) -> PlanningStrategy {
self.current_strategy
}
pub fn current_backend(&self) -> PlannerBackend {
self.current_backend.clone()
}
pub fn get_plan(&mut self) -> FFTResult<Arc<FftPlan>> {
if let Some(plan) = &self.current_plan {
return Ok(plan.clone());
}
use crate::planning::{AdvancedFftPlanner, PlanningConfig};
let config = PlanningConfig {
strategy: self.current_strategy,
..Default::default()
};
let mut planner = AdvancedFftPlanner::with_config(config);
let plan = planner.plan_fft(&self.size, self.forward, self.current_backend.clone())?;
self.current_plan = Some(plan.clone());
Ok(plan)
}
pub fn record_execution(&mut self, executiontime: Duration) -> FFTResult<()> {
if !self.config.enabled {
return Ok(());
}
if let Some(metrics) = self.metrics.get_mut(&self.current_strategy) {
metrics.record(executiontime);
}
let should_evaluate =
self.metrics[&self.current_strategy].count >= self.config.min_samples &&
self.last_strategy_switch.elapsed() >= self.config.evaluation_interval &&
self.strategy_switches < self.config.max_strategy_switches;
if should_evaluate {
self.evaluate_strategies()?;
}
Ok(())
}
fn evaluate_strategies(&mut self) -> FFTResult<()> {
let mut best_strategy = self.current_strategy;
let mut best_time = self.metrics[&self.current_strategy].avg_time;
for (strategy, metrics) in &self.metrics {
if metrics.count == 0 {
continue;
}
let improvement_ratio =
best_time.as_nanos() as f64 / metrics.avg_time.as_nanos() as f64;
if improvement_ratio > self.config.improvement_threshold {
best_strategy = *strategy;
best_time = metrics.avg_time;
}
}
if best_strategy != self.current_strategy {
self.current_strategy = best_strategy;
self.last_strategy_switch = Instant::now();
self.strategy_switches += 1;
self.current_plan = None;
}
if self.config.enable_backend_switching {
}
Ok(())
}
pub fn get_statistics(&self) -> HashMap<PlanningStrategy, (Duration, usize)> {
let mut stats = HashMap::new();
for (strategy, metrics) in &self.metrics {
stats.insert(*strategy, (metrics.avg_time, metrics.count));
}
stats
}
}
pub struct AdaptiveExecutor {
planner: Arc<Mutex<AdaptivePlanner>>,
}
impl AdaptiveExecutor {
pub fn new(size: &[usize], forward: bool, config: Option<AdaptivePlanningConfig>) -> Self {
let planner = AdaptivePlanner::new(size, forward, config);
Self {
planner: Arc::new(Mutex::new(planner)),
}
}
pub fn execute(
&self,
input: &[scirs2_core::numeric::Complex64],
output: &mut [scirs2_core::numeric::Complex64],
) -> FFTResult<()> {
let start = Instant::now();
let plan = {
let mut planner = self.planner.lock().expect("Operation failed");
planner.get_plan()?
};
let executor = crate::planning::FftPlanExecutor::new(plan);
executor.execute(input, output)?;
let execution_time = start.elapsed();
{
let mut planner = self.planner.lock().expect("Operation failed");
planner.record_execution(execution_time)?;
}
Ok(())
}
pub fn current_strategy(&self) -> PlanningStrategy {
let planner = self.planner.lock().expect("Operation failed");
planner.current_strategy()
}
pub fn get_statistics(&self) -> HashMap<PlanningStrategy, (Duration, usize)> {
let planner = self.planner.lock().expect("Operation failed");
planner.get_statistics()
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::numeric::Complex64;
#[test]
fn test_adaptive_planner_basics() {
let mut planner = AdaptivePlanner::new(&[16], true, None);
assert_eq!(planner.current_strategy(), PlanningStrategy::CacheFirst);
for _ in 0..10 {
planner
.record_execution(Duration::from_micros(100))
.expect("Operation failed");
}
let stats = planner.get_statistics();
assert_eq!(stats[&PlanningStrategy::CacheFirst].1, 10);
}
#[test]
fn test_adaptive_executor() {
let executor = AdaptiveExecutor::new(&[16], true, None);
let input = vec![Complex64::new(1.0, 0.0); 16];
let mut output = vec![Complex64::default(); 16];
for _ in 0..5 {
executor
.execute(&input, &mut output)
.expect("Operation failed");
}
let stats = executor.get_statistics();
assert!(stats[&executor.current_strategy()].1 >= 5);
}
}