1use std::future::Future;
45use std::pin::Pin;
46use std::task::{Context, Poll, Waker};
47use std::time::{Duration, Instant};
48
49use oxicuda_driver::error::{CudaError, CudaResult};
50use oxicuda_driver::event::Event;
51use oxicuda_driver::stream::Stream;
52
53use crate::kernel::{Kernel, KernelArgs};
54use crate::params::LaunchParams;
55
56#[derive(Debug, Clone, PartialEq, Eq)]
62pub enum CompletionStatus {
63 Pending,
65 Complete,
67 Error(String),
69}
70
71impl CompletionStatus {
72 #[inline]
74 pub fn is_complete(&self) -> bool {
75 matches!(self, Self::Complete)
76 }
77
78 #[inline]
80 pub fn is_pending(&self) -> bool {
81 matches!(self, Self::Pending)
82 }
83
84 #[inline]
86 pub fn is_error(&self) -> bool {
87 matches!(self, Self::Error(_))
88 }
89}
90
91impl std::fmt::Display for CompletionStatus {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 match self {
94 Self::Pending => write!(f, "Pending"),
95 Self::Complete => write!(f, "Complete"),
96 Self::Error(msg) => write!(f, "Error: {msg}"),
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum PollStrategy {
111 Spin,
115
116 Yield,
120
121 BackoffMicros(u64),
125}
126
127impl Default for PollStrategy {
128 #[inline]
131 fn default() -> Self {
132 Self::Yield
133 }
134}
135
136#[derive(Debug, Clone)]
142pub struct AsyncLaunchConfig {
143 pub poll_strategy: PollStrategy,
145 pub timeout: Option<Duration>,
148}
149
150impl Default for AsyncLaunchConfig {
151 #[inline]
153 fn default() -> Self {
154 Self {
155 poll_strategy: PollStrategy::Yield,
156 timeout: None,
157 }
158 }
159}
160
161impl AsyncLaunchConfig {
162 #[inline]
164 pub fn new(poll_strategy: PollStrategy) -> Self {
165 Self {
166 poll_strategy,
167 timeout: None,
168 }
169 }
170
171 #[inline]
173 pub fn with_timeout(mut self, timeout: Duration) -> Self {
174 self.timeout = Some(timeout);
175 self
176 }
177}
178
179#[derive(Debug, Clone, Copy, PartialEq)]
185pub struct LaunchTiming {
186 pub elapsed_us: f64,
188}
189
190impl LaunchTiming {
191 #[inline]
193 pub fn elapsed_ms(&self) -> f64 {
194 self.elapsed_us / 1000.0
195 }
196
197 #[inline]
199 pub fn elapsed_secs(&self) -> f64 {
200 self.elapsed_us / 1_000_000.0
201 }
202}
203
204impl std::fmt::Display for LaunchTiming {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 if self.elapsed_us < 1000.0 {
207 write!(f, "{:.2} us", self.elapsed_us)
208 } else if self.elapsed_us < 1_000_000.0 {
209 write!(f, "{:.3} ms", self.elapsed_ms())
210 } else {
211 write!(f, "{:.4} s", self.elapsed_secs())
212 }
213 }
214}
215
216pub struct LaunchCompletion {
225 event: Event,
227 strategy: PollStrategy,
229 timeout: Option<Duration>,
231 start_time: Option<Instant>,
233 waker: Option<Waker>,
235 poller_spawned: bool,
237}
238
239impl LaunchCompletion {
240 fn new(event: Event, config: &AsyncLaunchConfig) -> Self {
242 Self {
243 event,
244 strategy: config.poll_strategy,
245 timeout: config.timeout,
246 start_time: None,
247 waker: None,
248 poller_spawned: false,
249 }
250 }
251
252 pub fn status(&self) -> CompletionStatus {
254 match self.event.query() {
255 Ok(true) => CompletionStatus::Complete,
256 Ok(false) => CompletionStatus::Pending,
257 Err(e) => CompletionStatus::Error(e.to_string()),
258 }
259 }
260
261 fn check_timeout(&self) -> bool {
263 match (self.timeout, self.start_time) {
264 (Some(timeout), Some(start)) => start.elapsed() >= timeout,
265 _ => false,
266 }
267 }
268
269 fn spawn_poller(strategy: PollStrategy, waker: Waker) {
272 std::thread::spawn(move || {
273 match strategy {
274 PollStrategy::Spin => {
275 }
277 PollStrategy::Yield => {
278 std::thread::yield_now();
279 }
280 PollStrategy::BackoffMicros(us) => {
281 std::thread::sleep(Duration::from_micros(us));
282 }
283 }
284 waker.wake();
285 });
286 }
287}
288
289impl Future for LaunchCompletion {
290 type Output = CudaResult<()>;
291
292 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
293 if self.start_time.is_none() {
295 self.start_time = Some(Instant::now());
296 }
297
298 if self.check_timeout() {
300 return Poll::Ready(Err(CudaError::Timeout));
301 }
302
303 match self.event.query() {
305 Ok(true) => Poll::Ready(Ok(())),
306 Ok(false) => {
307 let waker = cx.waker().clone();
309 self.waker = Some(waker.clone());
310
311 if !self.poller_spawned || self.strategy == PollStrategy::Spin {
312 self.poller_spawned = true;
313 Self::spawn_poller(self.strategy, waker);
314 }
315
316 Poll::Pending
317 }
318 Err(e) => Poll::Ready(Err(e)),
319 }
320 }
321}
322
323pub struct TimedLaunchCompletion {
330 start_event: Event,
332 end_event: Event,
334 strategy: PollStrategy,
336 timeout: Option<Duration>,
338 start_time: Option<Instant>,
340 poller_spawned: bool,
342}
343
344impl TimedLaunchCompletion {
345 fn new(start_event: Event, end_event: Event, config: &AsyncLaunchConfig) -> Self {
347 Self {
348 start_event,
349 end_event,
350 strategy: config.poll_strategy,
351 timeout: config.timeout,
352 start_time: None,
353 poller_spawned: false,
354 }
355 }
356
357 pub fn status(&self) -> CompletionStatus {
359 match self.end_event.query() {
360 Ok(true) => CompletionStatus::Complete,
361 Ok(false) => CompletionStatus::Pending,
362 Err(e) => CompletionStatus::Error(e.to_string()),
363 }
364 }
365
366 fn check_timeout(&self) -> bool {
368 match (self.timeout, self.start_time) {
369 (Some(timeout), Some(start)) => start.elapsed() >= timeout,
370 _ => false,
371 }
372 }
373}
374
375impl Future for TimedLaunchCompletion {
376 type Output = CudaResult<LaunchTiming>;
377
378 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
379 if self.start_time.is_none() {
380 self.start_time = Some(Instant::now());
381 }
382
383 if self.check_timeout() {
384 return Poll::Ready(Err(CudaError::Timeout));
385 }
386
387 match self.end_event.query() {
388 Ok(true) => {
389 match Event::elapsed_time(&self.start_event, &self.end_event) {
391 Ok(ms) => {
392 let elapsed_us = f64::from(ms) * 1000.0;
393 Poll::Ready(Ok(LaunchTiming { elapsed_us }))
394 }
395 Err(e) => Poll::Ready(Err(e)),
396 }
397 }
398 Ok(false) => {
399 let waker = cx.waker().clone();
400
401 if !self.poller_spawned || self.strategy == PollStrategy::Spin {
402 self.poller_spawned = true;
403 LaunchCompletion::spawn_poller(self.strategy, waker);
404 }
405
406 Poll::Pending
407 }
408 Err(e) => Poll::Ready(Err(e)),
409 }
410 }
411}
412
413pub struct AsyncKernel {
422 kernel: Kernel,
424 config: AsyncLaunchConfig,
426}
427
428impl AsyncKernel {
429 #[inline]
431 pub fn new(kernel: Kernel) -> Self {
432 Self {
433 kernel,
434 config: AsyncLaunchConfig::default(),
435 }
436 }
437
438 #[inline]
440 pub fn with_config(kernel: Kernel, config: AsyncLaunchConfig) -> Self {
441 Self { kernel, config }
442 }
443
444 #[inline]
446 pub fn kernel(&self) -> &Kernel {
447 &self.kernel
448 }
449
450 #[inline]
452 pub fn name(&self) -> &str {
453 self.kernel.name()
454 }
455
456 #[inline]
458 pub fn config(&self) -> &AsyncLaunchConfig {
459 &self.config
460 }
461
462 #[inline]
464 pub fn set_config(&mut self, config: AsyncLaunchConfig) {
465 self.config = config;
466 }
467
468 pub fn launch_async<A: KernelArgs>(
480 &self,
481 params: &LaunchParams,
482 stream: &Stream,
483 args: &A,
484 ) -> CudaResult<LaunchCompletion> {
485 self.kernel.launch(params, stream, args)?;
487
488 let event = Event::new()?;
490 event.record(stream)?;
491
492 Ok(LaunchCompletion::new(event, &self.config))
493 }
494
495 pub fn launch_and_time_async<A: KernelArgs>(
506 &self,
507 params: &LaunchParams,
508 stream: &Stream,
509 args: &A,
510 ) -> CudaResult<TimedLaunchCompletion> {
511 let start_event = Event::new()?;
512 start_event.record(stream)?;
513
514 self.kernel.launch(params, stream, args)?;
515
516 let end_event = Event::new()?;
517 end_event.record(stream)?;
518
519 Ok(TimedLaunchCompletion::new(
520 start_event,
521 end_event,
522 &self.config,
523 ))
524 }
525}
526
527impl std::fmt::Debug for AsyncKernel {
528 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
529 f.debug_struct("AsyncKernel")
530 .field("kernel", &self.kernel)
531 .field("config", &self.config)
532 .finish()
533 }
534}
535
536impl std::fmt::Display for AsyncKernel {
537 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
538 write!(f, "AsyncKernel({})", self.kernel.name())
539 }
540}
541
542pub fn multi_launch_async(
565 launches: &[(&Kernel, &LaunchParams)],
566 args_list: &[&dyn ErasedKernelArgs],
567 stream: &Stream,
568 config: &AsyncLaunchConfig,
569) -> CudaResult<LaunchCompletion> {
570 for (i, (kernel, params)) in launches.iter().enumerate() {
571 let args = args_list.get(i).ok_or(CudaError::InvalidValue)?;
572 kernel.launch_erased(params, stream, *args)?;
573 }
574
575 let event = Event::new()?;
576 event.record(stream)?;
577
578 Ok(LaunchCompletion::new(event, config))
579}
580
581pub unsafe trait ErasedKernelArgs {
593 fn erased_param_ptrs(&self) -> Vec<*mut std::ffi::c_void>;
595}
596
597unsafe impl<T: KernelArgs> ErasedKernelArgs for T {
603 #[inline]
604 fn erased_param_ptrs(&self) -> Vec<*mut std::ffi::c_void> {
605 self.as_param_ptrs()
606 }
607}
608
609impl Kernel {
614 pub(crate) fn launch_erased(
618 &self,
619 params: &LaunchParams,
620 stream: &Stream,
621 args: &dyn ErasedKernelArgs,
622 ) -> CudaResult<()> {
623 let driver = oxicuda_driver::loader::try_driver()?;
624 let mut param_ptrs = args.erased_param_ptrs();
625 oxicuda_driver::error::check(unsafe {
626 (driver.cu_launch_kernel)(
627 self.function().raw(),
628 params.grid.x,
629 params.grid.y,
630 params.grid.z,
631 params.block.x,
632 params.block.y,
633 params.block.z,
634 params.shared_mem_bytes,
635 stream.raw(),
636 param_ptrs.as_mut_ptr(),
637 std::ptr::null_mut(),
638 )
639 })
640 }
641}
642
643#[cfg(test)]
648mod tests {
649 use super::*;
650
651 #[test]
654 fn completion_status_is_complete() {
655 let status = CompletionStatus::Complete;
656 assert!(status.is_complete());
657 assert!(!status.is_pending());
658 assert!(!status.is_error());
659 }
660
661 #[test]
662 fn completion_status_is_pending() {
663 let status = CompletionStatus::Pending;
664 assert!(status.is_pending());
665 assert!(!status.is_complete());
666 assert!(!status.is_error());
667 }
668
669 #[test]
670 fn completion_status_is_error() {
671 let status = CompletionStatus::Error("test error".to_string());
672 assert!(status.is_error());
673 assert!(!status.is_complete());
674 assert!(!status.is_pending());
675 }
676
677 #[test]
678 fn completion_status_display() {
679 assert_eq!(CompletionStatus::Pending.to_string(), "Pending");
680 assert_eq!(CompletionStatus::Complete.to_string(), "Complete");
681 assert_eq!(
682 CompletionStatus::Error("oops".to_string()).to_string(),
683 "Error: oops"
684 );
685 }
686
687 #[test]
688 fn completion_status_eq() {
689 assert_eq!(CompletionStatus::Pending, CompletionStatus::Pending);
690 assert_eq!(CompletionStatus::Complete, CompletionStatus::Complete);
691 assert_ne!(CompletionStatus::Pending, CompletionStatus::Complete);
692 assert_eq!(
693 CompletionStatus::Error("a".into()),
694 CompletionStatus::Error("a".into())
695 );
696 assert_ne!(
697 CompletionStatus::Error("a".into()),
698 CompletionStatus::Error("b".into())
699 );
700 }
701
702 #[test]
705 fn poll_strategy_default_is_yield() {
706 assert_eq!(PollStrategy::default(), PollStrategy::Yield);
707 }
708
709 #[test]
710 fn poll_strategy_backoff_value() {
711 let strategy = PollStrategy::BackoffMicros(100);
712 if let PollStrategy::BackoffMicros(us) = strategy {
713 assert_eq!(us, 100);
714 } else {
715 panic!("expected BackoffMicros");
716 }
717 }
718
719 #[test]
722 fn async_launch_config_default() {
723 let config = AsyncLaunchConfig::default();
724 assert_eq!(config.poll_strategy, PollStrategy::Yield);
725 assert!(config.timeout.is_none());
726 }
727
728 #[test]
729 fn async_launch_config_new() {
730 let config = AsyncLaunchConfig::new(PollStrategy::Spin);
731 assert_eq!(config.poll_strategy, PollStrategy::Spin);
732 assert!(config.timeout.is_none());
733 }
734
735 #[test]
736 fn async_launch_config_with_timeout() {
737 let config = AsyncLaunchConfig::new(PollStrategy::BackoffMicros(50))
738 .with_timeout(Duration::from_millis(500));
739 assert_eq!(config.poll_strategy, PollStrategy::BackoffMicros(50));
740 assert_eq!(config.timeout, Some(Duration::from_millis(500)));
741 }
742
743 #[test]
746 fn launch_timing_conversions() {
747 let timing = LaunchTiming {
748 elapsed_us: 1_500_000.0,
749 };
750 assert!((timing.elapsed_ms() - 1500.0).abs() < f64::EPSILON);
751 assert!((timing.elapsed_secs() - 1.5).abs() < f64::EPSILON);
752 }
753
754 #[test]
755 fn launch_timing_display_microseconds() {
756 let timing = LaunchTiming { elapsed_us: 42.5 };
757 let display = timing.to_string();
758 assert!(display.contains("us"), "expected 'us' in: {display}");
759 }
760
761 #[test]
762 fn launch_timing_display_milliseconds() {
763 let timing = LaunchTiming {
764 elapsed_us: 5_000.0,
765 };
766 let display = timing.to_string();
767 assert!(display.contains("ms"), "expected 'ms' in: {display}");
768 }
769
770 #[test]
771 fn launch_timing_display_seconds() {
772 let timing = LaunchTiming {
773 elapsed_us: 2_500_000.0,
774 };
775 let display = timing.to_string();
776 assert!(display.contains("s"), "expected 's' in: {display}");
777 assert!(
778 !display.contains("us"),
779 "should not contain 'us' in: {display}"
780 );
781 assert!(
782 !display.contains("ms"),
783 "should not contain 'ms' in: {display}"
784 );
785 }
786
787 #[test]
788 fn launch_timing_zero() {
789 let timing = LaunchTiming { elapsed_us: 0.0 };
790 assert!(timing.elapsed_ms().abs() < f64::EPSILON);
791 assert!(timing.elapsed_secs().abs() < f64::EPSILON);
792 assert!(timing.to_string().contains("us"));
793 }
794
795 #[test]
800 fn async_launch_status_pending_initially() {
801 let status = CompletionStatus::Pending;
804 assert!(status.is_pending(), "Newly created status must be Pending");
805 assert!(!status.is_complete());
806 assert!(!status.is_error());
807 }
808
809 #[test]
810 fn async_launch_debug_impl() {
811 let config = AsyncLaunchConfig::new(PollStrategy::Yield);
813 let dbg = format!("{config:?}");
814 assert!(
815 dbg.contains("AsyncLaunchConfig"),
816 "Debug output must contain type name, got: {dbg}"
817 );
818 let strategy_dbg = format!("{:?}", PollStrategy::BackoffMicros(200));
820 assert!(
821 strategy_dbg.contains("BackoffMicros"),
822 "PollStrategy Debug must contain variant name, got: {strategy_dbg}"
823 );
824 }
825
826 #[test]
827 fn async_completion_event_created() {
828 let config = AsyncLaunchConfig {
831 poll_strategy: PollStrategy::Spin,
832 timeout: Some(Duration::from_secs(5)),
833 };
834 assert_eq!(config.poll_strategy, PollStrategy::Spin);
835 assert_eq!(config.timeout, Some(Duration::from_secs(5)));
836
837 let config2 = AsyncLaunchConfig::new(PollStrategy::BackoffMicros(100))
839 .with_timeout(Duration::from_millis(250));
840 assert_eq!(config2.poll_strategy, PollStrategy::BackoffMicros(100));
841 assert_eq!(config2.timeout, Some(Duration::from_millis(250)));
842 }
843}