use oxicuda_driver::device::Device;
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::stream::Stream;
use crate::memory_info::{MemAdvice, mem_advise, mem_prefetch};
use crate::unified::UnifiedBuffer;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum MigrationPolicy {
Default,
ReadMostly,
PreferDevice(i32),
PreferHost,
}
impl MigrationPolicy {
pub fn to_advice_pairs(&self) -> Vec<MemAdvice> {
match self {
Self::Default => Vec::new(),
Self::ReadMostly => vec![MemAdvice::SetReadMostly],
Self::PreferDevice(_) => vec![MemAdvice::SetPreferredLocation],
Self::PreferHost => vec![MemAdvice::SetPreferredLocation],
}
}
#[inline]
pub fn is_default(&self) -> bool {
matches!(self, Self::Default)
}
}
impl std::fmt::Display for MigrationPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Default => write!(f, "MigrationPolicy::Default"),
Self::ReadMostly => write!(f, "MigrationPolicy::ReadMostly"),
Self::PreferDevice(ord) => write!(f, "MigrationPolicy::PreferDevice({ord})"),
Self::PreferHost => write!(f, "MigrationPolicy::PreferHost"),
}
}
}
#[derive(Debug, Clone)]
pub struct ManagedMemoryHints {
ptr: u64,
byte_size: usize,
}
impl ManagedMemoryHints {
pub fn for_buffer(ptr: u64, byte_size: usize) -> CudaResult<Self> {
if byte_size == 0 {
return Err(CudaError::InvalidValue);
}
Ok(Self { ptr, byte_size })
}
pub fn from_unified<T: Copy>(buf: &UnifiedBuffer<T>) -> CudaResult<Self> {
Self::for_buffer(buf.as_device_ptr(), buf.byte_size())
}
#[inline]
pub fn ptr(&self) -> u64 {
self.ptr
}
#[inline]
pub fn byte_size(&self) -> usize {
self.byte_size
}
pub fn set_read_mostly(&self, device: &Device) -> CudaResult<()> {
mem_advise(self.ptr, self.byte_size, MemAdvice::SetReadMostly, device)
}
pub fn unset_read_mostly(&self, device: &Device) -> CudaResult<()> {
mem_advise(self.ptr, self.byte_size, MemAdvice::UnsetReadMostly, device)
}
pub fn set_preferred_location(&self, device: &Device) -> CudaResult<()> {
mem_advise(
self.ptr,
self.byte_size,
MemAdvice::SetPreferredLocation,
device,
)
}
pub fn unset_preferred_location(&self, device: &Device) -> CudaResult<()> {
mem_advise(
self.ptr,
self.byte_size,
MemAdvice::UnsetPreferredLocation,
device,
)
}
pub fn set_accessed_by(&self, device: &Device) -> CudaResult<()> {
mem_advise(self.ptr, self.byte_size, MemAdvice::SetAccessedBy, device)
}
pub fn unset_accessed_by(&self, device: &Device) -> CudaResult<()> {
mem_advise(self.ptr, self.byte_size, MemAdvice::UnsetAccessedBy, device)
}
pub fn prefetch_to(&self, device: &Device, stream: &Stream) -> CudaResult<()> {
mem_prefetch(self.ptr, self.byte_size, device, stream)
}
pub fn prefetch_range(
&self,
offset_bytes: usize,
count_bytes: usize,
device: &Device,
stream: &Stream,
) -> CudaResult<()> {
if count_bytes == 0 {
return Err(CudaError::InvalidValue);
}
let end = offset_bytes
.checked_add(count_bytes)
.ok_or(CudaError::InvalidValue)?;
if end > self.byte_size {
return Err(CudaError::InvalidValue);
}
let range_ptr = self
.ptr
.checked_add(offset_bytes as u64)
.ok_or(CudaError::InvalidValue)?;
mem_prefetch(range_ptr, count_bytes, device, stream)
}
pub fn apply_policy(&self, policy: &MigrationPolicy, device: &Device) -> CudaResult<()> {
apply_migration_policy(self.ptr, self.byte_size, policy, device)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PrefetchEntry {
pub ptr: u64,
pub byte_size: usize,
pub device_ordinal: i32,
}
#[derive(Debug, Clone)]
pub struct PrefetchPlan {
entries: Vec<PrefetchEntry>,
}
impl PrefetchPlan {
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
pub fn add(&mut self, ptr: u64, byte_size: usize, device_ordinal: i32) -> &mut Self {
self.entries.push(PrefetchEntry {
ptr,
byte_size,
device_ordinal,
});
self
}
#[inline]
pub fn len(&self) -> usize {
self.entries.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[inline]
pub fn entries(&self) -> &[PrefetchEntry] {
&self.entries
}
pub fn execute(&self, stream: &Stream) -> CudaResult<()> {
for entry in &self.entries {
let device = Device::get(entry.device_ordinal)?;
mem_prefetch(entry.ptr, entry.byte_size, &device, stream)?;
}
Ok(())
}
}
impl Default for PrefetchPlan {
fn default() -> Self {
Self::new()
}
}
pub fn apply_migration_policy(
ptr: u64,
byte_size: usize,
policy: &MigrationPolicy,
device: &Device,
) -> CudaResult<()> {
match policy {
MigrationPolicy::Default => Ok(()),
MigrationPolicy::ReadMostly => mem_advise(ptr, byte_size, MemAdvice::SetReadMostly, device),
MigrationPolicy::PreferDevice(_ordinal) => {
mem_advise(ptr, byte_size, MemAdvice::SetPreferredLocation, device)
}
MigrationPolicy::PreferHost => {
mem_advise(ptr, byte_size, MemAdvice::SetPreferredLocation, device)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn migration_policy_default_produces_empty_advice() {
let pairs = MigrationPolicy::Default.to_advice_pairs();
assert!(pairs.is_empty());
}
#[test]
fn migration_policy_read_mostly_advice() {
let pairs = MigrationPolicy::ReadMostly.to_advice_pairs();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0], MemAdvice::SetReadMostly);
}
#[test]
fn migration_policy_prefer_device_advice() {
let pairs = MigrationPolicy::PreferDevice(0).to_advice_pairs();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0], MemAdvice::SetPreferredLocation);
}
#[test]
fn migration_policy_prefer_host_advice() {
let pairs = MigrationPolicy::PreferHost.to_advice_pairs();
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0], MemAdvice::SetPreferredLocation);
}
#[test]
fn migration_policy_is_default() {
assert!(MigrationPolicy::Default.is_default());
assert!(!MigrationPolicy::ReadMostly.is_default());
assert!(!MigrationPolicy::PreferDevice(0).is_default());
assert!(!MigrationPolicy::PreferHost.is_default());
}
#[test]
fn migration_policy_display() {
let s = format!("{}", MigrationPolicy::PreferDevice(2));
assert!(s.contains("PreferDevice(2)"));
let s2 = format!("{}", MigrationPolicy::Default);
assert!(s2.contains("Default"));
}
#[test]
fn hints_for_buffer_rejects_zero_size() {
let result = ManagedMemoryHints::for_buffer(0x1000, 0);
assert!(result.is_err());
}
#[test]
fn hints_for_buffer_valid() {
let hints = ManagedMemoryHints::for_buffer(0x1000, 4096);
assert!(hints.is_ok());
let hints = hints.ok();
assert!(hints.is_some());
let hints = hints.map(|h| {
assert_eq!(h.ptr(), 0x1000);
assert_eq!(h.byte_size(), 4096);
});
let _ = hints;
}
#[test]
fn hints_accessors() {
let hints = ManagedMemoryHints::for_buffer(0xDEAD, 512);
if let Ok(h) = hints {
assert_eq!(h.ptr(), 0xDEAD);
assert_eq!(h.byte_size(), 512);
}
}
#[test]
fn prefetch_plan_new_is_empty() {
let plan = PrefetchPlan::new();
assert!(plan.is_empty());
assert_eq!(plan.len(), 0);
}
#[test]
fn prefetch_plan_default_is_empty() {
let plan = PrefetchPlan::default();
assert!(plan.is_empty());
}
#[test]
fn prefetch_plan_add_and_len() {
let mut plan = PrefetchPlan::new();
plan.add(0x1000, 4096, 0).add(0x2000, 8192, 1);
assert_eq!(plan.len(), 2);
assert!(!plan.is_empty());
let entries = plan.entries();
assert_eq!(entries[0].ptr, 0x1000);
assert_eq!(entries[0].byte_size, 4096);
assert_eq!(entries[0].device_ordinal, 0);
assert_eq!(entries[1].ptr, 0x2000);
assert_eq!(entries[1].byte_size, 8192);
assert_eq!(entries[1].device_ordinal, 1);
}
#[test]
fn prefetch_plan_chaining() {
let mut plan = PrefetchPlan::new();
plan.add(0x100, 100, 0)
.add(0x200, 200, 0)
.add(0x300, 300, 0);
assert_eq!(plan.len(), 3);
}
#[test]
fn prefetch_range_rejects_zero_count() {
if let Ok(dev) = Device::get(0) {
let hints = ManagedMemoryHints::for_buffer(0x1000, 4096);
let _ = (hints, dev);
}
let _: fn(&ManagedMemoryHints, usize, usize, &Device, &Stream) -> CudaResult<()> =
ManagedMemoryHints::prefetch_range;
}
#[test]
fn prefetch_range_out_of_bounds_detected() {
let byte_size: usize = 4096;
let offset: usize = 4000;
let count: usize = 200;
let end = offset.checked_add(count);
assert!(end.is_some());
let end = end.map(|e| e > byte_size);
assert_eq!(end, Some(true));
}
#[test]
fn apply_policy_default_is_noop() {
let fake_dev: Device = unsafe { std::mem::zeroed() };
let result = apply_migration_policy(0x1000, 4096, &MigrationPolicy::Default, &fake_dev);
assert!(result.is_ok());
}
#[test]
fn signature_set_read_mostly() {
let _: fn(&ManagedMemoryHints, &Device) -> CudaResult<()> =
ManagedMemoryHints::set_read_mostly;
}
#[test]
fn signature_unset_read_mostly() {
let _: fn(&ManagedMemoryHints, &Device) -> CudaResult<()> =
ManagedMemoryHints::unset_read_mostly;
}
#[test]
fn signature_prefetch_to() {
let _: fn(&ManagedMemoryHints, &Device, &Stream) -> CudaResult<()> =
ManagedMemoryHints::prefetch_to;
}
#[test]
fn signature_apply_policy() {
let _: fn(&ManagedMemoryHints, &MigrationPolicy, &Device) -> CudaResult<()> =
ManagedMemoryHints::apply_policy;
}
#[test]
fn signature_execute_plan() {
let _: fn(&PrefetchPlan, &Stream) -> CudaResult<()> = PrefetchPlan::execute;
}
}