Skip to main content

cognis_core/wrappers/
middleware.rs

1//! Runnable-level middleware — generic before/after/error hooks that
2//! wrap any [`Runnable`] without subclassing.
3//!
4//! Differs from the agent-orchestration middleware in `cognis::middleware`:
5//! that one is specific to LLM-call shapes; this one is generic over
6//! `(I, O)` and works for any runnable in the framework.
7//!
8//! Three integration points:
9//! - [`Middleware<I, O>`] trait — implement to plug behavior in.
10//! - [`WithMiddleware<R, I, O>`] wrapper — wraps a single runnable with
11//!   one or more middlewares (executed in registration order).
12//! - [`MiddlewareStack<I, O>`] — explicit ordered list, useful when a
13//!   caller wants to compose middlewares once and reuse them.
14//!
15//! Customization: implement `Middleware<I, O>` for full control. For
16//! lighter cases, use [`fn_middleware`] to lift a closure or
17//! [`InspectMiddleware`] to attach pure read-only observers.
18
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23
24use crate::runnable::{Runnable, RunnableConfig};
25use crate::{CognisError, Result};
26
27/// Object-safe middleware trait. The three hooks fire in this order:
28///
29/// 1. `before_invoke` — receives the input by `&mut`, may rewrite it or
30///    short-circuit by returning an error.
31/// 2. (the wrapped runnable runs)
32/// 3. `after_invoke` (on `Ok`) — receives the output by `&mut`, may
33///    rewrite it.
34/// 4. `on_error` (on `Err`) — receives the error by `&mut`, may rewrite
35///    or substitute. Returning `Ok(Some(new_output))` recovers the call.
36///
37/// All hooks have default no-op impls; implement only what you need.
38#[async_trait]
39pub trait Middleware<I, O>: Send + Sync
40where
41    I: Send + 'static,
42    O: Send + 'static,
43{
44    /// Called with the input before the wrapped runnable is invoked.
45    /// Mutate `input` in place to rewrite it. Return `Err(...)` to
46    /// short-circuit (the wrapped runnable will not be called).
47    async fn before_invoke(&self, _input: &mut I, _config: &RunnableConfig) -> Result<()> {
48        Ok(())
49    }
50
51    /// Called on the success path with the output. Mutate to rewrite.
52    async fn after_invoke(&self, _output: &mut O, _config: &RunnableConfig) -> Result<()> {
53        Ok(())
54    }
55
56    /// Called on the error path. Returning `Ok(Some(o))` recovers — the
57    /// invocation completes successfully with that output. Returning
58    /// `Ok(None)` (the default) re-propagates the error.
59    /// Returning `Err(e)` substitutes the original error.
60    async fn on_error(
61        &self,
62        _err: &mut CognisError,
63        _config: &RunnableConfig,
64    ) -> Result<Option<O>> {
65        Ok(None)
66    }
67
68    /// Friendly name for telemetry / diagnostics.
69    fn name(&self) -> &str {
70        std::any::type_name::<Self>()
71    }
72}
73
74// ---------------------------------------------------------------------------
75// fn_middleware — lift a closure into a Middleware.
76// ---------------------------------------------------------------------------
77
78/// Build a middleware from a `before_invoke` closure only. For the
79/// after / error hooks, use [`Middleware`] directly.
80pub fn fn_middleware<I, O, F>(before: F) -> FnMiddleware<I, O, F>
81where
82    I: Send + 'static,
83    O: Send + 'static,
84    F: Fn(&mut I, &RunnableConfig) -> Result<()> + Send + Sync + 'static,
85{
86    FnMiddleware {
87        before,
88        _t: PhantomData,
89    }
90}
91
92/// Closure-backed middleware (before-only).
93pub struct FnMiddleware<I, O, F> {
94    before: F,
95    _t: PhantomData<fn(I) -> O>,
96}
97
98#[async_trait]
99impl<I, O, F> Middleware<I, O> for FnMiddleware<I, O, F>
100where
101    I: Send + 'static,
102    O: Send + 'static,
103    F: Fn(&mut I, &RunnableConfig) -> Result<()> + Send + Sync + 'static,
104{
105    async fn before_invoke(&self, input: &mut I, config: &RunnableConfig) -> Result<()> {
106        (self.before)(input, config)
107    }
108}
109
110// ---------------------------------------------------------------------------
111// InspectMiddleware — pure read-only observation, no mutation.
112// ---------------------------------------------------------------------------
113
114/// Read-only inspector. Closure receives `(&I, &RunnableConfig)` after
115/// each successful invocation (purely for telemetry / metrics).
116pub struct InspectMiddleware<I, O, F> {
117    on_ok: F,
118    _t: PhantomData<fn(I) -> O>,
119}
120
121impl<I, O, F> InspectMiddleware<I, O, F>
122where
123    I: Send + Sync + 'static,
124    O: Send + Sync + 'static,
125    F: Fn(&O, &RunnableConfig) + Send + Sync + 'static,
126{
127    /// Build an inspector that fires on `after_invoke`.
128    pub fn new(on_ok: F) -> Self {
129        Self {
130            on_ok,
131            _t: PhantomData,
132        }
133    }
134}
135
136#[async_trait]
137impl<I, O, F> Middleware<I, O> for InspectMiddleware<I, O, F>
138where
139    I: Send + Sync + 'static,
140    O: Send + Sync + 'static,
141    F: Fn(&O, &RunnableConfig) + Send + Sync + 'static,
142{
143    async fn after_invoke(&self, output: &mut O, config: &RunnableConfig) -> Result<()> {
144        (self.on_ok)(output, config);
145        Ok(())
146    }
147}
148
149// ---------------------------------------------------------------------------
150// MiddlewareStack — explicit ordered list of middlewares.
151// ---------------------------------------------------------------------------
152
153/// Ordered list of middlewares. `before_invoke` runs first→last;
154/// `after_invoke` runs in reverse (last registered → first registered),
155/// matching standard middleware ordering.
156pub struct MiddlewareStack<I, O> {
157    inner: Vec<Arc<dyn Middleware<I, O>>>,
158}
159
160impl<I, O> Default for MiddlewareStack<I, O>
161where
162    I: Send + 'static,
163    O: Send + 'static,
164{
165    fn default() -> Self {
166        Self::new()
167    }
168}
169
170impl<I, O> Clone for MiddlewareStack<I, O> {
171    fn clone(&self) -> Self {
172        Self {
173            inner: self.inner.clone(),
174        }
175    }
176}
177
178impl<I, O> MiddlewareStack<I, O>
179where
180    I: Send + 'static,
181    O: Send + 'static,
182{
183    /// Empty stack.
184    pub fn new() -> Self {
185        Self { inner: Vec::new() }
186    }
187
188    /// Append a middleware to the end of the chain.
189    pub fn push(mut self, m: Arc<dyn Middleware<I, O>>) -> Self {
190        self.inner.push(m);
191        self
192    }
193
194    /// Number of registered middlewares.
195    pub fn len(&self) -> usize {
196        self.inner.len()
197    }
198
199    /// True if no middlewares are registered.
200    pub fn is_empty(&self) -> bool {
201        self.inner.is_empty()
202    }
203
204    /// Borrow the registered middlewares (read-only).
205    pub fn middlewares(&self) -> &[Arc<dyn Middleware<I, O>>] {
206        &self.inner
207    }
208}
209
210// ---------------------------------------------------------------------------
211// WithMiddleware — wrap a Runnable with a MiddlewareStack.
212// ---------------------------------------------------------------------------
213
214/// Wraps a `Runnable<I, O>` with a stack of middlewares.
215pub struct WithMiddleware<R, I, O> {
216    inner: R,
217    stack: MiddlewareStack<I, O>,
218    _phantom: PhantomData<fn(I) -> O>,
219}
220
221impl<R, I, O> WithMiddleware<R, I, O>
222where
223    R: Runnable<I, O>,
224    I: Send + 'static,
225    O: Send + 'static,
226{
227    /// Wrap with an empty stack — equivalent to `inner` itself but
228    /// chainable via [`Self::push`].
229    pub fn new(inner: R) -> Self {
230        Self {
231            inner,
232            stack: MiddlewareStack::new(),
233            _phantom: PhantomData,
234        }
235    }
236
237    /// Wrap with an existing stack.
238    pub fn with_stack(inner: R, stack: MiddlewareStack<I, O>) -> Self {
239        Self {
240            inner,
241            stack,
242            _phantom: PhantomData,
243        }
244    }
245
246    /// Append a middleware.
247    pub fn push(mut self, m: Arc<dyn Middleware<I, O>>) -> Self {
248        self.stack = self.stack.push(m);
249        self
250    }
251}
252
253#[async_trait]
254impl<R, I, O> Runnable<I, O> for WithMiddleware<R, I, O>
255where
256    R: Runnable<I, O>,
257    I: Send + 'static,
258    O: Send + 'static,
259{
260    async fn invoke(&self, mut input: I, config: RunnableConfig) -> Result<O> {
261        // before — first → last
262        for m in self.stack.inner.iter() {
263            m.before_invoke(&mut input, &config).await?;
264        }
265        let result = self.inner.invoke(input, config.clone()).await;
266        match result {
267            Ok(mut output) => {
268                // after — last → first (standard onion order)
269                for m in self.stack.inner.iter().rev() {
270                    m.after_invoke(&mut output, &config).await?;
271                }
272                Ok(output)
273            }
274            Err(mut err) => {
275                // on_error — last → first; first to recover wins.
276                for m in self.stack.inner.iter().rev() {
277                    if let Some(o) = m.on_error(&mut err, &config).await? {
278                        return Ok(o);
279                    }
280                }
281                Err(err)
282            }
283        }
284    }
285
286    fn name(&self) -> &str {
287        self.inner.name()
288    }
289
290    fn input_schema(&self) -> Option<serde_json::Value> {
291        self.inner.input_schema()
292    }
293
294    fn output_schema(&self) -> Option<serde_json::Value> {
295        self.inner.output_schema()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
303
304    struct Echo;
305
306    #[async_trait]
307    impl Runnable<String, String> for Echo {
308        async fn invoke(&self, input: String, _: RunnableConfig) -> Result<String> {
309            Ok(input)
310        }
311    }
312
313    struct UppercaseInput;
314
315    #[async_trait]
316    impl Middleware<String, String> for UppercaseInput {
317        async fn before_invoke(&self, input: &mut String, _: &RunnableConfig) -> Result<()> {
318            *input = input.to_uppercase();
319            Ok(())
320        }
321    }
322
323    struct AppendOutput(&'static str);
324
325    #[async_trait]
326    impl Middleware<String, String> for AppendOutput {
327        async fn after_invoke(&self, output: &mut String, _: &RunnableConfig) -> Result<()> {
328            output.push_str(self.0);
329            Ok(())
330        }
331    }
332
333    #[tokio::test]
334    async fn before_invoke_rewrites_input() {
335        let chain = WithMiddleware::new(Echo).push(Arc::new(UppercaseInput));
336        let out = chain
337            .invoke("hello".into(), RunnableConfig::default())
338            .await
339            .unwrap();
340        assert_eq!(out, "HELLO");
341    }
342
343    #[tokio::test]
344    async fn after_invoke_rewrites_output() {
345        let chain = WithMiddleware::new(Echo).push(Arc::new(AppendOutput("!")));
346        let out = chain
347            .invoke("hi".into(), RunnableConfig::default())
348            .await
349            .unwrap();
350        assert_eq!(out, "hi!");
351    }
352
353    #[tokio::test]
354    async fn middlewares_run_in_onion_order() {
355        // before order: outer → inner. after order: inner → outer.
356        // Starting input "x" → outer adds "(", then inner adds "[":
357        //   before(outer): "(x"
358        //   before(inner): "([x"
359        //   echo:          "([x"
360        //   after(inner):  "([x]"
361        //   after(outer):  "([x])"
362        struct Outer;
363        struct Inner;
364
365        #[async_trait]
366        impl Middleware<String, String> for Outer {
367            async fn before_invoke(&self, input: &mut String, _: &RunnableConfig) -> Result<()> {
368                *input = format!("({input}");
369                Ok(())
370            }
371            async fn after_invoke(&self, output: &mut String, _: &RunnableConfig) -> Result<()> {
372                output.push(')');
373                Ok(())
374            }
375        }
376        #[async_trait]
377        impl Middleware<String, String> for Inner {
378            async fn before_invoke(&self, input: &mut String, _: &RunnableConfig) -> Result<()> {
379                // Insert `[` right after the leading `(` placed by Outer.
380                // Input arrives as "(x"; we want "([x".
381                if let Some(idx) = input.find('(') {
382                    input.insert(idx + 1, '[');
383                }
384                Ok(())
385            }
386            async fn after_invoke(&self, output: &mut String, _: &RunnableConfig) -> Result<()> {
387                output.push(']');
388                Ok(())
389            }
390        }
391
392        let chain = WithMiddleware::new(Echo)
393            .push(Arc::new(Outer))
394            .push(Arc::new(Inner));
395        let out = chain
396            .invoke("x".into(), RunnableConfig::default())
397            .await
398            .unwrap();
399        assert_eq!(out, "([x])");
400    }
401
402    #[tokio::test]
403    async fn before_invoke_can_short_circuit() {
404        struct Reject;
405        #[async_trait]
406        impl Middleware<String, String> for Reject {
407            async fn before_invoke(&self, _: &mut String, _: &RunnableConfig) -> Result<()> {
408                Err(CognisError::Configuration("rejected by middleware".into()))
409            }
410        }
411        let chain = WithMiddleware::new(Echo).push(Arc::new(Reject));
412        let err = chain
413            .invoke("x".into(), RunnableConfig::default())
414            .await
415            .unwrap_err();
416        assert!(matches!(err, CognisError::Configuration(_)));
417    }
418
419    #[tokio::test]
420    async fn on_error_can_recover() {
421        struct Failing;
422        #[async_trait]
423        impl Runnable<String, String> for Failing {
424            async fn invoke(&self, _: String, _: RunnableConfig) -> Result<String> {
425                Err(CognisError::Internal("boom".into()))
426            }
427        }
428        struct Recover;
429        #[async_trait]
430        impl Middleware<String, String> for Recover {
431            async fn on_error(
432                &self,
433                _: &mut CognisError,
434                _: &RunnableConfig,
435            ) -> Result<Option<String>> {
436                Ok(Some("recovered".into()))
437            }
438        }
439        let chain = WithMiddleware::new(Failing).push(Arc::new(Recover));
440        let out = chain
441            .invoke("x".into(), RunnableConfig::default())
442            .await
443            .unwrap();
444        assert_eq!(out, "recovered");
445    }
446
447    #[tokio::test]
448    async fn fn_middleware_lifts_closure() {
449        let saw = Arc::new(AtomicBool::new(false));
450        let saw_for_mw = saw.clone();
451        let mw = fn_middleware::<String, String, _>(move |input, _| {
452            saw_for_mw.store(true, Ordering::SeqCst);
453            input.push('!');
454            Ok(())
455        });
456        let chain = WithMiddleware::new(Echo).push(Arc::new(mw));
457        let out = chain
458            .invoke("hi".into(), RunnableConfig::default())
459            .await
460            .unwrap();
461        assert!(saw.load(Ordering::SeqCst));
462        assert_eq!(out, "hi!");
463    }
464
465    #[tokio::test]
466    async fn inspect_middleware_does_not_mutate() {
467        let count = Arc::new(AtomicUsize::new(0));
468        let count_for_mw = count.clone();
469        let inspector = InspectMiddleware::<String, String, _>::new(move |_out, _cfg| {
470            count_for_mw.fetch_add(1, Ordering::SeqCst);
471        });
472        let chain = WithMiddleware::new(Echo).push(Arc::new(inspector));
473        let out = chain
474            .invoke("hi".into(), RunnableConfig::default())
475            .await
476            .unwrap();
477        assert_eq!(out, "hi");
478        assert_eq!(count.load(Ordering::SeqCst), 1);
479    }
480
481    #[tokio::test]
482    async fn middleware_stack_clone_independent() {
483        let stack = MiddlewareStack::<String, String>::new()
484            .push(Arc::new(UppercaseInput))
485            .push(Arc::new(AppendOutput("!")));
486        let chain1 = WithMiddleware::with_stack(Echo, stack.clone());
487        let chain2 = WithMiddleware::with_stack(Echo, stack);
488        let cfg = RunnableConfig::default();
489        let o1 = chain1.invoke("hi".into(), cfg.clone()).await.unwrap();
490        let o2 = chain2.invoke("hi".into(), cfg).await.unwrap();
491        assert_eq!(o1, "HI!");
492        assert_eq!(o2, "HI!");
493    }
494}