use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use arc_swap::ArcSwapOption;
use tokio::sync::watch;
#[cfg(feature = "cuda-runtime-tests")]
type ContextHandle = cudarc::driver::CudaContext;
#[cfg(not(feature = "cuda-runtime-tests"))]
type ContextHandle = cudarc::driver::CudaContext;
pub struct DeviceState {
device_id: u32,
generation: AtomicU64,
accepting_ops: AtomicBool,
current_ctx: ArcSwapOption<ContextHandle>,
generation_tx: watch::Sender<u64>,
}
impl DeviceState {
pub fn new(device_id: u32) -> Self {
let (tx, _rx) = watch::channel(0u64);
Self {
device_id,
generation: AtomicU64::new(0),
accepting_ops: AtomicBool::new(true),
current_ctx: ArcSwapOption::empty(),
generation_tx: tx,
}
}
pub fn device_id(&self) -> u32 {
self.device_id
}
pub fn generation(&self) -> u64 {
self.generation.load(Ordering::Acquire)
}
pub fn bump_generation(&self) -> u64 {
let new = self.generation.fetch_add(1, Ordering::AcqRel) + 1;
let _ = self.generation_tx.send(new);
new
}
pub fn generation_watch(&self) -> watch::Receiver<u64> {
self.generation_tx.subscribe()
}
pub fn accepting_ops(&self) -> bool {
self.accepting_ops.load(Ordering::Acquire)
}
pub fn begin_shutdown(&self) {
self.accepting_ops.store(false, Ordering::Release);
}
pub fn install_context(&self, ctx: Arc<ContextHandle>) {
self.current_ctx.store(Some(ctx));
}
pub fn clear_context(&self) {
self.current_ctx.store(None);
}
pub fn current_context(&self) -> Option<Arc<ContextHandle>> {
self.current_ctx.load_full()
}
}
impl std::fmt::Debug for DeviceState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DeviceState")
.field("device_id", &self.device_id)
.field("generation", &self.generation())
.field("accepting_ops", &self.accepting_ops())
.field("has_context", &self.current_ctx.load().is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generation_starts_zero_and_bumps_monotonically() {
let s = DeviceState::new(0);
assert_eq!(s.generation(), 0);
assert_eq!(s.bump_generation(), 1);
assert_eq!(s.bump_generation(), 2);
assert_eq!(s.generation(), 2);
}
#[test]
fn shutdown_flips_accepting_ops() {
let s = DeviceState::new(0);
assert!(s.accepting_ops());
s.begin_shutdown();
assert!(!s.accepting_ops());
}
}