oxigdal_gpu_advanced/multi_gpu/
affinity.rs1use super::GpuDevice;
4use dashmap::DashMap;
5use std::sync::Arc;
6use std::thread::ThreadId;
7
8pub struct AffinityManager {
10 thread_gpu_map: DashMap<ThreadId, usize>,
12 gpu_threads_map: DashMap<usize, Vec<ThreadId>>,
14 devices: Vec<Arc<GpuDevice>>,
16}
17
18impl AffinityManager {
19 pub fn new(devices: Vec<Arc<GpuDevice>>) -> Self {
21 Self {
22 thread_gpu_map: DashMap::new(),
23 gpu_threads_map: DashMap::new(),
24 devices,
25 }
26 }
27
28 pub fn pin_thread(&self, gpu_index: usize) -> Result<(), String> {
30 if gpu_index >= self.devices.len() {
31 return Err(format!(
32 "Invalid GPU index: {} (total: {})",
33 gpu_index,
34 self.devices.len()
35 ));
36 }
37
38 let thread_id = std::thread::current().id();
39
40 self.thread_gpu_map.insert(thread_id, gpu_index);
42
43 self.gpu_threads_map
45 .entry(gpu_index)
46 .or_default()
47 .push(thread_id);
48
49 Ok(())
50 }
51
52 pub fn unpin_thread(&self) {
54 let thread_id = std::thread::current().id();
55
56 if let Some((_, gpu_index)) = self.thread_gpu_map.remove(&thread_id) {
57 if let Some(mut threads) = self.gpu_threads_map.get_mut(&gpu_index) {
59 threads.retain(|&tid| tid != thread_id);
60 }
61 }
62 }
63
64 pub fn get_thread_gpu(&self) -> Option<usize> {
66 let thread_id = std::thread::current().id();
67 self.thread_gpu_map.get(&thread_id).map(|v| *v)
68 }
69
70 pub fn get_thread_device(&self) -> Option<Arc<GpuDevice>> {
72 self.get_thread_gpu()
73 .and_then(|idx| self.devices.get(idx).cloned())
74 }
75
76 pub fn get_gpu_threads(&self, gpu_index: usize) -> Vec<ThreadId> {
78 self.gpu_threads_map
79 .get(&gpu_index)
80 .map(|threads| threads.clone())
81 .unwrap_or_default()
82 }
83
84 pub fn auto_pin_thread(&self) -> Result<usize, String> {
86 if self.devices.is_empty() {
87 return Err("No GPU devices available".to_string());
88 }
89
90 let mut min_threads = usize::MAX;
92 let mut best_gpu = 0;
93
94 for i in 0..self.devices.len() {
95 let thread_count = self
96 .gpu_threads_map
97 .get(&i)
98 .map(|threads| threads.len())
99 .unwrap_or(0);
100
101 if thread_count < min_threads {
102 min_threads = thread_count;
103 best_gpu = i;
104 }
105 }
106
107 self.pin_thread(best_gpu)?;
108 Ok(best_gpu)
109 }
110
111 pub fn get_stats(&self) -> AffinityStats {
113 let mut threads_per_gpu = vec![0; self.devices.len()];
114 let mut total_threads = 0;
115
116 for entry in self.gpu_threads_map.iter() {
117 let gpu_index = *entry.key();
118 let threads = entry.value();
119 let count = threads.len();
120 if let Some(slot) = threads_per_gpu.get_mut(gpu_index) {
121 *slot = count;
122 }
123 total_threads += count;
124 }
125
126 AffinityStats {
127 threads_per_gpu,
128 total_threads,
129 total_gpus: self.devices.len(),
130 }
131 }
132
133 pub fn print_affinity_info(&self) {
135 let stats = self.get_stats();
136 println!("\nGPU Affinity Information:");
137 println!(" Total GPUs: {}", stats.total_gpus);
138 println!(" Total pinned threads: {}", stats.total_threads);
139
140 for (gpu_index, thread_count) in stats.threads_per_gpu.iter().enumerate() {
141 if let Some(device) = self.devices.get(gpu_index) {
142 println!(
143 " GPU {}: {} - {} thread(s) pinned",
144 gpu_index, device.info.name, thread_count
145 );
146 }
147 }
148 }
149
150 pub fn clear_all(&self) {
152 self.thread_gpu_map.clear();
153 self.gpu_threads_map.clear();
154 }
155
156 pub fn device_count(&self) -> usize {
158 self.devices.len()
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct AffinityStats {
165 pub threads_per_gpu: Vec<usize>,
167 pub total_threads: usize,
169 pub total_gpus: usize,
171}
172
173impl AffinityStats {
174 pub fn avg_threads_per_gpu(&self) -> f64 {
176 if self.total_gpus == 0 {
177 0.0
178 } else {
179 (self.total_threads as f64) / (self.total_gpus as f64)
180 }
181 }
182
183 pub fn load_balance_factor(&self) -> f64 {
185 if self.total_gpus == 0 || self.total_threads == 0 {
186 return 0.0;
187 }
188
189 let avg = self.avg_threads_per_gpu();
190 let variance: f64 = self
191 .threads_per_gpu
192 .iter()
193 .map(|&count| {
194 let diff = (count as f64) - avg;
195 diff * diff
196 })
197 .sum::<f64>()
198 / (self.total_gpus as f64);
199
200 let std_dev = variance.sqrt();
201 std_dev / (avg + 1.0) }
203}
204
205pub struct AffinityGuard<'a> {
207 manager: &'a AffinityManager,
208}
209
210impl<'a> AffinityGuard<'a> {
211 pub fn new(manager: &'a AffinityManager, gpu_index: usize) -> Result<Self, String> {
213 manager.pin_thread(gpu_index)?;
214 Ok(Self { manager })
215 }
216
217 pub fn auto(manager: &'a AffinityManager) -> Result<Self, String> {
219 manager.auto_pin_thread()?;
220 Ok(Self { manager })
221 }
222}
223
224impl Drop for AffinityGuard<'_> {
225 fn drop(&mut self) {
226 self.manager.unpin_thread();
227 }
228}
229
230pub struct AffinityThreadPool {
232 affinity: Arc<AffinityManager>,
234 handles: Vec<std::thread::JoinHandle<()>>,
236}
237
238impl AffinityThreadPool {
239 pub fn new(affinity: Arc<AffinityManager>, threads_per_gpu: usize) -> Self {
241 let mut handles = Vec::new();
242 let gpu_count = affinity.device_count();
243
244 for gpu_index in 0..gpu_count {
245 for _ in 0..threads_per_gpu {
246 let affinity_clone = affinity.clone();
247
248 let handle = std::thread::spawn(move || {
249 if let Err(e) = affinity_clone.pin_thread(gpu_index) {
251 eprintln!("Failed to pin thread to GPU {}: {}", gpu_index, e);
252 return;
253 }
254
255 tracing::info!("Thread pinned to GPU {}", gpu_index);
258 });
259
260 handles.push(handle);
261 }
262 }
263
264 Self { affinity, handles }
265 }
266
267 pub fn join(self) {
269 for handle in self.handles {
270 let _ = handle.join();
271 }
272 }
273
274 pub fn affinity(&self) -> Arc<AffinityManager> {
276 self.affinity.clone()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_affinity_stats() {
286 let stats = AffinityStats {
287 threads_per_gpu: vec![2, 2, 2],
288 total_threads: 6,
289 total_gpus: 3,
290 };
291
292 assert_eq!(stats.avg_threads_per_gpu(), 2.0);
293 assert!(stats.load_balance_factor() < 0.01); }
295
296 #[test]
297 fn test_affinity_stats_imbalance() {
298 let stats = AffinityStats {
299 threads_per_gpu: vec![5, 1, 0],
300 total_threads: 6,
301 total_gpus: 3,
302 };
303
304 assert_eq!(stats.avg_threads_per_gpu(), 2.0);
305 assert!(stats.load_balance_factor() > 0.5); }
307}