ai_lib_rust/plugins/
middleware.rs1use crate::Result;
4use async_trait::async_trait;
5use std::sync::Arc;
6
7#[derive(Debug, Clone)]
8pub struct MiddlewareContext {
9 pub request: serde_json::Value,
10 pub response: Option<serde_json::Value>,
11 pub request_id: Option<String>,
12 pub model: Option<String>,
13 pub metadata: std::collections::HashMap<String, serde_json::Value>,
14}
15
16impl MiddlewareContext {
17 pub fn new(request: serde_json::Value) -> Self {
18 Self {
19 request,
20 response: None,
21 request_id: None,
22 model: None,
23 metadata: std::collections::HashMap::new(),
24 }
25 }
26 pub fn set_response(&mut self, r: serde_json::Value) {
27 self.response = Some(r);
28 }
29 pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
30 self.request_id = Some(id.into());
31 self
32 }
33 pub fn with_model(mut self, m: impl Into<String>) -> Self {
34 self.model = Some(m.into());
35 self
36 }
37}
38
39pub type NextFn<'a> = Box<
40 dyn FnOnce(
41 MiddlewareContext,
42 ) -> std::pin::Pin<
43 Box<dyn std::future::Future<Output = Result<MiddlewareContext>> + Send + 'a>,
44 > + Send
45 + 'a,
46>;
47
48#[async_trait]
49pub trait Middleware: Send + Sync {
50 async fn process(&self, ctx: MiddlewareContext, next: NextFn<'_>) -> Result<MiddlewareContext>;
51 fn name(&self) -> &str {
52 "unnamed"
53 }
54}
55
56pub struct MiddlewareChain {
57 middlewares: Vec<Arc<dyn Middleware>>,
58}
59impl MiddlewareChain {
60 pub fn new() -> Self {
61 Self {
62 middlewares: Vec::new(),
63 }
64 }
65 pub fn add(mut self, m: Arc<dyn Middleware>) -> Self {
66 self.middlewares.push(m);
67 self
68 }
69 pub fn len(&self) -> usize {
70 self.middlewares.len()
71 }
72 pub fn is_empty(&self) -> bool {
73 self.middlewares.is_empty()
74 }
75
76 pub async fn execute<F, Fut>(
77 &self,
78 ctx: MiddlewareContext,
79 handler: F,
80 ) -> Result<MiddlewareContext>
81 where
82 F: FnOnce(MiddlewareContext) -> Fut + Send + 'static,
83 Fut: std::future::Future<Output = Result<MiddlewareContext>> + Send + 'static,
84 {
85 if self.middlewares.is_empty() {
86 return handler(ctx).await;
87 }
88 let mut current = ctx;
89 for mw in &self.middlewares {
90 let next: NextFn<'_> = Box::new(move |c| Box::pin(async move { Ok(c) }));
91 current = mw.process(current, next).await?;
92 }
93 handler(current).await
94 }
95}
96impl Default for MiddlewareChain {
97 fn default() -> Self {
98 Self::new()
99 }
100}