use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::time::{Duration, Instant};
use oxicuda_driver::error::{CudaError, CudaResult};
use oxicuda_driver::event::Event;
use oxicuda_driver::stream::Stream;
use crate::kernel::{Kernel, KernelArgs};
use crate::params::LaunchParams;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompletionStatus {
Pending,
Complete,
Error(String),
}
impl CompletionStatus {
#[inline]
pub fn is_complete(&self) -> bool {
matches!(self, Self::Complete)
}
#[inline]
pub fn is_pending(&self) -> bool {
matches!(self, Self::Pending)
}
#[inline]
pub fn is_error(&self) -> bool {
matches!(self, Self::Error(_))
}
}
impl std::fmt::Display for CompletionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Pending => write!(f, "Pending"),
Self::Complete => write!(f, "Complete"),
Self::Error(msg) => write!(f, "Error: {msg}"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PollStrategy {
Spin,
Yield,
BackoffMicros(u64),
}
impl Default for PollStrategy {
#[inline]
fn default() -> Self {
Self::Yield
}
}
#[derive(Debug, Clone)]
pub struct AsyncLaunchConfig {
pub poll_strategy: PollStrategy,
pub timeout: Option<Duration>,
}
impl Default for AsyncLaunchConfig {
#[inline]
fn default() -> Self {
Self {
poll_strategy: PollStrategy::Yield,
timeout: None,
}
}
}
impl AsyncLaunchConfig {
#[inline]
pub fn new(poll_strategy: PollStrategy) -> Self {
Self {
poll_strategy,
timeout: None,
}
}
#[inline]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LaunchTiming {
pub elapsed_us: f64,
}
impl LaunchTiming {
#[inline]
pub fn elapsed_ms(&self) -> f64 {
self.elapsed_us / 1000.0
}
#[inline]
pub fn elapsed_secs(&self) -> f64 {
self.elapsed_us / 1_000_000.0
}
}
impl std::fmt::Display for LaunchTiming {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.elapsed_us < 1000.0 {
write!(f, "{:.2} us", self.elapsed_us)
} else if self.elapsed_us < 1_000_000.0 {
write!(f, "{:.3} ms", self.elapsed_ms())
} else {
write!(f, "{:.4} s", self.elapsed_secs())
}
}
}
pub struct LaunchCompletion {
event: Event,
strategy: PollStrategy,
timeout: Option<Duration>,
start_time: Option<Instant>,
waker: Option<Waker>,
poller_spawned: bool,
}
impl LaunchCompletion {
fn new(event: Event, config: &AsyncLaunchConfig) -> Self {
Self {
event,
strategy: config.poll_strategy,
timeout: config.timeout,
start_time: None,
waker: None,
poller_spawned: false,
}
}
pub fn status(&self) -> CompletionStatus {
match self.event.query() {
Ok(true) => CompletionStatus::Complete,
Ok(false) => CompletionStatus::Pending,
Err(e) => CompletionStatus::Error(e.to_string()),
}
}
fn check_timeout(&self) -> bool {
match (self.timeout, self.start_time) {
(Some(timeout), Some(start)) => start.elapsed() >= timeout,
_ => false,
}
}
fn spawn_poller(strategy: PollStrategy, waker: Waker) {
std::thread::spawn(move || {
match strategy {
PollStrategy::Spin => {
}
PollStrategy::Yield => {
std::thread::yield_now();
}
PollStrategy::BackoffMicros(us) => {
std::thread::sleep(Duration::from_micros(us));
}
}
waker.wake();
});
}
}
impl Future for LaunchCompletion {
type Output = CudaResult<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.start_time.is_none() {
self.start_time = Some(Instant::now());
}
if self.check_timeout() {
return Poll::Ready(Err(CudaError::Timeout));
}
match self.event.query() {
Ok(true) => Poll::Ready(Ok(())),
Ok(false) => {
let waker = cx.waker().clone();
self.waker = Some(waker.clone());
if !self.poller_spawned || self.strategy == PollStrategy::Spin {
self.poller_spawned = true;
Self::spawn_poller(self.strategy, waker);
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
pub struct TimedLaunchCompletion {
start_event: Event,
end_event: Event,
strategy: PollStrategy,
timeout: Option<Duration>,
start_time: Option<Instant>,
poller_spawned: bool,
}
impl TimedLaunchCompletion {
fn new(start_event: Event, end_event: Event, config: &AsyncLaunchConfig) -> Self {
Self {
start_event,
end_event,
strategy: config.poll_strategy,
timeout: config.timeout,
start_time: None,
poller_spawned: false,
}
}
pub fn status(&self) -> CompletionStatus {
match self.end_event.query() {
Ok(true) => CompletionStatus::Complete,
Ok(false) => CompletionStatus::Pending,
Err(e) => CompletionStatus::Error(e.to_string()),
}
}
fn check_timeout(&self) -> bool {
match (self.timeout, self.start_time) {
(Some(timeout), Some(start)) => start.elapsed() >= timeout,
_ => false,
}
}
}
impl Future for TimedLaunchCompletion {
type Output = CudaResult<LaunchTiming>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.start_time.is_none() {
self.start_time = Some(Instant::now());
}
if self.check_timeout() {
return Poll::Ready(Err(CudaError::Timeout));
}
match self.end_event.query() {
Ok(true) => {
match Event::elapsed_time(&self.start_event, &self.end_event) {
Ok(ms) => {
let elapsed_us = f64::from(ms) * 1000.0;
Poll::Ready(Ok(LaunchTiming { elapsed_us }))
}
Err(e) => Poll::Ready(Err(e)),
}
}
Ok(false) => {
let waker = cx.waker().clone();
if !self.poller_spawned || self.strategy == PollStrategy::Spin {
self.poller_spawned = true;
LaunchCompletion::spawn_poller(self.strategy, waker);
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
pub struct AsyncKernel {
kernel: Kernel,
config: AsyncLaunchConfig,
}
impl AsyncKernel {
#[inline]
pub fn new(kernel: Kernel) -> Self {
Self {
kernel,
config: AsyncLaunchConfig::default(),
}
}
#[inline]
pub fn with_config(kernel: Kernel, config: AsyncLaunchConfig) -> Self {
Self { kernel, config }
}
#[inline]
pub fn kernel(&self) -> &Kernel {
&self.kernel
}
#[inline]
pub fn name(&self) -> &str {
self.kernel.name()
}
#[inline]
pub fn config(&self) -> &AsyncLaunchConfig {
&self.config
}
#[inline]
pub fn set_config(&mut self, config: AsyncLaunchConfig) {
self.config = config;
}
pub fn launch_async<A: KernelArgs>(
&self,
params: &LaunchParams,
stream: &Stream,
args: &A,
) -> CudaResult<LaunchCompletion> {
self.kernel.launch(params, stream, args)?;
let event = Event::new()?;
event.record(stream)?;
Ok(LaunchCompletion::new(event, &self.config))
}
pub fn launch_and_time_async<A: KernelArgs>(
&self,
params: &LaunchParams,
stream: &Stream,
args: &A,
) -> CudaResult<TimedLaunchCompletion> {
let start_event = Event::new()?;
start_event.record(stream)?;
self.kernel.launch(params, stream, args)?;
let end_event = Event::new()?;
end_event.record(stream)?;
Ok(TimedLaunchCompletion::new(
start_event,
end_event,
&self.config,
))
}
}
impl std::fmt::Debug for AsyncKernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncKernel")
.field("kernel", &self.kernel)
.field("config", &self.config)
.finish()
}
}
impl std::fmt::Display for AsyncKernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "AsyncKernel({})", self.kernel.name())
}
}
pub fn multi_launch_async(
launches: &[(&Kernel, &LaunchParams)],
args_list: &[&dyn ErasedKernelArgs],
stream: &Stream,
config: &AsyncLaunchConfig,
) -> CudaResult<LaunchCompletion> {
for (i, (kernel, params)) in launches.iter().enumerate() {
let args = args_list.get(i).ok_or(CudaError::InvalidValue)?;
kernel.launch_erased(params, stream, *args)?;
}
let event = Event::new()?;
event.record(stream)?;
Ok(LaunchCompletion::new(event, config))
}
pub unsafe trait ErasedKernelArgs {
fn erased_param_ptrs(&self) -> Vec<*mut std::ffi::c_void>;
}
unsafe impl<T: KernelArgs> ErasedKernelArgs for T {
#[inline]
fn erased_param_ptrs(&self) -> Vec<*mut std::ffi::c_void> {
self.as_param_ptrs()
}
}
impl Kernel {
pub(crate) fn launch_erased(
&self,
params: &LaunchParams,
stream: &Stream,
args: &dyn ErasedKernelArgs,
) -> CudaResult<()> {
let driver = oxicuda_driver::loader::try_driver()?;
let mut param_ptrs = args.erased_param_ptrs();
oxicuda_driver::error::check(unsafe {
(driver.cu_launch_kernel)(
self.function().raw(),
params.grid.x,
params.grid.y,
params.grid.z,
params.block.x,
params.block.y,
params.block.z,
params.shared_mem_bytes,
stream.raw(),
param_ptrs.as_mut_ptr(),
std::ptr::null_mut(),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn completion_status_is_complete() {
let status = CompletionStatus::Complete;
assert!(status.is_complete());
assert!(!status.is_pending());
assert!(!status.is_error());
}
#[test]
fn completion_status_is_pending() {
let status = CompletionStatus::Pending;
assert!(status.is_pending());
assert!(!status.is_complete());
assert!(!status.is_error());
}
#[test]
fn completion_status_is_error() {
let status = CompletionStatus::Error("test error".to_string());
assert!(status.is_error());
assert!(!status.is_complete());
assert!(!status.is_pending());
}
#[test]
fn completion_status_display() {
assert_eq!(CompletionStatus::Pending.to_string(), "Pending");
assert_eq!(CompletionStatus::Complete.to_string(), "Complete");
assert_eq!(
CompletionStatus::Error("oops".to_string()).to_string(),
"Error: oops"
);
}
#[test]
fn completion_status_eq() {
assert_eq!(CompletionStatus::Pending, CompletionStatus::Pending);
assert_eq!(CompletionStatus::Complete, CompletionStatus::Complete);
assert_ne!(CompletionStatus::Pending, CompletionStatus::Complete);
assert_eq!(
CompletionStatus::Error("a".into()),
CompletionStatus::Error("a".into())
);
assert_ne!(
CompletionStatus::Error("a".into()),
CompletionStatus::Error("b".into())
);
}
#[test]
fn poll_strategy_default_is_yield() {
assert_eq!(PollStrategy::default(), PollStrategy::Yield);
}
#[test]
fn poll_strategy_backoff_value() {
let strategy = PollStrategy::BackoffMicros(100);
if let PollStrategy::BackoffMicros(us) = strategy {
assert_eq!(us, 100);
} else {
panic!("expected BackoffMicros");
}
}
#[test]
fn async_launch_config_default() {
let config = AsyncLaunchConfig::default();
assert_eq!(config.poll_strategy, PollStrategy::Yield);
assert!(config.timeout.is_none());
}
#[test]
fn async_launch_config_new() {
let config = AsyncLaunchConfig::new(PollStrategy::Spin);
assert_eq!(config.poll_strategy, PollStrategy::Spin);
assert!(config.timeout.is_none());
}
#[test]
fn async_launch_config_with_timeout() {
let config = AsyncLaunchConfig::new(PollStrategy::BackoffMicros(50))
.with_timeout(Duration::from_millis(500));
assert_eq!(config.poll_strategy, PollStrategy::BackoffMicros(50));
assert_eq!(config.timeout, Some(Duration::from_millis(500)));
}
#[test]
fn launch_timing_conversions() {
let timing = LaunchTiming {
elapsed_us: 1_500_000.0,
};
assert!((timing.elapsed_ms() - 1500.0).abs() < f64::EPSILON);
assert!((timing.elapsed_secs() - 1.5).abs() < f64::EPSILON);
}
#[test]
fn launch_timing_display_microseconds() {
let timing = LaunchTiming { elapsed_us: 42.5 };
let display = timing.to_string();
assert!(display.contains("us"), "expected 'us' in: {display}");
}
#[test]
fn launch_timing_display_milliseconds() {
let timing = LaunchTiming {
elapsed_us: 5_000.0,
};
let display = timing.to_string();
assert!(display.contains("ms"), "expected 'ms' in: {display}");
}
#[test]
fn launch_timing_display_seconds() {
let timing = LaunchTiming {
elapsed_us: 2_500_000.0,
};
let display = timing.to_string();
assert!(display.contains("s"), "expected 's' in: {display}");
assert!(
!display.contains("us"),
"should not contain 'us' in: {display}"
);
assert!(
!display.contains("ms"),
"should not contain 'ms' in: {display}"
);
}
#[test]
fn launch_timing_zero() {
let timing = LaunchTiming { elapsed_us: 0.0 };
assert!(timing.elapsed_ms().abs() < f64::EPSILON);
assert!(timing.elapsed_secs().abs() < f64::EPSILON);
assert!(timing.to_string().contains("us"));
}
#[test]
fn async_launch_status_pending_initially() {
let status = CompletionStatus::Pending;
assert!(status.is_pending(), "Newly created status must be Pending");
assert!(!status.is_complete());
assert!(!status.is_error());
}
#[test]
fn async_launch_debug_impl() {
let config = AsyncLaunchConfig::new(PollStrategy::Yield);
let dbg = format!("{config:?}");
assert!(
dbg.contains("AsyncLaunchConfig"),
"Debug output must contain type name, got: {dbg}"
);
let strategy_dbg = format!("{:?}", PollStrategy::BackoffMicros(200));
assert!(
strategy_dbg.contains("BackoffMicros"),
"PollStrategy Debug must contain variant name, got: {strategy_dbg}"
);
}
#[test]
fn async_completion_event_created() {
let config = AsyncLaunchConfig {
poll_strategy: PollStrategy::Spin,
timeout: Some(Duration::from_secs(5)),
};
assert_eq!(config.poll_strategy, PollStrategy::Spin);
assert_eq!(config.timeout, Some(Duration::from_secs(5)));
let config2 = AsyncLaunchConfig::new(PollStrategy::BackoffMicros(100))
.with_timeout(Duration::from_millis(250));
assert_eq!(config2.poll_strategy, PollStrategy::BackoffMicros(100));
assert_eq!(config2.timeout, Some(Duration::from_millis(250)));
}
}