use std::env;
#[derive(Debug, Clone)]
pub struct MpsConfig {
pub thread_percentage: u32,
pub pinned_mem_limit_mb: Option<u64>,
pub checkpoint_every_steps: usize,
}
impl Default for MpsConfig {
fn default() -> Self {
Self { thread_percentage: 50, pinned_mem_limit_mb: None, checkpoint_every_steps: 100 }
}
}
impl MpsConfig {
#[must_use]
pub fn with_share(thread_pct: u32) -> Self {
assert!(thread_pct > 0 && thread_pct <= 100, "thread_pct must be 1-100");
Self { thread_percentage: thread_pct, ..Default::default() }
}
#[must_use]
pub fn with_mem_limit(mut self, limit_mb: u64) -> Self {
self.pinned_mem_limit_mb = Some(limit_mb);
self
}
}
pub fn setup_mps_env(config: &MpsConfig) -> Vec<(String, String)> {
let mut vars = Vec::new();
let thread_pct = config.thread_percentage.to_string();
#[allow(clippy::disallowed_methods)]
env::set_var("CUDA_MPS_ACTIVE_THREAD_PERCENTAGE", &thread_pct);
vars.push(("CUDA_MPS_ACTIVE_THREAD_PERCENTAGE".to_string(), thread_pct));
if let Some(limit_mb) = config.pinned_mem_limit_mb {
let limit_str = format!("0={limit_mb}MB");
#[allow(clippy::disallowed_methods)]
env::set_var("CUDA_MPS_PINNED_DEVICE_MEM_LIMIT", &limit_str);
vars.push(("CUDA_MPS_PINNED_DEVICE_MEM_LIMIT".to_string(), limit_str));
}
vars
}
pub fn print_mps_warning(config: &MpsConfig) {
eprintln!("WARNING: MPS enabled — a GPU fault in any job will crash ALL jobs on this GPU.");
eprintln!(" Thread allocation: {}%", config.thread_percentage);
if let Some(limit) = config.pinned_mem_limit_mb {
eprintln!(" Pinned memory limit: {limit} MB");
}
eprintln!(
" Checkpoint frequency: every {} steps (blast radius limit)",
config.checkpoint_every_steps
);
eprintln!(" Use --experimental-mps only if you understand the risks.");
eprintln!();
}
#[must_use]
pub fn is_mps_daemon_running() -> bool {
std::path::Path::new("/tmp/nvidia-mps/control").exists()
}
pub fn validate_mps_config(config: &MpsConfig) -> MpsValidation {
let mut warnings = Vec::new();
let mut errors = Vec::new();
if config.thread_percentage < 30 {
warnings.push(
"Thread percentage below 30% is unreliable on Jetson (NVIDIA Forum).".to_string(),
);
}
if config.thread_percentage < 10 {
errors
.push("Thread percentage below 10% causes severe performance degradation.".to_string());
}
if config.pinned_mem_limit_mb.is_none() {
warnings.push(
"No pinned memory limit set. OOM in one job may crash all MPS clients.".to_string(),
);
}
MpsValidation { warnings, errors }
}
#[derive(Debug, Clone)]
pub struct MpsValidation {
pub warnings: Vec<String>,
pub errors: Vec<String>,
}
impl MpsValidation {
#[must_use]
pub fn has_errors(&self) -> bool {
!self.errors.is_empty()
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_default_config() {
let config = MpsConfig::default();
assert_eq!(config.thread_percentage, 50);
assert!(config.pinned_mem_limit_mb.is_none());
assert_eq!(config.checkpoint_every_steps, 100);
}
#[test]
fn test_with_share() {
let config = MpsConfig::with_share(33);
assert_eq!(config.thread_percentage, 33);
}
#[test]
fn test_with_mem_limit() {
let config = MpsConfig::with_share(50).with_mem_limit(8000);
assert_eq!(config.pinned_mem_limit_mb, Some(8000));
}
#[test]
#[should_panic(expected = "thread_pct must be 1-100")]
fn test_zero_thread_pct_panics() {
let _ = MpsConfig::with_share(0);
}
#[test]
#[should_panic(expected = "thread_pct must be 1-100")]
fn test_over_100_thread_pct_panics() {
let _ = MpsConfig::with_share(101);
}
#[test]
fn test_setup_mps_env_sets_thread_pct() {
let config = MpsConfig::with_share(33);
let vars = setup_mps_env(&config);
assert!(vars.iter().any(|(k, v)| k == "CUDA_MPS_ACTIVE_THREAD_PERCENTAGE" && v == "33"));
}
#[test]
fn test_setup_mps_env_sets_mem_limit() {
let config = MpsConfig::with_share(50).with_mem_limit(8000);
let vars = setup_mps_env(&config);
assert!(vars
.iter()
.any(|(k, v)| k == "CUDA_MPS_PINNED_DEVICE_MEM_LIMIT" && v == "0=8000MB"));
}
#[test]
fn test_validate_ok() {
let config = MpsConfig::with_share(50).with_mem_limit(8000);
let result = validate_mps_config(&config);
assert!(!result.has_errors());
assert!(result.warnings.is_empty());
}
#[test]
fn test_validate_low_thread_warning() {
let config = MpsConfig::with_share(25);
let result = validate_mps_config(&config);
assert!(!result.has_errors());
assert!(result.warnings.iter().any(|w| w.contains("below 30%")));
assert!(result.warnings.iter().any(|w| w.contains("pinned memory")));
}
#[test]
fn test_validate_very_low_thread_error() {
let config = MpsConfig::with_share(5);
let result = validate_mps_config(&config);
assert!(result.has_errors());
assert!(result.errors.iter().any(|e| e.contains("below 10%")));
}
#[test]
fn test_mps_daemon_check() {
let _running = is_mps_daemon_running();
}
}