1#[cfg(feature = "oxifft")]
12use crate::oxifft_plan_cache;
13#[cfg(feature = "oxifft")]
14use oxifft::{Complex as OxiComplex, Direction};
15use scirs2_core::numeric::Complex64;
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18use std::fs::{self, File};
19use std::io::{BufReader, BufWriter};
20use std::path::{Path, PathBuf};
21use std::time::Instant;
22
23use crate::error::{FFTError, FFTResult};
24use crate::plan_serialization::PlanSerializationManager;
25
26#[derive(Debug, Clone)]
28pub struct SizeRange {
29 pub min: usize,
31 pub max: usize,
33 pub step: SizeStep,
35}
36
37#[derive(Debug, Clone)]
39pub enum SizeStep {
40 Linear(usize),
42 Exponential(f64),
44 PowersOfTwo,
46 Custom(Vec<usize>),
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
52pub enum FftVariant {
53 Standard,
55 InPlace,
57 Cached,
59 SplitRadix,
61}
62
63#[derive(Debug, Clone)]
65pub struct AutoTuneConfig {
66 pub sizes: SizeRange,
68 pub repetitions: usize,
70 pub warmup: usize,
72 pub variants: Vec<FftVariant>,
74 pub database_path: PathBuf,
76}
77
78impl Default for AutoTuneConfig {
79 fn default() -> Self {
80 Self {
81 sizes: SizeRange {
82 min: 16,
83 max: 8192,
84 step: SizeStep::PowersOfTwo,
85 },
86 repetitions: 10,
87 warmup: 3,
88 variants: vec![FftVariant::Standard, FftVariant::Cached],
89 database_path: PathBuf::from(".fft_tuning_db.json"),
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct BenchmarkResult {
97 pub size: usize,
99 pub variant: FftVariant,
101 pub forward: bool,
103 pub avg_time_ns: u64,
105 pub min_time_ns: u64,
107 pub std_dev_ns: f64,
109 pub system_info: SystemInfo,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct SystemInfo {
116 pub cpu_model: String,
118 pub num_cores: usize,
120 pub architecture: String,
122 pub cpu_features: Vec<String>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct TuningDatabase {
129 pub results: Vec<BenchmarkResult>,
131 pub last_updated: u64,
133 pub best_algorithms: HashMap<(usize, bool), FftVariant>,
135}
136
137pub struct AutoTuner {
139 config: AutoTuneConfig,
141 database: TuningDatabase,
143 enabled: bool,
145}
146
147impl Default for AutoTuner {
148 fn default() -> Self {
149 Self::with_config(AutoTuneConfig::default())
150 }
151}
152
153impl AutoTuner {
154 pub fn new() -> Self {
156 Self::default()
157 }
158
159 pub fn with_config(config: AutoTuneConfig) -> Self {
161 let database =
162 Self::load_database(&config.database_path).unwrap_or_else(|_| TuningDatabase {
163 results: Vec::new(),
164 last_updated: std::time::SystemTime::now()
165 .duration_since(std::time::UNIX_EPOCH)
166 .unwrap_or_default()
167 .as_secs(),
168 best_algorithms: HashMap::new(),
169 });
170
171 Self {
172 config,
173 database,
174 enabled: true,
175 }
176 }
177
178 fn load_database(path: &Path) -> FFTResult<TuningDatabase> {
180 if !path.exists() {
181 return Err(FFTError::IOError(format!(
182 "Tuning database file not found: {}",
183 path.display()
184 )));
185 }
186
187 let file = File::open(path)
188 .map_err(|e| FFTError::IOError(format!("Failed to open tuning database: {e}")))?;
189
190 let reader = BufReader::new(file);
191 let database: TuningDatabase = serde_json::from_reader(reader)
192 .map_err(|e| FFTError::ValueError(format!("Failed to parse tuning database: {e}")))?;
193
194 Ok(database)
195 }
196
197 pub fn save_database(&self) -> FFTResult<()> {
199 if let Some(parent) = self.config.database_path.parent() {
201 fs::create_dir_all(parent).map_err(|e| {
202 FFTError::IOError(format!(
203 "Failed to create directory for tuning database: {e}"
204 ))
205 })?;
206 }
207
208 let file = File::create(&self.config.database_path).map_err(|e| {
209 FFTError::IOError(format!("Failed to create tuning database file: {e}"))
210 })?;
211
212 let writer = BufWriter::new(file);
213 serde_json::to_writer_pretty(writer, &self.database)
214 .map_err(|e| FFTError::IOError(format!("Failed to serialize tuning database: {e}")))?;
215
216 Ok(())
217 }
218
219 pub fn set_enabled(&mut self, enabled: bool) {
221 self.enabled = enabled;
222 }
223
224 pub fn is_enabled(&self) -> bool {
226 self.enabled
227 }
228
229 pub fn run_benchmarks(&mut self) -> FFTResult<()> {
231 if !self.enabled {
232 return Ok(());
233 }
234
235 let sizes = self.generate_sizes();
236 let mut results = Vec::new();
237
238 for size in sizes {
239 for &variant in &self.config.variants {
240 let forward_result = self.benchmark_variant(size, variant, true)?;
242 results.push(forward_result);
243
244 let inverse_result = self.benchmark_variant(size, variant, false)?;
246 results.push(inverse_result);
247 }
248 }
249
250 self.database.results.extend(results);
252 self.update_best_algorithms();
253 self.save_database()?;
254
255 Ok(())
256 }
257
258 fn generate_sizes(&self) -> Vec<usize> {
260 let mut sizes = Vec::new();
261
262 match &self.config.sizes.step {
263 SizeStep::Linear(step) => {
264 let mut size = self.config.sizes.min;
265 while size <= self.config.sizes.max {
266 sizes.push(size);
267 size += step;
268 }
269 }
270 SizeStep::Exponential(factor) => {
271 let mut size = self.config.sizes.min as f64;
272 while size <= self.config.sizes.max as f64 {
273 sizes.push(size as usize);
274 size *= factor;
275 }
276 }
277 SizeStep::PowersOfTwo => {
278 let mut size = 1;
279 while size < self.config.sizes.min {
280 size *= 2;
281 }
282 while size <= self.config.sizes.max {
283 sizes.push(size);
284 size *= 2;
285 }
286 }
287 SizeStep::Custom(custom_sizes) => {
288 for &size in custom_sizes {
289 if size >= self.config.sizes.min && size <= self.config.sizes.max {
290 sizes.push(size);
291 }
292 }
293 }
294 }
295
296 sizes
297 }
298
299 fn benchmark_variant(
301 &self,
302 size: usize,
303 variant: FftVariant,
304 forward: bool,
305 ) -> FFTResult<BenchmarkResult> {
306 let mut data = vec![Complex64::new(0.0, 0.0); size];
308 for (i, val) in data.iter_mut().enumerate().take(size) {
309 *val = Complex64::new(i as f64, (i * 2) as f64);
310 }
311
312 for _ in 0..self.config.warmup {
314 match variant {
315 FftVariant::Standard => {
316 #[cfg(feature = "oxifft")]
317 {
318 let input_oxi: Vec<OxiComplex<f64>> =
319 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
320 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
321
322 let direction = if forward {
323 Direction::Forward
324 } else {
325 Direction::Backward
326 };
327 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
328 }
329
330 #[cfg(not(feature = "oxifft"))]
331 {
332 #[cfg(feature = "rustfft-backend")]
333 {
334 let mut planner = FftPlanner::new();
335 let fft = if forward {
336 planner.plan_fft_forward(size)
337 } else {
338 planner.plan_fft_inverse(size)
339 };
340 let mut buffer = data.clone();
341 fft.process(&mut buffer);
342 }
343 }
344 }
345 FftVariant::InPlace => {
346 #[cfg(feature = "oxifft")]
347 {
348 let input_oxi: Vec<OxiComplex<f64>> =
349 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
350 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
351
352 let direction = if forward {
353 Direction::Forward
354 } else {
355 Direction::Backward
356 };
357 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
358 }
359
360 #[cfg(not(feature = "oxifft"))]
361 {
362 #[cfg(feature = "rustfft-backend")]
363 {
364 let mut planner = FftPlanner::new();
365 let fft = if forward {
366 planner.plan_fft_forward(size)
367 } else {
368 planner.plan_fft_inverse(size)
369 };
370 let mut buffer = data.clone();
372 let mut scratch =
373 vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
374 fft.process_with_scratch(&mut buffer, &mut scratch);
375 }
376 }
377 }
378 FftVariant::Cached => {
379 let manager = PlanSerializationManager::new(&self.config.database_path);
381 let plan_info = manager.create_plan_info(size, forward);
382 let time = crate::plan_serialization::create_and_time_plan(size, forward);
383 manager.record_plan_usage(&plan_info, time).unwrap_or(());
384 }
385 FftVariant::SplitRadix => {
386 #[cfg(feature = "oxifft")]
387 {
388 let input_oxi: Vec<OxiComplex<f64>> =
390 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
391 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
392
393 let direction = if forward {
394 Direction::Forward
395 } else {
396 Direction::Backward
397 };
398 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
399 }
400
401 #[cfg(not(feature = "oxifft"))]
402 {
403 #[cfg(feature = "rustfft-backend")]
404 {
405 let mut planner = FftPlanner::new();
408 let fft = if forward {
409 planner.plan_fft_forward(size)
410 } else {
411 planner.plan_fft_inverse(size)
412 };
413 let mut buffer = data.clone();
414 fft.process(&mut buffer);
415 }
416 }
417 }
418 }
419 }
420
421 let mut times = Vec::with_capacity(self.config.repetitions);
423
424 for _ in 0..self.config.repetitions {
425 let start = Instant::now();
426
427 match variant {
428 FftVariant::Standard => {
429 #[cfg(feature = "oxifft")]
430 {
431 let input_oxi: Vec<OxiComplex<f64>> =
432 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
433 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
434
435 let direction = if forward {
436 Direction::Forward
437 } else {
438 Direction::Backward
439 };
440 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
441 }
442
443 #[cfg(not(feature = "oxifft"))]
444 {
445 #[cfg(feature = "rustfft-backend")]
446 {
447 let mut planner = FftPlanner::new();
448 let fft = if forward {
449 planner.plan_fft_forward(size)
450 } else {
451 planner.plan_fft_inverse(size)
452 };
453 let mut buffer = data.clone();
454 fft.process(&mut buffer);
455 }
456 }
457 }
458 FftVariant::InPlace => {
459 #[cfg(feature = "oxifft")]
460 {
461 let input_oxi: Vec<OxiComplex<f64>> =
462 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
463 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
464
465 let direction = if forward {
466 Direction::Forward
467 } else {
468 Direction::Backward
469 };
470 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
471 }
472
473 #[cfg(not(feature = "oxifft"))]
474 {
475 #[cfg(feature = "rustfft-backend")]
476 {
477 let mut planner = FftPlanner::new();
478 let fft = if forward {
479 planner.plan_fft_forward(size)
480 } else {
481 planner.plan_fft_inverse(size)
482 };
483 let mut buffer = data.clone();
485 let mut scratch =
486 vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
487 fft.process_with_scratch(&mut buffer, &mut scratch);
488 }
489 }
490 }
491 FftVariant::Cached => {
492 #[cfg(feature = "oxifft")]
493 {
494 let input_oxi: Vec<OxiComplex<f64>> =
495 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
496 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
497
498 let direction = if forward {
499 Direction::Forward
500 } else {
501 Direction::Backward
502 };
503 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
504 }
505
506 #[cfg(not(feature = "oxifft"))]
507 {
508 #[cfg(feature = "rustfft-backend")]
509 {
510 let mut planner = FftPlanner::new();
512 let fft = if forward {
513 planner.plan_fft_forward(size)
514 } else {
515 planner.plan_fft_inverse(size)
516 };
517 let mut buffer = data.clone();
518 fft.process(&mut buffer);
519 }
520 }
521 }
522 FftVariant::SplitRadix => {
523 #[cfg(feature = "oxifft")]
524 {
525 let input_oxi: Vec<OxiComplex<f64>> =
526 data.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
527 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); size];
528
529 let direction = if forward {
530 Direction::Forward
531 } else {
532 Direction::Backward
533 };
534 let _ = oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction);
535 }
536
537 #[cfg(not(feature = "oxifft"))]
538 {
539 #[cfg(feature = "rustfft-backend")]
540 {
541 let mut planner = FftPlanner::new();
543 let fft = if forward {
544 planner.plan_fft_forward(size)
545 } else {
546 planner.plan_fft_inverse(size)
547 };
548 let mut buffer = data.clone();
549 fft.process(&mut buffer);
550 }
551 }
552 }
553 }
554
555 let elapsed = start.elapsed();
556 times.push(elapsed.as_nanos() as u64);
557 }
558
559 let avg_time = times.iter().sum::<u64>() / times.len() as u64;
561 let min_time = *times.iter().min().unwrap_or(&0);
562
563 let variance = times
565 .iter()
566 .map(|&t| {
567 let diff = t as f64 - avg_time as f64;
568 diff * diff
569 })
570 .sum::<f64>()
571 / times.len() as f64;
572 let std_dev = variance.sqrt();
573
574 Ok(BenchmarkResult {
575 size,
576 variant,
577 forward,
578 avg_time_ns: avg_time,
579 min_time_ns: min_time,
580 std_dev_ns: std_dev,
581 system_info: self.detect_system_info(),
582 })
583 }
584
585 fn detect_system_info(&self) -> SystemInfo {
587 SystemInfo {
590 cpu_model: String::from("Unknown"),
591 num_cores: num_cpus::get(),
592 architecture: std::env::consts::ARCH.to_string(),
593 cpu_features: detect_cpu_features(),
594 }
595 }
596
597 fn update_best_algorithms(&mut self) {
599 self.database.best_algorithms.clear();
601
602 let mut grouped: HashMap<(usize, bool), Vec<&BenchmarkResult>> = HashMap::new();
604 for result in &self.database.results {
605 grouped
606 .entry((result.size, result.forward))
607 .or_default()
608 .push(result);
609 }
610
611 for ((size, forward), results) in grouped {
613 if let Some(best) = results.iter().min_by_key(|r| r.avg_time_ns) {
614 self.database
615 .best_algorithms
616 .insert((size, forward), best.variant);
617 }
618 }
619 }
620
621 pub fn get_best_variant(&self, size: usize, forward: bool) -> FftVariant {
623 if !self.enabled {
624 return FftVariant::Standard;
625 }
626
627 if let Some(&variant) = self.database.best_algorithms.get(&(size, forward)) {
629 return variant;
630 }
631
632 let mut closest_size = 0;
634 let mut min_diff = usize::MAX;
635
636 for &(s, f) in self.database.best_algorithms.keys() {
637 if f == forward {
638 let diff = s.abs_diff(size);
639 if diff < min_diff {
640 min_diff = diff;
641 closest_size = s;
642 }
643 }
644 }
645
646 if closest_size > 0 {
647 if let Some(&variant) = self.database.best_algorithms.get(&(closest_size, forward)) {
648 return variant;
649 }
650 }
651
652 FftVariant::Standard
654 }
655
656 pub fn run_optimal_fft<T>(
658 &self,
659 input: &[T],
660 size: Option<usize>,
661 forward: bool,
662 ) -> FFTResult<Vec<Complex64>>
663 where
664 T: Clone + Into<Complex64>,
665 {
666 let actual_size = size.unwrap_or(input.len());
667 let variant = self.get_best_variant(actual_size, forward);
668
669 let mut buffer: Vec<Complex64> = input.iter().map(|x| x.clone().into()).collect();
671 if buffer.len() < actual_size {
673 buffer.resize(actual_size, Complex64::new(0.0, 0.0));
674 }
675
676 #[cfg(feature = "oxifft")]
677 {
678 let input_oxi: Vec<OxiComplex<f64>> =
679 buffer.iter().map(|c| OxiComplex::new(c.re, c.im)).collect();
680 let mut output: Vec<OxiComplex<f64>> = vec![OxiComplex::zero(); actual_size];
681
682 let direction = if forward {
683 Direction::Forward
684 } else {
685 Direction::Backward
686 };
687 oxifft_plan_cache::execute_c2c(&input_oxi, &mut output, direction)?;
688
689 for (i, val) in output.iter().enumerate() {
691 buffer[i] = Complex64::new(val.re, val.im);
692 }
693 }
694
695 #[cfg(not(feature = "oxifft"))]
696 {
697 #[cfg(feature = "rustfft-backend")]
698 {
699 match variant {
700 FftVariant::Standard => {
701 let mut planner = FftPlanner::new();
702 let fft = if forward {
703 planner.plan_fft_forward(actual_size)
704 } else {
705 planner.plan_fft_inverse(actual_size)
706 };
707 fft.process(&mut buffer);
708 }
709 FftVariant::InPlace => {
710 let mut planner = FftPlanner::new();
711 let fft = if forward {
712 planner.plan_fft_forward(actual_size)
713 } else {
714 planner.plan_fft_inverse(actual_size)
715 };
716 let mut scratch =
717 vec![Complex64::new(0.0, 0.0); fft.get_inplace_scratch_len()];
718 fft.process_with_scratch(&mut buffer, &mut scratch);
719 }
720 FftVariant::Cached => {
721 let (plan_, _) =
724 crate::plan_serialization::create_and_time_plan(actual_size, forward);
725 plan_.process(&mut buffer);
726 }
727 FftVariant::SplitRadix => {
728 let mut planner = FftPlanner::new();
730 let fft = if forward {
731 planner.plan_fft_forward(actual_size)
732 } else {
733 planner.plan_fft_inverse(actual_size)
734 };
735 fft.process(&mut buffer);
736 }
737 }
738 }
739
740 {
741 return Err(FFTError::ComputationError(
742 "No FFT backend available. Enable either 'oxifft' or 'rustfft-backend' feature.".to_string()
743 ));
744 }
745 }
746
747 if !forward {
749 let scale = 1.0 / (actual_size as f64);
750 for val in &mut buffer {
751 *val *= scale;
752 }
753 }
754
755 Ok(buffer)
756 }
757}
758
759#[allow(dead_code)]
761fn detect_cpu_features() -> Vec<String> {
762 let mut features = Vec::new();
763
764 #[cfg(target_arch = "x86_64")]
766 {
767 #[cfg(target_feature = "sse")]
768 features.push("sse".to_string());
769
770 #[cfg(target_feature = "sse2")]
771 features.push("sse2".to_string());
772
773 #[cfg(target_feature = "sse3")]
774 features.push("sse3".to_string());
775
776 #[cfg(target_feature = "sse4.1")]
777 features.push("sse4.1".to_string());
778
779 #[cfg(target_feature = "sse4.2")]
780 features.push("sse4.2".to_string());
781
782 #[cfg(target_feature = "avx")]
783 features.push("avx".to_string());
784
785 #[cfg(target_feature = "avx2")]
786 features.push("avx2".to_string());
787
788 #[cfg(target_feature = "fma")]
789 features.push("fma".to_string());
790 }
791
792 #[cfg(target_arch = "aarch64")]
794 {
795 #[cfg(target_feature = "neon")]
796 features.push("neon".to_string());
797 }
798
799 features
802}
803
804pub struct IntegratedAutoSelector {
810 selector: crate::algorithm_selector::AlgorithmSelector,
812 tuner: AutoTuner,
814 prefer_learned: bool,
816}
817
818impl Default for IntegratedAutoSelector {
819 fn default() -> Self {
820 Self::new()
821 }
822}
823
824impl IntegratedAutoSelector {
825 pub fn new() -> Self {
827 Self {
828 selector: crate::algorithm_selector::AlgorithmSelector::new(),
829 tuner: AutoTuner::new(),
830 prefer_learned: true,
831 }
832 }
833
834 pub fn with_config(
836 selector_config: crate::algorithm_selector::SelectionConfig,
837 tuner_config: AutoTuneConfig,
838 prefer_learned: bool,
839 ) -> Self {
840 Self {
841 selector: crate::algorithm_selector::AlgorithmSelector::with_config(selector_config),
842 tuner: AutoTuner::with_config(tuner_config),
843 prefer_learned,
844 }
845 }
846
847 pub fn select(&self, size: usize, forward: bool) -> FFTResult<SelectionResult> {
849 if self.prefer_learned && self.tuner.is_enabled() {
851 let variant = self.tuner.get_best_variant(size, forward);
852 if variant != FftVariant::Standard {
853 return Ok(SelectionResult {
855 algorithm: variant_to_algorithm(variant),
856 variant,
857 source: SelectionSource::Learned,
858 confidence: 0.9,
859 recommendation: self.selector.select_algorithm(size, forward).ok(),
860 });
861 }
862 }
863
864 let recommendation = self.selector.select_algorithm(size, forward)?;
866 let variant = algorithm_to_variant(recommendation.algorithm);
867
868 Ok(SelectionResult {
869 algorithm: recommendation.algorithm,
870 variant,
871 source: SelectionSource::Characteristic,
872 confidence: recommendation.confidence,
873 recommendation: Some(recommendation),
874 })
875 }
876
877 pub fn auto_tune(&mut self, sizes: &[usize]) -> FFTResult<()> {
879 if sizes.is_empty() {
881 return Ok(());
882 }
883
884 let min = *sizes.iter().min().unwrap_or(&16);
885 let max = *sizes.iter().max().unwrap_or(&8192);
886
887 let config = AutoTuneConfig {
888 sizes: SizeRange {
889 min,
890 max,
891 step: SizeStep::Custom(sizes.to_vec()),
892 },
893 ..Default::default()
894 };
895
896 self.tuner = AutoTuner::with_config(config);
897 self.tuner.run_benchmarks()
898 }
899
900 pub fn execute<T>(
902 &self,
903 input: &[T],
904 size: Option<usize>,
905 forward: bool,
906 ) -> FFTResult<Vec<Complex64>>
907 where
908 T: Clone + Into<Complex64>,
909 {
910 let actual_size = size.unwrap_or(input.len());
911 let selection = self.select(actual_size, forward)?;
912
913 self.tuner.run_optimal_fft(input, size, forward)
915 }
916
917 pub fn selector(&self) -> &crate::algorithm_selector::AlgorithmSelector {
919 &self.selector
920 }
921
922 pub fn tuner(&self) -> &AutoTuner {
924 &self.tuner
925 }
926}
927
928#[derive(Debug, Clone)]
930pub struct SelectionResult {
931 pub algorithm: crate::algorithm_selector::FftAlgorithm,
933 pub variant: FftVariant,
935 pub source: SelectionSource,
937 pub confidence: f64,
939 pub recommendation: Option<crate::algorithm_selector::AlgorithmRecommendation>,
941}
942
943#[derive(Debug, Clone, Copy, PartialEq, Eq)]
945pub enum SelectionSource {
946 Learned,
948 Characteristic,
950 Forced,
952 Default,
954}
955
956fn variant_to_algorithm(variant: FftVariant) -> crate::algorithm_selector::FftAlgorithm {
958 use crate::algorithm_selector::FftAlgorithm;
959 match variant {
960 FftVariant::Standard => FftAlgorithm::MixedRadix,
961 FftVariant::InPlace => FftAlgorithm::InPlace,
962 FftVariant::Cached => FftAlgorithm::MixedRadix,
963 FftVariant::SplitRadix => FftAlgorithm::SplitRadix,
964 }
965}
966
967fn algorithm_to_variant(algorithm: crate::algorithm_selector::FftAlgorithm) -> FftVariant {
969 use crate::algorithm_selector::FftAlgorithm;
970 match algorithm {
971 FftAlgorithm::SplitRadix => FftVariant::SplitRadix,
972 FftAlgorithm::InPlace => FftVariant::InPlace,
973 _ => FftVariant::Standard,
974 }
975}
976
977pub fn auto_select_algorithm(size: usize, forward: bool) -> FFTResult<SelectionResult> {
1001 let selector = IntegratedAutoSelector::new();
1002 selector.select(size, forward)
1003}
1004
1005pub fn auto_fft<T>(input: &[T], size: Option<usize>, forward: bool) -> FFTResult<Vec<Complex64>>
1029where
1030 T: Clone + Into<Complex64>,
1031{
1032 let selector = IntegratedAutoSelector::new();
1033 selector.execute(input, size, forward)
1034}
1035
1036#[cfg(test)]
1037mod tests {
1038 use super::*;
1039 use tempfile::tempdir;
1040
1041 #[test]
1042 fn test_size_generation() {
1043 let config = AutoTuneConfig {
1045 sizes: SizeRange {
1046 min: 8,
1047 max: 64,
1048 step: SizeStep::PowersOfTwo,
1049 },
1050 ..Default::default()
1051 };
1052 let tuner = AutoTuner::with_config(config);
1053 let sizes = tuner.generate_sizes();
1054 assert_eq!(sizes, vec![8, 16, 32, 64]);
1055
1056 let config = AutoTuneConfig {
1058 sizes: SizeRange {
1059 min: 10,
1060 max: 30,
1061 step: SizeStep::Linear(5),
1062 },
1063 ..Default::default()
1064 };
1065 let tuner = AutoTuner::with_config(config);
1066 let sizes = tuner.generate_sizes();
1067 assert_eq!(sizes, vec![10, 15, 20, 25, 30]);
1068
1069 let config = AutoTuneConfig {
1071 sizes: SizeRange {
1072 min: 10,
1073 max: 100,
1074 step: SizeStep::Exponential(2.0),
1075 },
1076 ..Default::default()
1077 };
1078 let tuner = AutoTuner::with_config(config);
1079 let sizes = tuner.generate_sizes();
1080 assert_eq!(sizes, vec![10, 20, 40, 80]);
1081
1082 let config = AutoTuneConfig {
1084 sizes: SizeRange {
1085 min: 10,
1086 max: 100,
1087 step: SizeStep::Custom(vec![5, 15, 25, 50, 150]),
1088 },
1089 ..Default::default()
1090 };
1091 let tuner = AutoTuner::with_config(config);
1092 let sizes = tuner.generate_sizes();
1093 assert_eq!(sizes, vec![15, 25, 50]);
1094 }
1095
1096 #[test]
1097 fn test_auto_tuner_basic() {
1098 let temp_dir = tempdir().expect("Operation failed");
1100 let db_path = temp_dir.path().join("test_tuning_db.json");
1101
1102 let config = AutoTuneConfig {
1104 sizes: SizeRange {
1105 min: 16,
1106 max: 32,
1107 step: SizeStep::PowersOfTwo,
1108 },
1109 repetitions: 2,
1110 warmup: 1,
1111 variants: vec![FftVariant::Standard, FftVariant::InPlace],
1112 database_path: db_path.clone(),
1113 };
1114
1115 let mut tuner = AutoTuner::with_config(config);
1116
1117 match tuner.run_benchmarks() {
1119 Ok(_) => {
1120 assert!(db_path.exists());
1122
1123 let variant = tuner.get_best_variant(16, true);
1125 assert!(matches!(
1126 variant,
1127 FftVariant::Standard | FftVariant::InPlace
1128 ));
1129 }
1130 Err(e) => {
1131 println!("Benchmark failed: {e}");
1133 }
1134 }
1135 }
1136}