Skip to main content

dynamo_memory/
nixl.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NIXL registration wrapper for storage types.
5
6mod agent;
7mod config;
8
9use super::{MemoryDescriptor, StorageKind};
10use std::any::Any;
11use std::fmt;
12use std::sync::Arc;
13
14pub use agent::NixlAgent;
15pub use config::NixlBackendConfig;
16
17pub use nixl_sys::{
18    Agent, MemType, NotificationMap, OptArgs, RegistrationHandle, XferDescList, XferOp, XferRequest,
19};
20pub use serde::{Deserialize, Serialize};
21
22/// Trait for storage types that can be registered with NIXL.
23pub trait NixlCompatible {
24    /// Get parameters needed for NIXL registration.
25    ///
26    /// Returns (ptr, size, mem_type, device_id)
27    fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
28}
29
30/// Combined trait for memory that can be registered with NIXL.
31///
32/// This supertrait enables type erasure via `Arc<dyn NixlMemory>`.
33/// Any type implementing both `MemoryDescriptor` and `NixlCompatible`
34/// automatically implements this trait via the blanket implementation.
35pub trait NixlMemory: MemoryDescriptor + NixlCompatible {}
36
37// Blanket impl - any type with both traits automatically implements NixlMemory
38impl<T: MemoryDescriptor + NixlCompatible + ?Sized> NixlMemory for T {}
39
40/// NIXL descriptor containing registration information.
41///
42/// This struct holds the information needed to describe a memory region
43/// to NIXL for transfer operations.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct NixlDescriptor {
46    /// Base address of the memory region.
47    pub addr: u64,
48    /// Size of the memory region in bytes.
49    pub size: usize,
50    /// Type of memory (host, device, etc.).
51    pub mem_type: MemType,
52    /// Device identifier (GPU index for device memory, 0 for host memory).
53    pub device_id: u64,
54}
55
56impl nixl_sys::MemoryRegion for NixlDescriptor {
57    unsafe fn as_ptr(&self) -> *const u8 {
58        self.addr as *const u8
59    }
60
61    fn size(&self) -> usize {
62        self.size
63    }
64}
65
66impl nixl_sys::NixlDescriptor for NixlDescriptor {
67    fn mem_type(&self) -> MemType {
68        self.mem_type
69    }
70
71    fn device_id(&self) -> u64 {
72        self.device_id
73    }
74}
75
76/// View trait for accessing registration information without unwrapping.
77pub trait RegisteredView {
78    /// Get the name of the NIXL agent that registered this memory.
79    fn agent_name(&self) -> &str;
80
81    /// Get the NIXL descriptor for this registered memory.
82    fn descriptor(&self) -> NixlDescriptor;
83}
84
85/// Wrapper for storage that has been registered with NIXL.
86///
87/// This wrapper ensures proper drop order: the registration handle is
88/// dropped before the storage, ensuring deregistration happens before
89/// the memory is freed.
90pub struct NixlRegistered<S: NixlCompatible> {
91    storage: S,
92    handle: Option<RegistrationHandle>,
93    agent_name: String,
94}
95
96impl<S: NixlCompatible> Drop for NixlRegistered<S> {
97    fn drop(&mut self) {
98        // Explicitly drop the registration handle first
99        drop(self.handle.take());
100        // Storage drops naturally after
101    }
102}
103
104impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
105    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106        f.debug_struct("NixlRegistered")
107            .field("storage", &self.storage)
108            .field("agent_name", &self.agent_name)
109            .field("handle", &self.handle.is_some())
110            .finish()
111    }
112}
113
114impl<S: MemoryDescriptor + NixlCompatible + 'static> MemoryDescriptor for NixlRegistered<S> {
115    fn addr(&self) -> usize {
116        self.storage.addr()
117    }
118
119    fn size(&self) -> usize {
120        self.storage.size()
121    }
122
123    fn storage_kind(&self) -> StorageKind {
124        self.storage.storage_kind()
125    }
126
127    fn as_any(&self) -> &dyn Any {
128        self
129    }
130
131    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
132        Some(self.descriptor())
133    }
134}
135
136impl<S: MemoryDescriptor + NixlCompatible> RegisteredView for NixlRegistered<S> {
137    fn agent_name(&self) -> &str {
138        &self.agent_name
139    }
140
141    fn descriptor(&self) -> NixlDescriptor {
142        let (ptr, size, mem_type, device_id) = self.storage.nixl_params();
143        NixlDescriptor {
144            addr: ptr as u64,
145            size,
146            mem_type,
147            device_id,
148        }
149    }
150}
151
152impl<S: MemoryDescriptor + NixlCompatible> NixlRegistered<S> {
153    /// Get a reference to the underlying storage.
154    pub fn storage(&self) -> &S {
155        &self.storage
156    }
157
158    /// Get a mutable reference to the underlying storage.
159    pub fn storage_mut(&mut self) -> &mut S {
160        &mut self.storage
161    }
162
163    /// Check if the registration handle is still valid.
164    pub fn is_registered(&self) -> bool {
165        self.handle.is_some()
166    }
167
168    /// Consume this wrapper and return the underlying storage.
169    ///
170    /// This will deregister the storage from NIXL.
171    pub fn into_storage(mut self) -> S {
172        drop(self.handle.take());
173        let mut this = std::mem::ManuallyDrop::new(self);
174        unsafe {
175            let storage = std::ptr::read(&this.storage);
176            std::ptr::drop_in_place(&mut this.agent_name);
177            storage
178        }
179    }
180}
181
182/// Register storage with a NIXL agent.
183///
184/// This consumes the storage and returns a `NixlRegistered` wrapper that
185/// manages the registration lifetime. The registration handle will be
186/// automatically dropped when the wrapper is dropped, ensuring proper
187/// cleanup order.
188///
189/// # Arguments
190/// * `storage` - The storage to register (consumed)
191/// * `agent` - The NIXL agent to register with
192/// * `opt` - Optional arguments for registration
193///
194/// # Returns
195/// A `NixlRegistered` wrapper containing the storage and registration handle.
196pub fn register_with_nixl<S>(
197    storage: S,
198    agent: &Agent,
199    opt: Option<&OptArgs>,
200) -> std::result::Result<NixlRegistered<S>, S>
201where
202    S: MemoryDescriptor + NixlCompatible,
203{
204    // let storage_kind = storage.storage_kind();
205
206    // // Determine if registration is needed based on storage type and available backends
207    // let should_register = match storage_kind {
208    //     StorageKind::System | StorageKind::Pinned => {
209    //         // System/Pinned memory needs UCX for remote transfers
210    //         agent.has_backend("UCX") || agent.has_backend("POSIX")
211    //     }
212    //     StorageKind::Device(_) => {
213    //         // Device memory needs UCX for remote transfers OR GDS for direct disk transfers
214    //         agent.has_backend("UCX") || agent.has_backend("GDS_MT")
215    //     }
216    //     StorageKind::Disk(_) => {
217    //         // Disk storage needs POSIX for regular I/O OR GDS for GPU direct I/O
218    //         agent.has_backend("POSIX") || agent.has_backend("GDS_MT")
219    //     } // StorageKind::Object(_) => {
220    //       //     // Object storage is always registered via NIXL's OBJ plugin
221    //       //     agent.has_backend("OBJ")
222    //       // }
223    // };
224
225    // this is not true for our future object storage. so let's rethink this.
226    // for object, if there is no device_id or device_id is 0, then we need to register
227    // alternatively, the object storage holds it's own internal metadata but does not
228    // expose as a nixl descriptor, thus ObjectStorag will by default like all other storage
229    // types have a None for nixl_descriptor(), and we will use the internal
230    if storage.nixl_descriptor().is_some() {
231        return Ok(NixlRegistered {
232            storage,
233            handle: None,
234            agent_name: agent.name().to_string(),
235        });
236    }
237
238    // Get NIXL parameters
239    let (ptr, size, mem_type, device_id) = storage.nixl_params();
240
241    // Create a NIXL descriptor for registration
242    let descriptor = NixlDescriptor {
243        addr: ptr as u64,
244        size,
245        mem_type,
246        device_id,
247    };
248
249    match agent.register_memory(&descriptor, opt) {
250        Ok(handle) => Ok(NixlRegistered {
251            storage,
252            handle: Some(handle),
253            agent_name: agent.name().to_string(),
254        }),
255        Err(_) => Err(storage),
256    }
257}
258
259// =============================================================================
260// Arc<dyn NixlMemory> support
261// =============================================================================
262
263impl NixlCompatible for Arc<dyn NixlMemory + Send + Sync> {
264    fn nixl_params(&self) -> (*const u8, usize, MemType, u64) {
265        (**self).nixl_params()
266    }
267}
268
269impl MemoryDescriptor for Arc<dyn NixlMemory + Send + Sync> {
270    fn addr(&self) -> usize {
271        (**self).addr()
272    }
273
274    fn size(&self) -> usize {
275        (**self).size()
276    }
277
278    fn storage_kind(&self) -> StorageKind {
279        (**self).storage_kind()
280    }
281
282    fn as_any(&self) -> &dyn Any {
283        (**self).as_any()
284    }
285
286    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
287        (**self).nixl_descriptor()
288    }
289}
290
291// =============================================================================
292// Extension trait for ergonomic API
293// =============================================================================
294
295/// Extension trait providing ergonomic `.register()` method for NIXL registration.
296///
297/// This trait is automatically implemented for all types that implement both
298/// `MemoryDescriptor` and `NixlCompatible`. Import this trait to use the
299/// method syntax:
300///
301///
302pub trait NixlRegisterExt: MemoryDescriptor + NixlCompatible + Sized {
303    /// Get this memory as NIXL-registered.
304    ///
305    /// This operation is idempotent - it's a no-op if the memory is already registered.
306    ///
307    /// # Arguments
308    /// * `agent` - The NIXL agent to register with
309    /// * `opt` - Optional arguments for registration
310    ///
311    /// # Returns
312    /// A `NixlRegistered` wrapper on success, or the original storage on failure.
313    fn register(
314        self,
315        agent: &NixlAgent,
316        opt: Option<&OptArgs>,
317    ) -> std::result::Result<NixlRegistered<Self>, Self> {
318        register_with_nixl(self, agent, opt)
319    }
320}
321
322// Blanket impl for all compatible types
323impl<T: MemoryDescriptor + NixlCompatible + Sized> NixlRegisterExt for T {}