Skip to main content

cognis_core/wrappers/
bind.rs

1//! `Bind` — pre-apply a `RunnableConfig` mutation to a runnable.
2//!
3//! Useful when a chain expects a fixed config flavor (extra tags, a
4//! deadline, an observer) without forcing the caller to build it every time.
5
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9use async_trait::async_trait;
10
11use crate::runnable::{Runnable, RunnableConfig};
12use crate::Result;
13
14type ConfigFn = dyn Fn(&mut RunnableConfig) + Send + Sync;
15
16/// Wraps a runnable with a config-mutator that runs before each invocation.
17pub struct Bind<R, I, O> {
18    inner: R,
19    apply: Arc<ConfigFn>,
20    _phantom: PhantomData<fn(I) -> O>,
21}
22
23impl<R, I, O> Bind<R, I, O>
24where
25    R: Runnable<I, O>,
26    I: Send + 'static,
27    O: Send + 'static,
28{
29    /// Build a wrapper that applies `apply` to a clone of the caller's
30    /// config before delegating.
31    pub fn new<F>(inner: R, apply: F) -> Self
32    where
33        F: Fn(&mut RunnableConfig) + Send + Sync + 'static,
34    {
35        Self {
36            inner,
37            apply: Arc::new(apply),
38            _phantom: PhantomData,
39        }
40    }
41}
42
43#[async_trait]
44impl<R, I, O> Runnable<I, O> for Bind<R, I, O>
45where
46    R: Runnable<I, O>,
47    I: Send + 'static,
48    O: Send + 'static,
49{
50    async fn invoke(&self, input: I, mut config: RunnableConfig) -> Result<O> {
51        (self.apply)(&mut config);
52        self.inner.invoke(input, config).await
53    }
54    fn name(&self) -> &str {
55        "Bind"
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    struct CapturesTag {
64        captured: std::sync::Mutex<Vec<String>>,
65    }
66
67    #[async_trait]
68    impl Runnable<u32, u32> for CapturesTag {
69        async fn invoke(&self, input: u32, config: RunnableConfig) -> Result<u32> {
70            self.captured.lock().unwrap().extend(config.tags.clone());
71            Ok(input)
72        }
73    }
74
75    #[tokio::test]
76    async fn bind_applies_mutation() {
77        let inner = CapturesTag {
78            captured: std::sync::Mutex::new(Vec::new()),
79        };
80        let bound = Bind::new(inner, |c| c.tags.push("bound".into()));
81        let _ = bound.invoke(1, RunnableConfig::default()).await.unwrap();
82        let cap = bound.inner.captured.lock().unwrap().clone();
83        assert!(cap.contains(&"bound".to_string()));
84    }
85}