Skip to main content

baracuda_forge/
parallel.rs

1//! Parallel build configuration.
2
3use glob::Pattern;
4use std::path::Path;
5use std::str::FromStr;
6
7/// Parallel build configuration.
8#[derive(Debug, Clone)]
9pub struct ParallelConfig {
10    thread_percentage: f32,
11    max_threads: Option<usize>,
12    min_threads: usize,
13    nvcc_thread_file_patterns: Vec<String>,
14    num_nvcc_threads: Option<usize>,
15}
16
17impl Default for ParallelConfig {
18    fn default() -> Self {
19        Self {
20            thread_percentage: 0.5,
21            max_threads: None,
22            min_threads: 1,
23            nvcc_thread_file_patterns: vec!["flash_api".to_string(), "cutlass".to_string()],
24            num_nvcc_threads: Some(2),
25        }
26    }
27}
28
29impl ParallelConfig {
30    /// Create a new parallel config with default settings.
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Set the percentage of available threads to use (clamped to 0.0..=1.0).
36    pub fn with_percentage(mut self, percentage: f32) -> Self {
37        self.thread_percentage = percentage.clamp(0.0, 1.0);
38        self
39    }
40
41    /// Set the maximum number of threads.
42    pub fn with_max_threads(mut self, max: usize) -> Self {
43        self.max_threads = Some(max.max(1));
44        self
45    }
46
47    /// Set the minimum number of threads.
48    pub fn with_min_threads(mut self, min: usize) -> Self {
49        self.min_threads = min.max(1);
50        self
51    }
52
53    /// Set patterns for files that should use nvcc's internal `--threads=N` flag.
54    ///
55    /// Replaces the default patterns (`"flash_api"`, `"cutlass"`).
56    pub fn with_nvcc_thread_patterns<S: AsRef<str>>(
57        mut self,
58        patterns: &[S],
59        num_nvcc_threads: usize,
60    ) -> Self {
61        self.nvcc_thread_file_patterns = patterns.iter().map(|s| s.as_ref().to_string()).collect();
62        self.num_nvcc_threads = if num_nvcc_threads > 0 {
63            Some(num_nvcc_threads)
64        } else {
65            None
66        };
67        self
68    }
69
70    /// Check if a file matches any of the thread patterns.
71    ///
72    /// Supports glob patterns (e.g. `"gemm_*.cu"`) and substring matching.
73    pub fn should_use_nvcc_threads(&self, path_str: &str) -> bool {
74        let path = Path::new(path_str);
75        let filename_component = path.file_name().and_then(|s| s.to_str()).unwrap_or("");
76
77        self.nvcc_thread_file_patterns.iter().any(|pattern| {
78            if pattern.contains('*') || pattern.contains('?') || pattern.contains('[') {
79                if let Ok(compiled) = Pattern::new(pattern) {
80                    if !pattern.contains('/')
81                        && !pattern.contains('\\')
82                        && compiled.matches(filename_component)
83                    {
84                        return true;
85                    }
86
87                    if compiled.matches(path_str) {
88                        return true;
89                    }
90                }
91            }
92            path_str.contains(pattern)
93        })
94    }
95
96    /// Calculate the number of threads to use.
97    pub fn thread_count(&self) -> usize {
98        if let Ok(env_threads) = std::env::var("BARACUDA_FORGE_THREADS") {
99            if let Ok(n) = usize::from_str(&env_threads) {
100                return n.max(1);
101            }
102        }
103
104        if let Ok(env_threads) = std::env::var("RAYON_NUM_THREADS") {
105            if let Ok(n) = usize::from_str(&env_threads) {
106                return n.max(1);
107            }
108        }
109
110        let available = self.detect_available_threads();
111
112        let calculated = if let Some(max) = self.max_threads {
113            max.min(available)
114        } else {
115            (available as f32 * self.thread_percentage).ceil() as usize
116        };
117
118        calculated.max(self.min_threads).min(available)
119    }
120
121    /// Initialize the rayon thread pool with configured settings.
122    pub fn init_thread_pool(&self) -> Result<(), rayon::ThreadPoolBuildError> {
123        let thread_count = self.thread_count();
124        rayon::ThreadPoolBuilder::new()
125            .num_threads(thread_count)
126            .build_global()
127    }
128
129    /// Get thread count for nvcc's `--threads` argument.
130    pub fn nvcc_threads(&self) -> Option<usize> {
131        self.num_nvcc_threads
132    }
133
134    fn detect_available_threads(&self) -> usize {
135        if let Ok(parallelism) = std::thread::available_parallelism() {
136            return parallelism.get();
137        }
138        num_cpus::get_physical()
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_default_config() {
148        let config = ParallelConfig::default();
149        assert_eq!(config.thread_percentage, 0.5);
150        assert!(config.max_threads.is_none());
151    }
152
153    #[test]
154    fn test_percentage_clamping() {
155        let config = ParallelConfig::new().with_percentage(1.5);
156        assert_eq!(config.thread_percentage, 1.0);
157
158        let config = ParallelConfig::new().with_percentage(-0.5);
159        assert_eq!(config.thread_percentage, 0.0);
160    }
161
162    #[test]
163    fn test_thread_patterns() {
164        let config = ParallelConfig::default();
165        assert!(config.should_use_nvcc_threads("flash_api.cu"));
166        assert!(config.should_use_nvcc_threads("src/flash_api_v2.cu"));
167        assert!(config.should_use_nvcc_threads("cutlass_gemm.cu"));
168        assert!(!config.should_use_nvcc_threads("simple.cu"));
169
170        let config = ParallelConfig::new().with_nvcc_thread_patterns(&["gemm_*.cu", "special"], 4);
171        assert!(config.should_use_nvcc_threads("gemm_fp16.cu"));
172        assert!(config.should_use_nvcc_threads("src/gemm_int8.cu"));
173        assert!(config.should_use_nvcc_threads("special_kernel.cu"));
174        assert!(!config.should_use_nvcc_threads("flash_api.cu"));
175    }
176
177    #[test]
178    fn test_glob_vs_substring() {
179        let config = ParallelConfig::new().with_nvcc_thread_patterns(&["*gemm*.cu"], 2);
180        assert!(config.should_use_nvcc_threads("/path/to/my_gemm_kernel.cu"));
181        assert!(!config.should_use_nvcc_threads("/path/to/other.cu"));
182    }
183}