cognis_core/wrappers/
timeout.rs1use std::marker::PhantomData;
4use std::time::Duration;
5
6use async_trait::async_trait;
7
8use crate::runnable::{Runnable, RunnableConfig};
9use crate::{CognisError, Result};
10
11pub struct Timeout<R, I, O> {
14 inner: R,
15 duration: Duration,
16 _phantom: PhantomData<fn(I) -> O>,
17}
18
19impl<R, I, O> Timeout<R, I, O>
20where
21 R: Runnable<I, O>,
22 I: Send + 'static,
23 O: Send + 'static,
24{
25 pub fn new(inner: R, duration: Duration) -> Self {
27 Self {
28 inner,
29 duration,
30 _phantom: PhantomData,
31 }
32 }
33}
34
35#[async_trait]
36impl<R, I, O> Runnable<I, O> for Timeout<R, I, O>
37where
38 R: Runnable<I, O>,
39 I: Send + 'static,
40 O: Send + 'static,
41{
42 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
43 let op_name = self.inner.name().to_string();
44 let dur = self.duration;
45 match tokio::time::timeout(dur, self.inner.invoke(input, config)).await {
46 Ok(r) => r,
47 Err(_) => Err(CognisError::Timeout {
48 operation: op_name,
49 timeout_ms: dur.as_millis() as u64,
50 }),
51 }
52 }
53 fn name(&self) -> &str {
54 "Timeout"
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 struct Slow;
63
64 #[async_trait]
65 impl Runnable<u32, u32> for Slow {
66 async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
67 tokio::time::sleep(Duration::from_millis(50)).await;
68 Ok(input)
69 }
70 }
71
72 #[tokio::test]
73 async fn errors_when_exceeded() {
74 let t = Timeout::new(Slow, Duration::from_millis(5));
75 let err = t.invoke(1, RunnableConfig::default()).await.unwrap_err();
76 assert!(matches!(err, CognisError::Timeout { .. }));
77 }
78
79 #[tokio::test]
80 async fn succeeds_within_budget() {
81 let t = Timeout::new(Slow, Duration::from_millis(500));
82 assert_eq!(t.invoke(7, RunnableConfig::default()).await.unwrap(), 7);
83 }
84}