riglr_core/
spawn.rs

1// riglr-core/src/spawn.rs
2
3use crate::SignerContext;
4use std::future::Future;
5use tokio::task::JoinHandle;
6
7/// Spawns a new task while preserving SignerContext if available.
8///
9/// # Why This Is Necessary
10///
11/// The `SignerContext` is stored in `tokio::task_local!` storage, which means it is
12/// **NOT** automatically propagated to new tasks spawned with `tokio::spawn`. This is
13/// a fundamental limitation of task-local storage - it's only accessible within the
14/// same task where it was set.
15///
16/// Without this function, the following would fail:
17///
18/// ```rust,ignore
19/// // WRONG - This will fail with "No signer context"
20/// SignerContext::with_signer(signer, async {
21///     let handle = tokio::spawn(async {
22///         // This will fail - SignerContext is not available here!
23///         let current = SignerContext::current().await?; // ERROR
24///         transfer_sol("recipient", 1.0).await
25///     });
26///     handle.await?
27/// }).await
28/// ```
29///
30/// The correct approach using `spawn_with_context`:
31///
32/// ```rust,ignore
33/// // CORRECT - This properly propagates the SignerContext
34/// SignerContext::with_signer(signer, async {
35///     let handle = spawn_with_context(async {
36///         // SignerContext is available here!
37///         transfer_sol("recipient", 1.0).await
38///     }).await;
39///     handle.await?
40/// }).await
41/// ```
42///
43/// # How It Works
44///
45/// This function:
46/// 1. Checks if a SignerContext exists in the current task
47/// 2. If yes, captures it and wraps the spawned future with `SignerContext::with_signer`
48/// 3. If no, spawns the task normally without context
49///
50/// This ensures that tools requiring signing operations work correctly when
51/// executed through agent frameworks that use task spawning for parallelism.
52///
53/// Note: This function is async because it needs to check for the current signer.
54/// The future passed in should return a Result<T, SignerError>.
55pub async fn spawn_with_context<F, T>(future: F) -> JoinHandle<Result<T, crate::SignerError>>
56where
57    F: Future<Output = Result<T, crate::SignerError>> + Send + 'static,
58    T: Send + 'static,
59{
60    // Try to get the current signer from the context
61    if let Ok(signer) = SignerContext::current().await {
62        // We have a signer context - propagate it to the spawned task
63        tokio::task::spawn(async move { SignerContext::with_signer(signer, future).await })
64    } else {
65        // No signer context - spawn normally
66        tokio::task::spawn(future)
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[tokio::test]
75    async fn test_spawn_with_context_without_signer() {
76        // No signer context set - this should work fine
77        let handle = spawn_with_context(async move {
78            // Should not have a signer
79            match SignerContext::current().await {
80                Ok(_) => Err(crate::SignerError::NoSignerContext),
81                Err(_) => Ok(true),
82            }
83        })
84        .await;
85
86        let result = handle.await.unwrap().unwrap();
87        assert!(result, "No SignerContext should be available when not set");
88    }
89
90    #[tokio::test]
91    async fn test_spawn_with_context_with_signer() {
92        use crate::signer::granular_traits::{SignerBase, UnifiedSigner};
93        use std::{any::Any, sync::Arc};
94
95        // Create a mock signer implementation
96        #[derive(Debug)]
97        struct MockSigner;
98
99        impl SignerBase for MockSigner {
100            fn as_any(&self) -> &dyn Any {
101                self
102            }
103        }
104
105        impl UnifiedSigner for MockSigner {
106            fn supports_solana(&self) -> bool {
107                false
108            }
109            fn supports_evm(&self) -> bool {
110                false
111            }
112            fn as_solana(&self) -> Option<&dyn crate::signer::granular_traits::SolanaSigner> {
113                None
114            }
115            fn as_evm(&self) -> Option<&dyn crate::signer::granular_traits::EvmSigner> {
116                None
117            }
118            fn as_multi_chain(
119                &self,
120            ) -> Option<&dyn crate::signer::granular_traits::MultiChainSigner> {
121                None
122            }
123        }
124
125        let mock_signer: Arc<dyn UnifiedSigner> = Arc::new(MockSigner);
126        let signer_id = format!("{:?}", mock_signer);
127
128        // Set the signer context and spawn a task within it
129        let result = SignerContext::with_signer(mock_signer.clone(), async move {
130            // Spawn a task that should have access to the signer
131            let handle = spawn_with_context(async move {
132                // Inside the spawned task, verify we can access the signer
133                let current_signer = SignerContext::current().await?;
134                let current_id = format!("{:?}", current_signer);
135
136                // Return true if the signer matches
137                if current_id == signer_id {
138                    Ok(true)
139                } else {
140                    Ok(false)
141                }
142            })
143            .await;
144
145            // Await the spawned task and propagate its result
146            handle.await.unwrap()
147        })
148        .await;
149
150        assert!(
151            result.is_ok(),
152            "Should successfully get result from spawned task"
153        );
154        assert!(
155            result.unwrap(),
156            "SignerContext should be propagated to spawned task"
157        );
158    }
159}