Skip to main content

oxicuda_memory/
aligned.rs

1//! Aligned GPU memory allocation for optimal access patterns.
2//!
3//! This module provides [`AlignedBuffer<T>`], a device memory buffer that
4//! guarantees a specific alignment for the starting address.  Proper alignment
5//! is critical for coalesced memory accesses on GPUs — misaligned loads and
6//! stores can incur extra memory transactions, significantly hurting
7//! throughput.
8//!
9//! # Alignment options
10//!
11//! | Variant          | Bytes   | Use case                                    |
12//! |------------------|---------|---------------------------------------------|
13//! | `Default`        | 256     | CUDA's natural allocation alignment          |
14//! | `Align256`       | 256     | Explicit 256-byte alignment                  |
15//! | `Align512`       | 512     | Optimal for many GPU texture/surface ops     |
16//! | `Align1024`      | 1024    | Large-stride access patterns                 |
17//! | `Align4096`      | 4096    | Page-aligned for unified/mapped memory       |
18//! | `Custom(n)`      | n       | User-specified (must be a power of two)      |
19//!
20//! # Example
21//!
22//! ```rust,no_run
23//! # use oxicuda_memory::aligned::{Alignment, AlignedBuffer};
24//! let buf = AlignedBuffer::<f32>::alloc(1024, Alignment::Align512)?;
25//! assert!(buf.is_aligned());
26//! assert_eq!(buf.as_device_ptr() % 512, 0);
27//! # Ok::<(), oxicuda_driver::error::CudaError>(())
28//! ```
29
30use std::marker::PhantomData;
31
32use oxicuda_driver::error::{CudaError, CudaResult};
33use oxicuda_driver::ffi::CUdeviceptr;
34#[cfg(not(target_os = "macos"))]
35use oxicuda_driver::loader::try_driver;
36
37// ---------------------------------------------------------------------------
38// Alignment enum
39// ---------------------------------------------------------------------------
40
41/// Specifies the byte alignment for a device memory allocation.
42///
43/// All variants represent alignments that are powers of two.  The `Custom`
44/// variant is validated at allocation time.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum Alignment {
47    /// CUDA's default allocation alignment (typically 256 bytes).
48    Default,
49    /// 256-byte alignment.
50    Align256,
51    /// 512-byte alignment.
52    Align512,
53    /// 1024-byte alignment.
54    Align1024,
55    /// 4096-byte (page) alignment.
56    Align4096,
57    /// User-specified alignment in bytes (must be a power of two).
58    Custom(usize),
59}
60
61impl Alignment {
62    /// Returns the alignment in bytes.
63    ///
64    /// For [`Default`](Alignment::Default), this returns 256 (the typical CUDA
65    /// allocation alignment).
66    #[inline]
67    pub fn bytes(&self) -> usize {
68        match self {
69            Self::Default => 256,
70            Self::Align256 => 256,
71            Self::Align512 => 512,
72            Self::Align1024 => 1024,
73            Self::Align4096 => 4096,
74            Self::Custom(n) => *n,
75        }
76    }
77
78    /// Returns `true` if the alignment value is a power of two.
79    ///
80    /// This is always `true` for the named variants and may be `false` for
81    /// [`Custom`](Alignment::Custom) with an invalid value.
82    #[inline]
83    pub fn is_power_of_two(&self) -> bool {
84        let b = self.bytes();
85        b > 0 && (b & (b - 1)) == 0
86    }
87
88    /// Returns `true` if the given device pointer satisfies this alignment.
89    #[inline]
90    pub fn is_aligned(&self, ptr: u64) -> bool {
91        let b = self.bytes() as u64;
92        if b == 0 {
93            return false;
94        }
95        (ptr % b) == 0
96    }
97}
98
99impl std::fmt::Display for Alignment {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            Self::Default => write!(f, "Default(256)"),
103            Self::Align256 => write!(f, "256"),
104            Self::Align512 => write!(f, "512"),
105            Self::Align1024 => write!(f, "1024"),
106            Self::Align4096 => write!(f, "4096"),
107            Self::Custom(n) => write!(f, "Custom({n})"),
108        }
109    }
110}
111
112// ---------------------------------------------------------------------------
113// Validation helpers
114// ---------------------------------------------------------------------------
115
116/// Maximum alignment we allow (256 MiB).  Anything beyond this is almost
117/// certainly a programming error.
118const MAX_ALIGNMENT: usize = 256 * 1024 * 1024;
119
120/// Validates that an [`Alignment`] is a power of two and within a reasonable
121/// range.
122///
123/// # Errors
124///
125/// Returns [`CudaError::InvalidValue`] if:
126/// - The alignment is zero.
127/// - The alignment is not a power of two.
128/// - The alignment exceeds 256 MiB.
129pub fn validate_alignment(alignment: &Alignment) -> CudaResult<()> {
130    let b = alignment.bytes();
131    if b == 0 {
132        return Err(CudaError::InvalidValue);
133    }
134    if !alignment.is_power_of_two() {
135        return Err(CudaError::InvalidValue);
136    }
137    if b > MAX_ALIGNMENT {
138        return Err(CudaError::InvalidValue);
139    }
140    Ok(())
141}
142
143/// Rounds `bytes` up to the next multiple of `alignment`.
144///
145/// `alignment` must be a power of two; otherwise the result is unspecified.
146///
147/// # Examples
148///
149/// ```
150/// # use oxicuda_memory::aligned::round_up_to_alignment;
151/// assert_eq!(round_up_to_alignment(100, 256), 256);
152/// assert_eq!(round_up_to_alignment(256, 256), 256);
153/// assert_eq!(round_up_to_alignment(257, 256), 512);
154/// assert_eq!(round_up_to_alignment(0, 256), 0);
155/// ```
156#[inline]
157pub fn round_up_to_alignment(bytes: usize, alignment: usize) -> usize {
158    if alignment == 0 {
159        return bytes;
160    }
161    let mask = alignment - 1;
162    (bytes + mask) & !mask
163}
164
165/// Recommends an optimal [`Alignment`] for a type based on its size.
166///
167/// The heuristic prefers alignments that enable coalesced memory accesses:
168///
169/// - Types of 16 bytes or more benefit from 512-byte alignment because a
170///   32-thread warp issuing 16-byte loads touches exactly 512 bytes.
171/// - Types of 8 bytes benefit from 256-byte alignment (warp touches 256 bytes).
172/// - Smaller types use the CUDA default (256 bytes).
173pub fn optimal_alignment_for_type<T>() -> Alignment {
174    let size = std::mem::size_of::<T>();
175    if size >= 16 {
176        Alignment::Align512
177    } else if size >= 8 {
178        Alignment::Align256
179    } else {
180        Alignment::Default
181    }
182}
183
184/// Computes the smallest alignment that ensures coalesced memory access for a
185/// given `access_width` (in bytes) across a warp of `warp_size` threads.
186///
187/// The coalesced access pattern requires that `warp_size * access_width` bytes
188/// are naturally aligned to the segment boundary used by the memory controller.
189/// This function returns the smallest power-of-two alignment that is at least
190/// `warp_size * access_width` bytes, capped at 4096 (page alignment).
191///
192/// # Examples
193///
194/// ```
195/// # use oxicuda_memory::aligned::coalesce_alignment;
196/// // 32 threads × 4 bytes = 128 → rounded up to 128
197/// assert_eq!(coalesce_alignment(4, 32), 128);
198/// // 32 threads × 16 bytes = 512
199/// assert_eq!(coalesce_alignment(16, 32), 512);
200/// ```
201pub fn coalesce_alignment(access_width: usize, warp_size: u32) -> usize {
202    let total = (warp_size as usize).saturating_mul(access_width);
203    if total == 0 {
204        return 1;
205    }
206    // Round up to the next power of two.
207    let pot = total.next_power_of_two();
208    // Cap at page alignment.
209    pot.min(4096)
210}
211
212// ---------------------------------------------------------------------------
213// AlignmentInfo
214// ---------------------------------------------------------------------------
215
216/// Information about the alignment of an existing device pointer.
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub struct AlignmentInfo {
219    /// The device pointer that was inspected.
220    pub ptr: CUdeviceptr,
221    /// The largest power-of-two alignment that the pointer satisfies.
222    pub natural_alignment: usize,
223    /// Whether the pointer is 256-byte aligned.
224    pub is_256_aligned: bool,
225    /// Whether the pointer is 512-byte aligned.
226    pub is_512_aligned: bool,
227    /// Whether the pointer is page-aligned (4096 bytes).
228    pub is_page_aligned: bool,
229}
230
231/// Inspects a device pointer and reports its alignment characteristics.
232///
233/// For a null (zero) pointer the natural alignment is reported as `usize::MAX`
234/// because zero is trivially aligned to every power of two.
235pub fn check_alignment(ptr: CUdeviceptr) -> AlignmentInfo {
236    let natural = if ptr == 0 {
237        // Zero is aligned to every power of two; report maximum.
238        usize::MAX
239    } else {
240        // The largest power-of-two factor is 2^(trailing_zeros).
241        1_usize << (ptr.trailing_zeros().min(63))
242    };
243    AlignmentInfo {
244        ptr,
245        natural_alignment: natural,
246        is_256_aligned: (ptr % 256) == 0,
247        is_512_aligned: (ptr % 512) == 0,
248        is_page_aligned: (ptr % 4096) == 0,
249    }
250}
251
252// ---------------------------------------------------------------------------
253// AlignedBuffer<T>
254// ---------------------------------------------------------------------------
255
256/// A device memory buffer whose starting address is guaranteed to meet the
257/// requested [`Alignment`].
258///
259/// Internally this may over-allocate by up to `alignment - 1` extra bytes and
260/// offset the user-visible pointer so that it lands on an aligned boundary.
261/// The extra bytes (if any) are reported by [`wasted_bytes`](Self::wasted_bytes).
262///
263/// The buffer frees the *original* (unaligned) allocation on [`Drop`].
264pub struct AlignedBuffer<T: Copy> {
265    /// The aligned device pointer presented to the user.
266    ptr: CUdeviceptr,
267    /// Number of `T` elements.
268    len: usize,
269    /// Total bytes allocated (may be larger than `len * size_of::<T>()`).
270    allocated_bytes: usize,
271    /// The alignment that was requested.
272    alignment: Alignment,
273    /// Byte offset from the raw allocation base to `ptr`.
274    offset: usize,
275    /// The raw allocation base pointer (what we pass to `cuMemFree`).
276    #[cfg_attr(target_os = "macos", allow(dead_code))]
277    raw_ptr: CUdeviceptr,
278    /// Phantom marker for `T`.
279    _phantom: PhantomData<T>,
280}
281
282// SAFETY: Same reasoning as `DeviceBuffer<T>` — the `u64` device pointer
283// handle is managed by the thread-safe CUDA driver.
284unsafe impl<T: Copy + Send> Send for AlignedBuffer<T> {}
285unsafe impl<T: Copy + Sync> Sync for AlignedBuffer<T> {}
286
287impl<T: Copy> AlignedBuffer<T> {
288    /// Allocates an aligned device buffer capable of holding `n` elements of
289    /// type `T`.
290    ///
291    /// The returned buffer's device pointer is guaranteed to be aligned to
292    /// `alignment.bytes()`.  The allocation may be slightly larger than
293    /// `n * size_of::<T>()` to accommodate the alignment offset.
294    ///
295    /// # Errors
296    ///
297    /// * [`CudaError::InvalidValue`] if `n` is zero, alignment is invalid, or
298    ///   the byte-size computation overflows.
299    /// * [`CudaError::OutOfMemory`] if the GPU cannot satisfy the allocation.
300    pub fn alloc(n: usize, alignment: Alignment) -> CudaResult<Self> {
301        if n == 0 {
302            return Err(CudaError::InvalidValue);
303        }
304        validate_alignment(&alignment)?;
305
306        let elem_bytes = n
307            .checked_mul(std::mem::size_of::<T>())
308            .ok_or(CudaError::InvalidValue)?;
309
310        let align_bytes = alignment.bytes();
311
312        // Over-allocate by (alignment - 1) so we can always find an aligned
313        // address within the allocation.
314        let extra = align_bytes.saturating_sub(1);
315        let total_bytes = elem_bytes
316            .checked_add(extra)
317            .ok_or(CudaError::InvalidValue)?;
318
319        #[cfg(target_os = "macos")]
320        let (raw_ptr, aligned_ptr, offset) = {
321            // On macOS there is no CUDA driver.  Simulate with a synthetic
322            // pointer that mimics typical driver behaviour (256-byte aligned
323            // base).  Tests can exercise the alignment arithmetic.
324            let base: CUdeviceptr = 0x0000_0001_0000_0100; // 256-byte aligned
325            let aligned = round_up_to_alignment(base as usize, align_bytes) as CUdeviceptr;
326            let off = (aligned - base) as usize;
327            (base, aligned, off)
328        };
329
330        #[cfg(not(target_os = "macos"))]
331        let (raw_ptr, aligned_ptr, offset) = {
332            let api = try_driver()?;
333            let mut base: CUdeviceptr = 0;
334            let rc = unsafe { (api.cu_mem_alloc_v2)(&mut base, total_bytes) };
335            oxicuda_driver::check(rc)?;
336            let aligned = round_up_to_alignment(base as usize, align_bytes) as CUdeviceptr;
337            let off = (aligned - base) as usize;
338            (base, aligned, off)
339        };
340
341        Ok(Self {
342            ptr: aligned_ptr,
343            len: n,
344            allocated_bytes: total_bytes,
345            alignment,
346            offset,
347            raw_ptr,
348            _phantom: PhantomData,
349        })
350    }
351
352    /// Returns the aligned device pointer.
353    #[inline]
354    pub fn as_device_ptr(&self) -> CUdeviceptr {
355        self.ptr
356    }
357
358    /// Returns the number of `T` elements in this buffer.
359    #[inline]
360    pub fn len(&self) -> usize {
361        self.len
362    }
363
364    /// Returns `true` if the buffer contains zero elements.
365    ///
366    /// In practice this is always `false` because [`alloc`](Self::alloc)
367    /// rejects zero-length allocations.
368    #[inline]
369    pub fn is_empty(&self) -> bool {
370        self.len == 0
371    }
372
373    /// Returns a reference to the alignment that was requested.
374    #[inline]
375    pub fn alignment(&self) -> &Alignment {
376        &self.alignment
377    }
378
379    /// Returns the number of bytes wasted for alignment padding.
380    ///
381    /// This is the difference between the total allocation size and the
382    /// minimum required (`len * size_of::<T>()`).
383    #[inline]
384    pub fn wasted_bytes(&self) -> usize {
385        let needed = self.len * std::mem::size_of::<T>();
386        self.allocated_bytes.saturating_sub(needed)
387    }
388
389    /// Returns `true` if the buffer's device pointer satisfies the requested
390    /// alignment.
391    #[inline]
392    pub fn is_aligned(&self) -> bool {
393        self.alignment.is_aligned(self.ptr)
394    }
395
396    /// Returns the total number of bytes that were allocated (including
397    /// alignment padding).
398    #[inline]
399    pub fn allocated_bytes(&self) -> usize {
400        self.allocated_bytes
401    }
402
403    /// Returns the byte offset from the raw allocation base to the aligned
404    /// pointer.
405    #[inline]
406    pub fn offset(&self) -> usize {
407        self.offset
408    }
409}
410
411impl<T: Copy> Drop for AlignedBuffer<T> {
412    fn drop(&mut self) {
413        // Free the *raw* (unaligned) allocation, not the offset pointer.
414        #[cfg(not(target_os = "macos"))]
415        {
416            if let Ok(api) = try_driver() {
417                let rc = unsafe { (api.cu_mem_free_v2)(self.raw_ptr) };
418                if rc != 0 {
419                    tracing::warn!(
420                        cuda_error = rc,
421                        ptr = self.raw_ptr,
422                        aligned_ptr = self.ptr,
423                        len = self.len,
424                        "cuMemFree_v2 failed during AlignedBuffer drop"
425                    );
426                }
427            }
428        }
429    }
430}
431
432// ---------------------------------------------------------------------------
433// Tests
434// ---------------------------------------------------------------------------
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    // -- Alignment enum tests -----------------------------------------------
441
442    #[test]
443    fn alignment_bytes_named_variants() {
444        assert_eq!(Alignment::Default.bytes(), 256);
445        assert_eq!(Alignment::Align256.bytes(), 256);
446        assert_eq!(Alignment::Align512.bytes(), 512);
447        assert_eq!(Alignment::Align1024.bytes(), 1024);
448        assert_eq!(Alignment::Align4096.bytes(), 4096);
449    }
450
451    #[test]
452    fn alignment_bytes_custom() {
453        assert_eq!(Alignment::Custom(64).bytes(), 64);
454        assert_eq!(Alignment::Custom(2048).bytes(), 2048);
455    }
456
457    #[test]
458    fn alignment_is_power_of_two() {
459        assert!(Alignment::Default.is_power_of_two());
460        assert!(Alignment::Align256.is_power_of_two());
461        assert!(Alignment::Align512.is_power_of_two());
462        assert!(Alignment::Align1024.is_power_of_two());
463        assert!(Alignment::Align4096.is_power_of_two());
464        assert!(Alignment::Custom(128).is_power_of_two());
465        assert!(!Alignment::Custom(0).is_power_of_two());
466        assert!(!Alignment::Custom(3).is_power_of_two());
467        assert!(!Alignment::Custom(100).is_power_of_two());
468    }
469
470    #[test]
471    fn alignment_is_aligned() {
472        let a256 = Alignment::Align256;
473        assert!(a256.is_aligned(0));
474        assert!(a256.is_aligned(256));
475        assert!(a256.is_aligned(512));
476        assert!(!a256.is_aligned(1));
477        assert!(!a256.is_aligned(128));
478        assert!(!a256.is_aligned(255));
479
480        let a512 = Alignment::Align512;
481        assert!(a512.is_aligned(0));
482        assert!(a512.is_aligned(512));
483        assert!(!a512.is_aligned(256));
484    }
485
486    // -- round_up_to_alignment tests ----------------------------------------
487
488    #[test]
489    fn round_up_basic() {
490        assert_eq!(round_up_to_alignment(0, 256), 0);
491        assert_eq!(round_up_to_alignment(1, 256), 256);
492        assert_eq!(round_up_to_alignment(100, 256), 256);
493        assert_eq!(round_up_to_alignment(256, 256), 256);
494        assert_eq!(round_up_to_alignment(257, 256), 512);
495        assert_eq!(round_up_to_alignment(511, 512), 512);
496        assert_eq!(round_up_to_alignment(512, 512), 512);
497        assert_eq!(round_up_to_alignment(513, 512), 1024);
498    }
499
500    #[test]
501    fn round_up_zero_alignment() {
502        // Zero alignment should not modify the value.
503        assert_eq!(round_up_to_alignment(42, 0), 42);
504    }
505
506    // -- validate_alignment tests -------------------------------------------
507
508    #[test]
509    fn validate_named_variants_ok() {
510        assert!(validate_alignment(&Alignment::Default).is_ok());
511        assert!(validate_alignment(&Alignment::Align256).is_ok());
512        assert!(validate_alignment(&Alignment::Align512).is_ok());
513        assert!(validate_alignment(&Alignment::Align1024).is_ok());
514        assert!(validate_alignment(&Alignment::Align4096).is_ok());
515    }
516
517    #[test]
518    fn validate_custom_ok() {
519        assert!(validate_alignment(&Alignment::Custom(64)).is_ok());
520        assert!(validate_alignment(&Alignment::Custom(128)).is_ok());
521        assert!(validate_alignment(&Alignment::Custom(8192)).is_ok());
522    }
523
524    #[test]
525    fn validate_custom_bad() {
526        // Zero
527        assert!(validate_alignment(&Alignment::Custom(0)).is_err());
528        // Not power of two
529        assert!(validate_alignment(&Alignment::Custom(3)).is_err());
530        assert!(validate_alignment(&Alignment::Custom(100)).is_err());
531        // Too large (> 256 MiB)
532        assert!(validate_alignment(&Alignment::Custom(512 * 1024 * 1024)).is_err());
533    }
534
535    // -- optimal_alignment_for_type tests -----------------------------------
536
537    #[test]
538    fn optimal_alignment_small_types() {
539        // f32 = 4 bytes → Default
540        assert_eq!(optimal_alignment_for_type::<f32>(), Alignment::Default);
541        // u8 = 1 byte → Default
542        assert_eq!(optimal_alignment_for_type::<u8>(), Alignment::Default);
543    }
544
545    #[test]
546    fn optimal_alignment_medium_types() {
547        // f64 = 8 bytes → Align256
548        assert_eq!(optimal_alignment_for_type::<f64>(), Alignment::Align256);
549        // u64 = 8 bytes → Align256
550        assert_eq!(optimal_alignment_for_type::<u64>(), Alignment::Align256);
551    }
552
553    #[test]
554    fn optimal_alignment_large_types() {
555        // [f32; 4] = 16 bytes → Align512
556        assert_eq!(
557            optimal_alignment_for_type::<[f32; 4]>(),
558            Alignment::Align512
559        );
560        // [f64; 4] = 32 bytes → Align512
561        assert_eq!(
562            optimal_alignment_for_type::<[f64; 4]>(),
563            Alignment::Align512
564        );
565    }
566
567    // -- coalesce_alignment tests -------------------------------------------
568
569    #[test]
570    fn coalesce_basic() {
571        // 32 threads × 4 bytes = 128
572        assert_eq!(coalesce_alignment(4, 32), 128);
573        // 32 threads × 8 bytes = 256
574        assert_eq!(coalesce_alignment(8, 32), 256);
575        // 32 threads × 16 bytes = 512
576        assert_eq!(coalesce_alignment(16, 32), 512);
577        // 32 threads × 32 bytes = 1024
578        assert_eq!(coalesce_alignment(32, 32), 1024);
579    }
580
581    #[test]
582    fn coalesce_caps_at_page() {
583        // 64 threads × 128 bytes = 8192 → capped at 4096
584        assert_eq!(coalesce_alignment(128, 64), 4096);
585    }
586
587    #[test]
588    fn coalesce_zero_inputs() {
589        assert_eq!(coalesce_alignment(0, 32), 1);
590        assert_eq!(coalesce_alignment(4, 0), 1);
591        assert_eq!(coalesce_alignment(0, 0), 1);
592    }
593
594    // -- check_alignment tests ----------------------------------------------
595
596    #[test]
597    fn check_alignment_page_aligned() {
598        let info = check_alignment(4096);
599        assert!(info.is_256_aligned);
600        assert!(info.is_512_aligned);
601        assert!(info.is_page_aligned);
602        assert!(info.natural_alignment >= 4096);
603    }
604
605    #[test]
606    fn check_alignment_512_not_page() {
607        let info = check_alignment(512);
608        assert!(info.is_256_aligned);
609        assert!(info.is_512_aligned);
610        assert!(!info.is_page_aligned);
611        assert_eq!(info.natural_alignment, 512);
612    }
613
614    #[test]
615    fn check_alignment_odd_ptr() {
616        let info = check_alignment(0x0001_0001);
617        assert!(!info.is_256_aligned);
618        assert!(!info.is_512_aligned);
619        assert!(!info.is_page_aligned);
620        assert_eq!(info.natural_alignment, 1);
621    }
622
623    #[test]
624    fn check_alignment_null() {
625        let info = check_alignment(0);
626        assert_eq!(info.natural_alignment, usize::MAX);
627        assert!(info.is_256_aligned);
628        assert!(info.is_512_aligned);
629        assert!(info.is_page_aligned);
630    }
631
632    // -- AlignedBuffer tests (macOS synthetic) ------------------------------
633
634    #[cfg(target_os = "macos")]
635    mod buffer_tests {
636        use super::super::*;
637
638        #[test]
639        fn alloc_default_alignment() {
640            let buf = AlignedBuffer::<f32>::alloc(128, Alignment::Default);
641            assert!(buf.is_ok());
642            let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
643            assert_eq!(buf.len(), 128);
644            assert!(!buf.is_empty());
645            assert!(buf.is_aligned());
646        }
647
648        #[test]
649        fn alloc_512_alignment() {
650            let buf = AlignedBuffer::<f32>::alloc(256, Alignment::Align512);
651            assert!(buf.is_ok());
652            let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
653            assert!(buf.is_aligned());
654            assert_eq!(buf.as_device_ptr() % 512, 0);
655        }
656
657        #[test]
658        fn alloc_4096_alignment() {
659            let buf = AlignedBuffer::<f64>::alloc(64, Alignment::Align4096);
660            assert!(buf.is_ok());
661            let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
662            assert!(buf.is_aligned());
663            assert_eq!(buf.as_device_ptr() % 4096, 0);
664        }
665
666        #[test]
667        fn alloc_zero_elements_fails() {
668            let result = AlignedBuffer::<f32>::alloc(0, Alignment::Default);
669            assert!(result.is_err());
670        }
671
672        #[test]
673        fn alloc_invalid_alignment_fails() {
674            let result = AlignedBuffer::<f32>::alloc(64, Alignment::Custom(3));
675            assert!(result.is_err());
676        }
677
678        #[test]
679        fn wasted_bytes_at_least_zero() {
680            let buf = AlignedBuffer::<f32>::alloc(128, Alignment::Align512)
681                .unwrap_or_else(|_| panic!("alloc failed"));
682            // Wasted bytes = allocated_bytes - (128 * 4)
683            // allocated_bytes = 128*4 + (512 - 1) = 1023
684            // wasted = 1023 - 512 = 511
685            assert!(buf.wasted_bytes() <= buf.alignment().bytes());
686        }
687
688        #[test]
689        fn alignment_accessor() {
690            let buf = AlignedBuffer::<u8>::alloc(64, Alignment::Align1024)
691                .unwrap_or_else(|_| panic!("alloc failed"));
692            assert_eq!(*buf.alignment(), Alignment::Align1024);
693        }
694    }
695
696    // -- Display -----------------------------------------------------------
697
698    #[test]
699    fn alignment_display() {
700        assert_eq!(format!("{}", Alignment::Default), "Default(256)");
701        assert_eq!(format!("{}", Alignment::Align256), "256");
702        assert_eq!(format!("{}", Alignment::Align512), "512");
703        assert_eq!(format!("{}", Alignment::Custom(128)), "Custom(128)");
704    }
705}