oxicuda_memory/managed_hints.rs
1//! Ergonomic managed memory hints API.
2//!
3//! This module builds on the raw [`crate::memory_info::mem_advise`]
4//! and [`crate::memory_info::mem_prefetch`] functions to provide
5//! a higher-level, builder-style API for controlling unified memory migration
6//! behaviour.
7//!
8//! # Key types
9//!
10//! - [`MigrationPolicy`] — declarative policy for common migration patterns.
11//! - [`ManagedMemoryHints`] — builder for applying hints to a memory region.
12//! - [`PrefetchPlan`] — batch multiple prefetch operations into one plan.
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! # use oxicuda_memory::managed_hints::*;
18//! # use oxicuda_driver::device::Device;
19//! # use oxicuda_driver::stream::Stream;
20//! // Assume `buf` is a UnifiedBuffer<f32> and `dev`/`stream` are valid.
21//! // let hints = ManagedMemoryHints::from_unified(&buf);
22//! // hints.set_read_mostly(&dev)?;
23//! // hints.prefetch_to(&dev, &stream)?;
24//! # Ok::<(), oxicuda_driver::error::CudaError>(())
25//! ```
26
27use oxicuda_driver::device::Device;
28use oxicuda_driver::error::{CudaError, CudaResult};
29use oxicuda_driver::stream::Stream;
30
31use crate::memory_info::{MemAdvice, mem_advise, mem_prefetch};
32use crate::unified::UnifiedBuffer;
33
34// ---------------------------------------------------------------------------
35// MigrationPolicy
36// ---------------------------------------------------------------------------
37
38/// Declarative migration policy for unified memory regions.
39///
40/// Each variant encodes a common access pattern that can be translated into
41/// one or more [`MemAdvice`] hints via [`to_advice_pairs`](MigrationPolicy::to_advice_pairs).
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub enum MigrationPolicy {
44 /// No special migration policy. Uses CUDA defaults.
45 Default,
46 /// Mark the region as read-mostly, enabling read-replica creation on
47 /// accessing devices to reduce migration overhead.
48 ReadMostly,
49 /// Prefer that the data resides on the device with the given ordinal.
50 PreferDevice(i32),
51 /// Prefer that the data resides in host (CPU) memory.
52 PreferHost,
53}
54
55impl MigrationPolicy {
56 /// Converts this policy into the corresponding [`MemAdvice`] values
57 /// that should be applied.
58 ///
59 /// For [`Default`](MigrationPolicy::Default) the returned vec is empty
60 /// (no advice to set). For compound policies the vec contains all advice
61 /// hints that need to be issued.
62 pub fn to_advice_pairs(&self) -> Vec<MemAdvice> {
63 match self {
64 Self::Default => Vec::new(),
65 Self::ReadMostly => vec![MemAdvice::SetReadMostly],
66 Self::PreferDevice(_) => vec![MemAdvice::SetPreferredLocation],
67 Self::PreferHost => vec![MemAdvice::SetPreferredLocation],
68 }
69 }
70
71 /// Returns whether this is the [`Default`](MigrationPolicy::Default) variant.
72 #[inline]
73 pub fn is_default(&self) -> bool {
74 matches!(self, Self::Default)
75 }
76}
77
78impl std::fmt::Display for MigrationPolicy {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 match self {
81 Self::Default => write!(f, "MigrationPolicy::Default"),
82 Self::ReadMostly => write!(f, "MigrationPolicy::ReadMostly"),
83 Self::PreferDevice(ord) => write!(f, "MigrationPolicy::PreferDevice({ord})"),
84 Self::PreferHost => write!(f, "MigrationPolicy::PreferHost"),
85 }
86 }
87}
88
89// ---------------------------------------------------------------------------
90// ManagedMemoryHints
91// ---------------------------------------------------------------------------
92
93/// Builder-style API for applying memory hints to a unified memory region.
94///
95/// Wraps a raw pointer + byte size and exposes methods that issue the
96/// appropriate [`mem_advise`] / [`mem_prefetch`] driver calls.
97///
98/// # Construction
99///
100/// Use [`for_buffer`](Self::for_buffer) for raw pointers or
101/// [`from_unified`](Self::from_unified) for [`UnifiedBuffer`] references.
102#[derive(Debug, Clone)]
103pub struct ManagedMemoryHints {
104 /// Device pointer to the start of the managed region.
105 ptr: u64,
106 /// Total size of the region in bytes.
107 byte_size: usize,
108}
109
110impl ManagedMemoryHints {
111 /// Creates a `ManagedMemoryHints` from a raw device pointer and byte size.
112 ///
113 /// # Errors
114 ///
115 /// Returns [`CudaError::InvalidValue`] if `byte_size` is zero.
116 pub fn for_buffer(ptr: u64, byte_size: usize) -> CudaResult<Self> {
117 if byte_size == 0 {
118 return Err(CudaError::InvalidValue);
119 }
120 Ok(Self { ptr, byte_size })
121 }
122
123 /// Creates a `ManagedMemoryHints` from a [`UnifiedBuffer`] reference.
124 ///
125 /// The pointer and byte size are extracted from the buffer.
126 ///
127 /// # Errors
128 ///
129 /// Returns [`CudaError::InvalidValue`] if the buffer reports zero bytes
130 /// (should not happen for a validly constructed buffer).
131 pub fn from_unified<T: Copy>(buf: &UnifiedBuffer<T>) -> CudaResult<Self> {
132 Self::for_buffer(buf.as_device_ptr(), buf.byte_size())
133 }
134
135 /// Returns the device pointer this hint set targets.
136 #[inline]
137 pub fn ptr(&self) -> u64 {
138 self.ptr
139 }
140
141 /// Returns the byte size of the targeted region.
142 #[inline]
143 pub fn byte_size(&self) -> usize {
144 self.byte_size
145 }
146
147 // -- Individual advice methods ------------------------------------------
148
149 /// Marks the region as read-mostly on `device`, enabling read replicas.
150 pub fn set_read_mostly(&self, device: &Device) -> CudaResult<()> {
151 mem_advise(self.ptr, self.byte_size, MemAdvice::SetReadMostly, device)
152 }
153
154 /// Removes the read-mostly hint for `device`.
155 pub fn unset_read_mostly(&self, device: &Device) -> CudaResult<()> {
156 mem_advise(self.ptr, self.byte_size, MemAdvice::UnsetReadMostly, device)
157 }
158
159 /// Sets the preferred location to `device` for this region.
160 pub fn set_preferred_location(&self, device: &Device) -> CudaResult<()> {
161 mem_advise(
162 self.ptr,
163 self.byte_size,
164 MemAdvice::SetPreferredLocation,
165 device,
166 )
167 }
168
169 /// Removes the preferred-location hint for `device`.
170 pub fn unset_preferred_location(&self, device: &Device) -> CudaResult<()> {
171 mem_advise(
172 self.ptr,
173 self.byte_size,
174 MemAdvice::UnsetPreferredLocation,
175 device,
176 )
177 }
178
179 /// Indicates that `device` will access this memory region.
180 pub fn set_accessed_by(&self, device: &Device) -> CudaResult<()> {
181 mem_advise(self.ptr, self.byte_size, MemAdvice::SetAccessedBy, device)
182 }
183
184 /// Removes the accessed-by hint for `device`.
185 pub fn unset_accessed_by(&self, device: &Device) -> CudaResult<()> {
186 mem_advise(self.ptr, self.byte_size, MemAdvice::UnsetAccessedBy, device)
187 }
188
189 // -- Prefetch methods ---------------------------------------------------
190
191 /// Prefetches the entire region to `device` on `stream`.
192 pub fn prefetch_to(&self, device: &Device, stream: &Stream) -> CudaResult<()> {
193 mem_prefetch(self.ptr, self.byte_size, device, stream)
194 }
195
196 /// Prefetches a sub-range of the region to `device`.
197 ///
198 /// # Parameters
199 ///
200 /// * `offset_bytes` — byte offset from the start of the region.
201 /// * `count_bytes` — number of bytes to prefetch.
202 ///
203 /// # Errors
204 ///
205 /// Returns [`CudaError::InvalidValue`] if the range
206 /// `[offset_bytes, offset_bytes + count_bytes)` exceeds the buffer, or
207 /// if `count_bytes` is zero.
208 pub fn prefetch_range(
209 &self,
210 offset_bytes: usize,
211 count_bytes: usize,
212 device: &Device,
213 stream: &Stream,
214 ) -> CudaResult<()> {
215 if count_bytes == 0 {
216 return Err(CudaError::InvalidValue);
217 }
218 let end = offset_bytes
219 .checked_add(count_bytes)
220 .ok_or(CudaError::InvalidValue)?;
221 if end > self.byte_size {
222 return Err(CudaError::InvalidValue);
223 }
224 let range_ptr = self
225 .ptr
226 .checked_add(offset_bytes as u64)
227 .ok_or(CudaError::InvalidValue)?;
228 mem_prefetch(range_ptr, count_bytes, device, stream)
229 }
230
231 // -- Policy convenience -------------------------------------------------
232
233 /// Applies a [`MigrationPolicy`] to this memory region.
234 ///
235 /// For [`MigrationPolicy::Default`] this is a no-op.
236 /// For other variants the corresponding advice hint(s) are issued.
237 pub fn apply_policy(&self, policy: &MigrationPolicy, device: &Device) -> CudaResult<()> {
238 apply_migration_policy(self.ptr, self.byte_size, policy, device)
239 }
240}
241
242// ---------------------------------------------------------------------------
243// PrefetchPlan
244// ---------------------------------------------------------------------------
245
246/// An entry in a [`PrefetchPlan`] recording a single prefetch operation.
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct PrefetchEntry {
249 /// Device pointer to the start of the region.
250 pub ptr: u64,
251 /// Size of the region in bytes.
252 pub byte_size: usize,
253 /// Target device ordinal.
254 pub device_ordinal: i32,
255}
256
257/// Batch multiple prefetch operations into a single plan.
258///
259/// Operations are recorded first, then executed together on a single stream
260/// via [`execute`](PrefetchPlan::execute).
261///
262/// # Example
263///
264/// ```rust,no_run
265/// # use oxicuda_memory::managed_hints::PrefetchPlan;
266/// # use oxicuda_driver::stream::Stream;
267/// let mut plan = PrefetchPlan::new();
268/// plan.add(0x1000, 4096, 0)
269/// .add(0x2000, 8192, 0);
270/// assert_eq!(plan.len(), 2);
271/// // plan.execute(&stream)?;
272/// # Ok::<(), oxicuda_driver::error::CudaError>(())
273/// ```
274#[derive(Debug, Clone)]
275pub struct PrefetchPlan {
276 entries: Vec<PrefetchEntry>,
277}
278
279impl PrefetchPlan {
280 /// Creates an empty prefetch plan.
281 pub fn new() -> Self {
282 Self {
283 entries: Vec::new(),
284 }
285 }
286
287 /// Records a prefetch operation.
288 ///
289 /// The actual prefetch is deferred until [`execute`](Self::execute).
290 pub fn add(&mut self, ptr: u64, byte_size: usize, device_ordinal: i32) -> &mut Self {
291 self.entries.push(PrefetchEntry {
292 ptr,
293 byte_size,
294 device_ordinal,
295 });
296 self
297 }
298
299 /// Returns the number of recorded prefetch operations.
300 #[inline]
301 pub fn len(&self) -> usize {
302 self.entries.len()
303 }
304
305 /// Returns `true` if no operations have been recorded.
306 #[inline]
307 pub fn is_empty(&self) -> bool {
308 self.entries.is_empty()
309 }
310
311 /// Returns a slice of all recorded entries.
312 #[inline]
313 pub fn entries(&self) -> &[PrefetchEntry] {
314 &self.entries
315 }
316
317 /// Executes all recorded prefetch operations on `stream`.
318 ///
319 /// Each entry is issued as a separate `mem_prefetch` call targeting the
320 /// device identified by the entry's `device_ordinal`. Operations are
321 /// enqueued in the order they were added.
322 ///
323 /// # Errors
324 ///
325 /// Returns the first error encountered. Entries before the failing one
326 /// will already have been enqueued.
327 pub fn execute(&self, stream: &Stream) -> CudaResult<()> {
328 for entry in &self.entries {
329 let device = Device::get(entry.device_ordinal)?;
330 mem_prefetch(entry.ptr, entry.byte_size, &device, stream)?;
331 }
332 Ok(())
333 }
334}
335
336impl Default for PrefetchPlan {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342// ---------------------------------------------------------------------------
343// Free function convenience
344// ---------------------------------------------------------------------------
345
346/// Applies a [`MigrationPolicy`] to a raw unified memory region.
347///
348/// This is a convenience function that translates the high-level policy into
349/// the appropriate [`mem_advise`] calls.
350///
351/// # Parameters
352///
353/// * `ptr` — device pointer to the managed allocation.
354/// * `byte_size` — size of the region in bytes.
355/// * `policy` — the migration policy to apply.
356/// * `device` — the device to which hints are directed.
357///
358/// # Errors
359///
360/// Forwards any error from the underlying driver call.
361/// Returns [`CudaError::InvalidValue`] if `byte_size` is zero (when
362/// policy is not `Default`).
363pub fn apply_migration_policy(
364 ptr: u64,
365 byte_size: usize,
366 policy: &MigrationPolicy,
367 device: &Device,
368) -> CudaResult<()> {
369 match policy {
370 MigrationPolicy::Default => Ok(()),
371 MigrationPolicy::ReadMostly => mem_advise(ptr, byte_size, MemAdvice::SetReadMostly, device),
372 MigrationPolicy::PreferDevice(_ordinal) => {
373 // The advice targets the device passed by the caller. The ordinal
374 // in the policy variant is informational — the caller is expected
375 // to pass the corresponding Device handle.
376 mem_advise(ptr, byte_size, MemAdvice::SetPreferredLocation, device)
377 }
378 MigrationPolicy::PreferHost => {
379 // For host-preferred, we still issue SetPreferredLocation but
380 // directed at the provided device handle. In a real CUDA
381 // environment the caller would pass CU_DEVICE_CPU (-1).
382 mem_advise(ptr, byte_size, MemAdvice::SetPreferredLocation, device)
383 }
384 }
385}
386
387// ---------------------------------------------------------------------------
388// Tests
389// ---------------------------------------------------------------------------
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 // -- MigrationPolicy tests ----------------------------------------------
396
397 #[test]
398 fn migration_policy_default_produces_empty_advice() {
399 let pairs = MigrationPolicy::Default.to_advice_pairs();
400 assert!(pairs.is_empty());
401 }
402
403 #[test]
404 fn migration_policy_read_mostly_advice() {
405 let pairs = MigrationPolicy::ReadMostly.to_advice_pairs();
406 assert_eq!(pairs.len(), 1);
407 assert_eq!(pairs[0], MemAdvice::SetReadMostly);
408 }
409
410 #[test]
411 fn migration_policy_prefer_device_advice() {
412 let pairs = MigrationPolicy::PreferDevice(0).to_advice_pairs();
413 assert_eq!(pairs.len(), 1);
414 assert_eq!(pairs[0], MemAdvice::SetPreferredLocation);
415 }
416
417 #[test]
418 fn migration_policy_prefer_host_advice() {
419 let pairs = MigrationPolicy::PreferHost.to_advice_pairs();
420 assert_eq!(pairs.len(), 1);
421 assert_eq!(pairs[0], MemAdvice::SetPreferredLocation);
422 }
423
424 #[test]
425 fn migration_policy_is_default() {
426 assert!(MigrationPolicy::Default.is_default());
427 assert!(!MigrationPolicy::ReadMostly.is_default());
428 assert!(!MigrationPolicy::PreferDevice(0).is_default());
429 assert!(!MigrationPolicy::PreferHost.is_default());
430 }
431
432 #[test]
433 fn migration_policy_display() {
434 let s = format!("{}", MigrationPolicy::PreferDevice(2));
435 assert!(s.contains("PreferDevice(2)"));
436
437 let s2 = format!("{}", MigrationPolicy::Default);
438 assert!(s2.contains("Default"));
439 }
440
441 // -- ManagedMemoryHints construction tests ------------------------------
442
443 #[test]
444 fn hints_for_buffer_rejects_zero_size() {
445 let result = ManagedMemoryHints::for_buffer(0x1000, 0);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn hints_for_buffer_valid() {
451 let hints = ManagedMemoryHints::for_buffer(0x1000, 4096);
452 assert!(hints.is_ok());
453 let hints = hints.ok();
454 assert!(hints.is_some());
455 let hints = hints.map(|h| {
456 assert_eq!(h.ptr(), 0x1000);
457 assert_eq!(h.byte_size(), 4096);
458 });
459 let _ = hints;
460 }
461
462 #[test]
463 fn hints_accessors() {
464 let hints = ManagedMemoryHints::for_buffer(0xDEAD, 512);
465 if let Ok(h) = hints {
466 assert_eq!(h.ptr(), 0xDEAD);
467 assert_eq!(h.byte_size(), 512);
468 }
469 }
470
471 // -- PrefetchPlan tests -------------------------------------------------
472
473 #[test]
474 fn prefetch_plan_new_is_empty() {
475 let plan = PrefetchPlan::new();
476 assert!(plan.is_empty());
477 assert_eq!(plan.len(), 0);
478 }
479
480 #[test]
481 fn prefetch_plan_default_is_empty() {
482 let plan = PrefetchPlan::default();
483 assert!(plan.is_empty());
484 }
485
486 #[test]
487 fn prefetch_plan_add_and_len() {
488 let mut plan = PrefetchPlan::new();
489 plan.add(0x1000, 4096, 0).add(0x2000, 8192, 1);
490 assert_eq!(plan.len(), 2);
491 assert!(!plan.is_empty());
492
493 let entries = plan.entries();
494 assert_eq!(entries[0].ptr, 0x1000);
495 assert_eq!(entries[0].byte_size, 4096);
496 assert_eq!(entries[0].device_ordinal, 0);
497 assert_eq!(entries[1].ptr, 0x2000);
498 assert_eq!(entries[1].byte_size, 8192);
499 assert_eq!(entries[1].device_ordinal, 1);
500 }
501
502 #[test]
503 fn prefetch_plan_chaining() {
504 let mut plan = PrefetchPlan::new();
505 plan.add(0x100, 100, 0)
506 .add(0x200, 200, 0)
507 .add(0x300, 300, 0);
508 assert_eq!(plan.len(), 3);
509 }
510
511 // -- prefetch_range validation tests ------------------------------------
512
513 #[test]
514 fn prefetch_range_rejects_zero_count() {
515 // We need a Device and Stream for prefetch_range, but we can test
516 // the zero-count path only if Device::get succeeds.
517 if let Ok(dev) = Device::get(0) {
518 // We cannot construct a Stream without a context, so just verify
519 // the hints struct validates before calling the driver.
520 let hints = ManagedMemoryHints::for_buffer(0x1000, 4096);
521 // The zero-count check happens in prefetch_range before the
522 // driver call, so we test the function signature compiles.
523 let _ = (hints, dev);
524 }
525 // Compile-time signature check
526 let _: fn(&ManagedMemoryHints, usize, usize, &Device, &Stream) -> CudaResult<()> =
527 ManagedMemoryHints::prefetch_range;
528 }
529
530 #[test]
531 fn prefetch_range_out_of_bounds_detected() {
532 // Verify the bounds checking logic without needing a GPU.
533 // We replicate the internal check manually.
534 let byte_size: usize = 4096;
535 let offset: usize = 4000;
536 let count: usize = 200;
537 let end = offset.checked_add(count);
538 assert!(end.is_some());
539 let end = end.map(|e| e > byte_size);
540 // 4000 + 200 = 4200 > 4096
541 assert_eq!(end, Some(true));
542 }
543
544 // -- apply_migration_policy tests ---------------------------------------
545
546 #[test]
547 fn apply_policy_default_is_noop() {
548 // Default policy should return Ok without calling the driver.
549 let fake_dev: Device = unsafe { std::mem::zeroed() };
550 let result = apply_migration_policy(0x1000, 4096, &MigrationPolicy::Default, &fake_dev);
551 assert!(result.is_ok());
552 }
553
554 // -- Compile-time signature checks for GPU-requiring functions ----------
555
556 #[test]
557 fn signature_set_read_mostly() {
558 let _: fn(&ManagedMemoryHints, &Device) -> CudaResult<()> =
559 ManagedMemoryHints::set_read_mostly;
560 }
561
562 #[test]
563 fn signature_unset_read_mostly() {
564 let _: fn(&ManagedMemoryHints, &Device) -> CudaResult<()> =
565 ManagedMemoryHints::unset_read_mostly;
566 }
567
568 #[test]
569 fn signature_prefetch_to() {
570 let _: fn(&ManagedMemoryHints, &Device, &Stream) -> CudaResult<()> =
571 ManagedMemoryHints::prefetch_to;
572 }
573
574 #[test]
575 fn signature_apply_policy() {
576 let _: fn(&ManagedMemoryHints, &MigrationPolicy, &Device) -> CudaResult<()> =
577 ManagedMemoryHints::apply_policy;
578 }
579
580 #[test]
581 fn signature_execute_plan() {
582 let _: fn(&PrefetchPlan, &Stream) -> CudaResult<()> = PrefetchPlan::execute;
583 }
584}