dynamo_memory/nixl/
config.rs1use anyhow::{Result, bail};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13use dynamo_config::parse_bool;
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
41pub struct NixlBackendConfig {
42 #[serde(default)]
48 backends: HashMap<String, HashMap<String, String>>,
49}
50
51impl NixlBackendConfig {
52 pub fn new(backends: HashMap<String, HashMap<String, String>>) -> Self {
56 Self { backends }
57 }
58
59 pub fn from_env() -> Result<Self> {
68 let mut backends = HashMap::new();
69
70 for (key, value) in std::env::vars() {
72 if let Some(remainder) = key.strip_prefix("DYN_KVBM_NIXL_BACKEND_") {
73 if remainder.contains('_') {
75 bail!(
76 "Custom NIXL backend parameters are not yet supported. \
77 Found: {}. Please use only DYN_KVBM_NIXL_BACKEND_<backend>=true \
78 to enable backends with default parameters.",
79 key
80 );
81 }
82
83 let backend_name = remainder.to_uppercase();
85 match parse_bool(&value) {
86 Ok(true) => {
87 backends.insert(backend_name, HashMap::new());
88 }
89 Ok(false) => {
90 continue;
92 }
93 Err(e) => bail!("Invalid value for {}: {}", key, e),
94 }
95 }
96 }
97
98 Ok(Self { backends })
99 }
100
101 pub fn with_backend(mut self, backend: impl Into<String>) -> Self {
104 self.backends
105 .insert(backend.into().to_uppercase(), HashMap::new());
106 self
107 }
108
109 pub fn with_backend_params(
112 mut self,
113 backend: impl Into<String>,
114 params: HashMap<String, String>,
115 ) -> Self {
116 self.backends.insert(backend.into().to_uppercase(), params);
117 self
118 }
119
120 pub fn backends(&self) -> Vec<String> {
122 self.backends.keys().cloned().collect()
123 }
124
125 pub fn backend_params(&self, backend: &str) -> Option<&HashMap<String, String>> {
130 self.backends.get(&backend.to_uppercase())
131 }
132
133 pub fn has_backend(&self, backend: &str) -> bool {
135 self.backends.contains_key(&backend.to_uppercase())
136 }
137
138 pub fn merge(mut self, other: NixlBackendConfig) -> Self {
143 self.backends.extend(other.backends);
144 self
145 }
146
147 pub fn iter(&self) -> impl Iterator<Item = (&String, &HashMap<String, String>)> {
149 self.backends.iter()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn test_new_config_is_empty() {
159 let config = NixlBackendConfig::default();
160 assert_eq!(config.backends().len(), 0);
161 }
162
163 #[test]
164 fn test_default_is_empty() {
165 let config = NixlBackendConfig::default();
166 assert!(config.backends().is_empty()); }
168
169 #[test]
170 fn test_with_backend() {
171 let config = NixlBackendConfig::default()
172 .with_backend("ucx")
173 .with_backend("gds_mt");
174
175 assert!(config.has_backend("ucx"));
176 assert!(config.has_backend("UCX"));
177 assert!(config.has_backend("gds_mt"));
178 assert!(config.has_backend("GDS_MT"));
179 assert!(!config.has_backend("other"));
180 }
181
182 #[test]
183 fn test_with_backend_params() {
184 let mut params = HashMap::new();
185 params.insert("threads".to_string(), "4".to_string());
186 params.insert("buffer_size".to_string(), "1048576".to_string());
187
188 let config = NixlBackendConfig::default()
189 .with_backend("UCX")
190 .with_backend_params("GDS", params);
191
192 let ucx_params = config.backend_params("UCX").unwrap();
194 assert!(ucx_params.is_empty());
195
196 let gds_params = config.backend_params("GDS").unwrap();
198 assert_eq!(gds_params.get("threads"), Some(&"4".to_string()));
199 assert_eq!(gds_params.get("buffer_size"), Some(&"1048576".to_string()));
200 }
201
202 #[test]
203 fn test_merge_configs() {
204 let config1 = NixlBackendConfig::default().with_backend("ucx");
205 let config2 = NixlBackendConfig::default().with_backend("gds");
206
207 let merged = config1.merge(config2);
208
209 assert!(merged.has_backend("ucx"));
210 assert!(merged.has_backend("gds"));
211 }
212
213 #[test]
214 fn test_backend_name_case_insensitive() {
215 let config = NixlBackendConfig::default()
216 .with_backend("ucx")
217 .with_backend("Gds_mt")
218 .with_backend("OTHER");
219
220 assert!(config.has_backend("UCX"));
221 assert!(config.has_backend("ucx"));
222 assert!(config.has_backend("GDS_MT"));
223 assert!(config.has_backend("gds_mt"));
224 assert!(config.has_backend("OTHER"));
225 assert!(config.has_backend("other"));
226 }
227
228 #[test]
229 fn test_iter() {
230 let mut params = HashMap::new();
231 params.insert("key".to_string(), "value".to_string());
232
233 let config = NixlBackendConfig::default()
234 .with_backend("UCX")
235 .with_backend_params("GDS", params);
236
237 let items: Vec<_> = config.iter().collect();
238 assert_eq!(items.len(), 2);
239 }
240
241 }