Skip to main content

dynamo_memory/nixl/
agent.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NIXL agent wrapper and configuration.
5//!
6//! This module provides:
7//! - `NixlAgent`: Wrapper around nixl_sys::Agent that tracks initialized backends
8//! - `NixlBackendConfig`: Configuration for NIXL backends from environment variables
9
10use anyhow::Result;
11use nixl_sys::Agent;
12use std::collections::{HashMap, HashSet};
13
14use crate::nixl::NixlBackendConfig;
15
16/// A NIXL agent wrapper that tracks which backends were successfully initialized.
17///
18/// This wrapper provides:
19/// - Runtime validation of backend availability
20/// - Clear error messages when operations need unavailable backends
21/// - Single source of truth for backend state in tests and production
22///
23/// # Backend Tracking
24///
25/// Since `nixl_sys::Agent` doesn't provide a method to query active backends,
26/// we track them during initialization. The `available_backends` set is populated
27/// based on successful `create_backend()` calls.
28#[derive(Clone, Debug)]
29pub struct NixlAgent {
30    agent: Agent,
31    available_backends: HashSet<String>,
32}
33
34impl NixlAgent {
35    /// Create a NIXL agent without any backends.
36    pub fn new(name: &str) -> Result<Self> {
37        let agent = Agent::new(name)?;
38
39        Ok(Self {
40            agent,
41            available_backends: HashSet::new(),
42        })
43    }
44
45    /// Creates a new agent configured with backends from the given config.
46    ///
47    /// This method iterates over all backends in the config and initializes them
48    /// with their associated parameters. If a backend has custom parameters defined
49    /// in the config, those are used; otherwise, default plugin parameters are used.
50    pub fn from_nixl_backend_config(name: &str, config: NixlBackendConfig) -> Result<Self> {
51        let mut agent = Self::new(name)?;
52        for (backend, params) in config.iter() {
53            agent.add_backend_with_params(backend, params)?;
54        }
55        Ok(agent)
56    }
57
58    /// Add a backend to the agent with default parameters.
59    pub fn add_backend(&mut self, backend: &str) -> Result<()> {
60        self.add_backend_with_params(backend, &HashMap::new())
61    }
62
63    /// Add a backend to the agent with optional custom parameters.
64    ///
65    /// If `custom_params` is non-empty, those parameters are used instead of
66    /// the plugin defaults. If empty, default parameters from the plugin are used.
67    ///
68    /// # Errors
69    /// Returns an error if custom parameters are provided (not yet supported until nixl_sys 0.9).
70    pub fn add_backend_with_params(
71        &mut self,
72        backend: &str,
73        custom_params: &HashMap<String, String>,
74    ) -> Result<()> {
75        let backend_upper = backend.to_uppercase();
76        if self.available_backends.contains(&backend_upper) {
77            return Ok(());
78        }
79
80        // TODO(DIS-1310): Custom params require nixl_sys 0.9+ which adds nixl_capi_params_add
81        if !custom_params.is_empty() {
82            anyhow::bail!(
83                "Custom NIXL backend parameters for {} are not yet supported. \
84                 This feature requires nixl_sys 0.9+. Params provided: {:?}",
85                backend_upper,
86                custom_params.keys().collect::<Vec<_>>()
87            );
88        }
89
90        // Get default params from plugin
91        let (_, params) = match self.agent.get_plugin_params(&backend_upper) {
92            Ok(result) => result,
93            Err(_) => anyhow::bail!("No {} plugin found", backend_upper),
94        };
95
96        match self.agent.create_backend(&backend_upper, &params) {
97            Ok(_) => {
98                self.available_backends.insert(backend_upper);
99                Ok(())
100            }
101            Err(e) => anyhow::bail!("Failed to create nixl backend: {}", e),
102        }
103    }
104
105    /// Create a NIXL agent requiring ALL specified backends to be available.
106    ///
107    /// Unlike `new_with_backends()` which continues if some backends fail, this method
108    /// will return an error if ANY backend fails to initialize. Use this in production
109    /// when specific backends are mandatory.
110    ///
111    /// # Arguments
112    /// * `name` - Agent name
113    /// * `backends` - List of backend names that MUST be available
114    ///
115    /// # Returns
116    /// A `NixlAgent` with all requested backends initialized.
117    ///
118    /// # Errors
119    /// Returns an error if:
120    /// - Agent creation fails
121    /// - Any backend fails to initialize
122    pub fn with_backends(name: &str, backends: &[&str]) -> Result<Self> {
123        let mut agent = Self::new(name)?;
124        let mut failed_backends = Vec::new();
125
126        for backend in backends {
127            let backend_upper = backend.to_uppercase();
128            match agent.add_backend(&backend_upper) {
129                Ok(_) => {
130                    tracing::debug!("Initialized NIXL backend: {}", backend_upper);
131                }
132                Err(e) => {
133                    tracing::error!("Failed to initialize {} backend: {}", backend_upper, e);
134                    failed_backends.push((backend_upper, e.to_string()));
135                }
136            }
137        }
138
139        if !failed_backends.is_empty() {
140            let error_details: Vec<String> = failed_backends
141                .iter()
142                .map(|(name, reason)| format!("{}: {}", name, reason))
143                .collect();
144
145            anyhow::bail!(
146                "Failed to initialize required backends: [{}]",
147                error_details.join(", ")
148            );
149        }
150
151        Ok(agent)
152    }
153
154    /// Get a reference to the underlying raw NIXL agent.
155    pub fn raw_agent(&self) -> &Agent {
156        &self.agent
157    }
158
159    /// Consume and return the underlying raw NIXL agent.
160    ///
161    /// **Warning**: Once consumed, backend tracking is lost. Use this only when
162    /// interfacing with code that requires `nixl_sys::Agent` directly.
163    pub fn into_raw_agent(self) -> Agent {
164        self.agent
165    }
166
167    /// Check if a specific backend is available.
168    pub fn has_backend(&self, backend: &str) -> bool {
169        self.available_backends.contains(&backend.to_uppercase())
170    }
171
172    /// Get all available backends.
173    pub fn backends(&self) -> &HashSet<String> {
174        &self.available_backends
175    }
176
177    /// Require a specific backend, returning an error if unavailable.
178    ///
179    /// Use this at the start of operations that need specific backends.
180    ///
181    /// Note: In general, you want to instantiate all your backends before you start registering memory.
182    /// We may change this to a builder pattern in the future to enforce all backends are instantiated
183    /// before you start registering memory.
184    pub fn require_backend(&self, backend: &str) -> Result<()> {
185        let backend_upper = backend.to_uppercase();
186        if self.has_backend(&backend_upper) {
187            Ok(())
188        } else {
189            anyhow::bail!(
190                "Operation requires {} backend, but it was not initialized. Available backends: {:?}",
191                backend_upper,
192                self.available_backends
193            )
194        }
195    }
196}
197
198// Delegate common methods to the underlying agent
199impl std::ops::Deref for NixlAgent {
200    type Target = Agent;
201
202    fn deref(&self) -> &Self::Target {
203        &self.agent
204    }
205}
206
207#[cfg(all(test, feature = "testing-nixl"))]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_agent_backend_tracking() {
213        // Try to create agent with UCX
214        let agent = NixlAgent::with_backends("test", &["UCX"]).expect("Need UCX for test");
215
216        // Should succeed if UCX is available
217        assert!(agent.has_backend("UCX"));
218        assert!(agent.has_backend("ucx")); // Case insensitive
219    }
220
221    #[test]
222    fn test_require_backend() {
223        let agent = NixlAgent::with_backends("test", &["UCX"]).expect("Need UCX for test");
224
225        // Should succeed for available backend
226        assert!(agent.require_backend("UCX").is_ok());
227
228        // Should fail for unavailable backend
229        assert!(agent.require_backend("GDS_MT").is_err());
230    }
231
232    #[test]
233    fn test_require_backends_strict() {
234        // Should succeed if UCX is available
235        let agent =
236            NixlAgent::with_backends("test_strict", &["UCX"]).expect("Failed to require backends");
237        assert!(agent.has_backend("UCX"));
238
239        // Should fail if any backend is missing (GDS likely not available)
240        let result = NixlAgent::with_backends("test_strict_fail", &["UCX", "DUDE"]);
241        assert!(result.is_err());
242    }
243
244    #[test]
245    fn test_add_backend_with_empty_params() {
246        let mut agent = NixlAgent::new("test_empty_params").expect("Failed to create agent");
247
248        // Empty params should work (uses plugin defaults)
249        let result = agent.add_backend_with_params("UCX", &HashMap::new());
250        assert!(result.is_ok());
251        assert!(agent.has_backend("UCX"));
252    }
253
254    #[test]
255    fn test_add_backend_with_custom_params_fails() {
256        let mut agent = NixlAgent::new("test_custom_params").expect("Failed to create agent");
257
258        // Custom params should fail until nixl_sys 0.9
259        let mut params = HashMap::new();
260        params.insert("some_key".to_string(), "some_value".to_string());
261
262        let result = agent.add_backend_with_params("UCX", &params);
263        assert!(result.is_err());
264
265        let err_msg = result.unwrap_err().to_string();
266        assert!(err_msg.contains("not yet supported"));
267        assert!(err_msg.contains("nixl_sys 0.9"));
268        assert!(err_msg.contains("some_key"));
269    }
270
271    #[test]
272    fn test_from_nixl_backend_config_with_custom_params_fails() {
273        // Config with custom params should fail
274        let mut params = HashMap::new();
275        params.insert("threads".to_string(), "4".to_string());
276
277        let config = NixlBackendConfig::default().with_backend_params("UCX", params);
278
279        let result = NixlAgent::from_nixl_backend_config("test_config_params", config);
280        assert!(result.is_err());
281
282        let err_msg = result.unwrap_err().to_string();
283        assert!(err_msg.contains("not yet supported"));
284        assert!(err_msg.contains("threads"));
285    }
286
287    #[test]
288    fn test_from_nixl_backend_config_with_empty_params() {
289        // Config with empty params should work
290        let config = NixlBackendConfig::default().with_backend("UCX");
291
292        let result = NixlAgent::from_nixl_backend_config("test_config_empty", config);
293        assert!(result.is_ok());
294
295        let agent = result.unwrap();
296        assert!(agent.has_backend("UCX"));
297    }
298}