Skip to main content

oxifft/api/
memory.rs

1//! Memory allocation utilities for FFT buffers.
2//!
3//! Provides aligned memory allocation for optimal SIMD performance.
4
5use crate::kernel::{Complex, Float};
6use crate::prelude::*;
7use core::alloc::Layout;
8
9/// Default alignment for FFT buffers (64 bytes for AVX-512).
10pub const DEFAULT_ALIGNMENT: usize = 64;
11
12/// An aligned buffer that guarantees proper alignment for SIMD operations.
13///
14/// This is a wrapper around a raw allocation that ensures the data is aligned
15/// to `DEFAULT_ALIGNMENT` bytes (64 bytes for AVX-512 compatibility).
16pub struct AlignedBuffer<T> {
17    ptr: *mut T,
18    len: usize,
19    capacity: usize,
20}
21
22impl<T: Clone + Default> AlignedBuffer<T> {
23    /// Create a new aligned buffer with the given size, initialized to default values.
24    ///
25    /// # Panics
26    /// Panics if memory allocation fails.
27    #[must_use]
28    pub fn new(size: usize) -> Self {
29        if size == 0 {
30            return Self {
31                ptr: core::ptr::NonNull::dangling().as_ptr(),
32                len: 0,
33                capacity: 0,
34            };
35        }
36
37        let layout = Layout::from_size_align(
38            size * core::mem::size_of::<T>(),
39            DEFAULT_ALIGNMENT.max(core::mem::align_of::<T>()),
40        )
41        .expect("Invalid layout");
42
43        // SAFETY: layout is non-zero size
44        #[cfg(feature = "std")]
45        let ptr = unsafe { std::alloc::alloc_zeroed(layout) as *mut T };
46        #[cfg(not(feature = "std"))]
47        let ptr = unsafe { alloc::alloc::alloc_zeroed(layout) as *mut T };
48
49        if ptr.is_null() {
50            #[cfg(feature = "std")]
51            std::alloc::handle_alloc_error(layout);
52            #[cfg(not(feature = "std"))]
53            alloc::alloc::handle_alloc_error(layout);
54        }
55
56        // Initialize with default values
57        for i in 0..size {
58            // SAFETY: ptr is valid for size elements, and we're within bounds
59            unsafe {
60                core::ptr::write(ptr.add(i), T::default());
61            }
62        }
63
64        Self {
65            ptr,
66            len: size,
67            capacity: size,
68        }
69    }
70
71    /// Get the length of the buffer.
72    #[must_use]
73    pub fn len(&self) -> usize {
74        self.len
75    }
76
77    /// Check if the buffer is empty.
78    #[must_use]
79    pub fn is_empty(&self) -> bool {
80        self.len == 0
81    }
82
83    /// Get a raw pointer to the data.
84    #[must_use]
85    pub fn as_ptr(&self) -> *const T {
86        self.ptr
87    }
88
89    /// Get a mutable raw pointer to the data.
90    #[must_use]
91    pub fn as_mut_ptr(&mut self) -> *mut T {
92        self.ptr
93    }
94
95    /// Get a slice view of the buffer.
96    #[must_use]
97    pub fn as_slice(&self) -> &[T] {
98        if self.len == 0 {
99            &[]
100        } else {
101            // SAFETY: ptr is valid for len elements
102            unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
103        }
104    }
105
106    /// Get a mutable slice view of the buffer.
107    #[must_use]
108    pub fn as_mut_slice(&mut self) -> &mut [T] {
109        if self.len == 0 {
110            &mut []
111        } else {
112            // SAFETY: ptr is valid for len elements
113            unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
114        }
115    }
116}
117
118impl<T> Drop for AlignedBuffer<T> {
119    fn drop(&mut self) {
120        if self.capacity > 0 {
121            let layout = Layout::from_size_align(
122                self.capacity * core::mem::size_of::<T>(),
123                DEFAULT_ALIGNMENT.max(core::mem::align_of::<T>()),
124            )
125            .expect("Invalid layout");
126
127            // SAFETY: ptr was allocated with this layout
128            unsafe {
129                // Drop all elements first
130                for i in 0..self.len {
131                    core::ptr::drop_in_place(self.ptr.add(i));
132                }
133                #[cfg(feature = "std")]
134                std::alloc::dealloc(self.ptr as *mut u8, layout);
135                #[cfg(not(feature = "std"))]
136                alloc::alloc::dealloc(self.ptr as *mut u8, layout);
137            }
138        }
139    }
140}
141
142// SAFETY: AlignedBuffer is Send if T is Send
143unsafe impl<T: Send> Send for AlignedBuffer<T> {}
144
145// SAFETY: AlignedBuffer is Sync if T is Sync
146unsafe impl<T: Sync> Sync for AlignedBuffer<T> {}
147
148impl<T> core::ops::Deref for AlignedBuffer<T> {
149    type Target = [T];
150
151    fn deref(&self) -> &Self::Target {
152        if self.len == 0 {
153            &[]
154        } else {
155            // SAFETY: ptr is valid for len elements
156            unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
157        }
158    }
159}
160
161impl<T> core::ops::DerefMut for AlignedBuffer<T> {
162    fn deref_mut(&mut self) -> &mut Self::Target {
163        if self.len == 0 {
164            &mut []
165        } else {
166            // SAFETY: ptr is valid for len elements
167            unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
168        }
169    }
170}
171
172/// Allocate an aligned buffer for complex values.
173///
174/// The returned vector is guaranteed to have its data pointer
175/// aligned to `DEFAULT_ALIGNMENT` bytes (64 bytes for AVX-512).
176///
177/// Note: This returns a standard Vec which may not be aligned.
178/// For guaranteed alignment, use `AlignedBuffer::new()`.
179pub fn alloc_complex<T: Float>(size: usize) -> Vec<Complex<T>> {
180    vec![Complex::zero(); size]
181}
182
183/// Allocate an aligned buffer for complex values with guaranteed alignment.
184///
185/// The returned buffer is guaranteed to have its data pointer
186/// aligned to `DEFAULT_ALIGNMENT` bytes (64 bytes for AVX-512).
187pub fn alloc_complex_aligned<T: Float>(size: usize) -> AlignedBuffer<Complex<T>> {
188    AlignedBuffer::new(size)
189}
190
191/// Allocate an aligned buffer for real values.
192///
193/// Note: This returns a standard Vec which may not be aligned.
194/// For guaranteed alignment, use `AlignedBuffer::new()`.
195pub fn alloc_real<T: Float>(size: usize) -> Vec<T> {
196    vec![T::ZERO; size]
197}
198
199/// Allocate an aligned buffer for real values with guaranteed alignment.
200///
201/// The returned buffer is guaranteed to have its data pointer
202/// aligned to `DEFAULT_ALIGNMENT` bytes (64 bytes for AVX-512).
203pub fn alloc_real_aligned<T: Float>(size: usize) -> AlignedBuffer<T> {
204    AlignedBuffer::new(size)
205}
206
207/// Free an aligned buffer (for FFI compatibility).
208///
209/// # Safety
210/// The pointer must have been allocated by `alloc_complex` or `alloc_real`.
211pub unsafe fn free<T>(_ptr: *mut T) {
212    // Standard Rust Vec handles deallocation automatically
213    // This is mainly for FFI compatibility
214}
215
216/// Check if a pointer is properly aligned for SIMD operations.
217pub fn is_aligned<T>(ptr: *const T) -> bool {
218    (ptr as usize).is_multiple_of(DEFAULT_ALIGNMENT)
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_aligned_buffer_alignment() {
227        let buf: AlignedBuffer<f64> = AlignedBuffer::new(64);
228        assert!(is_aligned(buf.as_ptr()), "Buffer should be aligned");
229        assert_eq!(buf.len(), 64);
230    }
231
232    #[test]
233    fn test_aligned_buffer_complex() {
234        let buf = alloc_complex_aligned::<f64>(32);
235        assert!(is_aligned(buf.as_ptr()), "Complex buffer should be aligned");
236        assert_eq!(buf.len(), 32);
237    }
238
239    #[test]
240    fn test_aligned_buffer_real() {
241        let buf = alloc_real_aligned::<f64>(32);
242        assert!(is_aligned(buf.as_ptr()), "Real buffer should be aligned");
243        assert_eq!(buf.len(), 32);
244    }
245
246    #[test]
247    fn test_aligned_buffer_empty() {
248        let buf: AlignedBuffer<f64> = AlignedBuffer::new(0);
249        assert!(buf.is_empty());
250        assert_eq!(buf.len(), 0);
251    }
252
253    #[test]
254    fn test_aligned_buffer_access() {
255        let mut buf: AlignedBuffer<f64> = AlignedBuffer::new(4);
256        buf[0] = 1.0;
257        buf[1] = 2.0;
258        buf[2] = 3.0;
259        buf[3] = 4.0;
260
261        assert_eq!(buf[0], 1.0);
262        assert_eq!(buf[1], 2.0);
263        assert_eq!(buf[2], 3.0);
264        assert_eq!(buf[3], 4.0);
265    }
266
267    #[test]
268    fn test_aligned_buffer_slice() {
269        let buf: AlignedBuffer<f64> = AlignedBuffer::new(4);
270        let slice = buf.as_slice();
271        assert_eq!(slice.len(), 4);
272    }
273}