Skip to main content

ai_lib_rust/plugins/
middleware.rs

1//! Middleware system.
2
3use crate::Result;
4use async_trait::async_trait;
5use std::sync::Arc;
6
7#[derive(Debug, Clone)]
8pub struct MiddlewareContext {
9    pub request: serde_json::Value,
10    pub response: Option<serde_json::Value>,
11    pub request_id: Option<String>,
12    pub model: Option<String>,
13    pub metadata: std::collections::HashMap<String, serde_json::Value>,
14}
15
16impl MiddlewareContext {
17    pub fn new(request: serde_json::Value) -> Self {
18        Self {
19            request,
20            response: None,
21            request_id: None,
22            model: None,
23            metadata: std::collections::HashMap::new(),
24        }
25    }
26    pub fn set_response(&mut self, r: serde_json::Value) {
27        self.response = Some(r);
28    }
29    pub fn with_request_id(mut self, id: impl Into<String>) -> Self {
30        self.request_id = Some(id.into());
31        self
32    }
33    pub fn with_model(mut self, m: impl Into<String>) -> Self {
34        self.model = Some(m.into());
35        self
36    }
37}
38
39pub type NextFn<'a> = Box<
40    dyn FnOnce(
41            MiddlewareContext,
42        ) -> std::pin::Pin<
43            Box<dyn std::future::Future<Output = Result<MiddlewareContext>> + Send + 'a>,
44        > + Send
45        + 'a,
46>;
47
48#[async_trait]
49pub trait Middleware: Send + Sync {
50    async fn process(&self, ctx: MiddlewareContext, next: NextFn<'_>) -> Result<MiddlewareContext>;
51    fn name(&self) -> &str {
52        "unnamed"
53    }
54}
55
56pub struct MiddlewareChain {
57    middlewares: Vec<Arc<dyn Middleware>>,
58}
59impl MiddlewareChain {
60    pub fn new() -> Self {
61        Self {
62            middlewares: Vec::new(),
63        }
64    }
65    pub fn add(mut self, m: Arc<dyn Middleware>) -> Self {
66        self.middlewares.push(m);
67        self
68    }
69    pub fn len(&self) -> usize {
70        self.middlewares.len()
71    }
72    pub fn is_empty(&self) -> bool {
73        self.middlewares.is_empty()
74    }
75
76    pub async fn execute<F, Fut>(
77        &self,
78        ctx: MiddlewareContext,
79        handler: F,
80    ) -> Result<MiddlewareContext>
81    where
82        F: FnOnce(MiddlewareContext) -> Fut + Send + 'static,
83        Fut: std::future::Future<Output = Result<MiddlewareContext>> + Send + 'static,
84    {
85        if self.middlewares.is_empty() {
86            return handler(ctx).await;
87        }
88        let mut current = ctx;
89        for mw in &self.middlewares {
90            let next: NextFn<'_> = Box::new(move |c| Box::pin(async move { Ok(c) }));
91            current = mw.process(current, next).await?;
92        }
93        handler(current).await
94    }
95}
96impl Default for MiddlewareChain {
97    fn default() -> Self {
98        Self::new()
99    }
100}