1use std::marker::PhantomData;
31
32use oxicuda_driver::error::{CudaError, CudaResult};
33use oxicuda_driver::ffi::CUdeviceptr;
34#[cfg(not(target_os = "macos"))]
35use oxicuda_driver::loader::try_driver;
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
46pub enum Alignment {
47 Default,
49 Align256,
51 Align512,
53 Align1024,
55 Align4096,
57 Custom(usize),
59}
60
61impl Alignment {
62 #[inline]
67 pub fn bytes(&self) -> usize {
68 match self {
69 Self::Default => 256,
70 Self::Align256 => 256,
71 Self::Align512 => 512,
72 Self::Align1024 => 1024,
73 Self::Align4096 => 4096,
74 Self::Custom(n) => *n,
75 }
76 }
77
78 #[inline]
83 pub fn is_power_of_two(&self) -> bool {
84 let b = self.bytes();
85 b > 0 && (b & (b - 1)) == 0
86 }
87
88 #[inline]
90 pub fn is_aligned(&self, ptr: u64) -> bool {
91 let b = self.bytes() as u64;
92 if b == 0 {
93 return false;
94 }
95 (ptr % b) == 0
96 }
97}
98
99impl std::fmt::Display for Alignment {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 match self {
102 Self::Default => write!(f, "Default(256)"),
103 Self::Align256 => write!(f, "256"),
104 Self::Align512 => write!(f, "512"),
105 Self::Align1024 => write!(f, "1024"),
106 Self::Align4096 => write!(f, "4096"),
107 Self::Custom(n) => write!(f, "Custom({n})"),
108 }
109 }
110}
111
112const MAX_ALIGNMENT: usize = 256 * 1024 * 1024;
119
120pub fn validate_alignment(alignment: &Alignment) -> CudaResult<()> {
130 let b = alignment.bytes();
131 if b == 0 {
132 return Err(CudaError::InvalidValue);
133 }
134 if !alignment.is_power_of_two() {
135 return Err(CudaError::InvalidValue);
136 }
137 if b > MAX_ALIGNMENT {
138 return Err(CudaError::InvalidValue);
139 }
140 Ok(())
141}
142
143#[inline]
157pub fn round_up_to_alignment(bytes: usize, alignment: usize) -> usize {
158 if alignment == 0 {
159 return bytes;
160 }
161 let mask = alignment - 1;
162 (bytes + mask) & !mask
163}
164
165pub fn optimal_alignment_for_type<T>() -> Alignment {
174 let size = std::mem::size_of::<T>();
175 if size >= 16 {
176 Alignment::Align512
177 } else if size >= 8 {
178 Alignment::Align256
179 } else {
180 Alignment::Default
181 }
182}
183
184pub fn coalesce_alignment(access_width: usize, warp_size: u32) -> usize {
202 let total = (warp_size as usize).saturating_mul(access_width);
203 if total == 0 {
204 return 1;
205 }
206 let pot = total.next_power_of_two();
208 pot.min(4096)
210}
211
212#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub struct AlignmentInfo {
219 pub ptr: CUdeviceptr,
221 pub natural_alignment: usize,
223 pub is_256_aligned: bool,
225 pub is_512_aligned: bool,
227 pub is_page_aligned: bool,
229}
230
231pub fn check_alignment(ptr: CUdeviceptr) -> AlignmentInfo {
236 let natural = if ptr == 0 {
237 usize::MAX
239 } else {
240 1_usize << (ptr.trailing_zeros().min(63))
242 };
243 AlignmentInfo {
244 ptr,
245 natural_alignment: natural,
246 is_256_aligned: (ptr % 256) == 0,
247 is_512_aligned: (ptr % 512) == 0,
248 is_page_aligned: (ptr % 4096) == 0,
249 }
250}
251
252pub struct AlignedBuffer<T: Copy> {
265 ptr: CUdeviceptr,
267 len: usize,
269 allocated_bytes: usize,
271 alignment: Alignment,
273 offset: usize,
275 #[cfg_attr(target_os = "macos", allow(dead_code))]
277 raw_ptr: CUdeviceptr,
278 _phantom: PhantomData<T>,
280}
281
282unsafe impl<T: Copy + Send> Send for AlignedBuffer<T> {}
285unsafe impl<T: Copy + Sync> Sync for AlignedBuffer<T> {}
286
287impl<T: Copy> AlignedBuffer<T> {
288 pub fn alloc(n: usize, alignment: Alignment) -> CudaResult<Self> {
301 if n == 0 {
302 return Err(CudaError::InvalidValue);
303 }
304 validate_alignment(&alignment)?;
305
306 let elem_bytes = n
307 .checked_mul(std::mem::size_of::<T>())
308 .ok_or(CudaError::InvalidValue)?;
309
310 let align_bytes = alignment.bytes();
311
312 let extra = align_bytes.saturating_sub(1);
315 let total_bytes = elem_bytes
316 .checked_add(extra)
317 .ok_or(CudaError::InvalidValue)?;
318
319 #[cfg(target_os = "macos")]
320 let (raw_ptr, aligned_ptr, offset) = {
321 let base: CUdeviceptr = 0x0000_0001_0000_0100; let aligned = round_up_to_alignment(base as usize, align_bytes) as CUdeviceptr;
326 let off = (aligned - base) as usize;
327 (base, aligned, off)
328 };
329
330 #[cfg(not(target_os = "macos"))]
331 let (raw_ptr, aligned_ptr, offset) = {
332 let api = try_driver()?;
333 let mut base: CUdeviceptr = 0;
334 let rc = unsafe { (api.cu_mem_alloc_v2)(&mut base, total_bytes) };
335 oxicuda_driver::check(rc)?;
336 let aligned = round_up_to_alignment(base as usize, align_bytes) as CUdeviceptr;
337 let off = (aligned - base) as usize;
338 (base, aligned, off)
339 };
340
341 Ok(Self {
342 ptr: aligned_ptr,
343 len: n,
344 allocated_bytes: total_bytes,
345 alignment,
346 offset,
347 raw_ptr,
348 _phantom: PhantomData,
349 })
350 }
351
352 #[inline]
354 pub fn as_device_ptr(&self) -> CUdeviceptr {
355 self.ptr
356 }
357
358 #[inline]
360 pub fn len(&self) -> usize {
361 self.len
362 }
363
364 #[inline]
369 pub fn is_empty(&self) -> bool {
370 self.len == 0
371 }
372
373 #[inline]
375 pub fn alignment(&self) -> &Alignment {
376 &self.alignment
377 }
378
379 #[inline]
384 pub fn wasted_bytes(&self) -> usize {
385 let needed = self.len * std::mem::size_of::<T>();
386 self.allocated_bytes.saturating_sub(needed)
387 }
388
389 #[inline]
392 pub fn is_aligned(&self) -> bool {
393 self.alignment.is_aligned(self.ptr)
394 }
395
396 #[inline]
399 pub fn allocated_bytes(&self) -> usize {
400 self.allocated_bytes
401 }
402
403 #[inline]
406 pub fn offset(&self) -> usize {
407 self.offset
408 }
409}
410
411impl<T: Copy> Drop for AlignedBuffer<T> {
412 fn drop(&mut self) {
413 #[cfg(not(target_os = "macos"))]
415 {
416 if let Ok(api) = try_driver() {
417 let rc = unsafe { (api.cu_mem_free_v2)(self.raw_ptr) };
418 if rc != 0 {
419 tracing::warn!(
420 cuda_error = rc,
421 ptr = self.raw_ptr,
422 aligned_ptr = self.ptr,
423 len = self.len,
424 "cuMemFree_v2 failed during AlignedBuffer drop"
425 );
426 }
427 }
428 }
429 }
430}
431
432#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
443 fn alignment_bytes_named_variants() {
444 assert_eq!(Alignment::Default.bytes(), 256);
445 assert_eq!(Alignment::Align256.bytes(), 256);
446 assert_eq!(Alignment::Align512.bytes(), 512);
447 assert_eq!(Alignment::Align1024.bytes(), 1024);
448 assert_eq!(Alignment::Align4096.bytes(), 4096);
449 }
450
451 #[test]
452 fn alignment_bytes_custom() {
453 assert_eq!(Alignment::Custom(64).bytes(), 64);
454 assert_eq!(Alignment::Custom(2048).bytes(), 2048);
455 }
456
457 #[test]
458 fn alignment_is_power_of_two() {
459 assert!(Alignment::Default.is_power_of_two());
460 assert!(Alignment::Align256.is_power_of_two());
461 assert!(Alignment::Align512.is_power_of_two());
462 assert!(Alignment::Align1024.is_power_of_two());
463 assert!(Alignment::Align4096.is_power_of_two());
464 assert!(Alignment::Custom(128).is_power_of_two());
465 assert!(!Alignment::Custom(0).is_power_of_two());
466 assert!(!Alignment::Custom(3).is_power_of_two());
467 assert!(!Alignment::Custom(100).is_power_of_two());
468 }
469
470 #[test]
471 fn alignment_is_aligned() {
472 let a256 = Alignment::Align256;
473 assert!(a256.is_aligned(0));
474 assert!(a256.is_aligned(256));
475 assert!(a256.is_aligned(512));
476 assert!(!a256.is_aligned(1));
477 assert!(!a256.is_aligned(128));
478 assert!(!a256.is_aligned(255));
479
480 let a512 = Alignment::Align512;
481 assert!(a512.is_aligned(0));
482 assert!(a512.is_aligned(512));
483 assert!(!a512.is_aligned(256));
484 }
485
486 #[test]
489 fn round_up_basic() {
490 assert_eq!(round_up_to_alignment(0, 256), 0);
491 assert_eq!(round_up_to_alignment(1, 256), 256);
492 assert_eq!(round_up_to_alignment(100, 256), 256);
493 assert_eq!(round_up_to_alignment(256, 256), 256);
494 assert_eq!(round_up_to_alignment(257, 256), 512);
495 assert_eq!(round_up_to_alignment(511, 512), 512);
496 assert_eq!(round_up_to_alignment(512, 512), 512);
497 assert_eq!(round_up_to_alignment(513, 512), 1024);
498 }
499
500 #[test]
501 fn round_up_zero_alignment() {
502 assert_eq!(round_up_to_alignment(42, 0), 42);
504 }
505
506 #[test]
509 fn validate_named_variants_ok() {
510 assert!(validate_alignment(&Alignment::Default).is_ok());
511 assert!(validate_alignment(&Alignment::Align256).is_ok());
512 assert!(validate_alignment(&Alignment::Align512).is_ok());
513 assert!(validate_alignment(&Alignment::Align1024).is_ok());
514 assert!(validate_alignment(&Alignment::Align4096).is_ok());
515 }
516
517 #[test]
518 fn validate_custom_ok() {
519 assert!(validate_alignment(&Alignment::Custom(64)).is_ok());
520 assert!(validate_alignment(&Alignment::Custom(128)).is_ok());
521 assert!(validate_alignment(&Alignment::Custom(8192)).is_ok());
522 }
523
524 #[test]
525 fn validate_custom_bad() {
526 assert!(validate_alignment(&Alignment::Custom(0)).is_err());
528 assert!(validate_alignment(&Alignment::Custom(3)).is_err());
530 assert!(validate_alignment(&Alignment::Custom(100)).is_err());
531 assert!(validate_alignment(&Alignment::Custom(512 * 1024 * 1024)).is_err());
533 }
534
535 #[test]
538 fn optimal_alignment_small_types() {
539 assert_eq!(optimal_alignment_for_type::<f32>(), Alignment::Default);
541 assert_eq!(optimal_alignment_for_type::<u8>(), Alignment::Default);
543 }
544
545 #[test]
546 fn optimal_alignment_medium_types() {
547 assert_eq!(optimal_alignment_for_type::<f64>(), Alignment::Align256);
549 assert_eq!(optimal_alignment_for_type::<u64>(), Alignment::Align256);
551 }
552
553 #[test]
554 fn optimal_alignment_large_types() {
555 assert_eq!(
557 optimal_alignment_for_type::<[f32; 4]>(),
558 Alignment::Align512
559 );
560 assert_eq!(
562 optimal_alignment_for_type::<[f64; 4]>(),
563 Alignment::Align512
564 );
565 }
566
567 #[test]
570 fn coalesce_basic() {
571 assert_eq!(coalesce_alignment(4, 32), 128);
573 assert_eq!(coalesce_alignment(8, 32), 256);
575 assert_eq!(coalesce_alignment(16, 32), 512);
577 assert_eq!(coalesce_alignment(32, 32), 1024);
579 }
580
581 #[test]
582 fn coalesce_caps_at_page() {
583 assert_eq!(coalesce_alignment(128, 64), 4096);
585 }
586
587 #[test]
588 fn coalesce_zero_inputs() {
589 assert_eq!(coalesce_alignment(0, 32), 1);
590 assert_eq!(coalesce_alignment(4, 0), 1);
591 assert_eq!(coalesce_alignment(0, 0), 1);
592 }
593
594 #[test]
597 fn check_alignment_page_aligned() {
598 let info = check_alignment(4096);
599 assert!(info.is_256_aligned);
600 assert!(info.is_512_aligned);
601 assert!(info.is_page_aligned);
602 assert!(info.natural_alignment >= 4096);
603 }
604
605 #[test]
606 fn check_alignment_512_not_page() {
607 let info = check_alignment(512);
608 assert!(info.is_256_aligned);
609 assert!(info.is_512_aligned);
610 assert!(!info.is_page_aligned);
611 assert_eq!(info.natural_alignment, 512);
612 }
613
614 #[test]
615 fn check_alignment_odd_ptr() {
616 let info = check_alignment(0x0001_0001);
617 assert!(!info.is_256_aligned);
618 assert!(!info.is_512_aligned);
619 assert!(!info.is_page_aligned);
620 assert_eq!(info.natural_alignment, 1);
621 }
622
623 #[test]
624 fn check_alignment_null() {
625 let info = check_alignment(0);
626 assert_eq!(info.natural_alignment, usize::MAX);
627 assert!(info.is_256_aligned);
628 assert!(info.is_512_aligned);
629 assert!(info.is_page_aligned);
630 }
631
632 #[cfg(target_os = "macos")]
635 mod buffer_tests {
636 use super::super::*;
637
638 #[test]
639 fn alloc_default_alignment() {
640 let buf = AlignedBuffer::<f32>::alloc(128, Alignment::Default);
641 assert!(buf.is_ok());
642 let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
643 assert_eq!(buf.len(), 128);
644 assert!(!buf.is_empty());
645 assert!(buf.is_aligned());
646 }
647
648 #[test]
649 fn alloc_512_alignment() {
650 let buf = AlignedBuffer::<f32>::alloc(256, Alignment::Align512);
651 assert!(buf.is_ok());
652 let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
653 assert!(buf.is_aligned());
654 assert_eq!(buf.as_device_ptr() % 512, 0);
655 }
656
657 #[test]
658 fn alloc_4096_alignment() {
659 let buf = AlignedBuffer::<f64>::alloc(64, Alignment::Align4096);
660 assert!(buf.is_ok());
661 let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
662 assert!(buf.is_aligned());
663 assert_eq!(buf.as_device_ptr() % 4096, 0);
664 }
665
666 #[test]
667 fn alloc_zero_elements_fails() {
668 let result = AlignedBuffer::<f32>::alloc(0, Alignment::Default);
669 assert!(result.is_err());
670 }
671
672 #[test]
673 fn alloc_invalid_alignment_fails() {
674 let result = AlignedBuffer::<f32>::alloc(64, Alignment::Custom(3));
675 assert!(result.is_err());
676 }
677
678 #[test]
679 fn wasted_bytes_at_least_zero() {
680 let buf = AlignedBuffer::<f32>::alloc(128, Alignment::Align512)
681 .unwrap_or_else(|_| panic!("alloc failed"));
682 assert!(buf.wasted_bytes() <= buf.alignment().bytes());
686 }
687
688 #[test]
689 fn alignment_accessor() {
690 let buf = AlignedBuffer::<u8>::alloc(64, Alignment::Align1024)
691 .unwrap_or_else(|_| panic!("alloc failed"));
692 assert_eq!(*buf.alignment(), Alignment::Align1024);
693 }
694 }
695
696 #[test]
699 fn alignment_display() {
700 assert_eq!(format!("{}", Alignment::Default), "Default(256)");
701 assert_eq!(format!("{}", Alignment::Align256), "256");
702 assert_eq!(format!("{}", Alignment::Align512), "512");
703 assert_eq!(format!("{}", Alignment::Custom(128)), "Custom(128)");
704 }
705}