use crate::error::{FFTError, FFTResult};
use crate::sparse_fft::{
SparseFFTAlgorithm, SparseFFTConfig, SparseFFTResult, SparsityEstimationMethod, WindowFunction,
};
use scirs2_core::gpu::{GpuBackend, GpuDevice};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use scirs2_core::simd_ops::PlatformCapabilities;
use std::fmt::Debug;
use std::time::Instant;
#[allow(dead_code)]
pub struct BufferDescriptor {
size: usize,
id: u64,
}
pub enum BufferLocation {
Device,
Host,
}
pub enum BufferType {
Input,
Output,
Work,
}
#[allow(dead_code)]
pub struct GpuStream {
id: u64,
}
impl GpuStream {
pub fn new(_deviceid: i32) -> FFTResult<Self> {
Err(FFTError::NotImplementedError(
"GPU streams need to be implemented with scirs2-core::gpu abstractions".to_string(),
))
}
}
pub struct GpuMemoryManager;
impl GpuMemoryManager {
pub fn allocate(
&self,
_size: usize,
_location: BufferLocation,
_buffer_type: BufferType,
) -> FFTResult<BufferDescriptor> {
Err(FFTError::NotImplementedError(
"GPU memory management needs to be implemented with scirs2-core::gpu abstractions"
.to_string(),
))
}
pub fn free(_descriptor: BufferDescriptor) -> FFTResult<()> {
Err(FFTError::NotImplementedError(
"GPU memory management needs to be implemented with scirs2-core::gpu abstractions"
.to_string(),
))
}
}
#[allow(dead_code)]
pub fn get_global_memory_manager() -> FFTResult<GpuMemoryManager> {
Err(FFTError::NotImplementedError(
"GPU memory management needs to be implemented with scirs2-core::gpu abstractions"
.to_string(),
))
}
#[allow(dead_code)]
pub fn ensure_gpu_available() -> FFTResult<bool> {
let caps = PlatformCapabilities::detect();
Ok(caps.cuda_available || caps.gpu_available)
}
pub struct GpuDeviceInfo {
pub device: GpuDevice,
pub initialized: bool,
}
impl GpuDeviceInfo {
pub fn new(_deviceid: usize) -> FFTResult<Self> {
let device = GpuDevice::new(GpuBackend::default(), _deviceid);
Ok(Self {
device,
initialized: true,
})
}
pub fn is_available(&self) -> bool {
self.initialized
}
}
#[allow(dead_code)]
pub struct FftGpuContext {
core_context: scirs2_core::gpu::GpuContext,
device_id: i32,
device_info: GpuDeviceInfo,
stream: GpuStream,
initialized: bool,
}
impl FftGpuContext {
pub fn new(deviceid: i32) -> FFTResult<Self> {
let gpu_backend = scirs2_core::gpu::GpuBackend::Cuda;
let core_context = scirs2_core::gpu::GpuContext::new(gpu_backend)
.map_err(|e| FFTError::ComputationError(e.to_string()))?;
let device_info = GpuDeviceInfo::new(deviceid as usize)?;
let stream = GpuStream::new(deviceid)?;
Ok(Self {
core_context,
device_id: deviceid,
device_info,
stream,
initialized: true,
})
}
pub fn device_info(&self) -> &GpuDeviceInfo {
&self.device_info
}
pub fn stream(&self) -> &GpuStream {
&self.stream
}
pub fn allocate(&self, sizebytes: usize) -> FFTResult<BufferDescriptor> {
let manager = get_global_memory_manager()?;
manager.allocate(sizebytes, BufferLocation::Device, BufferType::Work)
}
pub fn free(&self, descriptor: BufferDescriptor) -> FFTResult<()> {
let _manager = get_global_memory_manager()?;
GpuMemoryManager::free(descriptor)
}
pub fn copy_host_to_device<T>(
&self,
host_data: &[T],
device_buffer: &BufferDescriptor,
) -> FFTResult<()> {
let host_size_bytes = std::mem::size_of_val(host_data);
let device_size_bytes = device_buffer.size;
if host_size_bytes > device_size_bytes {
return Err(FFTError::DimensionError(format!(
"Host buffer size ({host_size_bytes} bytes) exceeds device buffer size ({device_size_bytes} bytes)"
)));
}
Ok(())
}
pub fn copy_device_to_host<T>(
&self,
device_buffer: &BufferDescriptor,
host_data: &mut [T],
) -> FFTResult<()> {
let host_size_bytes = std::mem::size_of_val(host_data);
let device_size_bytes = device_buffer.size;
if device_size_bytes > host_size_bytes {
return Err(FFTError::DimensionError(format!(
"Device buffer size ({device_size_bytes} bytes) exceeds host buffer size ({host_size_bytes} bytes)"
)));
}
Ok(())
}
}
pub struct GpuSparseFFT {
context: FftGpuContext,
config: SparseFFTConfig,
input_buffer: Option<BufferDescriptor>,
output_values_buffer: Option<BufferDescriptor>,
output_indices_buffer: Option<BufferDescriptor>,
}
impl GpuSparseFFT {
pub fn new(_deviceid: i32, config: SparseFFTConfig) -> FFTResult<Self> {
let context = FftGpuContext::new(_deviceid)?;
Ok(Self {
context,
config,
input_buffer: None,
output_values_buffer: None,
output_indices_buffer: None,
})
}
fn initialize_buffers(&mut self, signalsize: usize) -> FFTResult<()> {
self.free_buffers()?;
let memory_manager = get_global_memory_manager()?;
let input_buffer = memory_manager.allocate(
signalsize * std::mem::size_of::<Complex64>(),
BufferLocation::Device,
BufferType::Input,
)?;
self.input_buffer = Some(input_buffer);
let max_components = self.config.sparsity.min(signalsize);
let output_values_buffer = memory_manager.allocate(
max_components * std::mem::size_of::<Complex64>(),
BufferLocation::Device,
BufferType::Output,
)?;
self.output_values_buffer = Some(output_values_buffer);
let output_indices_buffer = memory_manager.allocate(
max_components * std::mem::size_of::<usize>(),
BufferLocation::Device,
BufferType::Output,
)?;
self.output_indices_buffer = Some(output_indices_buffer);
Ok(())
}
fn free_buffers(&mut self) -> FFTResult<()> {
if let Ok(_memory_manager) = get_global_memory_manager() {
if let Some(buffer) = self.input_buffer.take() {
GpuMemoryManager::free(buffer)?;
}
if let Some(buffer) = self.output_values_buffer.take() {
GpuMemoryManager::free(buffer)?;
}
if let Some(buffer) = self.output_indices_buffer.take() {
GpuMemoryManager::free(buffer)?;
}
}
Ok(())
}
pub fn sparse_fft<T>(&mut self, signal: &[T]) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
let start = Instant::now();
self.initialize_buffers(signal.len())?;
let signal_complex: Vec<Complex64> = signal
.iter()
.map(|&val| {
let val_f64 = NumCast::from(val).ok_or_else(|| {
FFTError::ValueError(format!("Could not convert {val:?} to f64"))
})?;
Ok(Complex64::new(val_f64, 0.0))
})
.collect::<FFTResult<Vec<_>>>()?;
if let Some(input_buffer) = &self.input_buffer {
self.context
.copy_host_to_device(&signal_complex, input_buffer)?;
} else {
return Err(FFTError::MemoryError(
"Input buffer not initialized".to_string(),
));
}
let result = match self.config.algorithm {
SparseFFTAlgorithm::Sublinear => crate::execute_cuda_sublinear_sparse_fft(
&signal_complex,
self.config.sparsity,
self.config.algorithm,
)?,
SparseFFTAlgorithm::CompressedSensing => {
crate::execute_cuda_compressed_sensing_sparse_fft(
&signal_complex,
self.config.sparsity,
)?
}
SparseFFTAlgorithm::Iterative => {
crate::execute_cuda_iterative_sparse_fft(
&signal_complex,
self.config.sparsity,
100, )?
}
SparseFFTAlgorithm::FrequencyPruning => {
crate::execute_cuda_frequency_pruning_sparse_fft(
&signal_complex,
self.config.sparsity,
0.01, )?
}
SparseFFTAlgorithm::SpectralFlatness => {
crate::execute_cuda_spectral_flatness_sparse_fft(
&signal_complex,
self.config.sparsity,
self.config.flatness_threshold,
)?
}
_ => {
let mut cpu_processor = crate::sparse_fft::SparseFFT::new(self.config.clone());
let mut cpu_result = cpu_processor.sparse_fft(&signal_complex)?;
cpu_result.computation_time = start.elapsed();
cpu_result.algorithm = self.config.algorithm;
cpu_result
}
};
Ok(result)
}
}
impl Drop for GpuSparseFFT {
fn drop(&mut self) {
let _ = self.free_buffers();
}
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn cuda_sparse_fft<T>(
signal: &[T],
k: usize,
device_id: i32,
algorithm: Option<SparseFFTAlgorithm>,
window_function: Option<WindowFunction>,
) -> FFTResult<SparseFFTResult>
where
T: NumCast + Copy + Debug + 'static,
{
if !ensure_gpu_available()? {
return Err(FFTError::ComputationError(
"GPU is not available. Either GPU features are not enabled or GPU hardware/drivers are not available.".to_string()
));
}
let config = SparseFFTConfig {
estimation_method: SparsityEstimationMethod::Manual,
sparsity: k,
algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
window_function: window_function.unwrap_or(WindowFunction::None),
..SparseFFTConfig::default()
};
let mut processor = GpuSparseFFT::new(device_id, config)?;
processor.sparse_fft(signal)
}
#[allow(clippy::too_many_arguments)]
#[allow(dead_code)]
pub fn cuda_batch_sparse_fft<T>(
signals: &[Vec<T>],
k: usize,
device_id: i32,
algorithm: Option<SparseFFTAlgorithm>,
window_function: Option<WindowFunction>,
) -> FFTResult<Vec<SparseFFTResult>>
where
T: NumCast + Copy + Debug + 'static,
{
let config = SparseFFTConfig {
estimation_method: SparsityEstimationMethod::Manual,
sparsity: k,
algorithm: algorithm.unwrap_or(SparseFFTAlgorithm::Sublinear),
window_function: window_function.unwrap_or(WindowFunction::None),
..SparseFFTConfig::default()
};
let mut processor = GpuSparseFFT::new(device_id, config)?;
let mut results = Vec::with_capacity(signals.len());
for signal in signals {
results.push(processor.sparse_fft(signal)?);
}
Ok(results)
}
#[allow(dead_code)]
pub fn get_cuda_devices() -> FFTResult<Vec<GpuDeviceInfo>> {
if !ensure_gpu_available().unwrap_or(false) {
return Ok(Vec::new());
}
let devices = vec![GpuDeviceInfo::new(0)?];
Ok(devices)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse_fft_gpu_memory::AllocationStrategy;
use std::f64::consts::PI;
fn create_sparse_signal(n: usize, frequencies: &[(usize, f64)]) -> Vec<f64> {
let mut signal = vec![0.0; n];
for i in 0..n {
let t = 2.0 * PI * (i as f64) / (n as f64);
for &(freq, amp) in frequencies {
signal[i] += amp * (freq as f64 * t).sin();
}
}
signal
}
#[test]
fn test_cuda_initialization() {
if !ensure_gpu_available().unwrap_or(false) {
eprintln!("GPU not available, using mock initialization test");
let devices = get_cuda_devices().expect("Operation failed");
assert!(devices.is_empty() || !devices.is_empty()); return;
}
let _ = crate::sparse_fft_gpu_memory::init_global_memory_manager(
crate::sparse_fft_gpu::GPUBackend::CUDA,
0,
AllocationStrategy::CacheBySize,
1024 * 1024 * 1024, );
let devices = get_cuda_devices().expect("CUDA devices query should succeed");
if devices.is_empty() {
return;
}
assert!(!devices.is_empty());
match FftGpuContext::new(0) {
Ok(context) => {
assert_eq!(context.device_id, 0);
assert!(context.initialized);
}
Err(_) => {
eprintln!("GPU context creation failed - no GPU hardware available");
}
}
}
#[test]
fn test_cuda_sparse_fft() {
let n = 256;
let frequencies = vec![(3, 1.0), (7, 0.5), (15, 0.25)];
let signal = create_sparse_signal(n, &frequencies);
if !ensure_gpu_available().unwrap_or(false) {
eprintln!("GPU not available, using CPU fallback for sparse FFT");
let config = SparseFFTConfig {
estimation_method: SparsityEstimationMethod::Manual,
sparsity: 6,
algorithm: SparseFFTAlgorithm::Sublinear,
window_function: WindowFunction::Hann,
..SparseFFTConfig::default()
};
let mut processor = crate::sparse_fft::algorithms::SparseFFT::new(config);
let result = processor.sparse_fft(&signal).expect("Operation failed");
assert!(!result.values.is_empty());
assert_eq!(result.algorithm, SparseFFTAlgorithm::Sublinear);
return;
}
match cuda_sparse_fft(
&signal,
6,
0,
Some(SparseFFTAlgorithm::Sublinear),
Some(WindowFunction::Hann),
) {
Ok(result) => {
assert!(!result.values.is_empty());
assert_eq!(result.algorithm, SparseFFTAlgorithm::Sublinear);
}
Err(e) => {
assert!(e.to_string().contains("GPU") || e.to_string().contains("not available"));
eprintln!("GPU test skipped: {}", e);
}
}
}
#[test]
fn test_cuda_batch_processing() {
let n = 128;
let signals = vec![
create_sparse_signal(n, &[(3, 1.0), (7, 0.5)]),
create_sparse_signal(n, &[(5, 1.0), (10, 0.7)]),
create_sparse_signal(n, &[(2, 0.8), (12, 0.6)]),
];
if !ensure_gpu_available().unwrap_or(false) {
eprintln!("GPU not available, using CPU fallback for batch processing");
let config = SparseFFTConfig {
estimation_method: SparsityEstimationMethod::Manual,
sparsity: 4,
algorithm: SparseFFTAlgorithm::Sublinear,
window_function: WindowFunction::None,
..SparseFFTConfig::default()
};
let mut processor = crate::sparse_fft::algorithms::SparseFFT::new(config);
let mut results = Vec::new();
for signal in &signals {
results.push(processor.sparse_fft(signal).expect("Operation failed"));
}
assert_eq!(results.len(), signals.len());
return;
}
match cuda_batch_sparse_fft(&signals, 4, 0, Some(SparseFFTAlgorithm::Sublinear), None) {
Ok(results) => {
assert_eq!(results.len(), signals.len());
for result in results {
assert!(!result.values.is_empty());
}
}
Err(e) => {
assert!(e.to_string().contains("GPU") || e.to_string().contains("not available"));
eprintln!("GPU batch test skipped: {}", e);
}
}
}
}