use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use cudarc::nccl::{group_end, group_start};
use super::LIB;
use crate::error::GpuError;
#[derive(Debug, Default)]
pub struct GroupTracker {
pub begins: AtomicUsize,
pub ends: AtomicUsize,
}
pub struct GroupGuard {
tracker: Option<Arc<GroupTracker>>,
committed: bool,
inert: bool,
}
impl GroupGuard {
pub fn begin(tracker: Option<Arc<GroupTracker>>) -> Result<Self, GpuError> {
match group_start() {
Ok(_) => {
if let Some(t) = &tracker {
t.begins.fetch_add(1, Ordering::SeqCst);
}
Ok(Self {
tracker,
committed: false,
inert: false,
})
}
Err(e) => Err(GpuError::LibraryError {
lib: LIB,
msg: format!("group_start: {e:?}"),
}),
}
}
pub fn begin_inert(tracker: Option<Arc<GroupTracker>>) -> Self {
if let Some(t) = &tracker {
t.begins.fetch_add(1, Ordering::SeqCst);
}
Self {
tracker,
committed: false,
inert: true,
}
}
pub fn commit(mut self) -> Result<(), GpuError> {
self.committed = true;
if let Some(t) = &self.tracker {
t.ends.fetch_add(1, Ordering::SeqCst);
}
if self.inert {
return Ok(());
}
group_end().map(|_| ()).map_err(|e| GpuError::LibraryError {
lib: LIB,
msg: format!("group_end: {e:?}"),
})
}
}
impl Drop for GroupGuard {
fn drop(&mut self) {
if self.committed {
return;
}
if let Some(t) = &self.tracker {
t.ends.fetch_add(1, Ordering::SeqCst);
}
if self.inert {
return;
}
if let Err(e) = group_end() {
tracing::warn!(error = ?e, "GroupGuard::drop: group_end failed");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn group_scope_guard_emits_begin_end_pair() {
let tracker = Arc::new(GroupTracker::default());
{
let _g = GroupGuard::begin_inert(Some(tracker.clone()));
}
assert_eq!(tracker.begins.load(Ordering::SeqCst), 1);
assert_eq!(tracker.ends.load(Ordering::SeqCst), 1);
}
#[test]
fn commit_then_drop_does_not_double_count() {
let tracker = Arc::new(GroupTracker::default());
let g = GroupGuard::begin_inert(Some(tracker.clone()));
g.commit().unwrap();
assert_eq!(tracker.begins.load(Ordering::SeqCst), 1);
assert_eq!(tracker.ends.load(Ordering::SeqCst), 1);
}
}