Skip to main content

apple_accelerate/
vdsp.rs

1use crate::bridge;
2use crate::error::{Error, Result};
3use core::ffi::c_void;
4use core::ptr;
5
6/// `FFTDirection` constants.
7pub mod fft_direction {
8    /// `FFTDirection` value for forward transforms passed to `vDSP_fft_zip`.
9    pub const FORWARD: i32 = 1;
10    /// `FFTDirection` value for inverse transforms passed to `vDSP_fft_zip`.
11    pub const INVERSE: i32 = -1;
12}
13
14/// `FFTRadix` constants.
15pub mod fft_radix {
16    /// `FFTRadix` value for radix-2 setups passed to `vDSP_create_fftsetup`.
17    pub const RADIX2: i32 = 0;
18    /// `FFTRadix` value for radix-3 setups passed to `vDSP_create_fftsetup`.
19    pub const RADIX3: i32 = 1;
20    /// `FFTRadix` value for radix-5 setups passed to `vDSP_create_fftsetup`.
21    pub const RADIX5: i32 = 2;
22}
23
24/// Window-generation flags.
25pub mod window_flags {
26    /// `vDSP_HALF_WINDOW` flag for `vDSP_hamm_window` and `vDSP_blkman_window`.
27    pub const HALF_WINDOW: i32 = 1;
28    /// `vDSP_HANN_DENORM` flag for `vDSP_hamm_window` and `vDSP_blkman_window`.
29    pub const HANN_DENORM: i32 = 0;
30    /// `vDSP_HANN_NORM` flag for `vDSP_hamm_window` and `vDSP_blkman_window`.
31    pub const HANN_NORM: i32 = 2;
32}
33
34/// Owned `FFTSetup` handle backed by the Swift bridge.
35pub struct FftSetup {
36    ptr: *mut c_void,
37}
38
39unsafe impl Send for FftSetup {}
40unsafe impl Sync for FftSetup {}
41
42impl Drop for FftSetup {
43    fn drop(&mut self) {
44        if !self.ptr.is_null() {
45            // SAFETY: `ptr` is an opaque Swift object retained by the bridge.
46            unsafe { bridge::acc_release_handle(self.ptr) };
47            self.ptr = ptr::null_mut();
48        }
49    }
50}
51
52impl FftSetup {
53    /// Creates an `FFTSetup` with `vDSP_create_fftsetup`.
54    #[must_use]
55    pub fn new(log2n: usize, radix: i32) -> Option<Self> {
56        // SAFETY: Pure constructor over scalar inputs.
57        let ptr = unsafe { bridge::acc_vdsp_fft_setup_create(log2n, radix) };
58        if ptr.is_null() {
59            None
60        } else {
61            Some(Self { ptr })
62        }
63    }
64
65    /// Wraps `vDSP_fft_zip` for split-complex single-precision buffers.
66    pub fn fft_zip(
67        &self,
68        real: &mut [f32],
69        imag: &mut [f32],
70        log2n: usize,
71        direction: i32,
72    ) -> Result<()> {
73        let shift = u32::try_from(log2n)
74            .map_err(|_| Error::OperationFailed("FFT log2 length exceeds u32"))?;
75        let expected = 1_usize
76            .checked_shl(shift)
77            .ok_or(Error::OperationFailed("FFT length overflowed"))?;
78        if real.len() != expected {
79            return Err(Error::InvalidLength {
80                expected,
81                actual: real.len(),
82            });
83        }
84        if imag.len() != expected {
85            return Err(Error::InvalidLength {
86                expected,
87                actual: imag.len(),
88            });
89        }
90
91        // SAFETY: Buffers are valid for `expected` elements and `self.ptr` is a live bridge handle.
92        let ok = unsafe {
93            bridge::acc_vdsp_fft_setup_apply(
94                self.ptr,
95                real.as_mut_ptr(),
96                imag.as_mut_ptr(),
97                log2n,
98                direction,
99            )
100        };
101        if ok {
102            Ok(())
103        } else {
104            Err(Error::OperationFailed("vDSP FFT operation failed"))
105        }
106    }
107}
108
109/// Owned `vDSP_biquad_Setup` handle backed by the Swift bridge.
110pub struct BiquadSetup {
111    ptr: *mut c_void,
112}
113
114unsafe impl Send for BiquadSetup {}
115unsafe impl Sync for BiquadSetup {}
116
117impl Drop for BiquadSetup {
118    fn drop(&mut self) {
119        if !self.ptr.is_null() {
120            // SAFETY: `ptr` is an opaque Swift object retained by the bridge.
121            unsafe { bridge::acc_release_handle(self.ptr) };
122            self.ptr = ptr::null_mut();
123        }
124    }
125}
126
127impl BiquadSetup {
128    /// Creates a `vDSP_biquad_Setup` with `vDSP_biquad_CreateSetup`.
129    #[must_use]
130    pub fn new(coefficients: &[f64]) -> Option<Self> {
131        if coefficients.is_empty() || coefficients.len() % 5 != 0 {
132            return None;
133        }
134
135        // SAFETY: `coefficients` is valid for `count` contiguous `f64` values.
136        let ptr = unsafe {
137            bridge::acc_vdsp_biquad_setup_create(coefficients.as_ptr(), coefficients.len())
138        };
139        if ptr.is_null() {
140            None
141        } else {
142            Some(Self { ptr })
143        }
144    }
145
146    /// Wraps `vDSP_biquad` for single-precision input and output buffers.
147    pub fn apply(&self, delay: &mut [f32], input: &[f32], output: &mut [f32]) -> Result<()> {
148        if delay.is_empty() {
149            return Err(Error::InvalidLength {
150                expected: 1,
151                actual: 0,
152            });
153        }
154        if input.len() != output.len() {
155            return Err(Error::InvalidLength {
156                expected: input.len(),
157                actual: output.len(),
158            });
159        }
160
161        // SAFETY: Buffers are valid and `self.ptr` is a live bridge handle.
162        let ok = unsafe {
163            bridge::acc_vdsp_biquad_setup_apply(
164                self.ptr,
165                delay.as_mut_ptr(),
166                input.as_ptr(),
167                output.as_mut_ptr(),
168                input.len(),
169            )
170        };
171        if ok {
172            Ok(())
173        } else {
174            Err(Error::OperationFailed("vDSP biquad operation failed"))
175        }
176    }
177}
178
179type BinaryVectorOpF32 = unsafe extern "C" fn(*const f32, *const f32, *mut f32, usize) -> bool;
180type BinaryVectorOpF64 = unsafe extern "C" fn(*const f64, *const f64, *mut f64, usize) -> bool;
181type ReduceOpF32 = unsafe extern "C" fn(*const f32, *mut f32, usize) -> bool;
182type ReduceOpF64 = unsafe extern "C" fn(*const f64, *mut f64, usize) -> bool;
183type WindowOpF32 = unsafe extern "C" fn(*mut f32, usize, i32) -> bool;
184type WindowOpF64 = unsafe extern "C" fn(*mut f64, usize, i32) -> bool;
185
186fn binary_vector_op_f32(a: &[f32], b: &[f32], f: BinaryVectorOpF32) -> Result<Vec<f32>> {
187    if a.len() != b.len() {
188        return Err(Error::InvalidLength {
189            expected: a.len(),
190            actual: b.len(),
191        });
192    }
193
194    let mut out = vec![0.0_f32; a.len()];
195    // SAFETY: All slices are valid for `a.len()` contiguous `f32` elements.
196    let ok = unsafe { f(a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), a.len()) };
197    if ok {
198        Ok(out)
199    } else {
200        Err(Error::OperationFailed("vDSP vector operation failed"))
201    }
202}
203
204fn binary_vector_op_f64(a: &[f64], b: &[f64], f: BinaryVectorOpF64) -> Result<Vec<f64>> {
205    if a.len() != b.len() {
206        return Err(Error::InvalidLength {
207            expected: a.len(),
208            actual: b.len(),
209        });
210    }
211
212    let mut out = vec![0.0_f64; a.len()];
213    // SAFETY: All slices are valid for `a.len()` contiguous `f64` elements.
214    let ok = unsafe { f(a.as_ptr(), b.as_ptr(), out.as_mut_ptr(), a.len()) };
215    if ok {
216        Ok(out)
217    } else {
218        Err(Error::OperationFailed("vDSP vector operation failed"))
219    }
220}
221
222fn reduce_f32(values: &[f32], f: ReduceOpF32) -> Result<f32> {
223    if values.is_empty() {
224        return Err(Error::InvalidLength {
225            expected: 1,
226            actual: 0,
227        });
228    }
229
230    let mut out = 0.0_f32;
231    // SAFETY: The slice is valid for `values.len()` contiguous `f32` elements.
232    let ok = unsafe { f(values.as_ptr(), &mut out, values.len()) };
233    if ok {
234        Ok(out)
235    } else {
236        Err(Error::OperationFailed("vDSP reduction failed"))
237    }
238}
239
240fn reduce_f64(values: &[f64], f: ReduceOpF64) -> Result<f64> {
241    if values.is_empty() {
242        return Err(Error::InvalidLength {
243            expected: 1,
244            actual: 0,
245        });
246    }
247
248    let mut out = 0.0_f64;
249    // SAFETY: The slice is valid for `values.len()` contiguous `f64` elements.
250    let ok = unsafe { f(values.as_ptr(), &mut out, values.len()) };
251    if ok {
252        Ok(out)
253    } else {
254        Err(Error::OperationFailed("vDSP reduction failed"))
255    }
256}
257
258#[must_use]
259fn window_f32(length: usize, flags: i32, f: WindowOpF32) -> Vec<f32> {
260    let mut out = vec![0.0_f32; length];
261    if length == 0 {
262        return out;
263    }
264
265    // SAFETY: `out` is valid for `length` contiguous `f32` values.
266    let _ = unsafe { f(out.as_mut_ptr(), length, flags) };
267    out
268}
269
270#[must_use]
271fn window_f64(length: usize, flags: i32, f: WindowOpF64) -> Vec<f64> {
272    let mut out = vec![0.0_f64; length];
273    if length == 0 {
274        return out;
275    }
276
277    // SAFETY: `out` is valid for `length` contiguous `f64` values.
278    let _ = unsafe { f(out.as_mut_ptr(), length, flags) };
279    out
280}
281
282/// Wraps `vDSP_vadd`.
283pub fn add_f32(a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
284    binary_vector_op_f32(a, b, bridge::acc_vdsp_add_f32)
285}
286
287/// Wraps `vDSP_vaddD`.
288pub fn add_f64(a: &[f64], b: &[f64]) -> Result<Vec<f64>> {
289    binary_vector_op_f64(a, b, bridge::acc_vdsp_add_f64)
290}
291
292/// Wraps `vDSP_vsub`.
293pub fn sub_f32(a: &[f32], b: &[f32]) -> Result<Vec<f32>> {
294    binary_vector_op_f32(a, b, bridge::acc_vdsp_sub_f32)
295}
296
297/// Wraps `vDSP_vsubD`.
298pub fn sub_f64(a: &[f64], b: &[f64]) -> Result<Vec<f64>> {
299    binary_vector_op_f64(a, b, bridge::acc_vdsp_sub_f64)
300}
301
302/// Wraps `vDSP_dotpr`.
303pub fn dot_f32(a: &[f32], b: &[f32]) -> Result<f32> {
304    if a.len() != b.len() {
305        return Err(Error::InvalidLength {
306            expected: a.len(),
307            actual: b.len(),
308        });
309    }
310
311    let mut out = 0.0_f32;
312    // SAFETY: The slices are valid for `a.len()` contiguous `f32` elements.
313    let ok = unsafe { bridge::acc_vdsp_dot_f32(a.as_ptr(), b.as_ptr(), &mut out, a.len()) };
314    if ok {
315        Ok(out)
316    } else {
317        Err(Error::OperationFailed("vDSP dot-product failed"))
318    }
319}
320
321/// Wraps `vDSP_dotprD`.
322pub fn dot_f64(a: &[f64], b: &[f64]) -> Result<f64> {
323    if a.len() != b.len() {
324        return Err(Error::InvalidLength {
325            expected: a.len(),
326            actual: b.len(),
327        });
328    }
329
330    let mut out = 0.0_f64;
331    // SAFETY: The slices are valid for `a.len()` contiguous `f64` elements.
332    let ok = unsafe { bridge::acc_vdsp_dot_f64(a.as_ptr(), b.as_ptr(), &mut out, a.len()) };
333    if ok {
334        Ok(out)
335    } else {
336        Err(Error::OperationFailed("vDSP dot-product failed"))
337    }
338}
339
340/// Wraps `vDSP_maxv`.
341pub fn max_f32(values: &[f32]) -> Result<f32> {
342    reduce_f32(values, bridge::acc_vdsp_max_f32)
343}
344
345/// Wraps `vDSP_maxvD`.
346pub fn max_f64(values: &[f64]) -> Result<f64> {
347    reduce_f64(values, bridge::acc_vdsp_max_f64)
348}
349
350/// Wraps `vDSP_minv`.
351pub fn min_f32(values: &[f32]) -> Result<f32> {
352    reduce_f32(values, bridge::acc_vdsp_min_f32)
353}
354
355/// Wraps `vDSP_minvD`.
356pub fn min_f64(values: &[f64]) -> Result<f64> {
357    reduce_f64(values, bridge::acc_vdsp_min_f64)
358}
359
360/// Wraps `vDSP_meanv`.
361pub fn mean_f32(values: &[f32]) -> Result<f32> {
362    reduce_f32(values, bridge::acc_vdsp_mean_f32)
363}
364
365/// Wraps `vDSP_meanvD`.
366pub fn mean_f64(values: &[f64]) -> Result<f64> {
367    reduce_f64(values, bridge::acc_vdsp_mean_f64)
368}
369
370/// Wraps `vDSP_sve`.
371pub fn sum_f32(values: &[f32]) -> Result<f32> {
372    reduce_f32(values, bridge::acc_vdsp_sum_f32)
373}
374
375/// Wraps `vDSP_sveD`.
376pub fn sum_f64(values: &[f64]) -> Result<f64> {
377    reduce_f64(values, bridge::acc_vdsp_sum_f64)
378}
379
380/// Wraps `vDSP_hamm_window`.
381#[must_use]
382pub fn hamming_window(length: usize, flags: i32) -> Vec<f32> {
383    window_f32(length, flags, bridge::acc_vdsp_hamming_window)
384}
385
386/// Wraps `vDSP_hamm_windowD`.
387#[must_use]
388pub fn hamming_window_f64(length: usize, flags: i32) -> Vec<f64> {
389    window_f64(length, flags, bridge::acc_vdsp_hamming_window_f64)
390}
391
392/// Wraps `vDSP_blkman_window`.
393#[must_use]
394pub fn blackman_window(length: usize, flags: i32) -> Vec<f32> {
395    window_f32(length, flags, bridge::acc_vdsp_blackman_window)
396}
397
398/// Wraps `vDSP_blkman_windowD`.
399#[must_use]
400pub fn blackman_window_f64(length: usize, flags: i32) -> Vec<f64> {
401    window_f64(length, flags, bridge::acc_vdsp_blackman_window_f64)
402}