1use std::marker::PhantomData;
20use std::sync::Arc;
21
22use async_trait::async_trait;
23
24use crate::runnable::{Runnable, RunnableConfig};
25use crate::{CognisError, Result};
26
27#[async_trait]
39pub trait Middleware<I, O>: Send + Sync
40where
41 I: Send + 'static,
42 O: Send + 'static,
43{
44 async fn before_invoke(&self, _input: &mut I, _config: &RunnableConfig) -> Result<()> {
48 Ok(())
49 }
50
51 async fn after_invoke(&self, _output: &mut O, _config: &RunnableConfig) -> Result<()> {
53 Ok(())
54 }
55
56 async fn on_error(
61 &self,
62 _err: &mut CognisError,
63 _config: &RunnableConfig,
64 ) -> Result<Option<O>> {
65 Ok(None)
66 }
67
68 fn name(&self) -> &str {
70 std::any::type_name::<Self>()
71 }
72}
73
74pub fn fn_middleware<I, O, F>(before: F) -> FnMiddleware<I, O, F>
81where
82 I: Send + 'static,
83 O: Send + 'static,
84 F: Fn(&mut I, &RunnableConfig) -> Result<()> + Send + Sync + 'static,
85{
86 FnMiddleware {
87 before,
88 _t: PhantomData,
89 }
90}
91
92pub struct FnMiddleware<I, O, F> {
94 before: F,
95 _t: PhantomData<fn(I) -> O>,
96}
97
98#[async_trait]
99impl<I, O, F> Middleware<I, O> for FnMiddleware<I, O, F>
100where
101 I: Send + 'static,
102 O: Send + 'static,
103 F: Fn(&mut I, &RunnableConfig) -> Result<()> + Send + Sync + 'static,
104{
105 async fn before_invoke(&self, input: &mut I, config: &RunnableConfig) -> Result<()> {
106 (self.before)(input, config)
107 }
108}
109
110pub struct InspectMiddleware<I, O, F> {
117 on_ok: F,
118 _t: PhantomData<fn(I) -> O>,
119}
120
121impl<I, O, F> InspectMiddleware<I, O, F>
122where
123 I: Send + Sync + 'static,
124 O: Send + Sync + 'static,
125 F: Fn(&O, &RunnableConfig) + Send + Sync + 'static,
126{
127 pub fn new(on_ok: F) -> Self {
129 Self {
130 on_ok,
131 _t: PhantomData,
132 }
133 }
134}
135
136#[async_trait]
137impl<I, O, F> Middleware<I, O> for InspectMiddleware<I, O, F>
138where
139 I: Send + Sync + 'static,
140 O: Send + Sync + 'static,
141 F: Fn(&O, &RunnableConfig) + Send + Sync + 'static,
142{
143 async fn after_invoke(&self, output: &mut O, config: &RunnableConfig) -> Result<()> {
144 (self.on_ok)(output, config);
145 Ok(())
146 }
147}
148
149pub struct MiddlewareStack<I, O> {
157 inner: Vec<Arc<dyn Middleware<I, O>>>,
158}
159
160impl<I, O> Default for MiddlewareStack<I, O>
161where
162 I: Send + 'static,
163 O: Send + 'static,
164{
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170impl<I, O> Clone for MiddlewareStack<I, O> {
171 fn clone(&self) -> Self {
172 Self {
173 inner: self.inner.clone(),
174 }
175 }
176}
177
178impl<I, O> MiddlewareStack<I, O>
179where
180 I: Send + 'static,
181 O: Send + 'static,
182{
183 pub fn new() -> Self {
185 Self { inner: Vec::new() }
186 }
187
188 pub fn push(mut self, m: Arc<dyn Middleware<I, O>>) -> Self {
190 self.inner.push(m);
191 self
192 }
193
194 pub fn len(&self) -> usize {
196 self.inner.len()
197 }
198
199 pub fn is_empty(&self) -> bool {
201 self.inner.is_empty()
202 }
203
204 pub fn middlewares(&self) -> &[Arc<dyn Middleware<I, O>>] {
206 &self.inner
207 }
208}
209
210pub struct WithMiddleware<R, I, O> {
216 inner: R,
217 stack: MiddlewareStack<I, O>,
218 _phantom: PhantomData<fn(I) -> O>,
219}
220
221impl<R, I, O> WithMiddleware<R, I, O>
222where
223 R: Runnable<I, O>,
224 I: Send + 'static,
225 O: Send + 'static,
226{
227 pub fn new(inner: R) -> Self {
230 Self {
231 inner,
232 stack: MiddlewareStack::new(),
233 _phantom: PhantomData,
234 }
235 }
236
237 pub fn with_stack(inner: R, stack: MiddlewareStack<I, O>) -> Self {
239 Self {
240 inner,
241 stack,
242 _phantom: PhantomData,
243 }
244 }
245
246 pub fn push(mut self, m: Arc<dyn Middleware<I, O>>) -> Self {
248 self.stack = self.stack.push(m);
249 self
250 }
251}
252
253#[async_trait]
254impl<R, I, O> Runnable<I, O> for WithMiddleware<R, I, O>
255where
256 R: Runnable<I, O>,
257 I: Send + 'static,
258 O: Send + 'static,
259{
260 async fn invoke(&self, mut input: I, config: RunnableConfig) -> Result<O> {
261 for m in self.stack.inner.iter() {
263 m.before_invoke(&mut input, &config).await?;
264 }
265 let result = self.inner.invoke(input, config.clone()).await;
266 match result {
267 Ok(mut output) => {
268 for m in self.stack.inner.iter().rev() {
270 m.after_invoke(&mut output, &config).await?;
271 }
272 Ok(output)
273 }
274 Err(mut err) => {
275 for m in self.stack.inner.iter().rev() {
277 if let Some(o) = m.on_error(&mut err, &config).await? {
278 return Ok(o);
279 }
280 }
281 Err(err)
282 }
283 }
284 }
285
286 fn name(&self) -> &str {
287 self.inner.name()
288 }
289
290 fn input_schema(&self) -> Option<serde_json::Value> {
291 self.inner.input_schema()
292 }
293
294 fn output_schema(&self) -> Option<serde_json::Value> {
295 self.inner.output_schema()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
303
304 struct Echo;
305
306 #[async_trait]
307 impl Runnable<String, String> for Echo {
308 async fn invoke(&self, input: String, _: RunnableConfig) -> Result<String> {
309 Ok(input)
310 }
311 }
312
313 struct UppercaseInput;
314
315 #[async_trait]
316 impl Middleware<String, String> for UppercaseInput {
317 async fn before_invoke(&self, input: &mut String, _: &RunnableConfig) -> Result<()> {
318 *input = input.to_uppercase();
319 Ok(())
320 }
321 }
322
323 struct AppendOutput(&'static str);
324
325 #[async_trait]
326 impl Middleware<String, String> for AppendOutput {
327 async fn after_invoke(&self, output: &mut String, _: &RunnableConfig) -> Result<()> {
328 output.push_str(self.0);
329 Ok(())
330 }
331 }
332
333 #[tokio::test]
334 async fn before_invoke_rewrites_input() {
335 let chain = WithMiddleware::new(Echo).push(Arc::new(UppercaseInput));
336 let out = chain
337 .invoke("hello".into(), RunnableConfig::default())
338 .await
339 .unwrap();
340 assert_eq!(out, "HELLO");
341 }
342
343 #[tokio::test]
344 async fn after_invoke_rewrites_output() {
345 let chain = WithMiddleware::new(Echo).push(Arc::new(AppendOutput("!")));
346 let out = chain
347 .invoke("hi".into(), RunnableConfig::default())
348 .await
349 .unwrap();
350 assert_eq!(out, "hi!");
351 }
352
353 #[tokio::test]
354 async fn middlewares_run_in_onion_order() {
355 struct Outer;
363 struct Inner;
364
365 #[async_trait]
366 impl Middleware<String, String> for Outer {
367 async fn before_invoke(&self, input: &mut String, _: &RunnableConfig) -> Result<()> {
368 *input = format!("({input}");
369 Ok(())
370 }
371 async fn after_invoke(&self, output: &mut String, _: &RunnableConfig) -> Result<()> {
372 output.push(')');
373 Ok(())
374 }
375 }
376 #[async_trait]
377 impl Middleware<String, String> for Inner {
378 async fn before_invoke(&self, input: &mut String, _: &RunnableConfig) -> Result<()> {
379 if let Some(idx) = input.find('(') {
382 input.insert(idx + 1, '[');
383 }
384 Ok(())
385 }
386 async fn after_invoke(&self, output: &mut String, _: &RunnableConfig) -> Result<()> {
387 output.push(']');
388 Ok(())
389 }
390 }
391
392 let chain = WithMiddleware::new(Echo)
393 .push(Arc::new(Outer))
394 .push(Arc::new(Inner));
395 let out = chain
396 .invoke("x".into(), RunnableConfig::default())
397 .await
398 .unwrap();
399 assert_eq!(out, "([x])");
400 }
401
402 #[tokio::test]
403 async fn before_invoke_can_short_circuit() {
404 struct Reject;
405 #[async_trait]
406 impl Middleware<String, String> for Reject {
407 async fn before_invoke(&self, _: &mut String, _: &RunnableConfig) -> Result<()> {
408 Err(CognisError::Configuration("rejected by middleware".into()))
409 }
410 }
411 let chain = WithMiddleware::new(Echo).push(Arc::new(Reject));
412 let err = chain
413 .invoke("x".into(), RunnableConfig::default())
414 .await
415 .unwrap_err();
416 assert!(matches!(err, CognisError::Configuration(_)));
417 }
418
419 #[tokio::test]
420 async fn on_error_can_recover() {
421 struct Failing;
422 #[async_trait]
423 impl Runnable<String, String> for Failing {
424 async fn invoke(&self, _: String, _: RunnableConfig) -> Result<String> {
425 Err(CognisError::Internal("boom".into()))
426 }
427 }
428 struct Recover;
429 #[async_trait]
430 impl Middleware<String, String> for Recover {
431 async fn on_error(
432 &self,
433 _: &mut CognisError,
434 _: &RunnableConfig,
435 ) -> Result<Option<String>> {
436 Ok(Some("recovered".into()))
437 }
438 }
439 let chain = WithMiddleware::new(Failing).push(Arc::new(Recover));
440 let out = chain
441 .invoke("x".into(), RunnableConfig::default())
442 .await
443 .unwrap();
444 assert_eq!(out, "recovered");
445 }
446
447 #[tokio::test]
448 async fn fn_middleware_lifts_closure() {
449 let saw = Arc::new(AtomicBool::new(false));
450 let saw_for_mw = saw.clone();
451 let mw = fn_middleware::<String, String, _>(move |input, _| {
452 saw_for_mw.store(true, Ordering::SeqCst);
453 input.push('!');
454 Ok(())
455 });
456 let chain = WithMiddleware::new(Echo).push(Arc::new(mw));
457 let out = chain
458 .invoke("hi".into(), RunnableConfig::default())
459 .await
460 .unwrap();
461 assert!(saw.load(Ordering::SeqCst));
462 assert_eq!(out, "hi!");
463 }
464
465 #[tokio::test]
466 async fn inspect_middleware_does_not_mutate() {
467 let count = Arc::new(AtomicUsize::new(0));
468 let count_for_mw = count.clone();
469 let inspector = InspectMiddleware::<String, String, _>::new(move |_out, _cfg| {
470 count_for_mw.fetch_add(1, Ordering::SeqCst);
471 });
472 let chain = WithMiddleware::new(Echo).push(Arc::new(inspector));
473 let out = chain
474 .invoke("hi".into(), RunnableConfig::default())
475 .await
476 .unwrap();
477 assert_eq!(out, "hi");
478 assert_eq!(count.load(Ordering::SeqCst), 1);
479 }
480
481 #[tokio::test]
482 async fn middleware_stack_clone_independent() {
483 let stack = MiddlewareStack::<String, String>::new()
484 .push(Arc::new(UppercaseInput))
485 .push(Arc::new(AppendOutput("!")));
486 let chain1 = WithMiddleware::with_stack(Echo, stack.clone());
487 let chain2 = WithMiddleware::with_stack(Echo, stack);
488 let cfg = RunnableConfig::default();
489 let o1 = chain1.invoke("hi".into(), cfg.clone()).await.unwrap();
490 let o2 = chain2.invoke("hi".into(), cfg).await.unwrap();
491 assert_eq!(o1, "HI!");
492 assert_eq!(o2, "HI!");
493 }
494}