Skip to main content

entrenar/gpu/
mps.rs

1//! Experimental CUDA MPS (Multi-Process Service) support (GPU-SHARE ยง1.5).
2//!
3//! MPS is **not** auto-started. This module provides opt-in setup for users
4//! who understand the risks:
5//!
6//! - A GPU fault in any MPS client kills ALL clients on that GPU
7//! - Thread percentage is static once set (no rebalancing)
8//! - Jetson MPS is experimental below 30% thread allocation
9//!
10//! # Usage
11//!
12//! ```bash
13//! apr finetune model.apr --vram 8 --experimental-mps --gpu-share 50
14//! ```
15
16use std::env;
17
18/// MPS configuration for experimental GPU sharing.
19#[derive(Debug, Clone)]
20pub struct MpsConfig {
21    /// Percentage of GPU SMs allocated to this process (1-100).
22    pub thread_percentage: u32,
23    /// Pinned device memory limit per client (MB). Prevents OOM cascades.
24    pub pinned_mem_limit_mb: Option<u64>,
25    /// Override checkpoint frequency to every N steps (limits blast radius).
26    pub checkpoint_every_steps: usize,
27}
28
29impl Default for MpsConfig {
30    fn default() -> Self {
31        Self { thread_percentage: 50, pinned_mem_limit_mb: None, checkpoint_every_steps: 100 }
32    }
33}
34
35impl MpsConfig {
36    /// Create MPS config with the given thread share percentage.
37    ///
38    /// # Panics
39    /// Panics if `thread_pct` is 0 or > 100.
40    #[must_use]
41    pub fn with_share(thread_pct: u32) -> Self {
42        assert!(thread_pct > 0 && thread_pct <= 100, "thread_pct must be 1-100");
43        Self { thread_percentage: thread_pct, ..Default::default() }
44    }
45
46    /// Set pinned device memory limit (MB) to prevent OOM cascades.
47    #[must_use]
48    pub fn with_mem_limit(mut self, limit_mb: u64) -> Self {
49        self.pinned_mem_limit_mb = Some(limit_mb);
50        self
51    }
52}
53
54/// Set MPS environment variables before CUDA context creation.
55///
56/// **MUST be called before any CUDA API call** (cuCtxCreate, cuInit, etc.).
57/// Environment variables set after context creation have no effect.
58///
59/// # Returns
60/// List of environment variables that were set.
61pub fn setup_mps_env(config: &MpsConfig) -> Vec<(String, String)> {
62    let mut vars = Vec::new();
63
64    // Thread percentage: controls SM allocation for this MPS client
65    let thread_pct = config.thread_percentage.to_string();
66    #[allow(clippy::disallowed_methods)]
67    env::set_var("CUDA_MPS_ACTIVE_THREAD_PERCENTAGE", &thread_pct);
68    vars.push(("CUDA_MPS_ACTIVE_THREAD_PERCENTAGE".to_string(), thread_pct));
69
70    // Pinned memory limit: prevents OOM cascades across MPS clients
71    if let Some(limit_mb) = config.pinned_mem_limit_mb {
72        let limit_str = format!("0={limit_mb}MB");
73        #[allow(clippy::disallowed_methods)]
74        env::set_var("CUDA_MPS_PINNED_DEVICE_MEM_LIMIT", &limit_str);
75        vars.push(("CUDA_MPS_PINNED_DEVICE_MEM_LIMIT".to_string(), limit_str));
76    }
77
78    vars
79}
80
81/// Print MPS safety warning to stderr.
82///
83/// This warning is mandatory whenever `--experimental-mps` is used.
84pub fn print_mps_warning(config: &MpsConfig) {
85    eprintln!("WARNING: MPS enabled โ€” a GPU fault in any job will crash ALL jobs on this GPU.");
86    eprintln!("  Thread allocation: {}%", config.thread_percentage);
87    if let Some(limit) = config.pinned_mem_limit_mb {
88        eprintln!("  Pinned memory limit: {limit} MB");
89    }
90    eprintln!(
91        "  Checkpoint frequency: every {} steps (blast radius limit)",
92        config.checkpoint_every_steps
93    );
94    eprintln!("  Use --experimental-mps only if you understand the risks.");
95    eprintln!();
96}
97
98/// Check if MPS daemon appears to be running.
99///
100/// Checks for the existence of the MPS control pipe. This is a best-effort
101/// check โ€” the daemon may still be unhealthy even if the pipe exists.
102#[must_use]
103pub fn is_mps_daemon_running() -> bool {
104    // MPS control pipe is at /tmp/nvidia-mps/control by default
105    std::path::Path::new("/tmp/nvidia-mps/control").exists()
106}
107
108/// Validate MPS configuration for known issues.
109///
110/// Returns a list of warnings (non-fatal) and errors (fatal).
111pub fn validate_mps_config(config: &MpsConfig) -> MpsValidation {
112    let mut warnings = Vec::new();
113    let mut errors = Vec::new();
114
115    if config.thread_percentage < 30 {
116        warnings.push(
117            "Thread percentage below 30% is unreliable on Jetson (NVIDIA Forum).".to_string(),
118        );
119    }
120
121    if config.thread_percentage < 10 {
122        errors
123            .push("Thread percentage below 10% causes severe performance degradation.".to_string());
124    }
125
126    if config.pinned_mem_limit_mb.is_none() {
127        warnings.push(
128            "No pinned memory limit set. OOM in one job may crash all MPS clients.".to_string(),
129        );
130    }
131
132    MpsValidation { warnings, errors }
133}
134
135/// Result of MPS configuration validation.
136#[derive(Debug, Clone)]
137pub struct MpsValidation {
138    /// Non-fatal warnings.
139    pub warnings: Vec<String>,
140    /// Fatal errors (should abort).
141    pub errors: Vec<String>,
142}
143
144impl MpsValidation {
145    /// Whether there are fatal errors.
146    #[must_use]
147    pub fn has_errors(&self) -> bool {
148        !self.errors.is_empty()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    #![allow(clippy::unwrap_used)]
155    use super::*;
156
157    #[test]
158    fn test_default_config() {
159        let config = MpsConfig::default();
160        assert_eq!(config.thread_percentage, 50);
161        assert!(config.pinned_mem_limit_mb.is_none());
162        assert_eq!(config.checkpoint_every_steps, 100);
163    }
164
165    #[test]
166    fn test_with_share() {
167        let config = MpsConfig::with_share(33);
168        assert_eq!(config.thread_percentage, 33);
169    }
170
171    #[test]
172    fn test_with_mem_limit() {
173        let config = MpsConfig::with_share(50).with_mem_limit(8000);
174        assert_eq!(config.pinned_mem_limit_mb, Some(8000));
175    }
176
177    #[test]
178    #[should_panic(expected = "thread_pct must be 1-100")]
179    fn test_zero_thread_pct_panics() {
180        let _ = MpsConfig::with_share(0);
181    }
182
183    #[test]
184    #[should_panic(expected = "thread_pct must be 1-100")]
185    fn test_over_100_thread_pct_panics() {
186        let _ = MpsConfig::with_share(101);
187    }
188
189    #[test]
190    fn test_setup_mps_env_sets_thread_pct() {
191        let config = MpsConfig::with_share(33);
192        let vars = setup_mps_env(&config);
193        assert!(vars.iter().any(|(k, v)| k == "CUDA_MPS_ACTIVE_THREAD_PERCENTAGE" && v == "33"));
194    }
195
196    #[test]
197    fn test_setup_mps_env_sets_mem_limit() {
198        let config = MpsConfig::with_share(50).with_mem_limit(8000);
199        let vars = setup_mps_env(&config);
200        assert!(vars
201            .iter()
202            .any(|(k, v)| k == "CUDA_MPS_PINNED_DEVICE_MEM_LIMIT" && v == "0=8000MB"));
203    }
204
205    #[test]
206    fn test_validate_ok() {
207        let config = MpsConfig::with_share(50).with_mem_limit(8000);
208        let result = validate_mps_config(&config);
209        assert!(!result.has_errors());
210        assert!(result.warnings.is_empty());
211    }
212
213    #[test]
214    fn test_validate_low_thread_warning() {
215        let config = MpsConfig::with_share(25);
216        let result = validate_mps_config(&config);
217        assert!(!result.has_errors());
218        assert!(result.warnings.iter().any(|w| w.contains("below 30%")));
219        // No mem limit โ†’ also warns
220        assert!(result.warnings.iter().any(|w| w.contains("pinned memory")));
221    }
222
223    #[test]
224    fn test_validate_very_low_thread_error() {
225        let config = MpsConfig::with_share(5);
226        let result = validate_mps_config(&config);
227        assert!(result.has_errors());
228        assert!(result.errors.iter().any(|e| e.contains("below 10%")));
229    }
230
231    #[test]
232    fn test_mps_daemon_check() {
233        // Just verify it returns a bool without crashing
234        let _running = is_mps_daemon_running();
235    }
236}