dynamo_memory/nixl/
agent.rs1use anyhow::Result;
11use nixl_sys::Agent;
12use std::collections::{HashMap, HashSet};
13
14use crate::nixl::NixlBackendConfig;
15
16#[derive(Clone, Debug)]
29pub struct NixlAgent {
30 agent: Agent,
31 available_backends: HashSet<String>,
32}
33
34impl NixlAgent {
35 pub fn new(name: &str) -> Result<Self> {
37 let agent = Agent::new(name)?;
38
39 Ok(Self {
40 agent,
41 available_backends: HashSet::new(),
42 })
43 }
44
45 pub fn from_nixl_backend_config(name: &str, config: NixlBackendConfig) -> Result<Self> {
51 let mut agent = Self::new(name)?;
52 for (backend, params) in config.iter() {
53 agent.add_backend_with_params(backend, params)?;
54 }
55 Ok(agent)
56 }
57
58 pub fn add_backend(&mut self, backend: &str) -> Result<()> {
60 self.add_backend_with_params(backend, &HashMap::new())
61 }
62
63 pub fn add_backend_with_params(
71 &mut self,
72 backend: &str,
73 custom_params: &HashMap<String, String>,
74 ) -> Result<()> {
75 let backend_upper = backend.to_uppercase();
76 if self.available_backends.contains(&backend_upper) {
77 return Ok(());
78 }
79
80 if !custom_params.is_empty() {
82 anyhow::bail!(
83 "Custom NIXL backend parameters for {} are not yet supported. \
84 This feature requires nixl_sys 0.9+. Params provided: {:?}",
85 backend_upper,
86 custom_params.keys().collect::<Vec<_>>()
87 );
88 }
89
90 let (_, params) = match self.agent.get_plugin_params(&backend_upper) {
92 Ok(result) => result,
93 Err(_) => anyhow::bail!("No {} plugin found", backend_upper),
94 };
95
96 match self.agent.create_backend(&backend_upper, ¶ms) {
97 Ok(_) => {
98 self.available_backends.insert(backend_upper);
99 Ok(())
100 }
101 Err(e) => anyhow::bail!("Failed to create nixl backend: {}", e),
102 }
103 }
104
105 pub fn with_backends(name: &str, backends: &[&str]) -> Result<Self> {
123 let mut agent = Self::new(name)?;
124 let mut failed_backends = Vec::new();
125
126 for backend in backends {
127 let backend_upper = backend.to_uppercase();
128 match agent.add_backend(&backend_upper) {
129 Ok(_) => {
130 tracing::debug!("Initialized NIXL backend: {}", backend_upper);
131 }
132 Err(e) => {
133 tracing::error!("Failed to initialize {} backend: {}", backend_upper, e);
134 failed_backends.push((backend_upper, e.to_string()));
135 }
136 }
137 }
138
139 if !failed_backends.is_empty() {
140 let error_details: Vec<String> = failed_backends
141 .iter()
142 .map(|(name, reason)| format!("{}: {}", name, reason))
143 .collect();
144
145 anyhow::bail!(
146 "Failed to initialize required backends: [{}]",
147 error_details.join(", ")
148 );
149 }
150
151 Ok(agent)
152 }
153
154 pub fn raw_agent(&self) -> &Agent {
156 &self.agent
157 }
158
159 pub fn into_raw_agent(self) -> Agent {
164 self.agent
165 }
166
167 pub fn has_backend(&self, backend: &str) -> bool {
169 self.available_backends.contains(&backend.to_uppercase())
170 }
171
172 pub fn backends(&self) -> &HashSet<String> {
174 &self.available_backends
175 }
176
177 pub fn require_backend(&self, backend: &str) -> Result<()> {
185 let backend_upper = backend.to_uppercase();
186 if self.has_backend(&backend_upper) {
187 Ok(())
188 } else {
189 anyhow::bail!(
190 "Operation requires {} backend, but it was not initialized. Available backends: {:?}",
191 backend_upper,
192 self.available_backends
193 )
194 }
195 }
196}
197
198impl std::ops::Deref for NixlAgent {
200 type Target = Agent;
201
202 fn deref(&self) -> &Self::Target {
203 &self.agent
204 }
205}
206
207#[cfg(all(test, feature = "testing-nixl"))]
208mod tests {
209 use super::*;
210
211 #[test]
212 fn test_agent_backend_tracking() {
213 let agent = NixlAgent::with_backends("test", &["UCX"]).expect("Need UCX for test");
215
216 assert!(agent.has_backend("UCX"));
218 assert!(agent.has_backend("ucx")); }
220
221 #[test]
222 fn test_require_backend() {
223 let agent = NixlAgent::with_backends("test", &["UCX"]).expect("Need UCX for test");
224
225 assert!(agent.require_backend("UCX").is_ok());
227
228 assert!(agent.require_backend("GDS_MT").is_err());
230 }
231
232 #[test]
233 fn test_require_backends_strict() {
234 let agent =
236 NixlAgent::with_backends("test_strict", &["UCX"]).expect("Failed to require backends");
237 assert!(agent.has_backend("UCX"));
238
239 let result = NixlAgent::with_backends("test_strict_fail", &["UCX", "DUDE"]);
241 assert!(result.is_err());
242 }
243
244 #[test]
245 fn test_add_backend_with_empty_params() {
246 let mut agent = NixlAgent::new("test_empty_params").expect("Failed to create agent");
247
248 let result = agent.add_backend_with_params("UCX", &HashMap::new());
250 assert!(result.is_ok());
251 assert!(agent.has_backend("UCX"));
252 }
253
254 #[test]
255 fn test_add_backend_with_custom_params_fails() {
256 let mut agent = NixlAgent::new("test_custom_params").expect("Failed to create agent");
257
258 let mut params = HashMap::new();
260 params.insert("some_key".to_string(), "some_value".to_string());
261
262 let result = agent.add_backend_with_params("UCX", ¶ms);
263 assert!(result.is_err());
264
265 let err_msg = result.unwrap_err().to_string();
266 assert!(err_msg.contains("not yet supported"));
267 assert!(err_msg.contains("nixl_sys 0.9"));
268 assert!(err_msg.contains("some_key"));
269 }
270
271 #[test]
272 fn test_from_nixl_backend_config_with_custom_params_fails() {
273 let mut params = HashMap::new();
275 params.insert("threads".to_string(), "4".to_string());
276
277 let config = NixlBackendConfig::default().with_backend_params("UCX", params);
278
279 let result = NixlAgent::from_nixl_backend_config("test_config_params", config);
280 assert!(result.is_err());
281
282 let err_msg = result.unwrap_err().to_string();
283 assert!(err_msg.contains("not yet supported"));
284 assert!(err_msg.contains("threads"));
285 }
286
287 #[test]
288 fn test_from_nixl_backend_config_with_empty_params() {
289 let config = NixlBackendConfig::default().with_backend("UCX");
291
292 let result = NixlAgent::from_nixl_backend_config("test_config_empty", config);
293 assert!(result.is_ok());
294
295 let agent = result.unwrap();
296 assert!(agent.has_backend("UCX"));
297 }
298}