use glob::Pattern;
use std::path::Path;
use std::str::FromStr;
#[derive(Debug, Clone)]
pub struct ParallelConfig {
thread_percentage: f32,
max_threads: Option<usize>,
min_threads: usize,
nvcc_thread_file_patterns: Vec<String>,
num_nvcc_threads: Option<usize>,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
thread_percentage: 0.5, max_threads: None,
min_threads: 1,
nvcc_thread_file_patterns: vec!["flash_api".to_string(), "cutlass".to_string()],
num_nvcc_threads: Some(2),
}
}
}
impl ParallelConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_percentage(mut self, percentage: f32) -> Self {
self.thread_percentage = percentage.clamp(0.0, 1.0);
self
}
pub fn with_max_threads(mut self, max: usize) -> Self {
self.max_threads = Some(max.max(1));
self
}
pub fn with_min_threads(mut self, min: usize) -> Self {
self.min_threads = min.max(1);
self
}
pub fn with_nvcc_thread_patterns<S: AsRef<str>>(
mut self,
patterns: &[S],
num_nvcc_threads: usize,
) -> Self {
self.nvcc_thread_file_patterns = patterns.iter().map(|s| s.as_ref().to_string()).collect();
self.num_nvcc_threads = if num_nvcc_threads > 0 {
Some(num_nvcc_threads)
} else {
None
};
self
}
pub fn should_use_nvcc_threads(&self, path_str: &str) -> bool {
let path = Path::new(path_str);
let filename_component = path.file_name().and_then(|s| s.to_str()).unwrap_or("");
self.nvcc_thread_file_patterns.iter().any(|pattern| {
if pattern.contains('*') || pattern.contains('?') || pattern.contains('[') {
if let Ok(compiled) = Pattern::new(pattern) {
if !pattern.contains('/') && !pattern.contains('\\') {
if compiled.matches(filename_component) {
return true;
}
}
if compiled.matches(path_str) {
return true;
}
}
}
path_str.contains(pattern)
})
}
pub fn thread_count(&self) -> usize {
if let Ok(env_threads) = std::env::var("CUDAFORGE_THREADS") {
if let Ok(n) = usize::from_str(&env_threads) {
return n.max(1);
}
}
if let Ok(env_threads) = std::env::var("RAYON_NUM_THREADS") {
if let Ok(n) = usize::from_str(&env_threads) {
return n.max(1);
}
}
let available = self.detect_available_threads();
let calculated = if let Some(max) = self.max_threads {
max.min(available)
} else {
(available as f32 * self.thread_percentage).ceil() as usize
};
calculated.max(self.min_threads).min(available)
}
pub fn init_thread_pool(&self) -> Result<(), rayon::ThreadPoolBuildError> {
let thread_count = self.thread_count();
rayon::ThreadPoolBuilder::new()
.num_threads(thread_count)
.build_global()
}
pub fn nvcc_threads(&self) -> Option<usize> {
self.num_nvcc_threads
}
fn detect_available_threads(&self) -> usize {
if let Ok(parallelism) = std::thread::available_parallelism() {
return parallelism.get();
}
num_cpus::get_physical()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ParallelConfig::default();
assert_eq!(config.thread_percentage, 0.5);
assert!(config.max_threads.is_none());
}
#[test]
fn test_percentage_clamping() {
let config = ParallelConfig::new().with_percentage(1.5);
assert_eq!(config.thread_percentage, 1.0);
let config = ParallelConfig::new().with_percentage(-0.5);
assert_eq!(config.thread_percentage, 0.0);
}
#[test]
fn test_thread_patterns() {
let config = ParallelConfig::default();
assert!(config.should_use_nvcc_threads("flash_api.cu"));
assert!(config.should_use_nvcc_threads("src/flash_api_v2.cu"));
assert!(config.should_use_nvcc_threads("cutlass_gemm.cu"));
assert!(!config.should_use_nvcc_threads("simple.cu"));
let config = ParallelConfig::new().with_nvcc_thread_patterns(&["gemm_*.cu", "special"], 4);
assert!(config.should_use_nvcc_threads("gemm_fp16.cu"));
assert!(config.should_use_nvcc_threads("src/gemm_int8.cu")); assert!(config.should_use_nvcc_threads("special_kernel.cu")); assert!(!config.should_use_nvcc_threads("flash_api.cu"));
}
#[test]
fn test_glob_vs_substring() {
let config = ParallelConfig::new().with_nvcc_thread_patterns(&["*gemm*.cu"], 2);
assert!(config.should_use_nvcc_threads("/path/to/my_gemm_kernel.cu"));
assert!(!config.should_use_nvcc_threads("/path/to/other.cu"));
}
}