Skip to main content

cognis_core/
runnable_ext.rs

1//! Fluent extension methods for any `Runnable`.
2//!
3//! Imported via the prelude or directly:
4//!
5//! ```ignore
6//! use cognis_core::RunnableExt;
7//!
8//! let chain = prompt.pipe(model).pipe(parser);
9//! let resilient = model.with_max_retries(3).with_timeout(Duration::from_secs(30));
10//! ```
11
12use std::sync::Arc;
13use std::time::Duration;
14
15use crate::compose::{Each, Pipe};
16use crate::runnable::Runnable;
17use crate::wrappers::{Cache, Fallback, MemoryCache, Retry, RetryPolicy, Timeout};
18
19/// Adds composition + wrapper methods to any `Runnable`.
20pub trait RunnableExt<I, O>: Runnable<I, O> + Sized
21where
22    I: Send + 'static,
23    O: Send + 'static,
24{
25    /// Pipe this runnable into another, building a `Pipe<Self, Next>`.
26    fn pipe<R2, O2>(self, next: R2) -> Pipe<Self, R2, I, O, O2>
27    where
28        R2: Runnable<O, O2>,
29        O2: Send + 'static,
30    {
31        Pipe::new(self, next)
32    }
33
34    /// Wrap with a retry policy.
35    fn with_retry(self, policy: RetryPolicy) -> Retry<Self, I, O>
36    where
37        I: Clone,
38    {
39        Retry::new(self, policy)
40    }
41
42    /// Shortcut: retry with default policy and N attempts.
43    fn with_max_retries(self, attempts: u32) -> Retry<Self, I, O>
44    where
45        I: Clone,
46    {
47        Retry::new(self, RetryPolicy::new(attempts))
48    }
49
50    /// Wrap with a per-call timeout.
51    fn with_timeout(self, duration: Duration) -> Timeout<Self, I, O> {
52        Timeout::new(self, duration)
53    }
54
55    /// Wrap with a fallback runnable.
56    fn with_fallback<F>(self, fallback: F) -> Fallback<Self, F, I, O>
57    where
58        F: Runnable<I, O>,
59        I: Clone,
60    {
61        Fallback::new(self, fallback)
62    }
63
64    /// Wrap with an in-memory cache keyed by `key_fn(&I)`.
65    fn with_memory_cache<K, F>(self, key_fn: F) -> Cache<Self, I, O, K, MemoryCache<K, O>>
66    where
67        K: std::hash::Hash + Eq + Clone + Send + Sync + 'static,
68        O: Clone + Send + Sync + 'static,
69        F: Fn(&I) -> K + Send + Sync + 'static,
70    {
71        Cache::new(self, Arc::new(MemoryCache::new()), key_fn)
72    }
73
74    /// Apply this runnable to each element of a `Vec<I>` (preserves order).
75    fn each(self) -> Each<Self, I, O> {
76        Each::new(self)
77    }
78}
79
80impl<R, I, O> RunnableExt<I, O> for R
81where
82    R: Runnable<I, O>,
83    I: Send + 'static,
84    O: Send + 'static,
85{
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::runnable::RunnableConfig;
92    use crate::Result;
93    use async_trait::async_trait;
94
95    struct Inc;
96
97    #[async_trait]
98    impl Runnable<u32, u32> for Inc {
99        async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
100            Ok(input + 1)
101        }
102    }
103
104    struct Double;
105
106    #[async_trait]
107    impl Runnable<u32, u32> for Double {
108        async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
109            Ok(input * 2)
110        }
111    }
112
113    #[tokio::test]
114    async fn pipe_method_works() {
115        let chain = Inc.pipe(Double).pipe(Inc);
116        let out = chain.invoke(3, RunnableConfig::default()).await.unwrap();
117        assert_eq!(out, ((3 + 1) * 2) + 1);
118    }
119
120    #[tokio::test]
121    async fn each_works() {
122        let mapper = Inc.each();
123        let out = mapper
124            .invoke(vec![1, 2, 3], RunnableConfig::default())
125            .await
126            .unwrap();
127        assert_eq!(out, vec![2, 3, 4]);
128    }
129}