Skip to main content

oxicuda_memory/
managed_hints.rs

1//! Ergonomic managed memory hints API.
2//!
3//! This module builds on the raw [`crate::memory_info::mem_advise`]
4//! and [`crate::memory_info::mem_prefetch`] functions to provide
5//! a higher-level, builder-style API for controlling unified memory migration
6//! behaviour.
7//!
8//! # Key types
9//!
10//! - [`MigrationPolicy`] — declarative policy for common migration patterns.
11//! - [`ManagedMemoryHints`] — builder for applying hints to a memory region.
12//! - [`PrefetchPlan`] — batch multiple prefetch operations into one plan.
13//!
14//! # Example
15//!
16//! ```rust,no_run
17//! # use oxicuda_memory::managed_hints::*;
18//! # use oxicuda_driver::device::Device;
19//! # use oxicuda_driver::stream::Stream;
20//! // Assume `buf` is a UnifiedBuffer<f32> and `dev`/`stream` are valid.
21//! // let hints = ManagedMemoryHints::from_unified(&buf);
22//! // hints.set_read_mostly(&dev)?;
23//! // hints.prefetch_to(&dev, &stream)?;
24//! # Ok::<(), oxicuda_driver::error::CudaError>(())
25//! ```
26
27use oxicuda_driver::device::Device;
28use oxicuda_driver::error::{CudaError, CudaResult};
29use oxicuda_driver::stream::Stream;
30
31use crate::memory_info::{MemAdvice, mem_advise, mem_prefetch};
32use crate::unified::UnifiedBuffer;
33
34// ---------------------------------------------------------------------------
35// MigrationPolicy
36// ---------------------------------------------------------------------------
37
38/// Declarative migration policy for unified memory regions.
39///
40/// Each variant encodes a common access pattern that can be translated into
41/// one or more [`MemAdvice`] hints via [`to_advice_pairs`](MigrationPolicy::to_advice_pairs).
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
43pub enum MigrationPolicy {
44    /// No special migration policy. Uses CUDA defaults.
45    Default,
46    /// Mark the region as read-mostly, enabling read-replica creation on
47    /// accessing devices to reduce migration overhead.
48    ReadMostly,
49    /// Prefer that the data resides on the device with the given ordinal.
50    PreferDevice(i32),
51    /// Prefer that the data resides in host (CPU) memory.
52    PreferHost,
53}
54
55impl MigrationPolicy {
56    /// Converts this policy into the corresponding [`MemAdvice`] values
57    /// that should be applied.
58    ///
59    /// For [`Default`](MigrationPolicy::Default) the returned vec is empty
60    /// (no advice to set). For compound policies the vec contains all advice
61    /// hints that need to be issued.
62    pub fn to_advice_pairs(&self) -> Vec<MemAdvice> {
63        match self {
64            Self::Default => Vec::new(),
65            Self::ReadMostly => vec![MemAdvice::SetReadMostly],
66            Self::PreferDevice(_) => vec![MemAdvice::SetPreferredLocation],
67            Self::PreferHost => vec![MemAdvice::SetPreferredLocation],
68        }
69    }
70
71    /// Returns whether this is the [`Default`](MigrationPolicy::Default) variant.
72    #[inline]
73    pub fn is_default(&self) -> bool {
74        matches!(self, Self::Default)
75    }
76}
77
78impl std::fmt::Display for MigrationPolicy {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match self {
81            Self::Default => write!(f, "MigrationPolicy::Default"),
82            Self::ReadMostly => write!(f, "MigrationPolicy::ReadMostly"),
83            Self::PreferDevice(ord) => write!(f, "MigrationPolicy::PreferDevice({ord})"),
84            Self::PreferHost => write!(f, "MigrationPolicy::PreferHost"),
85        }
86    }
87}
88
89// ---------------------------------------------------------------------------
90// ManagedMemoryHints
91// ---------------------------------------------------------------------------
92
93/// Builder-style API for applying memory hints to a unified memory region.
94///
95/// Wraps a raw pointer + byte size and exposes methods that issue the
96/// appropriate [`mem_advise`] / [`mem_prefetch`] driver calls.
97///
98/// # Construction
99///
100/// Use [`for_buffer`](Self::for_buffer) for raw pointers or
101/// [`from_unified`](Self::from_unified) for [`UnifiedBuffer`] references.
102#[derive(Debug, Clone)]
103pub struct ManagedMemoryHints {
104    /// Device pointer to the start of the managed region.
105    ptr: u64,
106    /// Total size of the region in bytes.
107    byte_size: usize,
108}
109
110impl ManagedMemoryHints {
111    /// Creates a `ManagedMemoryHints` from a raw device pointer and byte size.
112    ///
113    /// # Errors
114    ///
115    /// Returns [`CudaError::InvalidValue`] if `byte_size` is zero.
116    pub fn for_buffer(ptr: u64, byte_size: usize) -> CudaResult<Self> {
117        if byte_size == 0 {
118            return Err(CudaError::InvalidValue);
119        }
120        Ok(Self { ptr, byte_size })
121    }
122
123    /// Creates a `ManagedMemoryHints` from a [`UnifiedBuffer`] reference.
124    ///
125    /// The pointer and byte size are extracted from the buffer.
126    ///
127    /// # Errors
128    ///
129    /// Returns [`CudaError::InvalidValue`] if the buffer reports zero bytes
130    /// (should not happen for a validly constructed buffer).
131    pub fn from_unified<T: Copy>(buf: &UnifiedBuffer<T>) -> CudaResult<Self> {
132        Self::for_buffer(buf.as_device_ptr(), buf.byte_size())
133    }
134
135    /// Returns the device pointer this hint set targets.
136    #[inline]
137    pub fn ptr(&self) -> u64 {
138        self.ptr
139    }
140
141    /// Returns the byte size of the targeted region.
142    #[inline]
143    pub fn byte_size(&self) -> usize {
144        self.byte_size
145    }
146
147    // -- Individual advice methods ------------------------------------------
148
149    /// Marks the region as read-mostly on `device`, enabling read replicas.
150    pub fn set_read_mostly(&self, device: &Device) -> CudaResult<()> {
151        mem_advise(self.ptr, self.byte_size, MemAdvice::SetReadMostly, device)
152    }
153
154    /// Removes the read-mostly hint for `device`.
155    pub fn unset_read_mostly(&self, device: &Device) -> CudaResult<()> {
156        mem_advise(self.ptr, self.byte_size, MemAdvice::UnsetReadMostly, device)
157    }
158
159    /// Sets the preferred location to `device` for this region.
160    pub fn set_preferred_location(&self, device: &Device) -> CudaResult<()> {
161        mem_advise(
162            self.ptr,
163            self.byte_size,
164            MemAdvice::SetPreferredLocation,
165            device,
166        )
167    }
168
169    /// Removes the preferred-location hint for `device`.
170    pub fn unset_preferred_location(&self, device: &Device) -> CudaResult<()> {
171        mem_advise(
172            self.ptr,
173            self.byte_size,
174            MemAdvice::UnsetPreferredLocation,
175            device,
176        )
177    }
178
179    /// Indicates that `device` will access this memory region.
180    pub fn set_accessed_by(&self, device: &Device) -> CudaResult<()> {
181        mem_advise(self.ptr, self.byte_size, MemAdvice::SetAccessedBy, device)
182    }
183
184    /// Removes the accessed-by hint for `device`.
185    pub fn unset_accessed_by(&self, device: &Device) -> CudaResult<()> {
186        mem_advise(self.ptr, self.byte_size, MemAdvice::UnsetAccessedBy, device)
187    }
188
189    // -- Prefetch methods ---------------------------------------------------
190
191    /// Prefetches the entire region to `device` on `stream`.
192    pub fn prefetch_to(&self, device: &Device, stream: &Stream) -> CudaResult<()> {
193        mem_prefetch(self.ptr, self.byte_size, device, stream)
194    }
195
196    /// Prefetches a sub-range of the region to `device`.
197    ///
198    /// # Parameters
199    ///
200    /// * `offset_bytes` — byte offset from the start of the region.
201    /// * `count_bytes` — number of bytes to prefetch.
202    ///
203    /// # Errors
204    ///
205    /// Returns [`CudaError::InvalidValue`] if the range
206    /// `[offset_bytes, offset_bytes + count_bytes)` exceeds the buffer, or
207    /// if `count_bytes` is zero.
208    pub fn prefetch_range(
209        &self,
210        offset_bytes: usize,
211        count_bytes: usize,
212        device: &Device,
213        stream: &Stream,
214    ) -> CudaResult<()> {
215        if count_bytes == 0 {
216            return Err(CudaError::InvalidValue);
217        }
218        let end = offset_bytes
219            .checked_add(count_bytes)
220            .ok_or(CudaError::InvalidValue)?;
221        if end > self.byte_size {
222            return Err(CudaError::InvalidValue);
223        }
224        let range_ptr = self
225            .ptr
226            .checked_add(offset_bytes as u64)
227            .ok_or(CudaError::InvalidValue)?;
228        mem_prefetch(range_ptr, count_bytes, device, stream)
229    }
230
231    // -- Policy convenience -------------------------------------------------
232
233    /// Applies a [`MigrationPolicy`] to this memory region.
234    ///
235    /// For [`MigrationPolicy::Default`] this is a no-op.
236    /// For other variants the corresponding advice hint(s) are issued.
237    pub fn apply_policy(&self, policy: &MigrationPolicy, device: &Device) -> CudaResult<()> {
238        apply_migration_policy(self.ptr, self.byte_size, policy, device)
239    }
240}
241
242// ---------------------------------------------------------------------------
243// PrefetchPlan
244// ---------------------------------------------------------------------------
245
246/// An entry in a [`PrefetchPlan`] recording a single prefetch operation.
247#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub struct PrefetchEntry {
249    /// Device pointer to the start of the region.
250    pub ptr: u64,
251    /// Size of the region in bytes.
252    pub byte_size: usize,
253    /// Target device ordinal.
254    pub device_ordinal: i32,
255}
256
257/// Batch multiple prefetch operations into a single plan.
258///
259/// Operations are recorded first, then executed together on a single stream
260/// via [`execute`](PrefetchPlan::execute).
261///
262/// # Example
263///
264/// ```rust,no_run
265/// # use oxicuda_memory::managed_hints::PrefetchPlan;
266/// # use oxicuda_driver::stream::Stream;
267/// let mut plan = PrefetchPlan::new();
268/// plan.add(0x1000, 4096, 0)
269///     .add(0x2000, 8192, 0);
270/// assert_eq!(plan.len(), 2);
271/// // plan.execute(&stream)?;
272/// # Ok::<(), oxicuda_driver::error::CudaError>(())
273/// ```
274#[derive(Debug, Clone)]
275pub struct PrefetchPlan {
276    entries: Vec<PrefetchEntry>,
277}
278
279impl PrefetchPlan {
280    /// Creates an empty prefetch plan.
281    pub fn new() -> Self {
282        Self {
283            entries: Vec::new(),
284        }
285    }
286
287    /// Records a prefetch operation.
288    ///
289    /// The actual prefetch is deferred until [`execute`](Self::execute).
290    pub fn add(&mut self, ptr: u64, byte_size: usize, device_ordinal: i32) -> &mut Self {
291        self.entries.push(PrefetchEntry {
292            ptr,
293            byte_size,
294            device_ordinal,
295        });
296        self
297    }
298
299    /// Returns the number of recorded prefetch operations.
300    #[inline]
301    pub fn len(&self) -> usize {
302        self.entries.len()
303    }
304
305    /// Returns `true` if no operations have been recorded.
306    #[inline]
307    pub fn is_empty(&self) -> bool {
308        self.entries.is_empty()
309    }
310
311    /// Returns a slice of all recorded entries.
312    #[inline]
313    pub fn entries(&self) -> &[PrefetchEntry] {
314        &self.entries
315    }
316
317    /// Executes all recorded prefetch operations on `stream`.
318    ///
319    /// Each entry is issued as a separate `mem_prefetch` call targeting the
320    /// device identified by the entry's `device_ordinal`. Operations are
321    /// enqueued in the order they were added.
322    ///
323    /// # Errors
324    ///
325    /// Returns the first error encountered. Entries before the failing one
326    /// will already have been enqueued.
327    pub fn execute(&self, stream: &Stream) -> CudaResult<()> {
328        for entry in &self.entries {
329            let device = Device::get(entry.device_ordinal)?;
330            mem_prefetch(entry.ptr, entry.byte_size, &device, stream)?;
331        }
332        Ok(())
333    }
334}
335
336impl Default for PrefetchPlan {
337    fn default() -> Self {
338        Self::new()
339    }
340}
341
342// ---------------------------------------------------------------------------
343// Free function convenience
344// ---------------------------------------------------------------------------
345
346/// Applies a [`MigrationPolicy`] to a raw unified memory region.
347///
348/// This is a convenience function that translates the high-level policy into
349/// the appropriate [`mem_advise`] calls.
350///
351/// # Parameters
352///
353/// * `ptr` — device pointer to the managed allocation.
354/// * `byte_size` — size of the region in bytes.
355/// * `policy` — the migration policy to apply.
356/// * `device` — the device to which hints are directed.
357///
358/// # Errors
359///
360/// Forwards any error from the underlying driver call.
361/// Returns [`CudaError::InvalidValue`] if `byte_size` is zero (when
362/// policy is not `Default`).
363pub fn apply_migration_policy(
364    ptr: u64,
365    byte_size: usize,
366    policy: &MigrationPolicy,
367    device: &Device,
368) -> CudaResult<()> {
369    match policy {
370        MigrationPolicy::Default => Ok(()),
371        MigrationPolicy::ReadMostly => mem_advise(ptr, byte_size, MemAdvice::SetReadMostly, device),
372        MigrationPolicy::PreferDevice(_ordinal) => {
373            // The advice targets the device passed by the caller. The ordinal
374            // in the policy variant is informational — the caller is expected
375            // to pass the corresponding Device handle.
376            mem_advise(ptr, byte_size, MemAdvice::SetPreferredLocation, device)
377        }
378        MigrationPolicy::PreferHost => {
379            // For host-preferred, we still issue SetPreferredLocation but
380            // directed at the provided device handle. In a real CUDA
381            // environment the caller would pass CU_DEVICE_CPU (-1).
382            mem_advise(ptr, byte_size, MemAdvice::SetPreferredLocation, device)
383        }
384    }
385}
386
387// ---------------------------------------------------------------------------
388// Tests
389// ---------------------------------------------------------------------------
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    // -- MigrationPolicy tests ----------------------------------------------
396
397    #[test]
398    fn migration_policy_default_produces_empty_advice() {
399        let pairs = MigrationPolicy::Default.to_advice_pairs();
400        assert!(pairs.is_empty());
401    }
402
403    #[test]
404    fn migration_policy_read_mostly_advice() {
405        let pairs = MigrationPolicy::ReadMostly.to_advice_pairs();
406        assert_eq!(pairs.len(), 1);
407        assert_eq!(pairs[0], MemAdvice::SetReadMostly);
408    }
409
410    #[test]
411    fn migration_policy_prefer_device_advice() {
412        let pairs = MigrationPolicy::PreferDevice(0).to_advice_pairs();
413        assert_eq!(pairs.len(), 1);
414        assert_eq!(pairs[0], MemAdvice::SetPreferredLocation);
415    }
416
417    #[test]
418    fn migration_policy_prefer_host_advice() {
419        let pairs = MigrationPolicy::PreferHost.to_advice_pairs();
420        assert_eq!(pairs.len(), 1);
421        assert_eq!(pairs[0], MemAdvice::SetPreferredLocation);
422    }
423
424    #[test]
425    fn migration_policy_is_default() {
426        assert!(MigrationPolicy::Default.is_default());
427        assert!(!MigrationPolicy::ReadMostly.is_default());
428        assert!(!MigrationPolicy::PreferDevice(0).is_default());
429        assert!(!MigrationPolicy::PreferHost.is_default());
430    }
431
432    #[test]
433    fn migration_policy_display() {
434        let s = format!("{}", MigrationPolicy::PreferDevice(2));
435        assert!(s.contains("PreferDevice(2)"));
436
437        let s2 = format!("{}", MigrationPolicy::Default);
438        assert!(s2.contains("Default"));
439    }
440
441    // -- ManagedMemoryHints construction tests ------------------------------
442
443    #[test]
444    fn hints_for_buffer_rejects_zero_size() {
445        let result = ManagedMemoryHints::for_buffer(0x1000, 0);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn hints_for_buffer_valid() {
451        let hints = ManagedMemoryHints::for_buffer(0x1000, 4096);
452        assert!(hints.is_ok());
453        let hints = hints.ok();
454        assert!(hints.is_some());
455        let hints = hints.map(|h| {
456            assert_eq!(h.ptr(), 0x1000);
457            assert_eq!(h.byte_size(), 4096);
458        });
459        let _ = hints;
460    }
461
462    #[test]
463    fn hints_accessors() {
464        let hints = ManagedMemoryHints::for_buffer(0xDEAD, 512);
465        if let Ok(h) = hints {
466            assert_eq!(h.ptr(), 0xDEAD);
467            assert_eq!(h.byte_size(), 512);
468        }
469    }
470
471    // -- PrefetchPlan tests -------------------------------------------------
472
473    #[test]
474    fn prefetch_plan_new_is_empty() {
475        let plan = PrefetchPlan::new();
476        assert!(plan.is_empty());
477        assert_eq!(plan.len(), 0);
478    }
479
480    #[test]
481    fn prefetch_plan_default_is_empty() {
482        let plan = PrefetchPlan::default();
483        assert!(plan.is_empty());
484    }
485
486    #[test]
487    fn prefetch_plan_add_and_len() {
488        let mut plan = PrefetchPlan::new();
489        plan.add(0x1000, 4096, 0).add(0x2000, 8192, 1);
490        assert_eq!(plan.len(), 2);
491        assert!(!plan.is_empty());
492
493        let entries = plan.entries();
494        assert_eq!(entries[0].ptr, 0x1000);
495        assert_eq!(entries[0].byte_size, 4096);
496        assert_eq!(entries[0].device_ordinal, 0);
497        assert_eq!(entries[1].ptr, 0x2000);
498        assert_eq!(entries[1].byte_size, 8192);
499        assert_eq!(entries[1].device_ordinal, 1);
500    }
501
502    #[test]
503    fn prefetch_plan_chaining() {
504        let mut plan = PrefetchPlan::new();
505        plan.add(0x100, 100, 0)
506            .add(0x200, 200, 0)
507            .add(0x300, 300, 0);
508        assert_eq!(plan.len(), 3);
509    }
510
511    // -- prefetch_range validation tests ------------------------------------
512
513    #[test]
514    fn prefetch_range_rejects_zero_count() {
515        // We need a Device and Stream for prefetch_range, but we can test
516        // the zero-count path only if Device::get succeeds.
517        if let Ok(dev) = Device::get(0) {
518            // We cannot construct a Stream without a context, so just verify
519            // the hints struct validates before calling the driver.
520            let hints = ManagedMemoryHints::for_buffer(0x1000, 4096);
521            // The zero-count check happens in prefetch_range before the
522            // driver call, so we test the function signature compiles.
523            let _ = (hints, dev);
524        }
525        // Compile-time signature check
526        let _: fn(&ManagedMemoryHints, usize, usize, &Device, &Stream) -> CudaResult<()> =
527            ManagedMemoryHints::prefetch_range;
528    }
529
530    #[test]
531    fn prefetch_range_out_of_bounds_detected() {
532        // Verify the bounds checking logic without needing a GPU.
533        // We replicate the internal check manually.
534        let byte_size: usize = 4096;
535        let offset: usize = 4000;
536        let count: usize = 200;
537        let end = offset.checked_add(count);
538        assert!(end.is_some());
539        let end = end.map(|e| e > byte_size);
540        // 4000 + 200 = 4200 > 4096
541        assert_eq!(end, Some(true));
542    }
543
544    // -- apply_migration_policy tests ---------------------------------------
545
546    #[test]
547    fn apply_policy_default_is_noop() {
548        // Default policy should return Ok without calling the driver.
549        let fake_dev: Device = unsafe { std::mem::zeroed() };
550        let result = apply_migration_policy(0x1000, 4096, &MigrationPolicy::Default, &fake_dev);
551        assert!(result.is_ok());
552    }
553
554    // -- Compile-time signature checks for GPU-requiring functions ----------
555
556    #[test]
557    fn signature_set_read_mostly() {
558        let _: fn(&ManagedMemoryHints, &Device) -> CudaResult<()> =
559            ManagedMemoryHints::set_read_mostly;
560    }
561
562    #[test]
563    fn signature_unset_read_mostly() {
564        let _: fn(&ManagedMemoryHints, &Device) -> CudaResult<()> =
565            ManagedMemoryHints::unset_read_mostly;
566    }
567
568    #[test]
569    fn signature_prefetch_to() {
570        let _: fn(&ManagedMemoryHints, &Device, &Stream) -> CudaResult<()> =
571            ManagedMemoryHints::prefetch_to;
572    }
573
574    #[test]
575    fn signature_apply_policy() {
576        let _: fn(&ManagedMemoryHints, &MigrationPolicy, &Device) -> CudaResult<()> =
577            ManagedMemoryHints::apply_policy;
578    }
579
580    #[test]
581    fn signature_execute_plan() {
582        let _: fn(&PrefetchPlan, &Stream) -> CudaResult<()> = PrefetchPlan::execute;
583    }
584}