baracuda_forge/
parallel.rs1use glob::Pattern;
4use std::path::Path;
5use std::str::FromStr;
6
7#[derive(Debug, Clone)]
9pub struct ParallelConfig {
10 thread_percentage: f32,
11 max_threads: Option<usize>,
12 min_threads: usize,
13 nvcc_thread_file_patterns: Vec<String>,
14 num_nvcc_threads: Option<usize>,
15}
16
17impl Default for ParallelConfig {
18 fn default() -> Self {
19 Self {
20 thread_percentage: 0.5,
21 max_threads: None,
22 min_threads: 1,
23 nvcc_thread_file_patterns: vec!["flash_api".to_string(), "cutlass".to_string()],
24 num_nvcc_threads: Some(2),
25 }
26 }
27}
28
29impl ParallelConfig {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn with_percentage(mut self, percentage: f32) -> Self {
37 self.thread_percentage = percentage.clamp(0.0, 1.0);
38 self
39 }
40
41 pub fn with_max_threads(mut self, max: usize) -> Self {
43 self.max_threads = Some(max.max(1));
44 self
45 }
46
47 pub fn with_min_threads(mut self, min: usize) -> Self {
49 self.min_threads = min.max(1);
50 self
51 }
52
53 pub fn with_nvcc_thread_patterns<S: AsRef<str>>(
57 mut self,
58 patterns: &[S],
59 num_nvcc_threads: usize,
60 ) -> Self {
61 self.nvcc_thread_file_patterns = patterns.iter().map(|s| s.as_ref().to_string()).collect();
62 self.num_nvcc_threads = if num_nvcc_threads > 0 {
63 Some(num_nvcc_threads)
64 } else {
65 None
66 };
67 self
68 }
69
70 pub fn should_use_nvcc_threads(&self, path_str: &str) -> bool {
74 let path = Path::new(path_str);
75 let filename_component = path.file_name().and_then(|s| s.to_str()).unwrap_or("");
76
77 self.nvcc_thread_file_patterns.iter().any(|pattern| {
78 if pattern.contains('*') || pattern.contains('?') || pattern.contains('[') {
79 if let Ok(compiled) = Pattern::new(pattern) {
80 if !pattern.contains('/')
81 && !pattern.contains('\\')
82 && compiled.matches(filename_component)
83 {
84 return true;
85 }
86
87 if compiled.matches(path_str) {
88 return true;
89 }
90 }
91 }
92 path_str.contains(pattern)
93 })
94 }
95
96 pub fn thread_count(&self) -> usize {
98 if let Ok(env_threads) = std::env::var("BARACUDA_FORGE_THREADS") {
99 if let Ok(n) = usize::from_str(&env_threads) {
100 return n.max(1);
101 }
102 }
103
104 if let Ok(env_threads) = std::env::var("RAYON_NUM_THREADS") {
105 if let Ok(n) = usize::from_str(&env_threads) {
106 return n.max(1);
107 }
108 }
109
110 let available = self.detect_available_threads();
111
112 let calculated = if let Some(max) = self.max_threads {
113 max.min(available)
114 } else {
115 (available as f32 * self.thread_percentage).ceil() as usize
116 };
117
118 calculated.max(self.min_threads).min(available)
119 }
120
121 pub fn init_thread_pool(&self) -> Result<(), rayon::ThreadPoolBuildError> {
123 let thread_count = self.thread_count();
124 rayon::ThreadPoolBuilder::new()
125 .num_threads(thread_count)
126 .build_global()
127 }
128
129 pub fn nvcc_threads(&self) -> Option<usize> {
131 self.num_nvcc_threads
132 }
133
134 fn detect_available_threads(&self) -> usize {
135 if let Ok(parallelism) = std::thread::available_parallelism() {
136 return parallelism.get();
137 }
138 num_cpus::get_physical()
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_default_config() {
148 let config = ParallelConfig::default();
149 assert_eq!(config.thread_percentage, 0.5);
150 assert!(config.max_threads.is_none());
151 }
152
153 #[test]
154 fn test_percentage_clamping() {
155 let config = ParallelConfig::new().with_percentage(1.5);
156 assert_eq!(config.thread_percentage, 1.0);
157
158 let config = ParallelConfig::new().with_percentage(-0.5);
159 assert_eq!(config.thread_percentage, 0.0);
160 }
161
162 #[test]
163 fn test_thread_patterns() {
164 let config = ParallelConfig::default();
165 assert!(config.should_use_nvcc_threads("flash_api.cu"));
166 assert!(config.should_use_nvcc_threads("src/flash_api_v2.cu"));
167 assert!(config.should_use_nvcc_threads("cutlass_gemm.cu"));
168 assert!(!config.should_use_nvcc_threads("simple.cu"));
169
170 let config = ParallelConfig::new().with_nvcc_thread_patterns(&["gemm_*.cu", "special"], 4);
171 assert!(config.should_use_nvcc_threads("gemm_fp16.cu"));
172 assert!(config.should_use_nvcc_threads("src/gemm_int8.cu"));
173 assert!(config.should_use_nvcc_threads("special_kernel.cu"));
174 assert!(!config.should_use_nvcc_threads("flash_api.cu"));
175 }
176
177 #[test]
178 fn test_glob_vs_substring() {
179 let config = ParallelConfig::new().with_nvcc_thread_patterns(&["*gemm*.cu"], 2);
180 assert!(config.should_use_nvcc_threads("/path/to/my_gemm_kernel.cu"));
181 assert!(!config.should_use_nvcc_threads("/path/to/other.cu"));
182 }
183}