#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::string::{String, ToString};
use super::buffer::GpuBuffer;
use super::error::{GpuError, GpuResult};
use super::plan::GpuDirection;
use super::GpuBackend;
use super::GpuCapabilities;
use crate::kernel::{Complex, Float};
#[must_use]
pub fn is_available() -> bool {
oxicuda_metal::device::MetalDevice::new().is_ok()
}
pub fn query_capabilities() -> GpuResult<GpuCapabilities> {
let device = oxicuda_metal::device::MetalDevice::new()
.map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
Ok(GpuCapabilities {
backend: GpuBackend::Metal,
device_name: device.name().to_string(),
total_memory: 0, available_memory: 0, max_fft_size: 1 << 24,
supports_f64: false, supports_f16: true,
compute_units: 0,
max_workgroup_size: 1024,
})
}
pub fn synchronize() -> GpuResult<()> {
Ok(())
}
pub struct MetalFftPlan {
size: usize,
batch_size: usize,
inner: oxicuda_metal::fft::MetalFftPlan,
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for MetalFftPlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MetalFftPlan")
.field("size", &self.size)
.field("batch_size", &self.batch_size)
.finish_non_exhaustive()
}
}
impl MetalFftPlan {
pub fn new(size: usize, batch_size: usize) -> GpuResult<Self> {
if !is_available() {
return Err(GpuError::NoBackendAvailable);
}
if size == 0 {
return Err(GpuError::InvalidSize(size));
}
if !size.is_power_of_two() {
return Err(GpuError::Unsupported(
"Metal FFT requires power-of-2 sizes".into(),
));
}
let inner = oxicuda_metal::fft::MetalFftPlan::new(size, batch_size)
.map_err(|e| GpuError::InitializationFailed(e.to_string()))?;
Ok(Self {
size,
batch_size,
inner,
})
}
pub fn execute<T: Float>(
&self,
input: &GpuBuffer<T>,
output: &mut GpuBuffer<T>,
direction: GpuDirection,
) -> GpuResult<()> {
let expected_size = self.size * self.batch_size;
if input.size() != expected_size || output.size() != expected_size {
return Err(GpuError::SizeMismatch {
expected: expected_size,
got: input.size().min(output.size()),
});
}
let input_f32: Vec<num_complex::Complex<f32>> = input
.cpu_data()
.iter()
.map(|c| {
let re = num_traits::ToPrimitive::to_f64(&c.re)
.map(|v| v as f32)
.unwrap_or(0.0_f32);
let im = num_traits::ToPrimitive::to_f64(&c.im)
.map(|v| v as f32)
.unwrap_or(0.0_f32);
num_complex::Complex::new(re, im)
})
.collect();
let mut output_f32 = vec![num_complex::Complex::<f32>::new(0.0, 0.0); expected_size];
let metal_dir = match direction {
GpuDirection::Forward => oxicuda_metal::fft::MetalFftDirection::Forward,
GpuDirection::Inverse => oxicuda_metal::fft::MetalFftDirection::Inverse,
};
self.inner
.execute(&input_f32, &mut output_f32, metal_dir)
.map_err(|e| GpuError::ExecutionFailed(e.to_string()))?;
let out_data = output.cpu_data_mut();
for (i, c) in output_f32.iter().enumerate() {
out_data[i] = Complex::new(T::from_f64(c.re as f64), T::from_f64(c.im as f64));
}
Ok(())
}
#[must_use]
pub fn log2n(&self) -> u32 {
self.inner.log2n()
}
}
impl Drop for MetalFftPlan {
fn drop(&mut self) {
}
}
#[derive(Debug)]
pub struct MetalBufferHandle {
pub id: u64,
}
pub fn upload_buffer<T: Float>(_buffer: &mut GpuBuffer<T>) -> GpuResult<()> {
Ok(())
}
pub fn download_buffer<T: Float>(_buffer: &mut GpuBuffer<T>) -> GpuResult<()> {
Ok(())
}
pub fn free_buffer(_handle: MetalBufferHandle) -> GpuResult<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metal_availability() {
let _ = is_available();
}
#[test]
fn test_metal_capabilities() {
if is_available() {
let caps = query_capabilities().expect("Failed to query capabilities");
assert_eq!(caps.backend, GpuBackend::Metal);
assert!(caps.supports_f16);
assert!(!caps.supports_f64);
}
}
#[test]
fn test_metal_plan_creation() {
if is_available() {
let plan = MetalFftPlan::new(1024, 1);
assert!(plan.is_ok());
if let Ok(p) = plan {
assert_eq!(p.log2n(), 10);
}
}
}
#[test]
fn test_metal_non_power_of_2() {
if is_available() {
let plan = MetalFftPlan::new(1000, 1);
assert!(plan.is_err());
}
}
#[test]
fn test_metal_fft_correctness_impulse() {
if !is_available() {
return;
}
let n = 64usize;
let plan = MetalFftPlan::new(n, 1).expect("plan creation");
let mut input: GpuBuffer<f32> = GpuBuffer::new(n, GpuBackend::Metal).expect("buffer");
let mut output: GpuBuffer<f32> = GpuBuffer::new(n, GpuBackend::Metal).expect("buffer");
let mut data = vec![Complex::<f32>::zero(); n];
data[0] = Complex::new(1.0f32, 0.0f32);
input.upload(&data).expect("upload");
plan.execute(&input, &mut output, GpuDirection::Forward)
.expect("FFT execute");
let mut result = vec![Complex::<f32>::zero(); n];
output.download(&mut result).expect("download");
for (i, c) in result.iter().enumerate() {
let mag = (c.re * c.re + c.im * c.im).sqrt();
assert!(
(mag - 1.0).abs() < 1e-4,
"bin {i}: expected magnitude 1.0, got {mag}"
);
}
}
#[test]
fn test_metal_fft_round_trip() {
if !is_available() {
return;
}
let n = 128usize;
let plan = MetalFftPlan::new(n, 1).expect("plan");
let original: Vec<Complex<f32>> = (0..n)
.map(|k| {
let t = k as f32 / n as f32;
Complex::new(t.sin(), 0.0f32)
})
.collect();
let mut buf_in: GpuBuffer<f32> = GpuBuffer::new(n, GpuBackend::Metal).expect("buf");
let mut buf_mid: GpuBuffer<f32> = GpuBuffer::new(n, GpuBackend::Metal).expect("buf");
let mut buf_out: GpuBuffer<f32> = GpuBuffer::new(n, GpuBackend::Metal).expect("buf");
buf_in.upload(&original).expect("upload");
plan.execute(&buf_in, &mut buf_mid, GpuDirection::Forward)
.expect("forward");
plan.execute(&buf_mid, &mut buf_out, GpuDirection::Inverse)
.expect("inverse");
let mut recovered = vec![Complex::<f32>::zero(); n];
buf_out.download(&mut recovered).expect("download");
for i in 0..n {
let err = ((recovered[i].re - original[i].re).powi(2)
+ (recovered[i].im - original[i].im).powi(2))
.sqrt();
assert!(
err < 1e-4,
"sample {i}: expected ({}, {}), got ({}, {}), error={err}",
original[i].re,
original[i].im,
recovered[i].re,
recovered[i].im
);
}
}
}