Skip to main content

dynamo_memory/
nixl.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NIXL registration wrapper for storage types.
5
6mod agent;
7mod config;
8
9use super::{MemoryDescription, StorageKind};
10use std::any::Any;
11use std::fmt;
12
13pub use agent::NixlAgent;
14pub use config::NixlBackendConfig;
15
16pub use nixl_sys::{MemType, OptArgs, RegistrationHandle};
17pub use serde::{Deserialize, Serialize};
18
19/// Trait for storage types that can be registered with NIXL.
20pub trait NixlCompatible {
21    /// Get parameters needed for NIXL registration.
22    ///
23    /// Returns (ptr, size, mem_type, device_id)
24    fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
25}
26
27/// NIXL descriptor containing registration information.
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct NixlDescriptor {
30    pub addr: u64,
31    pub size: usize,
32    pub mem_type: MemType,
33    pub device_id: u64,
34}
35
36impl nixl_sys::MemoryRegion for NixlDescriptor {
37    unsafe fn as_ptr(&self) -> *const u8 {
38        self.addr as *const u8
39    }
40
41    fn size(&self) -> usize {
42        self.size
43    }
44}
45
46impl nixl_sys::NixlDescriptor for NixlDescriptor {
47    fn mem_type(&self) -> MemType {
48        self.mem_type
49    }
50
51    fn device_id(&self) -> u64 {
52        self.device_id
53    }
54}
55
56/// View trait for accessing registration information without unwrapping.
57pub trait RegisteredView {
58    /// Get the name of the NIXL agent that registered this memory.
59    fn agent_name(&self) -> &str;
60
61    /// Get the NIXL descriptor for this registered memory.
62    fn descriptor(&self) -> NixlDescriptor;
63}
64
65/// Wrapper for storage that has been registered with NIXL.
66///
67/// This wrapper ensures proper drop order: the registration handle is
68/// dropped before the storage, ensuring deregistration happens before
69/// the memory is freed.
70pub struct NixlRegistered<S: NixlCompatible> {
71    storage: S,
72    handle: Option<RegistrationHandle>,
73    agent_name: String,
74}
75
76impl<S: NixlCompatible> Drop for NixlRegistered<S> {
77    fn drop(&mut self) {
78        // Explicitly drop the registration handle first
79        drop(self.handle.take());
80        // Storage drops naturally after
81    }
82}
83
84impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        f.debug_struct("NixlRegistered")
87            .field("storage", &self.storage)
88            .field("agent_name", &self.agent_name)
89            .field("handle", &self.handle.is_some())
90            .finish()
91    }
92}
93
94impl<S: MemoryDescription + NixlCompatible + 'static> MemoryDescription for NixlRegistered<S> {
95    fn addr(&self) -> usize {
96        self.storage.addr()
97    }
98
99    fn size(&self) -> usize {
100        self.storage.size()
101    }
102
103    fn storage_kind(&self) -> StorageKind {
104        self.storage.storage_kind()
105    }
106
107    fn as_any(&self) -> &dyn Any {
108        self
109    }
110
111    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
112        Some(self.descriptor())
113    }
114}
115
116impl<S: MemoryDescription + NixlCompatible> RegisteredView for NixlRegistered<S> {
117    fn agent_name(&self) -> &str {
118        &self.agent_name
119    }
120
121    fn descriptor(&self) -> NixlDescriptor {
122        let (ptr, size, mem_type, device_id) = self.storage.nixl_params();
123        NixlDescriptor {
124            addr: ptr as u64,
125            size,
126            mem_type,
127            device_id,
128        }
129    }
130}
131
132impl<S: MemoryDescription + NixlCompatible> NixlRegistered<S> {
133    /// Get a reference to the underlying storage.
134    pub fn storage(&self) -> &S {
135        &self.storage
136    }
137
138    /// Get a mutable reference to the underlying storage.
139    pub fn storage_mut(&mut self) -> &mut S {
140        &mut self.storage
141    }
142
143    /// Check if the registration handle is still valid.
144    pub fn is_registered(&self) -> bool {
145        self.handle.is_some()
146    }
147
148    /// Consume this wrapper and return the underlying storage.
149    ///
150    /// This will deregister the storage from NIXL.
151    pub fn into_storage(mut self) -> S {
152        drop(self.handle.take());
153        let mut this = std::mem::ManuallyDrop::new(self);
154        unsafe {
155            let storage = std::ptr::read(&this.storage);
156            std::ptr::drop_in_place(&mut this.agent_name);
157            storage
158        }
159    }
160}
161
162/// Register storage with a NIXL agent.
163///
164/// This consumes the storage and returns a `NixlRegistered` wrapper that
165/// manages the registration lifetime. The registration handle will be
166/// automatically dropped when the wrapper is dropped, ensuring proper
167/// cleanup order.
168///
169/// # Arguments
170/// * `storage` - The storage to register (consumed)
171/// * `agent` - The NIXL agent to register with
172/// * `opt` - Optional arguments for registration
173///
174/// # Returns
175/// A `NixlRegistered` wrapper containing the storage and registration handle.
176pub fn register_with_nixl<S>(
177    storage: S,
178    agent: &NixlAgent,
179    opt: Option<&OptArgs>,
180) -> std::result::Result<NixlRegistered<S>, S>
181where
182    S: MemoryDescription + NixlCompatible,
183{
184    // Get NIXL parameters
185    let (ptr, size, mem_type, device_id) = storage.nixl_params();
186
187    // Create a NIXL descriptor for registration
188    let descriptor = NixlDescriptor {
189        addr: ptr as u64,
190        size,
191        mem_type,
192        device_id,
193    };
194
195    match agent.register_memory(&descriptor, opt) {
196        Ok(handle) => Ok(NixlRegistered {
197            storage,
198            handle: Some(handle),
199            agent_name: agent.name().to_string(),
200        }),
201        Err(_) => Err(storage),
202    }
203}