Skip to main content

moonpool_core/
task.rs

1//! Task spawning abstraction for single-threaded simulation environments.
2//!
3//! This module provides task provider abstractions for spawning local tasks
4//! that work with both simulation and real Tokio execution.
5
6use std::future::Future;
7
8/// Error returned by [`TaskProvider::JoinHandle`] when a task did not complete
9/// normally.
10///
11/// This is the runtime-agnostic error surfaced by the [`TaskProvider`] trait.
12/// Implementations convert their runtime-specific join error into one of these
13/// variants.
14#[derive(Debug, thiserror::Error)]
15pub enum JoinError {
16    /// The task was cancelled (for example, the runtime aborted it).
17    #[error("task was cancelled")]
18    Cancelled,
19    /// The task panicked.
20    #[error("task panicked")]
21    Panicked,
22}
23
24/// Provider for spawning tasks.
25///
26/// This trait abstracts task spawning to enable both real tokio tasks
27/// and simulation-controlled task scheduling. The simulation runtime
28/// runs on a single OS thread, but the spawned futures are Send-bounded
29/// so customer call graphs can use `Arc<RwLock<…>>`, `DashMap`, and other
30/// `Send + Sync` primitives without contortion.
31pub trait TaskProvider: Clone + Send + Sync + 'static {
32    /// Future returned by [`Self::spawn_task`].
33    ///
34    /// Resolves with `Ok(())` on normal completion, or a [`JoinError`] if the
35    /// task was cancelled or panicked.
36    type JoinHandle: Future<Output = Result<(), JoinError>> + Send + Sync + 'static;
37
38    /// Spawn a named task.
39    fn spawn_task<F>(&self, name: &str, future: F) -> Self::JoinHandle
40    where
41        F: Future<Output = ()> + Send + 'static;
42
43    /// Yield control to allow other tasks to run.
44    ///
45    /// This is equivalent to `tokio::task::yield_now()` but abstracted
46    /// to enable simulation control and deterministic behavior.
47    fn yield_now(&self) -> impl Future<Output = ()> + Send;
48}
49
50/// Tokio-based task provider.
51///
52/// This provider creates tasks via `tokio::task::Builder::spawn`. When used
53/// inside the sim runtime (`new_current_thread().build()`) the runtime runs
54/// every task on a single OS thread, preserving determinism while still
55/// requiring `Send + 'static` futures.
56#[cfg(feature = "tokio-providers")]
57#[derive(Clone, Debug)]
58pub struct TokioTaskProvider;
59
60/// `JoinHandle` produced by [`TokioTaskProvider`].
61///
62/// Wraps tokio's `JoinHandle<()>` and converts the runtime-specific
63/// `tokio::task::JoinError` into the runtime-agnostic [`JoinError`] variants
64/// when polled.
65#[cfg(feature = "tokio-providers")]
66#[derive(Debug)]
67pub struct TokioJoinHandle(tokio::task::JoinHandle<()>);
68
69#[cfg(feature = "tokio-providers")]
70impl Future for TokioJoinHandle {
71    type Output = Result<(), JoinError>;
72
73    fn poll(
74        mut self: std::pin::Pin<&mut Self>,
75        cx: &mut std::task::Context<'_>,
76    ) -> std::task::Poll<Self::Output> {
77        use std::task::Poll;
78        match std::pin::Pin::new(&mut self.0).poll(cx) {
79            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
80            Poll::Ready(Err(e)) if e.is_cancelled() => Poll::Ready(Err(JoinError::Cancelled)),
81            Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError::Panicked)),
82            Poll::Pending => Poll::Pending,
83        }
84    }
85}
86
87#[cfg(feature = "tokio-providers")]
88impl TaskProvider for TokioTaskProvider {
89    type JoinHandle = TokioJoinHandle;
90
91    fn spawn_task<F>(&self, name: &str, future: F) -> Self::JoinHandle
92    where
93        F: Future<Output = ()> + Send + 'static,
94    {
95        let task_name = name.to_string();
96        let inner = tokio::task::Builder::new()
97            .name(name)
98            .spawn(async move {
99                tracing::trace!("Task {} starting", task_name);
100                future.await;
101                tracing::trace!("Task {} completed", task_name);
102            })
103            .expect("Failed to spawn task");
104        TokioJoinHandle(inner)
105    }
106
107    async fn yield_now(&self) {
108        tokio::task::yield_now().await;
109    }
110}