Skip to main content

rig_core/
wasm_compat.rs

1use bytes::Bytes;
2use std::pin::Pin;
3
4use futures::Stream;
5
6#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
7/// `Send` on native targets and a no-op marker on wasm32 with the `wasm` feature.
8pub trait WasmCompatSend: Send {}
9#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
10/// `Send` on native targets and a no-op marker on wasm32 with the `wasm` feature.
11pub trait WasmCompatSend {}
12
13#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
14impl<T> WasmCompatSend for T where T: Send {}
15#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
16impl<T> WasmCompatSend for T {}
17
18#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
19/// Streaming response bound that includes `Send` on native targets.
20pub trait WasmCompatSendStream:
21    Stream<Item = Result<Bytes, crate::http_client::Error>> + Send
22{
23    type InnerItem: Send;
24}
25
26#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
27/// Streaming response bound without `Send` on wasm32 with the `wasm` feature.
28pub trait WasmCompatSendStream: Stream<Item = Result<Bytes, crate::http_client::Error>> {
29    type InnerItem;
30}
31
32#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
33impl<T> WasmCompatSendStream for T
34where
35    T: Stream<Item = Result<Bytes, crate::http_client::Error>> + Send,
36{
37    type InnerItem = Result<Bytes, crate::http_client::Error>;
38}
39
40#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
41impl<T> WasmCompatSendStream for T
42where
43    T: Stream<Item = Result<Bytes, crate::http_client::Error>>,
44{
45    type InnerItem = Result<Bytes, crate::http_client::Error>;
46}
47
48#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
49/// `Sync` on native targets and a no-op marker on wasm32 with the `wasm` feature.
50pub trait WasmCompatSync: Sync {}
51#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
52/// `Sync` on native targets and a no-op marker on wasm32 with the `wasm` feature.
53pub trait WasmCompatSync {}
54
55#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
56impl<T> WasmCompatSync for T where T: Sync {}
57#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
58impl<T> WasmCompatSync for T {}
59
60#[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
61/// Boxed future type that includes `Send`, except on wasm32 with the `wasm` feature.
62///
63/// Gated to match [`WasmCompatSend`]/[`WasmCompatSync`] (and the streaming `Box`
64/// selection) — a `WasmBoxedFuture` returned by a `WasmCompatSend` bound (e.g.
65/// [`ToolDyn::call`](crate::tool::ToolDyn)) must drop `Send` under the same
66/// condition the marker relaxes it, or the two disagree on wasm.
67pub type WasmBoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
68
69#[cfg(all(feature = "wasm", target_arch = "wasm32"))]
70/// Boxed future type without `Send`, on wasm32 with the `wasm` feature.
71pub type WasmBoxedFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
72
73/// Error returned by [`timeout`] when the future does not complete in time.
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub struct Elapsed;
76
77impl std::fmt::Display for Elapsed {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.write_str("future timed out")
80    }
81}
82
83impl std::error::Error for Elapsed {}
84
85/// Await `future`, returning `Err(`[`Elapsed`]`)` if it does not complete within
86/// `duration`.
87///
88/// A cross-platform (native + wasm) replacement for `tokio::time::timeout`: rig's
89/// `tokio` dependency is built without the `time` feature, and `tokio::time` does
90/// not function on wasm. This is built on [`futures_timer::Delay`], which rig
91/// already uses for SSE retry backoff.
92///
93/// On elapse the pending `future` is **dropped** (cancelled by drop); it gets no
94/// chance to run cleanup beyond its own `Drop`. A zero or already-elapsed
95/// `duration` still polls `future` once before electing `Elapsed`, and an absurdly
96/// large `duration` may panic when added to `Instant::now()` inside the timer.
97///
98/// # Wasm
99/// On browser wasm (`wasm32-unknown-unknown`) the `futures-timer` `wasm-bindgen`
100/// (`setTimeout`) backend is selected automatically via a target-scoped
101/// dependency, so the timer fires without depending on any cargo feature. (The
102/// `futures_timer::Delay` SSE retry backoff relies on the same backend.)
103pub async fn timeout<F>(duration: std::time::Duration, future: F) -> Result<F::Output, Elapsed>
104where
105    F: Future,
106{
107    use futures::future::{Either, select};
108
109    let delay = futures_timer::Delay::new(duration);
110    futures::pin_mut!(future);
111    futures::pin_mut!(delay);
112    match select(future, delay).await {
113        Either::Left((output, _)) => Ok(output),
114        Either::Right(((), _)) => Err(Elapsed),
115    }
116}
117
118#[macro_export]
119macro_rules! if_wasm {
120    ($($tokens:tt)*) => {
121        #[cfg(all(feature = "wasm", target_arch = "wasm32"))]
122        $($tokens)*
123
124    };
125}
126
127#[macro_export]
128macro_rules! if_not_wasm {
129    ($($tokens:tt)*) => {
130        #[cfg(not(all(feature = "wasm", target_arch = "wasm32")))]
131        $($tokens)*
132
133    };
134}
135
136#[cfg(test)]
137mod tests {
138    use super::{Elapsed, timeout};
139    use std::time::Duration;
140
141    #[tokio::test]
142    async fn timeout_returns_ok_for_a_future_that_completes_in_time() {
143        let result = timeout(Duration::from_secs(5), async { 42 }).await;
144        assert_eq!(result, Ok(42));
145    }
146
147    #[tokio::test]
148    async fn timeout_returns_elapsed_for_a_future_that_never_completes() {
149        let result = timeout(Duration::from_millis(20), std::future::pending::<()>()).await;
150        assert_eq!(result, Err(Elapsed));
151    }
152
153    #[tokio::test]
154    async fn timeout_zero_duration_still_polls_a_ready_future_once() {
155        // Documented contract: a zero/already-elapsed duration still polls the
156        // future once before electing `Elapsed`, so a ready future wins.
157        let result = timeout(Duration::ZERO, async { 7 }).await;
158        assert_eq!(result, Ok(7));
159    }
160}