Skip to main content

scirs2_fft/
auto_tuning.rs

1//! Auto-tuning for hardware-specific FFT optimizations
2//!
3//! This module provides functionality to automatically tune FFT parameters
4//! for optimal performance on the current hardware. It includes:
5//!
6//! - Benchmarking different FFT configurations
7//! - Selecting optimal parameters based on timing results
8//! - Persisting tuning results for future use
9//! - Detecting CPU features and adapting algorithms accordingly
10
11#[cfg(feature = "oxifft")]
12use crate::oxifft_plan_cache;
13#[cfg(feature = "oxifft")]
14use oxifft::{Complex as OxiComplex, Direction};
15#[cfg(feature = "rustfft-backend")]
16use rustfft::FftPlanner;
17use scirs2_core::numeric::Complex64;
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20use std::fs::{self, File};
21use std::io::{BufReader, BufWriter};
22use std::path::{Path, PathBuf};
23use std::time::Instant;
24
25use crate::error::{FFTError, FFTResult};
26use crate::plan_serialization::PlanSerializationManager;
27
28/// A range of FFT sizes to benchmark
29#[derive(Debug, Clone)]
30pub struct SizeRange {
31    /// Minimum size to test
32    pub min: usize,
33    /// Maximum size to test
34    pub max: usize,
35    /// Step between sizes (can be multiplication factor)
36    pub step: SizeStep,
37}
38
39/// Step type for size range
40#[derive(Debug, Clone)]
41pub enum SizeStep {
42    /// Add a constant value
43    Linear(usize),
44    /// Multiply by a factor
45    Exponential(f64),
46    /// Use powers of two
47    PowersOfTwo,
48    /// Use specific sizes
49    Custom(Vec<usize>),
50}
51
52/// FFT algorithm variant to benchmark
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
54pub enum FftVariant {
55    /// Standard FFT
56    Standard,
57    /// In-place FFT
58    InPlace,
59    /// Cached-plan FFT
60    Cached,
61    /// Split-radix FFT
62    SplitRadix,
63}
64
65/// Configuration for auto-tuning
66#[derive(Debug, Clone)]
67pub struct AutoTuneConfig {
68    /// Sizes to benchmark
69    pub sizes: SizeRange,
70    /// Number of repetitions per test
71    pub repetitions: usize,
72    /// Warm-up iterations (not timed)
73    pub warmup: usize,
74    /// FFT variants to test
75    pub variants: Vec<FftVariant>,
76    /// Path to save tuning results
77    pub database_path: PathBuf,
78}
79
80impl Default for AutoTuneConfig {
81    fn default() -> Self {
82        Self {
83            sizes: SizeRange {
84                min: 16,
85                max: 8192,
86                step: SizeStep::PowersOfTwo,
87            },
88            repetitions: 10,
89            warmup: 3,
90            variants: vec![FftVariant::Standard, FftVariant::Cached],
91            database_path: PathBuf::from(".fft_tuning_db.json"),
92        }
93    }
94}
95
96/// Results from a single benchmark
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct BenchmarkResult {
99    /// FFT size
100    pub size: usize,
101    /// FFT variant
102    pub variant: FftVariant,
103    /// Whether this is forward or inverse FFT
104    pub forward: bool,
105    /// Average execution time in nanoseconds
106    pub avg_time_ns: u64,
107    /// Minimum execution time in nanoseconds
108    pub min_time_ns: u64,
109    /// Standard deviation in nanoseconds
110    pub std_dev_ns: f64,
111    /// System information when the benchmark was run
112    pub system_info: SystemInfo,
113}
114
115/// System information for result matching
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct SystemInfo {
118    /// CPU model
119    pub cpu_model: String,
120    /// Number of cores
121    pub num_cores: usize,
122    /// Architecture
123    pub architecture: String,
124    /// CPU features (SIMD instruction sets, etc.)
125    pub cpu_features: Vec<String>,
126}
127
128/// Database of tuning results
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct TuningDatabase {
131    /// Benchmark results
132    pub results: Vec<BenchmarkResult>,
133    /// Last updated timestamp
134    pub last_updated: u64,
135    /// Best algorithm for each size
136    pub best_algorithms: HashMap<(usize, bool), FftVariant>,
137}
138
139/// Auto-tuning manager
140pub struct AutoTuner {
141    /// Configuration
142    config: AutoTuneConfig,
143    /// Database of results
144    database: TuningDatabase,
145    /// Whether to use tuning
146    enabled: bool,
147}
148
149impl Default for AutoTuner {
150    fn default() -> Self {
151        Self::with_config(AutoTuneConfig::default())
152    }
153}
154
155impl AutoTuner {
156    /// Create a new auto-tuner with default configuration
157    pub fn new() -> Self {
158        Self::default()
159    }
160
161    /// Create a new auto-tuner with custom configuration
162    pub fn with_config(config: AutoTuneConfig) -> Self {
163        let database =
164            Self::load_database(&config.database_path).unwrap_or_else(|_| TuningDatabase {
165                results: Vec::new(),
166                last_updated: std::time::SystemTime::now()
167                    .duration_since(std::time::UNIX_EPOCH)
168                    .unwrap_or_default()
169                    .as_secs(),
170                best_algorithms: HashMap::new(),
171            });
172
173        Self {
174            config,
175            database,
176            enabled: true,
177        }
178    }
179
180    /// Load the tuning database from disk
181    fn load_database(path: &Path) -> FFTResult<TuningDatabase> {
182        if !path.exists() {
183            return Err(FFTError::IOError(format!(
184                "Tuning database file not found: {}",
185                path.display()
186            )));
187        }
188
189        let file = File::open(path)
190            .map_err(|e| FFTError::IOError(format!("Failed to open tuning database: {e}")))?;
191
192        let reader = BufReader::new(file);
193        let database: TuningDatabase = serde_json::from_reader(reader)
194            .map_err(|e| FFTError::ValueError(format!("Failed to parse tuning database: {e}")))?;
195
196        Ok(database)
197    }
198
199    /// Save the tuning database to disk
200    pub fn save_database(&self) -> FFTResult<()> {
201        // Create parent directories if they don't exist
202        if let Some(parent) = self.config.database_path.parent() {
203            fs::create_dir_all(parent).map_err(|e| {
204                FFTError::IOError(format!(
205                    "Failed to create directory for tuning database: {e}"
206                ))
207            })?;
208        }
209
210        let file = File::create(&self.config.database_path).map_err(|e| {
211            FFTError::IOError(format!("Failed to create tuning database file: {e}"))
212        })?;
213
214        let writer = BufWriter::new(file);
215        serde_json::to_writer_pretty(writer, &self.database)
216            .map_err(|e| FFTError::IOError(format!("Failed to serialize tuning database: {e}")))?;
217
218        Ok(())
219    }
220
221    /// Enable or disable auto-tuning
222    pub fn set_enabled(&mut self, enabled: bool) {
223        self.enabled = enabled;
224    }
225
226    /// Check if auto-tuning is enabled
227    pub fn is_enabled(&self) -> bool {
228        self.enabled
229    }
230
231    /// Run benchmarks for all configured FFT variants and sizes
232    pub fn run_benchmarks(&mut self) -> FFTResult<()> {
233        if !self.enabled {
234            return Ok(());
235        }
236
237        let sizes = self.generate_sizes();
238        let mut results = Vec::new();
239
240        for size in sizes {
241            for &variant in &self.config.variants {
242                // Benchmark forward transform
243                let forward_result = self.benchmark_variant(size, variant, true)?;
244                results.push(forward_result);
245
246                // Benchmark inverse transform
247                let inverse_result = self.benchmark_variant(size, variant, false)?;
248                results.push(inverse_result);
249            }
250        }
251
252        // Update database
253        self.database.results.extend(results);
254        self.update_best_algorithms();
255        self.save_database()?;
256
257        Ok(())
258    }
259
260    /// Generate the list of sizes to benchmark
261    fn generate_sizes(&self) -> Vec<usize> {
262        let mut sizes = Vec::new();
263
264        match &self.config.sizes.step {
265            SizeStep::Linear(step) => {
266                let mut size = self.config.sizes.min;
267                while size <= self.config.sizes.max {
268                    sizes.push(size);
269                    size += step;
270                }
271            }
272            SizeStep::Exponential(factor) => {
273                let mut size = self.config.sizes.min as f64;
274                while size <= self.config.sizes.max as f64 {
275                    sizes.push(size as usize);
276                    size *= factor;
277                }
278            }
279            SizeStep::PowersOfTwo => {
280                let mut size = 1;
281                while size < self.config.sizes.min {
282                    size *= 2;
283                }
284                while size <= self.config.sizes.max {
285                    sizes.push(size);
286                    size *= 2;
287                }
288            }
289            SizeStep::Custom(custom_sizes) => {
290                for &size in custom_sizes {
291                    if size >= self.config.sizes.min && size <= self.config.sizes.max {
292                        sizes.push(size);
293                    }
294                }
295            }
296        }
297
298        sizes
299    }
300
301    /// Benchmark a specific FFT variant for a given size
302    fn benchmark_variant(
303        &self,
304        size: usize,
305        variant: FftVariant,
306        forward: bool,
307    ) -> FFTResult<BenchmarkResult> {
308        // Create test data
309        let mut data = vec![Complex64::new(0.0, 0.0); size];
310        for (i, val) in data.iter_mut().enumerate().take(size) {
311            *val = Complex64::new(i as f64, (i * 2) as f64);
312        }
313
314        // Warm-up phase
315        for _ in 0..self.config.warmup {
316            match variant {
317                FftVariant::Standard => {
318                    #[cfg(feature = "oxifft")]
319                    {
320                        let input_oxi: Vec<OxiComplex<f64>> =
321                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
322                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
323
324                        let direction = if forward {
325                            Direction::Forward
326                        } else {
327                            Direction::Backward
328                        };
329                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
330                    }
331
332                    #[cfg(not(feature = "oxifft"))]
333                    {
334                        #[cfg(feature = "rustfft-backend")]
335                        {
336                            let mut planner = FftPlanner::new();
337                            let fft = if forward {
338                                planner.plan_fft_forward(size)
339                            } else {
340                                planner.plan_fft_inverse(size)
341                            };
342                            let mut buffer = data.clone();
343                            fft.process(&mut buffer);
344                        }
345                    }
346                }
347                FftVariant::InPlace => {
348                    #[cfg(feature = "oxifft")]
349                    {
350                        let input_oxi: Vec<OxiComplex<f64>> =
351                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
352                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
353
354                        let direction = if forward {
355                            Direction::Forward
356                        } else {
357                            Direction::Backward
358                        };
359                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
360                    }
361
362                    #[cfg(not(feature = "oxifft"))]
363                    {
364                        #[cfg(feature = "rustfft-backend")]
365                        {
366                            let mut planner = FftPlanner::new();
367                            let fft = if forward {
368                                planner.plan_fft_forward(size)
369                            } else {
370                                planner.plan_fft_inverse(size)
371                            };
372                            // Use in-place processing with scratch buffer
373                            let mut buffer = data.clone();
374                            let mut scratch =
375                                vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
376                            fft.process_with_scratch(&mut buffer, &mut scratch);
377                        }
378                    }
379                }
380                FftVariant::Cached => {
381                    // Create a plan via the serialization manager
382                    let manager = PlanSerializationManager::new(&self.config.database_path);
383                    let plan_info = manager.create_plan_info(size, forward);
384                    let (_, time) = crate::plan_serialization::create_and_time_plan(size, forward);
385                    manager.record_plan_usage(&plan_info, time).unwrap_or(());
386                }
387                FftVariant::SplitRadix => {
388                    #[cfg(feature = "oxifft")]
389                    {
390                        // For now, use OxiFFT's standard algorithm
391                        let input_oxi: Vec<OxiComplex<f64>> =
392                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
393                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
394
395                        let direction = if forward {
396                            Direction::Forward
397                        } else {
398                            Direction::Backward
399                        };
400                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
401                    }
402
403                    #[cfg(not(feature = "oxifft"))]
404                    {
405                        #[cfg(feature = "rustfft-backend")]
406                        {
407                            // For now, this is just an example variant
408                            // In a real implementation, we'd use a specific split-radix algorithm
409                            let mut planner = FftPlanner::new();
410                            let fft = if forward {
411                                planner.plan_fft_forward(size)
412                            } else {
413                                planner.plan_fft_inverse(size)
414                            };
415                            let mut buffer = data.clone();
416                            fft.process(&mut buffer);
417                        }
418                    }
419                }
420            }
421        }
422
423        // Timing phase
424        let mut times = Vec::with_capacity(self.config.repetitions);
425
426        for _ in 0..self.config.repetitions {
427            let start = Instant::now();
428
429            match variant {
430                FftVariant::Standard => {
431                    #[cfg(feature = "oxifft")]
432                    {
433                        let input_oxi: Vec<OxiComplex<f64>> =
434                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
435                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
436
437                        let direction = if forward {
438                            Direction::Forward
439                        } else {
440                            Direction::Backward
441                        };
442                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
443                    }
444
445                    #[cfg(not(feature = "oxifft"))]
446                    {
447                        #[cfg(feature = "rustfft-backend")]
448                        {
449                            let mut planner = FftPlanner::new();
450                            let fft = if forward {
451                                planner.plan_fft_forward(size)
452                            } else {
453                                planner.plan_fft_inverse(size)
454                            };
455                            let mut buffer = data.clone();
456                            fft.process(&mut buffer);
457                        }
458                    }
459                }
460                FftVariant::InPlace => {
461                    #[cfg(feature = "oxifft")]
462                    {
463                        let input_oxi: Vec<OxiComplex<f64>> =
464                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
465                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
466
467                        let direction = if forward {
468                            Direction::Forward
469                        } else {
470                            Direction::Backward
471                        };
472                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
473                    }
474
475                    #[cfg(not(feature = "oxifft"))]
476                    {
477                        #[cfg(feature = "rustfft-backend")]
478                        {
479                            let mut planner = FftPlanner::new();
480                            let fft = if forward {
481                                planner.plan_fft_forward(size)
482                            } else {
483                                planner.plan_fft_inverse(size)
484                            };
485                            // Use in-place processing with scratch buffer
486                            let mut buffer = data.clone();
487                            let mut scratch =
488                                vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
489                            fft.process_with_scratch(&mut buffer, &mut scratch);
490                        }
491                    }
492                }
493                FftVariant::Cached => {
494                    #[cfg(feature = "oxifft")]
495                    {
496                        let input_oxi: Vec<OxiComplex<f64>> =
497                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
498                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
499
500                        let direction = if forward {
501                            Direction::Forward
502                        } else {
503                            Direction::Backward
504                        };
505                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
506                    }
507
508                    #[cfg(not(feature = "oxifft"))]
509                    {
510                        #[cfg(feature = "rustfft-backend")]
511                        {
512                            // Use the plan cache
513                            let mut planner = FftPlanner::new();
514                            let fft = if forward {
515                                planner.plan_fft_forward(size)
516                            } else {
517                                planner.plan_fft_inverse(size)
518                            };
519                            let mut buffer = data.clone();
520                            fft.process(&mut buffer);
521                        }
522                    }
523                }
524                FftVariant::SplitRadix => {
525                    #[cfg(feature = "oxifft")]
526                    {
527                        let input_oxi: Vec<OxiComplex<f64>> =
528                            data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
529                        let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
530
531                        let direction = if forward {
532                            Direction::Forward
533                        } else {
534                            Direction::Backward
535                        };
536                        let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
537                    }
538
539                    #[cfg(not(feature = "oxifft"))]
540                    {
541                        #[cfg(feature = "rustfft-backend")]
542                        {
543                            // Placeholder for split-radix implementation
544                            let mut planner = FftPlanner::new();
545                            let fft = if forward {
546                                planner.plan_fft_forward(size)
547                            } else {
548                                planner.plan_fft_inverse(size)
549                            };
550                            let mut buffer = data.clone();
551                            fft.process(&mut buffer);
552                        }
553                    }
554                }
555            }
556
557            let elapsed = start.elapsed();
558            times.push(elapsed.as_nanos() as u64);
559        }
560
561        // Calculate statistics
562        let avg_time = times.iter().sum::<u64>() / times.len() as u64;
563        let min_time = *times.iter().min().unwrap_or(&0);
564
565        // Calculate standard deviation
566        let variance = times
567            .iter()
568            .map(|&t| {
569                let diff = t as f64 - avg_time as f64;
570                diff * diff
571            })
572            .sum::<f64>()
573            / times.len() as f64;
574        let std_dev = variance.sqrt();
575
576        Ok(BenchmarkResult {
577            size,
578            variant,
579            forward,
580            avg_time_ns: avg_time,
581            min_time_ns: min_time,
582            std_dev_ns: std_dev,
583            system_info: self.detect_system_info(),
584        })
585    }
586
587    /// Detect system information for result matching
588    fn detect_system_info(&self) -> SystemInfo {
589        // This is a simplified version - a real implementation would
590        // detect actual CPU model, features, etc.
591        SystemInfo {
592            cpu_model: String::from("Unknown"),
593            num_cores: num_cpus::get(),
594            architecture: std::env::consts::ARCH.to_string(),
595            cpu_features: detect_cpu_features(),
596        }
597    }
598
599    /// Update the best algorithms based on benchmark results
600    fn update_best_algorithms(&mut self) {
601        // Clear existing best algorithms
602        self.database.best_algorithms.clear();
603
604        // Group results by size and direction
605        let mut grouped: HashMap<(usize, bool), Vec<&BenchmarkResult>> = HashMap::new();
606        for result in &self.database.results {
607            grouped
608                .entry((result.size, result.forward))
609                .or_default()
610                .push(result);
611        }
612
613        // Find the best algorithm for each group
614        for ((size, forward), results) in grouped {
615            if let Some(best) = results.iter().min_by_key(|r| r.avg_time_ns) {
616                self.database
617                    .best_algorithms
618                    .insert((size, forward), best.variant);
619            }
620        }
621    }
622
623    /// Get the best FFT variant for the given size and direction
624    pub fn get_best_variant(&self, size: usize, forward: bool) -> FftVariant {
625        if !self.enabled {
626            return FftVariant::Standard;
627        }
628
629        // Look for exact size match
630        if let Some(&variant) = self.database.best_algorithms.get(&(size, forward)) {
631            return variant;
632        }
633
634        // Look for closest size match
635        let mut closest_size = 0;
636        let mut min_diff = usize::MAX;
637
638        for &(s, f) in self.database.best_algorithms.keys() {
639            if f == forward {
640                let diff = s.abs_diff(size);
641                if diff < min_diff {
642                    min_diff = diff;
643                    closest_size = s;
644                }
645            }
646        }
647
648        if closest_size > 0 {
649            if let Some(&variant) = self.database.best_algorithms.get(&(closest_size, forward)) {
650                return variant;
651            }
652        }
653
654        // Default to standard FFT if no match
655        FftVariant::Standard
656    }
657
658    /// Run FFT with optimal algorithm selection
659    pub fn run_optimal_fft<T>(
660        &self,
661        input: &[T],
662        size: Option<usize>,
663        forward: bool,
664    ) -> FFTResult<Vec<Complex64>>
665    where
666        T: Clone + Into<Complex64>,
667    {
668        let actual_size = size.unwrap_or(input.len());
669        let variant = self.get_best_variant(actual_size, forward);
670
671        // Convert input to complex
672        let mut buffer: Vec<Complex64> = input.iter().map(|x| x.clone().into()).collect();
673        // Pad if necessary
674        if buffer.len() < actual_size {
675            buffer.resize(actual_size, Complex64::new(0.0, 0.0));
676        }
677
678        #[cfg(feature = "oxifft")]
679        {
680            let input_oxi: Vec<OxiComplex<f64>> =
681                buffer.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
682            let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); actual_size];
683
684            let direction = if forward {
685                Direction::Forward
686            } else {
687                Direction::Backward
688            };
689            oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
690
691            // Copy result back to buffer
692            for (i, val) in output.iter().enumerate() {
693                buffer[i] = Complex64::new(val.re, val.im);
694            }
695        }
696
697        #[cfg(not(feature = "oxifft"))]
698        {
699            #[cfg(feature = "rustfft-backend")]
700            {
701                match variant {
702                    FftVariant::Standard => {
703                        let mut planner = FftPlanner::new();
704                        let fft = if forward {
705                            planner.plan_fft_forward(actual_size)
706                        } else {
707                            planner.plan_fft_inverse(actual_size)
708                        };
709                        fft.process(&mut buffer);
710                    }
711                    FftVariant::InPlace => {
712                        let mut planner = FftPlanner::new();
713                        let fft = if forward {
714                            planner.plan_fft_forward(actual_size)
715                        } else {
716                            planner.plan_fft_inverse(actual_size)
717                        };
718                        let mut scratch =
719                            vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
720                        fft.process_with_scratch(&mut buffer, &mut scratch);
721                    }
722                    FftVariant::Cached => {
723                        // Use the plan cache via PlanSerializationManager
724                        // Create a plan directly - manager is not needed here
725                        let (plan_, _) =
726                            crate::plan_serialization::create_and_time_plan(actual_size, forward);
727                        plan_.process(&mut buffer);
728                    }
729                    FftVariant::SplitRadix => {
730                        // Placeholder for split-radix FFT
731                        let mut planner = FftPlanner::new();
732                        let fft = if forward {
733                            planner.plan_fft_forward(actual_size)
734                        } else {
735                            planner.plan_fft_inverse(actual_size)
736                        };
737                        fft.process(&mut buffer);
738                    }
739                }
740            }
741
742            #[cfg(not(feature = "rustfft-backend"))]
743            {
744                return Err(FFTError::ComputationError(
745                    "No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature.".to_string()
746                ));
747            }
748        }
749
750        // Scale inverse FFT by 1/N if required
751        if !forward {
752            let scale = 1.0 / (actual_size as f64);
753            for val in &mut buffer {
754                *val *= scale;
755            }
756        }
757
758        Ok(buffer)
759    }
760}
761
762/// Detect CPU features for result matching
763#[allow(dead_code)]
764fn detect_cpu_features() -> Vec<String> {
765    let mut features = Vec::new();
766
767    // Target-specific feature detection
768    #[cfg(target_arch = "x86_64")]
769    {
770        #[cfg(target_feature = "sse")]
771        features.push("sse".to_string());
772
773        #[cfg(target_feature = "sse2")]
774        features.push("sse2".to_string());
775
776        #[cfg(target_feature = "sse3")]
777        features.push("sse3".to_string());
778
779        #[cfg(target_feature = "sse4.1")]
780        features.push("sse4.1".to_string());
781
782        #[cfg(target_feature = "sse4.2")]
783        features.push("sse4.2".to_string());
784
785        #[cfg(target_feature = "avx")]
786        features.push("avx".to_string());
787
788        #[cfg(target_feature = "avx2")]
789        features.push("avx2".to_string());
790
791        #[cfg(target_feature = "fma")]
792        features.push("fma".to_string());
793    }
794
795    // ARM-specific features
796    #[cfg(target_arch = "aarch64")]
797    {
798        #[cfg(target_feature = "neon")]
799        features.push("neon".to_string());
800    }
801
802    // Add more architecture-specific features if needed
803
804    features
805}
806
807// ============================================================================
808// Enhanced Auto-Selection (v0.2.0)
809// ============================================================================
810
811/// Integrated auto-selection that combines algorithm selection with auto-tuning
812pub struct IntegratedAutoSelector {
813    /// Algorithm selector for input-characteristic based selection
814    selector: crate::algorithm_selector::AlgorithmSelector,
815    /// Auto-tuner for performance-based selection
816    tuner: AutoTuner,
817    /// Whether to prefer learned performance data
818    prefer_learned: bool,
819}
820
821impl Default for IntegratedAutoSelector {
822    fn default() -> Self {
823        Self::new()
824    }
825}
826
827impl IntegratedAutoSelector {
828    /// Create a new integrated auto-selector
829    pub fn new() -> Self {
830        Self {
831            selector: crate::algorithm_selector::AlgorithmSelector::new(),
832            tuner: AutoTuner::new(),
833            prefer_learned: true,
834        }
835    }
836
837    /// Create with custom configuration
838    pub fn with_config(
839        selector_config: crate::algorithm_selector::SelectionConfig,
840        tuner_config: AutoTuneConfig,
841        prefer_learned: bool,
842    ) -> Self {
843        Self {
844            selector: crate::algorithm_selector::AlgorithmSelector::with_config(selector_config),
845            tuner: AutoTuner::with_config(tuner_config),
846            prefer_learned,
847        }
848    }
849
850    /// Select the best algorithm for the given size
851    pub fn select(&self, size: usize, forward: bool) -> FFTResult<SelectionResult> {
852        // First, check if we have learned performance data
853        if self.prefer_learned && self.tuner.is_enabled() {
854            let variant = self.tuner.get_best_variant(size, forward);
855            if variant != FftVariant::Standard {
856                // We have learned data, use it
857                return Ok(SelectionResult {
858                    algorithm: variant_to_algorithm(variant),
859                    variant,
860                    source: SelectionSource::Learned,
861                    confidence: 0.9,
862                    recommendation: self.selector.select_algorithm(size, forward).ok(),
863                });
864            }
865        }
866
867        // Fall back to input-characteristic based selection
868        let recommendation = self.selector.select_algorithm(size, forward)?;
869        let variant = algorithm_to_variant(recommendation.algorithm);
870
871        Ok(SelectionResult {
872            algorithm: recommendation.algorithm,
873            variant,
874            source: SelectionSource::Characteristic,
875            confidence: recommendation.confidence,
876            recommendation: Some(recommendation),
877        })
878    }
879
880    /// Run auto-tuning for a range of sizes
881    pub fn auto_tune(&mut self, sizes: &[usize]) -> FFTResult<()> {
882        // Generate size range from provided sizes
883        if sizes.is_empty() {
884            return Ok(());
885        }
886
887        let min = *sizes.iter().min().unwrap_or(&16);
888        let max = *sizes.iter().max().unwrap_or(&8192);
889
890        let config = AutoTuneConfig {
891            sizes: SizeRange {
892                min,
893                max,
894                step: SizeStep::Custom(sizes.to_vec()),
895            },
896            ..Default::default()
897        };
898
899        self.tuner = AutoTuner::with_config(config);
900        self.tuner.run_benchmarks()
901    }
902
903    /// Execute FFT with optimal algorithm
904    pub fn execute<T>(
905        &self,
906        input: &[T],
907        size: Option<usize>,
908        forward: bool,
909    ) -> FFTResult<Vec<Complex64>>
910    where
911        T: Clone + Into<Complex64>,
912    {
913        let actual_size = size.unwrap_or(input.len());
914        let selection = self.select(actual_size, forward)?;
915
916        // Use the tuner's run_optimal_fft which handles the actual execution
917        self.tuner.run_optimal_fft(input, size, forward)
918    }
919
920    /// Get the algorithm selector
921    pub fn selector(&self) -> &crate::algorithm_selector::AlgorithmSelector {
922        &self.selector
923    }
924
925    /// Get the auto-tuner
926    pub fn tuner(&self) -> &AutoTuner {
927        &self.tuner
928    }
929}
930
931/// Result of algorithm selection
932#[derive(Debug, Clone)]
933pub struct SelectionResult {
934    /// Selected algorithm
935    pub algorithm: crate::algorithm_selector::FftAlgorithm,
936    /// Corresponding FFT variant
937    pub variant: FftVariant,
938    /// Source of the selection
939    pub source: SelectionSource,
940    /// Confidence in the selection
941    pub confidence: f64,
942    /// Full recommendation (if available)
943    pub recommendation: Option<crate::algorithm_selector::AlgorithmRecommendation>,
944}
945
946/// Source of algorithm selection
947#[derive(Debug, Clone, Copy, PartialEq, Eq)]
948pub enum SelectionSource {
949    /// Selected based on learned performance data
950    Learned,
951    /// Selected based on input characteristics
952    Characteristic,
953    /// Forced by configuration
954    Forced,
955    /// Default fallback
956    Default,
957}
958
959/// Convert FftVariant to FftAlgorithm
960fn variant_to_algorithm(variant: FftVariant) -> crate::algorithm_selector::FftAlgorithm {
961    use crate::algorithm_selector::FftAlgorithm;
962    match variant {
963        FftVariant::Standard => FftAlgorithm::MixedRadix,
964        FftVariant::InPlace => FftAlgorithm::InPlace,
965        FftVariant::Cached => FftAlgorithm::MixedRadix,
966        FftVariant::SplitRadix => FftAlgorithm::SplitRadix,
967    }
968}
969
970/// Convert FftAlgorithm to FftVariant
971fn algorithm_to_variant(algorithm: crate::algorithm_selector::FftAlgorithm) -> FftVariant {
972    use crate::algorithm_selector::FftAlgorithm;
973    match algorithm {
974        FftAlgorithm::SplitRadix => FftVariant::SplitRadix,
975        FftAlgorithm::InPlace => FftVariant::InPlace,
976        _ => FftVariant::Standard,
977    }
978}
979
980/// Auto-select the best FFT algorithm for the given input
981///
982/// This is a convenience function that uses the integrated auto-selector
983/// to determine the optimal algorithm based on input characteristics and
984/// learned performance data.
985///
986/// # Arguments
987///
988/// * `size` - FFT size
989/// * `forward` - Whether this is a forward (true) or inverse (false) transform
990///
991/// # Returns
992///
993/// The recommended algorithm and metadata
994///
995/// # Example
996///
997/// ```rust
998/// use scirs2_fft::auto_tuning::auto_select_algorithm;
999///
1000/// let result = auto_select_algorithm(1024, true).expect("Selection failed");
1001/// println!("Recommended: {:?}", result.algorithm);
1002/// ```
1003pub fn auto_select_algorithm(size: usize, forward: bool) -> FFTResult<SelectionResult> {
1004    let selector = IntegratedAutoSelector::new();
1005    selector.select(size, forward)
1006}
1007
1008/// Execute FFT with automatic algorithm selection
1009///
1010/// This function automatically selects the best algorithm based on
1011/// input characteristics and executes the FFT.
1012///
1013/// # Arguments
1014///
1015/// * `input` - Input data
1016/// * `size` - Optional FFT size (if different from input length)
1017/// * `forward` - Whether this is a forward (true) or inverse (false) transform
1018///
1019/// # Returns
1020///
1021/// The FFT result as a vector of complex numbers
1022///
1023/// # Example
1024///
1025/// ```rust
1026/// use scirs2_fft::auto_tuning::auto_fft;
1027///
1028/// let signal = vec![1.0, 2.0, 3.0, 4.0];
1029/// let spectrum = auto_fft(&signal, None, true).expect("FFT failed");
1030/// ```
1031pub fn auto_fft<T>(input: &[T], size: Option<usize>, forward: bool) -> FFTResult<Vec<Complex64>>
1032where
1033    T: Clone + Into<Complex64>,
1034{
1035    let selector = IntegratedAutoSelector::new();
1036    selector.execute(input, size, forward)
1037}
1038
1039#[cfg(test)]
1040mod tests {
1041    use super::*;
1042    use tempfile::tempdir;
1043
1044    #[test]
1045    fn test_size_generation() {
1046        // Test powers of two
1047        let config = AutoTuneConfig {
1048            sizes: SizeRange {
1049                min: 8,
1050                max: 64,
1051                step: SizeStep::PowersOfTwo,
1052            },
1053            ..Default::default()
1054        };
1055        let tuner = AutoTuner::with_config(config);
1056        let sizes = tuner.generate_sizes();
1057        assert_eq!(sizes, vec![8, 16, 32, 64]);
1058
1059        // Test linear steps
1060        let config = AutoTuneConfig {
1061            sizes: SizeRange {
1062                min: 10,
1063                max: 30,
1064                step: SizeStep::Linear(5),
1065            },
1066            ..Default::default()
1067        };
1068        let tuner = AutoTuner::with_config(config);
1069        let sizes = tuner.generate_sizes();
1070        assert_eq!(sizes, vec![10, 15, 20, 25, 30]);
1071
1072        // Test exponential steps
1073        let config = AutoTuneConfig {
1074            sizes: SizeRange {
1075                min: 10,
1076                max: 100,
1077                step: SizeStep::Exponential(2.0),
1078            },
1079            ..Default::default()
1080        };
1081        let tuner = AutoTuner::with_config(config);
1082        let sizes = tuner.generate_sizes();
1083        assert_eq!(sizes, vec![10, 20, 40, 80]);
1084
1085        // Test custom sizes
1086        let config = AutoTuneConfig {
1087            sizes: SizeRange {
1088                min: 10,
1089                max: 100,
1090                step: SizeStep::Custom(vec![5, 15, 25, 50, 150]),
1091            },
1092            ..Default::default()
1093        };
1094        let tuner = AutoTuner::with_config(config);
1095        let sizes = tuner.generate_sizes();
1096        assert_eq!(sizes, vec![15, 25, 50]);
1097    }
1098
1099    #[test]
1100    fn test_auto_tuner_basic() {
1101        // Create a temporary directory for test
1102        let temp_dir = tempdir().expect("Operation failed");
1103        let db_path = temp_dir.path().join("test_tuning_db.json");
1104
1105        // Create configuration with minimal benchmarking
1106        let config = AutoTuneConfig {
1107            sizes: SizeRange {
1108                min: 16,
1109                max: 32,
1110                step: SizeStep::PowersOfTwo,
1111            },
1112            repetitions: 2,
1113            warmup: 1,
1114            variants: vec![FftVariant::Standard, FftVariant::InPlace],
1115            database_path: db_path.clone(),
1116        };
1117
1118        let mut tuner = AutoTuner::with_config(config);
1119
1120        // Run minimal benchmarks (this is fast enough for a test)
1121        match tuner.run_benchmarks() {
1122            Ok(_) => {
1123                // Verify database file was created
1124                assert!(db_path.exists());
1125
1126                // Test getting a best variant
1127                let variant = tuner.get_best_variant(16, true);
1128                assert!(matches!(
1129                    variant,
1130                    FftVariant::Standard | FftVariant::InPlace
1131                ));
1132            }
1133            Err(e) => {
1134                // Benchmark may fail in some environments, just log and continue
1135                println!("Benchmark failed: {e}");
1136            }
1137        }
1138    }
1139}