use crate::error::{FFTError, FFTResult};
#[cfg(feature = "rustfft-backend")]
use rustfft::FftPlanner;
use scirs2_core::numeric::Complex64;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
pub trait FftBackend: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn is_available(&self) -> bool;
fn fft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()>;
fn ifft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()>;
fn fft_sized(
&self,
input: &[Complex64],
output: &mut [Complex64],
size: usize,
) -> FFTResult<()>;
fn ifft_sized(
&self,
input: &[Complex64],
output: &mut [Complex64],
size: usize,
) -> FFTResult<()>;
fn supports_feature(&self, feature: &str) -> bool;
}
#[cfg(feature = "rustfft-backend")]
pub struct RustFftBackend {
planner: Arc<Mutex<FftPlanner<f64>>>,
}
#[cfg(feature = "rustfft-backend")]
impl RustFftBackend {
pub fn new() -> Self {
Self {
planner: Arc::new(Mutex::new(FftPlanner::new())),
}
}
}
#[cfg(feature = "rustfft-backend")]
impl Default for RustFftBackend {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "rustfft-backend")]
impl FftBackend for RustFftBackend {
fn name(&self) -> &str {
"rustfft"
}
fn description(&self) -> &str {
"Pure Rust FFT implementation using RustFFT library"
}
fn is_available(&self) -> bool {
true
}
fn fft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
self.fft_sized(input, output, input.len())
}
fn ifft(&self, input: &[Complex64], output: &mut [Complex64]) -> FFTResult<()> {
self.ifft_sized(input, output, input.len())
}
fn fft_sized(
&self,
input: &[Complex64],
output: &mut [Complex64],
size: usize,
) -> FFTResult<()> {
if input.len() != size || output.len() != size {
return Err(FFTError::ValueError(
"Input and output sizes must match the specified size".to_string(),
));
}
let mut planner = self.planner.lock().expect("Operation failed");
let fft = planner.plan_fft_forward(size);
let mut buffer: Vec<rustfft::num_complex::Complex<f64>> = input
.iter()
.map(|&c| rustfft::num_complex::Complex::new(c.re, c.im))
.collect();
fft.process(&mut buffer);
for (i, &c) in buffer.iter().enumerate() {
output[i] = Complex64::new(c.re, c.im);
}
Ok(())
}
fn ifft_sized(
&self,
input: &[Complex64],
output: &mut [Complex64],
size: usize,
) -> FFTResult<()> {
if input.len() != size || output.len() != size {
return Err(FFTError::ValueError(
"Input and output sizes must match the specified size".to_string(),
));
}
let mut planner = self.planner.lock().expect("Operation failed");
let fft = planner.plan_fft_inverse(size);
let mut buffer: Vec<rustfft::num_complex::Complex<f64>> = input
.iter()
.map(|&c| rustfft::num_complex::Complex::new(c.re, c.im))
.collect();
fft.process(&mut buffer);
let scale = 1.0 / size as f64;
for (i, &c) in buffer.iter().enumerate() {
output[i] = Complex64::new(c.re * scale, c.im * scale);
}
Ok(())
}
fn supports_feature(&self, feature: &str) -> bool {
matches!(feature, "1d_fft" | "2d_fft" | "nd_fft" | "cached_plans")
}
}
pub struct BackendManager {
backends: Arc<Mutex<HashMap<String, Arc<dyn FftBackend>>>>,
current_backend: Arc<Mutex<String>>,
}
impl BackendManager {
pub fn new() -> Self {
let mut backends = HashMap::new();
#[cfg(feature = "rustfft-backend")]
{
let rustfft_backend = Arc::new(RustFftBackend::new()) as Arc<dyn FftBackend>;
backends.insert("rustfft".to_string(), rustfft_backend);
}
#[cfg(feature = "rustfft-backend")]
let default_backend = "rustfft".to_string();
#[cfg(not(feature = "rustfft-backend"))]
let default_backend = "none".to_string();
Self {
backends: Arc::new(Mutex::new(backends)),
current_backend: Arc::new(Mutex::new(default_backend)),
}
}
pub fn register_backend(&self, name: String, backend: Arc<dyn FftBackend>) -> FFTResult<()> {
let mut backends = self.backends.lock().expect("Operation failed");
if backends.contains_key(&name) {
return Err(FFTError::ValueError(format!(
"Backend '{name}' already exists"
)));
}
backends.insert(name, backend);
Ok(())
}
pub fn list_backends(&self) -> Vec<String> {
let backends = self.backends.lock().expect("Operation failed");
backends.keys().cloned().collect()
}
pub fn set_backend(&self, name: &str) -> FFTResult<()> {
let backends = self.backends.lock().expect("Operation failed");
if !backends.contains_key(name) {
return Err(FFTError::ValueError(format!("Backend '{name}' not found")));
}
if let Some(backend) = backends.get(name) {
if !backend.is_available() {
return Err(FFTError::ValueError(format!(
"Backend '{name}' is not available"
)));
}
}
*self.current_backend.lock().expect("Operation failed") = name.to_string();
Ok(())
}
pub fn get_backend_name(&self) -> String {
self.current_backend
.lock()
.expect("Operation failed")
.clone()
}
pub fn get_backend(&self) -> Arc<dyn FftBackend> {
let current_name = self.current_backend.lock().expect("Operation failed");
let backends = self.backends.lock().expect("Operation failed");
backends
.get(&*current_name)
.cloned()
.expect("Current backend should always exist")
}
pub fn get_backend_info(&self, name: &str) -> Option<BackendInfo> {
let backends = self.backends.lock().expect("Operation failed");
backends.get(name).map(|backend| BackendInfo {
name: backend.name().to_string(),
description: backend.description().to_string(),
available: backend.is_available(),
})
}
}
impl Default for BackendManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct BackendInfo {
pub name: String,
pub description: String,
pub available: bool,
}
impl std::fmt::Display for BackendInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} - {} ({})",
self.name,
self.description,
if self.available {
"available"
} else {
"not available"
}
)
}
}
static GLOBAL_BACKEND_MANAGER: OnceLock<BackendManager> = OnceLock::new();
#[allow(dead_code)]
pub fn get_backend_manager() -> &'static BackendManager {
GLOBAL_BACKEND_MANAGER.get_or_init(BackendManager::new)
}
#[allow(dead_code)]
pub fn init_backend_manager(manager: BackendManager) -> Result<(), &'static str> {
GLOBAL_BACKEND_MANAGER
.set(manager)
.map_err(|_| "Global backend _manager already initialized")
}
#[allow(dead_code)]
pub fn list_backends() -> Vec<String> {
get_backend_manager().list_backends()
}
#[allow(dead_code)]
pub fn set_backend(name: &str) -> FFTResult<()> {
get_backend_manager().set_backend(name)
}
#[allow(dead_code)]
pub fn get_backend_name() -> String {
get_backend_manager().get_backend_name()
}
#[allow(dead_code)]
pub fn get_backend_info(name: &str) -> Option<BackendInfo> {
get_backend_manager().get_backend_info(name)
}
pub struct BackendContext {
previous_backend: String,
manager: &'static BackendManager,
}
impl BackendContext {
pub fn new(_backendname: &str) -> FFTResult<Self> {
let manager = get_backend_manager();
let previous_backend = manager.get_backend_name();
manager.set_backend(_backendname)?;
Ok(Self {
previous_backend,
manager,
})
}
}
impl Drop for BackendContext {
fn drop(&mut self) {
let _ = self.manager.set_backend(&self.previous_backend);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "rustfft-backend")]
fn test_rustfft_backend() {
let backend = RustFftBackend::new();
assert_eq!(backend.name(), "rustfft");
assert!(backend.is_available());
assert!(backend.supports_feature("1d_fft"));
}
#[test]
#[cfg(feature = "rustfft-backend")]
fn test_backend_manager() {
let manager = BackendManager::new();
assert_eq!(manager.get_backend_name(), "rustfft");
let backends = manager.list_backends();
assert!(backends.contains(&"rustfft".to_string()));
let info = manager
.get_backend_info("rustfft")
.expect("Operation failed");
assert!(info.available);
}
#[test]
#[cfg(feature = "rustfft-backend")]
fn test_backend_context() {
let manager = get_backend_manager();
let original = manager.get_backend_name();
{
let _ctx = BackendContext::new("rustfft").expect("Operation failed");
assert_eq!(manager.get_backend_name(), "rustfft");
}
assert_eq!(manager.get_backend_name(), original);
}
}