Skip to main content

sklears_simd/
adaptive_optimization.rs

1//! Adaptive optimization and runtime algorithm selection
2//!
3//! Provides dynamic dispatch, auto-tuning capabilities, and machine learning-guided optimization
4//! for SIMD operations based on runtime characteristics and performance feedback.
5//!
6//! ## no-std Compatibility
7//!
8//! This module supports both std and no-std environments through conditional compilation.
9//!
10//! ### Dependencies for no-std:
11//! - `alloc` crate for collections and Arc
12//! - `spin` crate for Mutex (add `spin = "0.9"` to Cargo.toml)
13//! - `rand` crate for random number generation
14//!
15//! ### Features:
16//! - Use `std` feature flag to enable std functionality
17//! - Without `std` feature, timing functionality uses mock values
18//! - SystemTime-based timestamps are disabled in no-std mode
19
20use crate::SimdCapabilities;
21
22// Required for no-std compatibility
23#[cfg(feature = "no-std")]
24extern crate alloc;
25
26// Conditional imports for std vs no-std
27#[cfg(not(feature = "no-std"))]
28use std::{
29    boxed::Box,
30    collections::HashMap,
31    fmt,
32    string::ToString,
33    sync::{Arc, Mutex},
34    time::Duration,
35};
36
37#[cfg(feature = "no-std")]
38use alloc::boxed::Box;
39#[cfg(feature = "no-std")]
40use alloc::collections::BTreeMap as HashMap;
41#[cfg(feature = "no-std")]
42use alloc::string::{String, ToString};
43#[cfg(feature = "no-std")]
44use alloc::sync::Arc;
45#[cfg(feature = "no-std")]
46use alloc::vec::Vec;
47#[cfg(feature = "no-std")]
48use core::fmt;
49#[cfg(feature = "no-std")]
50use spin::Mutex;
51
52#[cfg(feature = "no-std")]
53use core::time::Duration;
54#[cfg(not(feature = "no-std"))]
55use std::time::Instant;
56
57// SystemTime is only available in std
58#[cfg(not(feature = "no-std"))]
59use std::time::SystemTime;
60
61/// Runtime algorithm selector
62pub struct AdaptiveOptimizer {
63    #[allow(dead_code)] // Used via detect() at construction; stored for future adaptive queries
64    capabilities: SimdCapabilities,
65    performance_cache: Arc<Mutex<HashMap<String, AlgorithmPerformance>>>,
66    auto_tuning_enabled: bool,
67    learning_rate: f64,
68}
69
70/// Algorithm performance record
71#[derive(Debug, Clone)]
72pub struct AlgorithmPerformance {
73    pub algorithm_name: String,
74    pub avg_duration: Duration,
75    pub sample_count: usize,
76    pub data_size_range: (usize, usize),
77    pub success_rate: f64,
78    #[cfg(not(feature = "no-std"))]
79    pub last_updated: SystemTime,
80    #[cfg(feature = "no-std")]
81    pub last_updated: (),
82}
83
84/// Dynamic dispatch strategy
85#[derive(Debug, Clone)]
86pub enum DispatchStrategy {
87    /// Always use the fastest known algorithm
88    AlwaysFastest,
89    /// Use algorithm with best success rate
90    MostReliable,
91    /// Balance speed and reliability
92    Balanced,
93    /// Adapt based on data characteristics
94    DataDriven,
95    /// Machine learning guided selection
96    MLGuided,
97}
98
99/// Algorithm variant for dynamic selection
100pub trait AlgorithmVariant<T> {
101    fn name(&self) -> &str;
102    fn execute(&self, input: &T) -> Result<T, AlgorithmError>;
103    fn is_applicable(&self, input: &T) -> bool;
104    fn estimated_cost(&self, input: &T) -> f64;
105}
106
107/// Error type for algorithm execution
108#[derive(Debug)]
109pub enum AlgorithmError {
110    UnsupportedInput,
111    InsufficientResources,
112    NumericError,
113    RuntimeError(String),
114}
115
116impl core::fmt::Display for AlgorithmError {
117    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
118        match self {
119            AlgorithmError::UnsupportedInput => write!(f, "Unsupported input for algorithm"),
120            AlgorithmError::InsufficientResources => write!(f, "Insufficient resources"),
121            AlgorithmError::NumericError => write!(f, "Numeric computation error"),
122            AlgorithmError::RuntimeError(msg) => write!(f, "Runtime error: {}", msg),
123        }
124    }
125}
126
127#[cfg(not(feature = "no-std"))]
128impl std::error::Error for AlgorithmError {}
129
130#[cfg(feature = "no-std")]
131impl core::error::Error for AlgorithmError {}
132
133impl Default for AdaptiveOptimizer {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139impl AdaptiveOptimizer {
140    /// Create a new adaptive optimizer
141    pub fn new() -> Self {
142        Self {
143            capabilities: SimdCapabilities::detect(),
144            performance_cache: Arc::new(Mutex::new(HashMap::new())),
145            auto_tuning_enabled: true,
146            learning_rate: 0.1,
147        }
148    }
149
150    /// Helper function to handle mutex locking in both std and no-std environments
151    #[cfg(not(feature = "no-std"))]
152    fn lock_cache(&self) -> std::sync::MutexGuard<'_, HashMap<String, AlgorithmPerformance>> {
153        self.performance_cache
154            .lock()
155            .expect("lock should not be poisoned")
156    }
157
158    #[cfg(feature = "no-std")]
159    fn lock_cache(
160        &self,
161    ) -> spin::MutexGuard<'_, HashMap<String, AlgorithmPerformance>, spin::Spin> {
162        self.performance_cache.lock()
163    }
164
165    /// Helper function to handle mutable mutex locking in both std and no-std environments
166    #[cfg(not(feature = "no-std"))]
167    fn lock_cache_mut(&self) -> std::sync::MutexGuard<'_, HashMap<String, AlgorithmPerformance>> {
168        self.performance_cache
169            .lock()
170            .expect("lock should not be poisoned")
171    }
172
173    #[cfg(feature = "no-std")]
174    fn lock_cache_mut(
175        &self,
176    ) -> spin::MutexGuard<'_, HashMap<String, AlgorithmPerformance>, spin::Spin> {
177        self.performance_cache.lock()
178    }
179
180    /// Enable or disable auto-tuning
181    pub fn set_auto_tuning(&mut self, enabled: bool) {
182        self.auto_tuning_enabled = enabled;
183    }
184
185    /// Set learning rate for performance adaptation
186    pub fn set_learning_rate(&mut self, rate: f64) {
187        self.learning_rate = rate.clamp(0.0, 1.0);
188    }
189
190    /// Select best algorithm variant based on strategy
191    pub fn select_algorithm<'a, T>(
192        &self,
193        variants: &'a [Box<dyn AlgorithmVariant<T>>],
194        input: &T,
195        strategy: DispatchStrategy,
196    ) -> Option<&'a dyn AlgorithmVariant<T>> {
197        let applicable_variants: Vec<&'a dyn AlgorithmVariant<T>> = variants
198            .iter()
199            .filter(|variant| variant.is_applicable(input))
200            .map(|boxed| boxed.as_ref())
201            .collect();
202
203        if applicable_variants.is_empty() {
204            return None;
205        }
206
207        match strategy {
208            DispatchStrategy::AlwaysFastest => self.select_fastest(&applicable_variants, input),
209            DispatchStrategy::MostReliable => self.select_most_reliable(&applicable_variants),
210            DispatchStrategy::Balanced => self.select_balanced(&applicable_variants, input),
211            DispatchStrategy::DataDriven => self.select_data_driven(&applicable_variants, input),
212            DispatchStrategy::MLGuided => self.select_ml_guided(&applicable_variants, input),
213        }
214    }
215
216    /// Execute algorithm with performance tracking
217    pub fn execute_with_tracking<T>(
218        &self,
219        variant: &dyn AlgorithmVariant<T>,
220        input: &T,
221    ) -> Result<T, AlgorithmError> {
222        #[cfg(not(feature = "no-std"))]
223        let start = Instant::now();
224
225        let result = variant.execute(input);
226
227        #[cfg(not(feature = "no-std"))]
228        let duration = start.elapsed();
229        #[cfg(feature = "no-std")]
230        let duration = Duration::from_millis(1); // Mock duration for no-std
231
232        if self.auto_tuning_enabled {
233            self.update_performance_stats(variant.name(), duration, result.is_ok());
234        }
235
236        result
237    }
238
239    /// Get performance statistics for an algorithm
240    pub fn get_performance_stats(&self, algorithm_name: &str) -> Option<AlgorithmPerformance> {
241        let cache = self.lock_cache();
242        cache.get(algorithm_name).cloned()
243    }
244
245    /// Auto-tune parameters for an algorithm
246    pub fn auto_tune_parameters<T, P>(
247        &self,
248        algorithm_factory: impl Fn(P) -> Box<dyn AlgorithmVariant<T>>,
249        parameter_ranges: Vec<(P, P)>,
250        test_inputs: &[T],
251        iterations: usize,
252    ) -> P
253    where
254        P: Clone + PartialOrd + fmt::Debug,
255        T: Clone,
256    {
257        // Simple grid search for parameter tuning
258        // In a real implementation, this could use more sophisticated optimization
259        let mut best_params = parameter_ranges[0].0.clone();
260        let mut best_performance = Duration::from_secs(u64::MAX);
261
262        for _ in 0..iterations {
263            for (min_param, _max_param) in &parameter_ranges {
264                // For simplicity, just test the midpoint
265                // A real implementation would use proper parameter sampling
266                let test_param = min_param.clone(); // Simplified
267
268                let algorithm = algorithm_factory(test_param.clone());
269                let mut total_duration = Duration::from_nanos(0);
270                let mut successful_runs = 0;
271
272                for test_input in test_inputs {
273                    #[cfg(not(feature = "no-std"))]
274                    let start = Instant::now();
275
276                    if algorithm.execute(test_input).is_ok() {
277                        #[cfg(not(feature = "no-std"))]
278                        {
279                            total_duration += start.elapsed();
280                        }
281                        #[cfg(feature = "no-std")]
282                        {
283                            total_duration += Duration::from_millis(1); // Mock duration for no-std
284                        }
285                        successful_runs += 1;
286                    }
287                }
288
289                if successful_runs > 0 {
290                    let avg_duration = total_duration / successful_runs as u32;
291                    if avg_duration < best_performance {
292                        best_performance = avg_duration;
293                        best_params = test_param;
294                    }
295                }
296            }
297        }
298
299        best_params
300    }
301
302    fn select_fastest<'a, T>(
303        &self,
304        variants: &[&'a dyn AlgorithmVariant<T>],
305        _input: &T,
306    ) -> Option<&'a dyn AlgorithmVariant<T>> {
307        let cache = self.lock_cache();
308
309        variants
310            .iter()
311            .min_by_key(|variant| {
312                cache
313                    .get(variant.name())
314                    .map(|perf| perf.avg_duration)
315                    .unwrap_or(Duration::from_secs(u64::MAX))
316            })
317            .copied()
318    }
319
320    fn select_most_reliable<'a, T>(
321        &self,
322        variants: &[&'a dyn AlgorithmVariant<T>],
323    ) -> Option<&'a dyn AlgorithmVariant<T>> {
324        let cache = self.lock_cache();
325
326        variants
327            .iter()
328            .max_by(|a, b| {
329                let a_reliability = cache
330                    .get(a.name())
331                    .map(|perf| perf.success_rate)
332                    .unwrap_or(0.0);
333                let b_reliability = cache
334                    .get(b.name())
335                    .map(|perf| perf.success_rate)
336                    .unwrap_or(0.0);
337                a_reliability
338                    .partial_cmp(&b_reliability)
339                    .unwrap_or(core::cmp::Ordering::Equal)
340            })
341            .copied()
342    }
343
344    fn select_balanced<'a, T>(
345        &self,
346        variants: &[&'a dyn AlgorithmVariant<T>],
347        _input: &T,
348    ) -> Option<&'a dyn AlgorithmVariant<T>> {
349        let cache = self.lock_cache();
350
351        variants
352            .iter()
353            .max_by(|a, b| {
354                let a_score = self.calculate_balanced_score(a.name(), &cache);
355                let b_score = self.calculate_balanced_score(b.name(), &cache);
356                a_score
357                    .partial_cmp(&b_score)
358                    .unwrap_or(core::cmp::Ordering::Equal)
359            })
360            .copied()
361    }
362
363    fn select_data_driven<'a, T>(
364        &self,
365        variants: &[&'a dyn AlgorithmVariant<T>],
366        input: &T,
367    ) -> Option<&'a dyn AlgorithmVariant<T>> {
368        // Simple heuristic: prefer algorithms with lower estimated cost
369        variants
370            .iter()
371            .min_by(|a, b| {
372                let a_cost = a.estimated_cost(input);
373                let b_cost = b.estimated_cost(input);
374                a_cost
375                    .partial_cmp(&b_cost)
376                    .unwrap_or(core::cmp::Ordering::Equal)
377            })
378            .copied()
379    }
380
381    fn select_ml_guided<'a, T>(
382        &self,
383        variants: &[&'a dyn AlgorithmVariant<T>],
384        input: &T,
385    ) -> Option<&'a dyn AlgorithmVariant<T>> {
386        // Simplified ML-guided selection
387        // In a real implementation, this would use a trained model
388        let cache = self.lock_cache();
389
390        variants
391            .iter()
392            .max_by(|a, b| {
393                let a_prediction = self.predict_performance(a.name(), input, &cache);
394                let b_prediction = self.predict_performance(b.name(), input, &cache);
395                a_prediction
396                    .partial_cmp(&b_prediction)
397                    .unwrap_or(core::cmp::Ordering::Equal)
398            })
399            .copied()
400    }
401
402    fn calculate_balanced_score(
403        &self,
404        algorithm_name: &str,
405        cache: &HashMap<String, AlgorithmPerformance>,
406    ) -> f64 {
407        if let Some(perf) = cache.get(algorithm_name) {
408            // Balance speed (1/duration) and reliability
409            let speed_score = 1.0 / perf.avg_duration.as_secs_f64();
410            let reliability_score = perf.success_rate;
411            (speed_score * 0.6) + (reliability_score * 0.4)
412        } else {
413            0.0
414        }
415    }
416
417    fn predict_performance<T>(
418        &self,
419        algorithm_name: &str,
420        _input: &T,
421        cache: &HashMap<String, AlgorithmPerformance>,
422    ) -> f64 {
423        // Simplified performance prediction
424        // A real implementation would use features from the input and a trained model
425        if let Some(perf) = cache.get(algorithm_name) {
426            // Simple heuristic: recent performance with some randomness for exploration
427            let base_score = 1.0 / perf.avg_duration.as_secs_f64() * perf.success_rate;
428            let exploration_factor = {
429                use scirs2_core::random::thread_rng;
430                let mut rng = thread_rng();
431                0.1 * rng.random::<f64>()
432            };
433            base_score + exploration_factor
434        } else {
435            // Unknown algorithm gets a default score
436            0.5
437        }
438    }
439
440    fn update_performance_stats(&self, algorithm_name: &str, duration: Duration, success: bool) {
441        let mut cache = self.lock_cache_mut();
442
443        let updated_perf = if let Some(mut perf) = cache.get(algorithm_name).cloned() {
444            // Update existing performance record using exponential moving average
445            let new_duration_secs = duration.as_secs_f64();
446            let old_duration_secs = perf.avg_duration.as_secs_f64();
447            let updated_duration_secs = old_duration_secs * (1.0 - self.learning_rate)
448                + new_duration_secs * self.learning_rate;
449
450            perf.avg_duration = Duration::from_secs_f64(updated_duration_secs);
451            perf.sample_count += 1;
452
453            let new_success_rate = if success { 1.0 } else { 0.0 };
454            perf.success_rate = perf.success_rate * (1.0 - self.learning_rate)
455                + new_success_rate * self.learning_rate;
456
457            #[cfg(not(feature = "no-std"))]
458            {
459                perf.last_updated = SystemTime::now();
460            }
461            #[cfg(feature = "no-std")]
462            {
463                perf.last_updated = ();
464            }
465
466            perf
467        } else {
468            // Create new performance record
469            AlgorithmPerformance {
470                algorithm_name: algorithm_name.to_string(),
471                avg_duration: duration,
472                sample_count: 1,
473                data_size_range: (0, 0), // Would be updated based on input characteristics
474                success_rate: if success { 1.0 } else { 0.0 },
475                #[cfg(not(feature = "no-std"))]
476                last_updated: SystemTime::now(),
477                #[cfg(feature = "no-std")]
478                last_updated: (),
479            }
480        };
481
482        cache.insert(algorithm_name.to_string(), updated_perf);
483    }
484}
485
486/// Performance feedback loop for continuous optimization
487pub struct PerformanceFeedbackLoop {
488    optimizer: AdaptiveOptimizer,
489    feedback_history: Vec<FeedbackRecord>,
490    adaptation_threshold: f64,
491}
492
493#[derive(Debug, Clone)]
494pub struct FeedbackRecord {
495    #[cfg(not(feature = "no-std"))]
496    pub timestamp: SystemTime,
497    #[cfg(feature = "no-std")]
498    pub timestamp: (),
499    pub algorithm_name: String,
500    pub input_characteristics: String,
501    pub performance_metric: f64,
502    pub context: String,
503}
504
505impl Default for PerformanceFeedbackLoop {
506    fn default() -> Self {
507        Self::new()
508    }
509}
510
511impl PerformanceFeedbackLoop {
512    /// Create a new feedback loop
513    pub fn new() -> Self {
514        Self {
515            optimizer: AdaptiveOptimizer::new(),
516            feedback_history: Vec::new(),
517            adaptation_threshold: 0.05, // 5% performance change triggers adaptation
518        }
519    }
520
521    /// Add performance feedback
522    pub fn add_feedback(&mut self, record: FeedbackRecord) {
523        self.feedback_history.push(record);
524
525        // Trigger adaptation if we have enough feedback
526        if self.feedback_history.len().is_multiple_of(10) {
527            self.adapt_strategies();
528        }
529    }
530
531    /// Analyze feedback and adapt optimization strategies
532    fn adapt_strategies(&mut self) {
533        // Analyze recent feedback to identify trends
534        let recent_feedback: Vec<&FeedbackRecord> =
535            self.feedback_history.iter().rev().take(20).collect();
536
537        // Group by algorithm
538        let mut algorithm_groups: HashMap<String, Vec<&FeedbackRecord>> = HashMap::new();
539        for record in recent_feedback {
540            algorithm_groups
541                .entry(record.algorithm_name.clone())
542                .or_default()
543                .push(record);
544        }
545
546        // Adapt learning rate based on performance variance
547        for (_algorithm_name, records) in algorithm_groups {
548            if records.len() >= 3 {
549                let metrics: Vec<f64> = records.iter().map(|r| r.performance_metric).collect();
550                let variance = self.calculate_variance(&metrics);
551
552                // If performance is highly variable, increase exploration
553                if variance > self.adaptation_threshold {
554                    self.optimizer.set_learning_rate(0.2);
555                } else {
556                    self.optimizer.set_learning_rate(0.05);
557                }
558            }
559        }
560    }
561
562    fn calculate_variance(&self, values: &[f64]) -> f64 {
563        if values.len() < 2 {
564            return 0.0;
565        }
566
567        let mean = values.iter().sum::<f64>() / values.len() as f64;
568        let variance =
569            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (values.len() - 1) as f64;
570
571        variance
572    }
573}
574
575/// Auto-tuning configuration
576#[derive(Debug, Clone)]
577pub struct AutoTuningConfig {
578    pub enabled: bool,
579    pub max_iterations: usize,
580    pub convergence_threshold: f64,
581    pub exploration_rate: f64,
582    pub adaptation_interval: Duration,
583}
584
585impl Default for AutoTuningConfig {
586    fn default() -> Self {
587        Self {
588            enabled: true,
589            max_iterations: 100,
590            convergence_threshold: 0.01,
591            exploration_rate: 0.1,
592            adaptation_interval: Duration::from_secs(300), // 5 minutes
593        }
594    }
595}
596
597#[allow(non_snake_case)]
598#[cfg(all(test, not(feature = "no-std")))]
599mod tests {
600    use super::*;
601
602    #[cfg(feature = "no-std")]
603    use alloc::{
604        boxed::Box,
605        string::{String, ToString},
606        vec,
607        vec::Vec,
608    };
609
610    #[cfg(not(feature = "no-std"))]
611    use std::time::Duration;
612
613    // Mock algorithm variant for testing
614    struct MockAlgorithmVariant {
615        name: String,
616        execution_time: Duration,
617        success_rate: f64,
618    }
619
620    impl AlgorithmVariant<Vec<f32>> for MockAlgorithmVariant {
621        fn name(&self) -> &str {
622            &self.name
623        }
624
625        fn execute(&self, _input: &Vec<f32>) -> Result<Vec<f32>, AlgorithmError> {
626            #[cfg(not(feature = "no-std"))]
627            std::thread::sleep(self.execution_time);
628
629            let random_val = {
630                use scirs2_core::random::thread_rng;
631                let mut rng = thread_rng();
632                rng.random::<f64>()
633            };
634            if random_val < self.success_rate {
635                Ok(vec![1.0, 2.0, 3.0])
636            } else {
637                Err(AlgorithmError::RuntimeError("Mock error".to_string()))
638            }
639        }
640
641        fn is_applicable(&self, input: &Vec<f32>) -> bool {
642            !input.is_empty()
643        }
644
645        fn estimated_cost(&self, input: &Vec<f32>) -> f64 {
646            input.len() as f64 * self.execution_time.as_secs_f64()
647        }
648    }
649
650    #[test]
651    fn test_adaptive_optimizer_creation() {
652        let optimizer = AdaptiveOptimizer::new();
653        assert!(optimizer.auto_tuning_enabled);
654        assert_eq!(optimizer.learning_rate, 0.1);
655    }
656
657    #[test]
658    fn test_algorithm_selection() {
659        let optimizer = AdaptiveOptimizer::new();
660        let input = vec![1.0, 2.0, 3.0];
661
662        let variants: Vec<Box<dyn AlgorithmVariant<Vec<f32>>>> = vec![
663            Box::new(MockAlgorithmVariant {
664                name: "fast_algorithm".to_string(),
665                execution_time: Duration::from_millis(10),
666                success_rate: 0.9,
667            }),
668            Box::new(MockAlgorithmVariant {
669                name: "slow_algorithm".to_string(),
670                execution_time: Duration::from_millis(100),
671                success_rate: 0.99,
672            }),
673        ];
674
675        let selected =
676            optimizer.select_algorithm(&variants, &input, DispatchStrategy::AlwaysFastest);
677
678        assert!(selected.is_some());
679    }
680
681    #[test]
682    fn test_performance_tracking() {
683        let optimizer = AdaptiveOptimizer::new();
684        let input = vec![1.0, 2.0, 3.0];
685
686        let variant: Box<dyn AlgorithmVariant<Vec<f32>>> = Box::new(MockAlgorithmVariant {
687            name: "test_algorithm".to_string(),
688            execution_time: Duration::from_millis(1),
689            success_rate: 1.0,
690        });
691
692        let result = optimizer.execute_with_tracking(variant.as_ref(), &input);
693        assert!(result.is_ok());
694
695        // Check that performance stats were recorded
696        let stats = optimizer.get_performance_stats("test_algorithm");
697        assert!(stats.is_some());
698    }
699
700    #[test]
701    fn test_feedback_loop() {
702        let mut feedback_loop = PerformanceFeedbackLoop::new();
703
704        let record = FeedbackRecord {
705            #[cfg(not(feature = "no-std"))]
706            timestamp: SystemTime::now(),
707            #[cfg(feature = "no-std")]
708            timestamp: (),
709            algorithm_name: "test_algo".to_string(),
710            input_characteristics: "small_data".to_string(),
711            performance_metric: 0.5,
712            context: "test".to_string(),
713        };
714
715        feedback_loop.add_feedback(record);
716        assert_eq!(feedback_loop.feedback_history.len(), 1);
717    }
718
719    #[test]
720    fn test_auto_tuning_config() {
721        let config = AutoTuningConfig::default();
722        assert!(config.enabled);
723        assert_eq!(config.max_iterations, 100);
724        assert_eq!(config.convergence_threshold, 0.01);
725    }
726}