use std::sync::Arc;
use std::thread;
use atomr_core::dispatch::{DefaultDispatcher, Dispatcher, DispatcherHandle};
use futures_util::future::BoxFuture;
use tokio::sync::oneshot;
pub struct GpuDispatcher {
inner: Arc<GpuDispatcherInner>,
}
struct GpuDispatcherInner {
delegate: DefaultDispatcher,
_join: Option<thread::JoinHandle<()>>,
shutdown_tx: parking_lot::Mutex<Option<oneshot::Sender<()>>>,
}
impl GpuDispatcher {
pub fn new() -> std::io::Result<Self> {
let (handle_tx, handle_rx) = std::sync::mpsc::sync_channel(1);
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let join = thread::Builder::new()
.name("atomr-accel-cuda-gpu".into())
.spawn(move || {
let rt = match tokio::runtime::Builder::new_multi_thread()
.worker_threads(1)
.thread_name("atomr-accel-cuda-gpu-worker")
.enable_all()
.build()
{
Ok(rt) => rt,
Err(e) => {
let _ = handle_tx.send(Err(e));
return;
}
};
let _ = handle_tx.send(Ok(rt.handle().clone()));
rt.block_on(async move {
let _ = shutdown_rx.await;
});
})?;
let rt_handle = match handle_rx.recv() {
Ok(Ok(h)) => h,
Ok(Err(e)) => return Err(e),
Err(_) => {
return Err(std::io::Error::other(
"GpuDispatcher thread died before yielding its runtime handle",
));
}
};
Ok(Self {
inner: Arc::new(GpuDispatcherInner {
delegate: DefaultDispatcher::new(rt_handle, 16),
_join: Some(join),
shutdown_tx: parking_lot::Mutex::new(Some(shutdown_tx)),
}),
})
}
}
impl Dispatcher for GpuDispatcher {
fn spawn_task(&self, task: BoxFuture<'static, ()>) -> DispatcherHandle {
self.inner.delegate.spawn_task(task)
}
fn throughput(&self) -> u32 {
self.inner.delegate.throughput()
}
}
impl Drop for GpuDispatcherInner {
fn drop(&mut self) {
if let Some(tx) = self.shutdown_tx.lock().take() {
let _ = tx.send(());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn pinned_dispatcher_runs_on_dedicated_thread() {
let d = GpuDispatcher::new().expect("spawn dispatcher");
let (tx, rx) = std::sync::mpsc::channel::<thread::ThreadId>();
for _ in 0..3 {
let tx = tx.clone();
d.spawn_task(Box::pin(async move {
let _ = tx.send(thread::current().id());
}));
}
let mut ids = Vec::new();
for _ in 0..3 {
ids.push(rx.recv_timeout(Duration::from_secs(2)).unwrap());
}
assert!(
ids.windows(2).all(|w| w[0] == w[1]),
"tasks ran on different threads: {:?}",
ids
);
assert_ne!(ids[0], thread::current().id());
}
}