Skip to main content

oxigdal_gpu_advanced/multi_gpu/
affinity.rs

1//! GPU affinity and pinning management.
2
3use super::GpuDevice;
4use dashmap::DashMap;
5use std::sync::Arc;
6use std::thread::ThreadId;
7
8/// GPU affinity manager for thread-GPU binding
9pub struct AffinityManager {
10    /// Thread to GPU mapping
11    thread_gpu_map: DashMap<ThreadId, usize>,
12    /// GPU to threads mapping
13    gpu_threads_map: DashMap<usize, Vec<ThreadId>>,
14    /// Available GPU devices
15    devices: Vec<Arc<GpuDevice>>,
16}
17
18impl AffinityManager {
19    /// Create a new affinity manager
20    pub fn new(devices: Vec<Arc<GpuDevice>>) -> Self {
21        Self {
22            thread_gpu_map: DashMap::new(),
23            gpu_threads_map: DashMap::new(),
24            devices,
25        }
26    }
27
28    /// Pin current thread to a specific GPU
29    pub fn pin_thread(&self, gpu_index: usize) -> Result<(), String> {
30        if gpu_index >= self.devices.len() {
31            return Err(format!(
32                "Invalid GPU index: {} (total: {})",
33                gpu_index,
34                self.devices.len()
35            ));
36        }
37
38        let thread_id = std::thread::current().id();
39
40        // Update thread->GPU mapping
41        self.thread_gpu_map.insert(thread_id, gpu_index);
42
43        // Update GPU->threads mapping
44        self.gpu_threads_map
45            .entry(gpu_index)
46            .or_default()
47            .push(thread_id);
48
49        Ok(())
50    }
51
52    /// Unpin current thread
53    pub fn unpin_thread(&self) {
54        let thread_id = std::thread::current().id();
55
56        if let Some((_, gpu_index)) = self.thread_gpu_map.remove(&thread_id) {
57            // Remove thread from GPU->threads mapping
58            if let Some(mut threads) = self.gpu_threads_map.get_mut(&gpu_index) {
59                threads.retain(|&tid| tid != thread_id);
60            }
61        }
62    }
63
64    /// Get GPU index for current thread (if pinned)
65    pub fn get_thread_gpu(&self) -> Option<usize> {
66        let thread_id = std::thread::current().id();
67        self.thread_gpu_map.get(&thread_id).map(|v| *v)
68    }
69
70    /// Get device for current thread (if pinned)
71    pub fn get_thread_device(&self) -> Option<Arc<GpuDevice>> {
72        self.get_thread_gpu()
73            .and_then(|idx| self.devices.get(idx).cloned())
74    }
75
76    /// Get all threads pinned to a GPU
77    pub fn get_gpu_threads(&self, gpu_index: usize) -> Vec<ThreadId> {
78        self.gpu_threads_map
79            .get(&gpu_index)
80            .map(|threads| threads.clone())
81            .unwrap_or_default()
82    }
83
84    /// Auto-pin current thread to least loaded GPU
85    pub fn auto_pin_thread(&self) -> Result<usize, String> {
86        if self.devices.is_empty() {
87            return Err("No GPU devices available".to_string());
88        }
89
90        // Find GPU with fewest pinned threads
91        let mut min_threads = usize::MAX;
92        let mut best_gpu = 0;
93
94        for i in 0..self.devices.len() {
95            let thread_count = self
96                .gpu_threads_map
97                .get(&i)
98                .map(|threads| threads.len())
99                .unwrap_or(0);
100
101            if thread_count < min_threads {
102                min_threads = thread_count;
103                best_gpu = i;
104            }
105        }
106
107        self.pin_thread(best_gpu)?;
108        Ok(best_gpu)
109    }
110
111    /// Get affinity statistics
112    pub fn get_stats(&self) -> AffinityStats {
113        let mut threads_per_gpu = vec![0; self.devices.len()];
114        let mut total_threads = 0;
115
116        for entry in self.gpu_threads_map.iter() {
117            let gpu_index = *entry.key();
118            let threads = entry.value();
119            let count = threads.len();
120            if let Some(slot) = threads_per_gpu.get_mut(gpu_index) {
121                *slot = count;
122            }
123            total_threads += count;
124        }
125
126        AffinityStats {
127            threads_per_gpu,
128            total_threads,
129            total_gpus: self.devices.len(),
130        }
131    }
132
133    /// Print affinity information
134    pub fn print_affinity_info(&self) {
135        let stats = self.get_stats();
136        println!("\nGPU Affinity Information:");
137        println!("  Total GPUs: {}", stats.total_gpus);
138        println!("  Total pinned threads: {}", stats.total_threads);
139
140        for (gpu_index, thread_count) in stats.threads_per_gpu.iter().enumerate() {
141            if let Some(device) = self.devices.get(gpu_index) {
142                println!(
143                    "  GPU {}: {} - {} thread(s) pinned",
144                    gpu_index, device.info.name, thread_count
145                );
146            }
147        }
148    }
149
150    /// Clear all affinity mappings
151    pub fn clear_all(&self) {
152        self.thread_gpu_map.clear();
153        self.gpu_threads_map.clear();
154    }
155
156    /// Get total number of devices
157    pub fn device_count(&self) -> usize {
158        self.devices.len()
159    }
160}
161
162/// Affinity statistics
163#[derive(Debug, Clone)]
164pub struct AffinityStats {
165    /// Number of threads per GPU
166    pub threads_per_gpu: Vec<usize>,
167    /// Total number of pinned threads
168    pub total_threads: usize,
169    /// Total number of GPUs
170    pub total_gpus: usize,
171}
172
173impl AffinityStats {
174    /// Get average threads per GPU
175    pub fn avg_threads_per_gpu(&self) -> f64 {
176        if self.total_gpus == 0 {
177            0.0
178        } else {
179            (self.total_threads as f64) / (self.total_gpus as f64)
180        }
181    }
182
183    /// Get load balance factor (0.0 = perfect balance, 1.0 = worst imbalance)
184    pub fn load_balance_factor(&self) -> f64 {
185        if self.total_gpus == 0 || self.total_threads == 0 {
186            return 0.0;
187        }
188
189        let avg = self.avg_threads_per_gpu();
190        let variance: f64 = self
191            .threads_per_gpu
192            .iter()
193            .map(|&count| {
194                let diff = (count as f64) - avg;
195                diff * diff
196            })
197            .sum::<f64>()
198            / (self.total_gpus as f64);
199
200        let std_dev = variance.sqrt();
201        std_dev / (avg + 1.0) // Add 1.0 to avoid division by zero
202    }
203}
204
205/// RAII guard for automatic thread unpinning
206pub struct AffinityGuard<'a> {
207    manager: &'a AffinityManager,
208}
209
210impl<'a> AffinityGuard<'a> {
211    /// Create a new affinity guard
212    pub fn new(manager: &'a AffinityManager, gpu_index: usize) -> Result<Self, String> {
213        manager.pin_thread(gpu_index)?;
214        Ok(Self { manager })
215    }
216
217    /// Auto-pin to best available GPU
218    pub fn auto(manager: &'a AffinityManager) -> Result<Self, String> {
219        manager.auto_pin_thread()?;
220        Ok(Self { manager })
221    }
222}
223
224impl Drop for AffinityGuard<'_> {
225    fn drop(&mut self) {
226        self.manager.unpin_thread();
227    }
228}
229
230/// Thread pool with GPU affinity
231pub struct AffinityThreadPool {
232    /// Affinity manager
233    affinity: Arc<AffinityManager>,
234    /// Thread handles
235    handles: Vec<std::thread::JoinHandle<()>>,
236}
237
238impl AffinityThreadPool {
239    /// Create a new thread pool with GPU affinity
240    pub fn new(affinity: Arc<AffinityManager>, threads_per_gpu: usize) -> Self {
241        let mut handles = Vec::new();
242        let gpu_count = affinity.device_count();
243
244        for gpu_index in 0..gpu_count {
245            for _ in 0..threads_per_gpu {
246                let affinity_clone = affinity.clone();
247
248                let handle = std::thread::spawn(move || {
249                    // Pin thread to GPU
250                    if let Err(e) = affinity_clone.pin_thread(gpu_index) {
251                        eprintln!("Failed to pin thread to GPU {}: {}", gpu_index, e);
252                        return;
253                    }
254
255                    // Thread work loop would go here
256                    // For now, just demonstrate the pinning
257                    tracing::info!("Thread pinned to GPU {}", gpu_index);
258                });
259
260                handles.push(handle);
261            }
262        }
263
264        Self { affinity, handles }
265    }
266
267    /// Wait for all threads to complete
268    pub fn join(self) {
269        for handle in self.handles {
270            let _ = handle.join();
271        }
272    }
273
274    /// Get affinity manager
275    pub fn affinity(&self) -> Arc<AffinityManager> {
276        self.affinity.clone()
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_affinity_stats() {
286        let stats = AffinityStats {
287            threads_per_gpu: vec![2, 2, 2],
288            total_threads: 6,
289            total_gpus: 3,
290        };
291
292        assert_eq!(stats.avg_threads_per_gpu(), 2.0);
293        assert!(stats.load_balance_factor() < 0.01); // Perfect balance
294    }
295
296    #[test]
297    fn test_affinity_stats_imbalance() {
298        let stats = AffinityStats {
299            threads_per_gpu: vec![5, 1, 0],
300            total_threads: 6,
301            total_gpus: 3,
302        };
303
304        assert_eq!(stats.avg_threads_per_gpu(), 2.0);
305        assert!(stats.load_balance_factor() > 0.5); // Significant imbalance
306    }
307}