#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::string::{String, ToString};
use super::buffer::GpuBuffer;
use super::error::{GpuError, GpuResult};
use super::plan::GpuDirection;
use super::GpuBackend;
use super::GpuCapabilities;
use crate::kernel::{Complex, Float};
#[must_use]
pub fn is_available() -> bool {
#[cfg(target_os = "macos")]
{
true
}
#[cfg(target_os = "ios")]
{
true
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
false
}
}
pub fn query_capabilities() -> GpuResult<GpuCapabilities> {
if !is_available() {
return Err(GpuError::NoBackendAvailable);
}
Ok(GpuCapabilities {
backend: GpuBackend::Metal,
device_name: get_device_name(),
total_memory: 0, available_memory: 0,
max_fft_size: 1 << 24, supports_f64: false, supports_f16: true, compute_units: 0,
max_workgroup_size: 1024,
})
}
fn get_device_name() -> String {
#[cfg(target_os = "macos")]
{
"Apple GPU".to_string()
}
#[cfg(not(target_os = "macos"))]
{
"Unknown Metal Device".to_string()
}
}
pub fn synchronize() -> GpuResult<()> {
Ok(())
}
#[derive(Debug)]
pub struct MetalFftPlan {
size: usize,
batch_size: usize,
log2n: u32,
#[allow(dead_code)]
use_mps: bool,
}
impl MetalFftPlan {
pub fn new(size: usize, batch_size: usize) -> GpuResult<Self> {
if !is_available() {
return Err(GpuError::NoBackendAvailable);
}
if size == 0 {
return Err(GpuError::InvalidSize(size));
}
let log2n = if size.is_power_of_two() {
size.trailing_zeros()
} else {
return Err(GpuError::Unsupported(
"Metal FFT currently requires power-of-2 sizes".into(),
));
};
Ok(Self {
size,
batch_size,
log2n,
use_mps: true,
})
}
pub fn execute<T: Float>(
&self,
input: &GpuBuffer<T>,
output: &mut GpuBuffer<T>,
direction: GpuDirection,
) -> GpuResult<()> {
let expected_size = self.size * self.batch_size;
if input.size() != expected_size || output.size() != expected_size {
return Err(GpuError::SizeMismatch {
expected: expected_size,
got: input.size().min(output.size()),
});
}
self.execute_fallback(input, output, direction)
}
fn execute_fallback<T: Float>(
&self,
input: &GpuBuffer<T>,
output: &mut GpuBuffer<T>,
direction: GpuDirection,
) -> GpuResult<()> {
use crate::api::{Direction, Flags, Plan};
let dir = match direction {
GpuDirection::Forward => Direction::Forward,
GpuDirection::Inverse => Direction::Backward,
};
for batch in 0..self.batch_size {
let start = batch * self.size;
let end = start + self.size;
let input_slice = &input.cpu_data()[start..end];
let output_slice = &mut output.cpu_data_mut()[start..end];
if let Some(plan) = Plan::dft_1d(self.size, dir, Flags::ESTIMATE) {
let input_f64: Vec<Complex<f64>> = input_slice
.iter()
.map(|c| {
Complex::new(c.re.to_f64().unwrap_or(0.0), c.im.to_f64().unwrap_or(0.0))
})
.collect();
let mut output_f64 = vec![Complex::<f64>::zero(); self.size];
plan.execute(&input_f64, &mut output_f64);
for (i, c) in output_f64.iter().enumerate() {
output_slice[i] = Complex::new(T::from_f64(c.re), T::from_f64(c.im));
}
} else {
return Err(GpuError::ExecutionFailed(
"Failed to create CPU fallback plan".into(),
));
}
}
Ok(())
}
#[must_use]
pub const fn log2n(&self) -> u32 {
self.log2n
}
}
impl Drop for MetalFftPlan {
fn drop(&mut self) {
}
}
#[derive(Debug)]
pub struct MetalBufferHandle {
pub id: u64,
}
pub fn upload_buffer<T: Float>(_buffer: &mut GpuBuffer<T>) -> GpuResult<()> {
Ok(())
}
pub fn download_buffer<T: Float>(_buffer: &mut GpuBuffer<T>) -> GpuResult<()> {
Ok(())
}
pub fn free_buffer(_handle: MetalBufferHandle) -> GpuResult<()> {
Ok(())
}
#[allow(dead_code)]
const FFT_SHADER_SOURCE: &str = r"
#include <metal_stdlib>
using namespace metal;
// Complex number type
struct Complex {
float re;
float im;
};
// Complex multiplication
Complex cmul(Complex a, Complex b) {
return Complex{a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re};
}
// Complex addition
Complex cadd(Complex a, Complex b) {
return Complex{a.re + b.re, a.im + b.im};
}
// Complex subtraction
Complex csub(Complex a, Complex b) {
return Complex{a.re - b.re, a.im - b.im};
}
// Twiddle factor
Complex twiddle(uint k, uint n, bool inverse) {
float angle = (inverse ? 1.0 : -1.0) * 2.0 * M_PI_F * float(k) / float(n);
return Complex{cos(angle), sin(angle)};
}
// Radix-2 butterfly kernel
kernel void fft_butterfly(
device Complex* data [[buffer(0)]],
constant uint& stage [[buffer(1)]],
constant uint& n [[buffer(2)]],
constant bool& inverse [[buffer(3)]],
uint gid [[thread_position_in_grid]]
) {
uint butterfly_size = 1u << (stage + 1);
uint half_size = butterfly_size >> 1;
uint group = gid / half_size;
uint pair = gid % half_size;
uint i = group * butterfly_size + pair;
uint j = i + half_size;
Complex w = twiddle(pair, butterfly_size, inverse);
Complex u = data[i];
Complex t = cmul(w, data[j]);
data[i] = cadd(u, t);
data[j] = csub(u, t);
}
// Bit-reversal permutation kernel
kernel void bit_reverse(
device Complex* input [[buffer(0)]],
device Complex* output [[buffer(1)]],
constant uint& log2n [[buffer(2)]],
uint gid [[thread_position_in_grid]]
) {
uint rev = 0;
uint idx = gid;
for (uint i = 0; i < log2n; i++) {
rev = (rev << 1) | (idx & 1);
idx >>= 1;
}
output[rev] = input[gid];
}
";
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metal_availability() {
let _ = is_available();
}
#[test]
fn test_metal_capabilities() {
if is_available() {
let caps = query_capabilities().expect("Failed to query capabilities");
assert_eq!(caps.backend, GpuBackend::Metal);
assert!(caps.supports_f16);
}
}
#[test]
fn test_metal_plan_creation() {
if is_available() {
let plan = MetalFftPlan::new(1024, 1);
assert!(plan.is_ok());
if let Ok(p) = plan {
assert_eq!(p.log2n(), 10);
}
}
}
#[test]
fn test_metal_non_power_of_2() {
if is_available() {
let plan = MetalFftPlan::new(1000, 1);
assert!(plan.is_err());
}
}
}