Skip to main content

cognis_core/wrappers/
timeout.rs

1//! Timeout wrapper — bound a runnable's execution time.
2
3use std::marker::PhantomData;
4use std::time::Duration;
5
6use async_trait::async_trait;
7
8use crate::runnable::{Runnable, RunnableConfig};
9use crate::{CognisError, Result};
10
11/// Bounds the inner runnable's execution time. Errors with
12/// `CognisError::Timeout` if the deadline is exceeded.
13pub struct Timeout<R, I, O> {
14    inner: R,
15    duration: Duration,
16    _phantom: PhantomData<fn(I) -> O>,
17}
18
19impl<R, I, O> Timeout<R, I, O>
20where
21    R: Runnable<I, O>,
22    I: Send + 'static,
23    O: Send + 'static,
24{
25    /// Wrap a runnable with the given timeout.
26    pub fn new(inner: R, duration: Duration) -> Self {
27        Self {
28            inner,
29            duration,
30            _phantom: PhantomData,
31        }
32    }
33}
34
35#[async_trait]
36impl<R, I, O> Runnable<I, O> for Timeout<R, I, O>
37where
38    R: Runnable<I, O>,
39    I: Send + 'static,
40    O: Send + 'static,
41{
42    async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
43        let op_name = self.inner.name().to_string();
44        let dur = self.duration;
45        match tokio::time::timeout(dur, self.inner.invoke(input, config)).await {
46            Ok(r) => r,
47            Err(_) => Err(CognisError::Timeout {
48                operation: op_name,
49                timeout_ms: dur.as_millis() as u64,
50            }),
51        }
52    }
53    fn name(&self) -> &str {
54        "Timeout"
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    struct Slow;
63
64    #[async_trait]
65    impl Runnable<u32, u32> for Slow {
66        async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
67            tokio::time::sleep(Duration::from_millis(50)).await;
68            Ok(input)
69        }
70    }
71
72    #[tokio::test]
73    async fn errors_when_exceeded() {
74        let t = Timeout::new(Slow, Duration::from_millis(5));
75        let err = t.invoke(1, RunnableConfig::default()).await.unwrap_err();
76        assert!(matches!(err, CognisError::Timeout { .. }));
77    }
78
79    #[tokio::test]
80    async fn succeeds_within_budget() {
81        let t = Timeout::new(Slow, Duration::from_millis(500));
82        assert_eq!(t.invoke(7, RunnableConfig::default()).await.unwrap(), 7);
83    }
84}