1use std::env;
17
18#[derive(Debug, Clone)]
20pub struct MpsConfig {
21 pub thread_percentage: u32,
23 pub pinned_mem_limit_mb: Option<u64>,
25 pub checkpoint_every_steps: usize,
27}
28
29impl Default for MpsConfig {
30 fn default() -> Self {
31 Self { thread_percentage: 50, pinned_mem_limit_mb: None, checkpoint_every_steps: 100 }
32 }
33}
34
35impl MpsConfig {
36 #[must_use]
41 pub fn with_share(thread_pct: u32) -> Self {
42 assert!(thread_pct > 0 && thread_pct <= 100, "thread_pct must be 1-100");
43 Self { thread_percentage: thread_pct, ..Default::default() }
44 }
45
46 #[must_use]
48 pub fn with_mem_limit(mut self, limit_mb: u64) -> Self {
49 self.pinned_mem_limit_mb = Some(limit_mb);
50 self
51 }
52}
53
54pub fn setup_mps_env(config: &MpsConfig) -> Vec<(String, String)> {
62 let mut vars = Vec::new();
63
64 let thread_pct = config.thread_percentage.to_string();
66 #[allow(clippy::disallowed_methods)]
67 env::set_var("CUDA_MPS_ACTIVE_THREAD_PERCENTAGE", &thread_pct);
68 vars.push(("CUDA_MPS_ACTIVE_THREAD_PERCENTAGE".to_string(), thread_pct));
69
70 if let Some(limit_mb) = config.pinned_mem_limit_mb {
72 let limit_str = format!("0={limit_mb}MB");
73 #[allow(clippy::disallowed_methods)]
74 env::set_var("CUDA_MPS_PINNED_DEVICE_MEM_LIMIT", &limit_str);
75 vars.push(("CUDA_MPS_PINNED_DEVICE_MEM_LIMIT".to_string(), limit_str));
76 }
77
78 vars
79}
80
81pub fn print_mps_warning(config: &MpsConfig) {
85 eprintln!("WARNING: MPS enabled โ a GPU fault in any job will crash ALL jobs on this GPU.");
86 eprintln!(" Thread allocation: {}%", config.thread_percentage);
87 if let Some(limit) = config.pinned_mem_limit_mb {
88 eprintln!(" Pinned memory limit: {limit} MB");
89 }
90 eprintln!(
91 " Checkpoint frequency: every {} steps (blast radius limit)",
92 config.checkpoint_every_steps
93 );
94 eprintln!(" Use --experimental-mps only if you understand the risks.");
95 eprintln!();
96}
97
98#[must_use]
103pub fn is_mps_daemon_running() -> bool {
104 std::path::Path::new("/tmp/nvidia-mps/control").exists()
106}
107
108pub fn validate_mps_config(config: &MpsConfig) -> MpsValidation {
112 let mut warnings = Vec::new();
113 let mut errors = Vec::new();
114
115 if config.thread_percentage < 30 {
116 warnings.push(
117 "Thread percentage below 30% is unreliable on Jetson (NVIDIA Forum).".to_string(),
118 );
119 }
120
121 if config.thread_percentage < 10 {
122 errors
123 .push("Thread percentage below 10% causes severe performance degradation.".to_string());
124 }
125
126 if config.pinned_mem_limit_mb.is_none() {
127 warnings.push(
128 "No pinned memory limit set. OOM in one job may crash all MPS clients.".to_string(),
129 );
130 }
131
132 MpsValidation { warnings, errors }
133}
134
135#[derive(Debug, Clone)]
137pub struct MpsValidation {
138 pub warnings: Vec<String>,
140 pub errors: Vec<String>,
142}
143
144impl MpsValidation {
145 #[must_use]
147 pub fn has_errors(&self) -> bool {
148 !self.errors.is_empty()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 #![allow(clippy::unwrap_used)]
155 use super::*;
156
157 #[test]
158 fn test_default_config() {
159 let config = MpsConfig::default();
160 assert_eq!(config.thread_percentage, 50);
161 assert!(config.pinned_mem_limit_mb.is_none());
162 assert_eq!(config.checkpoint_every_steps, 100);
163 }
164
165 #[test]
166 fn test_with_share() {
167 let config = MpsConfig::with_share(33);
168 assert_eq!(config.thread_percentage, 33);
169 }
170
171 #[test]
172 fn test_with_mem_limit() {
173 let config = MpsConfig::with_share(50).with_mem_limit(8000);
174 assert_eq!(config.pinned_mem_limit_mb, Some(8000));
175 }
176
177 #[test]
178 #[should_panic(expected = "thread_pct must be 1-100")]
179 fn test_zero_thread_pct_panics() {
180 let _ = MpsConfig::with_share(0);
181 }
182
183 #[test]
184 #[should_panic(expected = "thread_pct must be 1-100")]
185 fn test_over_100_thread_pct_panics() {
186 let _ = MpsConfig::with_share(101);
187 }
188
189 #[test]
190 fn test_setup_mps_env_sets_thread_pct() {
191 let config = MpsConfig::with_share(33);
192 let vars = setup_mps_env(&config);
193 assert!(vars.iter().any(|(k, v)| k == "CUDA_MPS_ACTIVE_THREAD_PERCENTAGE" && v == "33"));
194 }
195
196 #[test]
197 fn test_setup_mps_env_sets_mem_limit() {
198 let config = MpsConfig::with_share(50).with_mem_limit(8000);
199 let vars = setup_mps_env(&config);
200 assert!(vars
201 .iter()
202 .any(|(k, v)| k == "CUDA_MPS_PINNED_DEVICE_MEM_LIMIT" && v == "0=8000MB"));
203 }
204
205 #[test]
206 fn test_validate_ok() {
207 let config = MpsConfig::with_share(50).with_mem_limit(8000);
208 let result = validate_mps_config(&config);
209 assert!(!result.has_errors());
210 assert!(result.warnings.is_empty());
211 }
212
213 #[test]
214 fn test_validate_low_thread_warning() {
215 let config = MpsConfig::with_share(25);
216 let result = validate_mps_config(&config);
217 assert!(!result.has_errors());
218 assert!(result.warnings.iter().any(|w| w.contains("below 30%")));
219 assert!(result.warnings.iter().any(|w| w.contains("pinned memory")));
221 }
222
223 #[test]
224 fn test_validate_very_low_thread_error() {
225 let config = MpsConfig::with_share(5);
226 let result = validate_mps_config(&config);
227 assert!(result.has_errors());
228 assert!(result.errors.iter().any(|e| e.contains("below 10%")));
229 }
230
231 #[test]
232 fn test_mps_daemon_check() {
233 let _running = is_mps_daemon_running();
235 }
236}