oxigdal_gpu_advanced/multi_gpu/
sync.rs1use crate::error::{GpuAdvancedError, Result};
7use parking_lot::{Mutex, RwLock};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::{Notify, Semaphore};
12use wgpu::{Buffer, Device, Queue};
13
14#[derive(Clone)]
16pub struct SyncManager {
17 devices: Arc<Vec<Arc<Device>>>,
18 queues: Arc<Vec<Arc<Queue>>>,
19 barriers: Arc<RwLock<HashMap<String, Arc<Barrier>>>>,
20 events: Arc<RwLock<HashMap<String, Arc<Event>>>>,
21 fence_pool: Arc<Mutex<FencePool>>,
22}
23
24impl SyncManager {
25 pub fn new(devices: Vec<Arc<Device>>, queues: Vec<Arc<Queue>>) -> Result<Self> {
27 if devices.len() != queues.len() {
28 return Err(GpuAdvancedError::InvalidConfiguration(
29 "Device and queue count mismatch".to_string(),
30 ));
31 }
32
33 Ok(Self {
34 devices: Arc::new(devices),
35 queues: Arc::new(queues),
36 barriers: Arc::new(RwLock::new(HashMap::new())),
37 events: Arc::new(RwLock::new(HashMap::new())),
38 fence_pool: Arc::new(Mutex::new(FencePool::new())),
39 })
40 }
41
42 pub fn create_barrier(&self, name: &str, gpu_count: usize) -> Result<Arc<Barrier>> {
44 if gpu_count == 0 || gpu_count > self.devices.len() {
45 return Err(GpuAdvancedError::ConfigError(format!(
46 "Invalid GPU count {} for barrier (available: {})",
47 gpu_count,
48 self.devices.len()
49 )));
50 }
51
52 let barrier = Arc::new(Barrier::new(gpu_count));
53 self.barriers
54 .write()
55 .insert(name.to_string(), barrier.clone());
56 Ok(barrier)
57 }
58
59 pub fn get_barrier(&self, name: &str) -> Option<Arc<Barrier>> {
61 self.barriers.read().get(name).cloned()
62 }
63
64 pub fn create_event(&self, name: &str) -> Arc<Event> {
66 let event = Arc::new(Event::new());
67 self.events.write().insert(name.to_string(), event.clone());
68 event
69 }
70
71 pub fn get_event(&self, name: &str) -> Option<Arc<Event>> {
73 self.events.read().get(name).cloned()
74 }
75
76 pub async fn transfer_between_gpus(
78 &self,
79 src_gpu_idx: usize,
80 dst_gpu_idx: usize,
81 src_buffer: &Buffer,
82 dst_buffer: &Buffer,
83 size: u64,
84 ) -> Result<Duration> {
85 if src_gpu_idx >= self.devices.len() || dst_gpu_idx >= self.devices.len() {
86 return Err(GpuAdvancedError::InvalidConfiguration(
87 "GPU index out of bounds".to_string(),
88 ));
89 }
90
91 let start = Instant::now();
92
93 let staging_buffer = self.devices[src_gpu_idx].create_buffer(&wgpu::BufferDescriptor {
95 label: Some("cross_gpu_staging"),
96 size,
97 usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
98 mapped_at_creation: false,
99 });
100
101 let mut encoder =
103 self.devices[src_gpu_idx].create_command_encoder(&wgpu::CommandEncoderDescriptor {
104 label: Some("cross_gpu_copy_src"),
105 });
106 encoder.copy_buffer_to_buffer(src_buffer, 0, &staging_buffer, 0, size);
107 self.queues[src_gpu_idx].submit(Some(encoder.finish()));
108
109 let slice = staging_buffer.slice(..);
111 let (tx, rx) = futures::channel::oneshot::channel();
112 slice.map_async(wgpu::MapMode::Read, move |result| {
113 let _ = tx.send(result);
114 });
115 rx.await
118 .map_err(|_| GpuAdvancedError::SyncError("Transfer channel closed".to_string()))?
119 .map_err(|e| GpuAdvancedError::SyncError(format!("Map async failed: {:?}", e)))?;
120
121 let data = slice.get_mapped_range();
123 let vec_data: Vec<u8> = data.to_vec();
124 drop(data);
125 staging_buffer.unmap();
126
127 let dst_staging = self.devices[dst_gpu_idx].create_buffer(&wgpu::BufferDescriptor {
129 label: Some("cross_gpu_staging_dst"),
130 size,
131 usage: wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::MAP_WRITE,
132 mapped_at_creation: true,
133 });
134
135 {
137 let mut mapped = dst_staging.slice(..).get_mapped_range_mut();
138 mapped.copy_from_slice(&vec_data);
139 }
140 dst_staging.unmap();
141
142 let mut encoder =
144 self.devices[dst_gpu_idx].create_command_encoder(&wgpu::CommandEncoderDescriptor {
145 label: Some("cross_gpu_copy_dst"),
146 });
147 encoder.copy_buffer_to_buffer(&dst_staging, 0, dst_buffer, 0, size);
148 self.queues[dst_gpu_idx].submit(Some(encoder.finish()));
149
150 Ok(start.elapsed())
153 }
154
155 pub fn acquire_fence(&self) -> Fence {
157 self.fence_pool.lock().acquire()
158 }
159
160 pub fn release_fence(&self, fence: Fence) {
162 self.fence_pool.lock().release(fence);
163 }
164
165 pub fn gpu_count(&self) -> usize {
167 self.devices.len()
168 }
169}
170
171pub struct Barrier {
173 count: usize,
174 arrived: Mutex<usize>,
175 generation: Mutex<usize>,
176 notify: Notify,
177}
178
179impl Barrier {
180 pub fn new(count: usize) -> Self {
182 Self {
183 count,
184 arrived: Mutex::new(0),
185 generation: Mutex::new(0),
186 notify: Notify::new(),
187 }
188 }
189
190 pub async fn wait(&self) -> Result<()> {
192 let current_gen = *self.generation.lock();
193
194 let arrived = {
195 let mut arrived = self.arrived.lock();
196 *arrived += 1;
197 *arrived
198 };
199
200 if arrived == self.count {
201 {
203 let mut arrived = self.arrived.lock();
204 *arrived = 0;
205 }
206 {
207 let mut gen_val = self.generation.lock();
208 *gen_val += 1;
209 }
210 self.notify.notify_waiters();
211 Ok(())
212 } else {
213 loop {
215 self.notify.notified().await;
216 let gen_val = *self.generation.lock();
217 if gen_val > current_gen {
218 break;
219 }
220 }
221 Ok(())
222 }
223 }
224
225 pub async fn wait_timeout(&self, timeout: Duration) -> Result<bool> {
227 let wait_future = self.wait();
228 match tokio::time::timeout(timeout, wait_future).await {
229 Ok(Ok(())) => Ok(true),
230 Ok(Err(e)) => Err(e),
231 Err(_) => Ok(false), }
233 }
234
235 pub fn count(&self) -> usize {
237 self.count
238 }
239
240 pub fn waiting(&self) -> usize {
242 *self.arrived.lock()
243 }
244}
245
246pub struct Event {
248 signaled: Mutex<bool>,
249 notify: Notify,
250 timestamp: Mutex<Option<Instant>>,
251}
252
253impl Event {
254 pub fn new() -> Self {
256 Self {
257 signaled: Mutex::new(false),
258 notify: Notify::new(),
259 timestamp: Mutex::new(None),
260 }
261 }
262
263 pub fn signal(&self) {
265 *self.signaled.lock() = true;
266 *self.timestamp.lock() = Some(Instant::now());
267 self.notify.notify_waiters();
268 }
269
270 pub fn reset(&self) {
272 *self.signaled.lock() = false;
273 *self.timestamp.lock() = None;
274 }
275
276 pub async fn wait(&self) {
278 if *self.signaled.lock() {
279 return;
280 }
281 self.notify.notified().await;
282 }
283
284 pub async fn wait_timeout(&self, timeout: Duration) -> bool {
286 if *self.signaled.lock() {
287 return true;
288 }
289 tokio::time::timeout(timeout, self.notify.notified())
290 .await
291 .is_ok()
292 }
293
294 pub fn is_signaled(&self) -> bool {
296 *self.signaled.lock()
297 }
298
299 pub fn timestamp(&self) -> Option<Instant> {
301 *self.timestamp.lock()
302 }
303}
304
305impl Default for Event {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
313pub struct Fence {
314 id: u64,
315}
316
317impl Fence {
318 fn new(id: u64) -> Self {
319 Self { id }
320 }
321
322 pub fn id(&self) -> u64 {
324 self.id
325 }
326}
327
328struct FencePool {
330 next_id: u64,
331 available: Vec<Fence>,
332 max_pool_size: usize,
333}
334
335impl FencePool {
336 fn new() -> Self {
337 Self {
338 next_id: 0,
339 available: Vec::new(),
340 max_pool_size: 256,
341 }
342 }
343
344 fn acquire(&mut self) -> Fence {
345 if let Some(fence) = self.available.pop() {
346 fence
347 } else {
348 let fence = Fence::new(self.next_id);
349 self.next_id += 1;
350 fence
351 }
352 }
353
354 fn release(&mut self, fence: Fence) {
355 if self.available.len() < self.max_pool_size {
356 self.available.push(fence);
357 }
358 }
359}
360
361pub struct GpuSemaphore {
363 inner: Arc<Semaphore>,
364}
365
366impl GpuSemaphore {
367 pub fn new(permits: usize) -> Self {
369 Self {
370 inner: Arc::new(Semaphore::new(permits)),
371 }
372 }
373
374 pub async fn acquire(&self) -> Result<SemaphoreGuard<'_>> {
376 let permit =
377 self.inner.acquire().await.map_err(|e| {
378 GpuAdvancedError::SyncError(format!("Semaphore acquire failed: {}", e))
379 })?;
380 Ok(SemaphoreGuard { _permit: permit })
381 }
382
383 pub fn try_acquire(&self) -> Option<SemaphoreGuard<'_>> {
385 self.inner
386 .try_acquire()
387 .ok()
388 .map(|permit| SemaphoreGuard { _permit: permit })
389 }
390
391 pub fn available_permits(&self) -> usize {
393 self.inner.available_permits()
394 }
395}
396
397impl Clone for GpuSemaphore {
398 fn clone(&self) -> Self {
399 Self {
400 inner: Arc::clone(&self.inner),
401 }
402 }
403}
404
405pub struct SemaphoreGuard<'a> {
407 _permit: tokio::sync::SemaphorePermit<'a>,
408}
409
410#[derive(Debug, Clone, Default)]
412pub struct SyncStats {
413 pub barrier_waits: u64,
415 pub event_signals: u64,
417 pub cross_gpu_transfers: u64,
419 pub total_transfer_time: Duration,
421 pub total_bytes_transferred: u64,
423}
424
425impl SyncStats {
426 pub fn average_bandwidth_gbs(&self) -> Option<f64> {
428 if self.total_transfer_time > Duration::ZERO && self.total_bytes_transferred > 0 {
429 let bytes_per_sec =
430 self.total_bytes_transferred as f64 / self.total_transfer_time.as_secs_f64();
431 Some(bytes_per_sec / 1_000_000_000.0)
432 } else {
433 None
434 }
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[tokio::test]
443 async fn test_barrier() {
444 let barrier = Arc::new(Barrier::new(3));
445 let mut handles = Vec::new();
446
447 for i in 0..3 {
448 let b = barrier.clone();
449 let handle = tokio::spawn(async move {
450 println!("Task {} waiting at barrier", i);
451 b.wait().await.ok();
452 println!("Task {} passed barrier", i);
453 });
454 handles.push(handle);
455 }
456
457 for handle in handles {
458 handle.await.ok();
459 }
460
461 assert_eq!(barrier.waiting(), 0);
462 }
463
464 #[tokio::test]
465 async fn test_event() {
466 let event = Arc::new(Event::new());
467 assert!(!event.is_signaled());
468
469 let e = event.clone();
470 let handle = tokio::spawn(async move {
471 e.wait().await;
472 });
473
474 tokio::time::sleep(Duration::from_millis(10)).await;
475 event.signal();
476 assert!(event.is_signaled());
477
478 handle.await.ok();
479 }
480
481 #[tokio::test]
482 async fn test_semaphore() {
483 let sem = GpuSemaphore::new(2);
484 assert_eq!(sem.available_permits(), 2);
485
486 let _guard1 = sem.acquire().await.ok();
487 assert_eq!(sem.available_permits(), 1);
488
489 let _guard2 = sem.acquire().await.ok();
490 assert_eq!(sem.available_permits(), 0);
491
492 drop(_guard1);
493 assert_eq!(sem.available_permits(), 1);
494 }
495
496 #[test]
497 fn test_fence_pool() {
498 let mut pool = FencePool::new();
499 let f1 = pool.acquire();
500 let f2 = pool.acquire();
501
502 assert_ne!(f1.id(), f2.id());
503
504 pool.release(f1);
505 let f3 = pool.acquire();
506 assert_eq!(f1, f3);
507 }
508}