use std::marker::PhantomData;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::ffi::CUdeviceptr;
#[cfg(not(target_os = "macos"))]
use oxicuda_driver::loader::try_driver;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Alignment {
Default,
Align256,
Align512,
Align1024,
Align4096,
Custom(usize),
}
impl Alignment {
#[inline]
pub fn bytes(&self) -> usize {
match self {
Self::Default => 256,
Self::Align256 => 256,
Self::Align512 => 512,
Self::Align1024 => 1024,
Self::Align4096 => 4096,
Self::Custom(n) => *n,
}
}
#[inline]
pub fn is_power_of_two(&self) -> bool {
let b = self.bytes();
b > 0 && (b & (b - 1)) == 0
}
#[inline]
pub fn is_aligned(&self, ptr: u64) -> bool {
let b = self.bytes() as u64;
if b == 0 {
return false;
}
(ptr % b) == 0
}
}
impl std::fmt::Display for Alignment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Default => write!(f, "Default(256)"),
Self::Align256 => write!(f, "256"),
Self::Align512 => write!(f, "512"),
Self::Align1024 => write!(f, "1024"),
Self::Align4096 => write!(f, "4096"),
Self::Custom(n) => write!(f, "Custom({n})"),
}
}
}
const MAX_ALIGNMENT: usize = 256 * 1024 * 1024;
pub fn validate_alignment(alignment: &Alignment) -> CudaResult<()> {
let b = alignment.bytes();
if b == 0 {
return Err(CudaError::InvalidValue);
}
if !alignment.is_power_of_two() {
return Err(CudaError::InvalidValue);
}
if b > MAX_ALIGNMENT {
return Err(CudaError::InvalidValue);
}
Ok(())
}
#[inline]
pub fn round_up_to_alignment(bytes: usize, alignment: usize) -> usize {
if alignment == 0 {
return bytes;
}
let mask = alignment - 1;
(bytes + mask) & !mask
}
pub fn optimal_alignment_for_type<T>() -> Alignment {
let size = std::mem::size_of::<T>();
if size >= 16 {
Alignment::Align512
} else if size >= 8 {
Alignment::Align256
} else {
Alignment::Default
}
}
pub fn coalesce_alignment(access_width: usize, warp_size: u32) -> usize {
let total = (warp_size as usize).saturating_mul(access_width);
if total == 0 {
return 1;
}
let pot = total.next_power_of_two();
pot.min(4096)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AlignmentInfo {
pub ptr: CUdeviceptr,
pub natural_alignment: usize,
pub is_256_aligned: bool,
pub is_512_aligned: bool,
pub is_page_aligned: bool,
}
pub fn check_alignment(ptr: CUdeviceptr) -> AlignmentInfo {
let natural = if ptr == 0 {
usize::MAX
} else {
1_usize << (ptr.trailing_zeros().min(63))
};
AlignmentInfo {
ptr,
natural_alignment: natural,
is_256_aligned: (ptr % 256) == 0,
is_512_aligned: (ptr % 512) == 0,
is_page_aligned: (ptr % 4096) == 0,
}
}
pub struct AlignedBuffer<T: Copy> {
ptr: CUdeviceptr,
len: usize,
allocated_bytes: usize,
alignment: Alignment,
offset: usize,
#[cfg_attr(target_os = "macos", allow(dead_code))]
raw_ptr: CUdeviceptr,
_phantom: PhantomData<T>,
}
unsafe impl<T: Copy + Send> Send for AlignedBuffer<T> {}
unsafe impl<T: Copy + Sync> Sync for AlignedBuffer<T> {}
impl<T: Copy> AlignedBuffer<T> {
pub fn alloc(n: usize, alignment: Alignment) -> CudaResult<Self> {
if n == 0 {
return Err(CudaError::InvalidValue);
}
validate_alignment(&alignment)?;
let elem_bytes = n
.checked_mul(std::mem::size_of::<T>())
.ok_or(CudaError::InvalidValue)?;
let align_bytes = alignment.bytes();
let extra = align_bytes.saturating_sub(1);
let total_bytes = elem_bytes
.checked_add(extra)
.ok_or(CudaError::InvalidValue)?;
#[cfg(target_os = "macos")]
let (raw_ptr, aligned_ptr, offset) = {
let base: CUdeviceptr = 0x0000_0001_0000_0100; let aligned = round_up_to_alignment(base as usize, align_bytes) as CUdeviceptr;
let off = (aligned - base) as usize;
(base, aligned, off)
};
#[cfg(not(target_os = "macos"))]
let (raw_ptr, aligned_ptr, offset) = {
let api = try_driver()?;
let mut base: CUdeviceptr = 0;
let rc = unsafe { (api.cu_mem_alloc_v2)(&mut base, total_bytes) };
oxicuda_driver::check(rc)?;
let aligned = round_up_to_alignment(base as usize, align_bytes) as CUdeviceptr;
let off = (aligned - base) as usize;
(base, aligned, off)
};
Ok(Self {
ptr: aligned_ptr,
len: n,
allocated_bytes: total_bytes,
alignment,
offset,
raw_ptr,
_phantom: PhantomData,
})
}
#[inline]
pub fn as_device_ptr(&self) -> CUdeviceptr {
self.ptr
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn alignment(&self) -> &Alignment {
&self.alignment
}
#[inline]
pub fn wasted_bytes(&self) -> usize {
let needed = self.len * std::mem::size_of::<T>();
self.allocated_bytes.saturating_sub(needed)
}
#[inline]
pub fn is_aligned(&self) -> bool {
self.alignment.is_aligned(self.ptr)
}
#[inline]
pub fn allocated_bytes(&self) -> usize {
self.allocated_bytes
}
#[inline]
pub fn offset(&self) -> usize {
self.offset
}
}
impl<T: Copy> Drop for AlignedBuffer<T> {
fn drop(&mut self) {
#[cfg(not(target_os = "macos"))]
{
if let Ok(api) = try_driver() {
let rc = unsafe { (api.cu_mem_free_v2)(self.raw_ptr) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
ptr = self.raw_ptr,
aligned_ptr = self.ptr,
len = self.len,
"cuMemFree_v2 failed during AlignedBuffer drop"
);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn alignment_bytes_named_variants() {
assert_eq!(Alignment::Default.bytes(), 256);
assert_eq!(Alignment::Align256.bytes(), 256);
assert_eq!(Alignment::Align512.bytes(), 512);
assert_eq!(Alignment::Align1024.bytes(), 1024);
assert_eq!(Alignment::Align4096.bytes(), 4096);
}
#[test]
fn alignment_bytes_custom() {
assert_eq!(Alignment::Custom(64).bytes(), 64);
assert_eq!(Alignment::Custom(2048).bytes(), 2048);
}
#[test]
fn alignment_is_power_of_two() {
assert!(Alignment::Default.is_power_of_two());
assert!(Alignment::Align256.is_power_of_two());
assert!(Alignment::Align512.is_power_of_two());
assert!(Alignment::Align1024.is_power_of_two());
assert!(Alignment::Align4096.is_power_of_two());
assert!(Alignment::Custom(128).is_power_of_two());
assert!(!Alignment::Custom(0).is_power_of_two());
assert!(!Alignment::Custom(3).is_power_of_two());
assert!(!Alignment::Custom(100).is_power_of_two());
}
#[test]
fn alignment_is_aligned() {
let a256 = Alignment::Align256;
assert!(a256.is_aligned(0));
assert!(a256.is_aligned(256));
assert!(a256.is_aligned(512));
assert!(!a256.is_aligned(1));
assert!(!a256.is_aligned(128));
assert!(!a256.is_aligned(255));
let a512 = Alignment::Align512;
assert!(a512.is_aligned(0));
assert!(a512.is_aligned(512));
assert!(!a512.is_aligned(256));
}
#[test]
fn round_up_basic() {
assert_eq!(round_up_to_alignment(0, 256), 0);
assert_eq!(round_up_to_alignment(1, 256), 256);
assert_eq!(round_up_to_alignment(100, 256), 256);
assert_eq!(round_up_to_alignment(256, 256), 256);
assert_eq!(round_up_to_alignment(257, 256), 512);
assert_eq!(round_up_to_alignment(511, 512), 512);
assert_eq!(round_up_to_alignment(512, 512), 512);
assert_eq!(round_up_to_alignment(513, 512), 1024);
}
#[test]
fn round_up_zero_alignment() {
assert_eq!(round_up_to_alignment(42, 0), 42);
}
#[test]
fn validate_named_variants_ok() {
assert!(validate_alignment(&Alignment::Default).is_ok());
assert!(validate_alignment(&Alignment::Align256).is_ok());
assert!(validate_alignment(&Alignment::Align512).is_ok());
assert!(validate_alignment(&Alignment::Align1024).is_ok());
assert!(validate_alignment(&Alignment::Align4096).is_ok());
}
#[test]
fn validate_custom_ok() {
assert!(validate_alignment(&Alignment::Custom(64)).is_ok());
assert!(validate_alignment(&Alignment::Custom(128)).is_ok());
assert!(validate_alignment(&Alignment::Custom(8192)).is_ok());
}
#[test]
fn validate_custom_bad() {
assert!(validate_alignment(&Alignment::Custom(0)).is_err());
assert!(validate_alignment(&Alignment::Custom(3)).is_err());
assert!(validate_alignment(&Alignment::Custom(100)).is_err());
assert!(validate_alignment(&Alignment::Custom(512 * 1024 * 1024)).is_err());
}
#[test]
fn optimal_alignment_small_types() {
assert_eq!(optimal_alignment_for_type::<f32>(), Alignment::Default);
assert_eq!(optimal_alignment_for_type::<u8>(), Alignment::Default);
}
#[test]
fn optimal_alignment_medium_types() {
assert_eq!(optimal_alignment_for_type::<f64>(), Alignment::Align256);
assert_eq!(optimal_alignment_for_type::<u64>(), Alignment::Align256);
}
#[test]
fn optimal_alignment_large_types() {
assert_eq!(
optimal_alignment_for_type::<[f32; 4]>(),
Alignment::Align512
);
assert_eq!(
optimal_alignment_for_type::<[f64; 4]>(),
Alignment::Align512
);
}
#[test]
fn coalesce_basic() {
assert_eq!(coalesce_alignment(4, 32), 128);
assert_eq!(coalesce_alignment(8, 32), 256);
assert_eq!(coalesce_alignment(16, 32), 512);
assert_eq!(coalesce_alignment(32, 32), 1024);
}
#[test]
fn coalesce_caps_at_page() {
assert_eq!(coalesce_alignment(128, 64), 4096);
}
#[test]
fn coalesce_zero_inputs() {
assert_eq!(coalesce_alignment(0, 32), 1);
assert_eq!(coalesce_alignment(4, 0), 1);
assert_eq!(coalesce_alignment(0, 0), 1);
}
#[test]
fn check_alignment_page_aligned() {
let info = check_alignment(4096);
assert!(info.is_256_aligned);
assert!(info.is_512_aligned);
assert!(info.is_page_aligned);
assert!(info.natural_alignment >= 4096);
}
#[test]
fn check_alignment_512_not_page() {
let info = check_alignment(512);
assert!(info.is_256_aligned);
assert!(info.is_512_aligned);
assert!(!info.is_page_aligned);
assert_eq!(info.natural_alignment, 512);
}
#[test]
fn check_alignment_odd_ptr() {
let info = check_alignment(0x0001_0001);
assert!(!info.is_256_aligned);
assert!(!info.is_512_aligned);
assert!(!info.is_page_aligned);
assert_eq!(info.natural_alignment, 1);
}
#[test]
fn check_alignment_null() {
let info = check_alignment(0);
assert_eq!(info.natural_alignment, usize::MAX);
assert!(info.is_256_aligned);
assert!(info.is_512_aligned);
assert!(info.is_page_aligned);
}
#[cfg(target_os = "macos")]
mod buffer_tests {
use super::super::*;
#[test]
fn alloc_default_alignment() {
let buf = AlignedBuffer::<f32>::alloc(128, Alignment::Default);
assert!(buf.is_ok());
let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
assert_eq!(buf.len(), 128);
assert!(!buf.is_empty());
assert!(buf.is_aligned());
}
#[test]
fn alloc_512_alignment() {
let buf = AlignedBuffer::<f32>::alloc(256, Alignment::Align512);
assert!(buf.is_ok());
let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
assert!(buf.is_aligned());
assert_eq!(buf.as_device_ptr() % 512, 0);
}
#[test]
fn alloc_4096_alignment() {
let buf = AlignedBuffer::<f64>::alloc(64, Alignment::Align4096);
assert!(buf.is_ok());
let buf = buf.unwrap_or_else(|_| panic!("alloc failed"));
assert!(buf.is_aligned());
assert_eq!(buf.as_device_ptr() % 4096, 0);
}
#[test]
fn alloc_zero_elements_fails() {
let result = AlignedBuffer::<f32>::alloc(0, Alignment::Default);
assert!(result.is_err());
}
#[test]
fn alloc_invalid_alignment_fails() {
let result = AlignedBuffer::<f32>::alloc(64, Alignment::Custom(3));
assert!(result.is_err());
}
#[test]
fn wasted_bytes_at_least_zero() {
let buf = AlignedBuffer::<f32>::alloc(128, Alignment::Align512)
.unwrap_or_else(|_| panic!("alloc failed"));
assert!(buf.wasted_bytes() <= buf.alignment().bytes());
}
#[test]
fn alignment_accessor() {
let buf = AlignedBuffer::<u8>::alloc(64, Alignment::Align1024)
.unwrap_or_else(|_| panic!("alloc failed"));
assert_eq!(*buf.alignment(), Alignment::Align1024);
}
}
#[test]
fn alignment_display() {
assert_eq!(format!("{}", Alignment::Default), "Default(256)");
assert_eq!(format!("{}", Alignment::Align256), "256");
assert_eq!(format!("{}", Alignment::Align512), "512");
assert_eq!(format!("{}", Alignment::Custom(128)), "Custom(128)");
}
}