Skip to main content

synapse_sdk/
middleware.rs

1//! Middleware system for service handlers
2//!
3//! Middlewares can inspect and modify request state before handlers execute.
4
5use anyhow::Result;
6use async_trait::async_trait;
7use bytes::Bytes;
8use std::{any::Any, collections::HashMap, sync::Arc};
9use synapse_proto::{HeaderEntry, RpcRequest};
10
11/// Request state passed to handlers
12///
13/// Contains request metadata and custom per-request state that
14/// can be set by middlewares and accessed by handlers.
15pub struct RequestState<T = ()> {
16    /// Request headers
17    pub headers: Vec<HeaderEntry>,
18
19    /// When the request was sent (unix milliseconds)
20    pub sent_at_unix_ms: i64,
21
22    /// Interface ID being called
23    pub interface_id: u32,
24
25    /// Method ID being called
26    pub method_id: u32,
27
28    /// Custom service-specific state (set by middlewares)
29    pub custom: T,
30
31    /// Additional dynamic state (set by middlewares)
32    extensions: HashMap<String, Box<dyn Any + Send + Sync>>,
33}
34
35impl<T> RequestState<T> {
36    /// Create a new request state from RpcRequest
37    pub fn from_request(request: &RpcRequest, custom: T) -> Self {
38        Self {
39            headers: request.headers.clone(),
40            sent_at_unix_ms: request.sent_at_unix_ms,
41            interface_id: request.interface_id,
42            method_id: request.method_id,
43            custom,
44            extensions: HashMap::new(),
45        }
46    }
47
48    /// Set an extension value
49    pub fn set_extension<V: Any + Send + Sync>(&mut self, key: impl Into<String>, value: V) {
50        self.extensions.insert(key.into(), Box::new(value));
51    }
52
53    /// Get an extension value
54    pub fn get_extension<V: Any + Send + Sync>(&self, key: &str) -> Option<&V> {
55        self.extensions.get(key).and_then(|v| v.downcast_ref::<V>())
56    }
57
58    /// Get a mutable extension value
59    pub fn get_extension_mut<V: Any + Send + Sync>(&mut self, key: &str) -> Option<&mut V> {
60        self.extensions
61            .get_mut(key)
62            .and_then(|v| v.downcast_mut::<V>())
63    }
64
65    /// Get a header value by key (searches by header key ID)
66    pub fn get_header(&self, key: u32) -> Option<&HeaderEntry> {
67        self.headers.iter().find(|h| h.key == key)
68    }
69}
70
71impl RequestState<()> {
72    /// Create default request state with no custom data
73    pub fn new(request: &RpcRequest) -> Self {
74        Self::from_request(request, ())
75    }
76}
77
78/// Middleware trait
79///
80/// Middlewares can inspect and modify request state before handlers execute.
81/// They can also short-circuit the request by returning an error.
82#[async_trait]
83pub trait Middleware<T>: Send + Sync {
84    /// Process the request state before the handler executes
85    ///
86    /// Return Ok(()) to continue to the next middleware/handler.
87    /// Return Err(_) to short-circuit and return an error response.
88    async fn process(&self, state: &mut RequestState<T>, payload: &Bytes) -> Result<()>;
89}
90
91/// Middleware chain
92///
93/// Executes middlewares in order before passing to the handler.
94pub struct MiddlewareChain<T> {
95    middlewares: Vec<Arc<dyn Middleware<T>>>,
96}
97
98impl<T> MiddlewareChain<T> {
99    /// Create a new empty middleware chain
100    pub fn new() -> Self {
101        Self {
102            middlewares: Vec::new(),
103        }
104    }
105
106    /// Add a middleware to the chain
107    pub fn add(&mut self, middleware: Arc<dyn Middleware<T>>) {
108        self.middlewares.push(middleware);
109    }
110
111    /// Process all middlewares in order
112    pub async fn process(&self, state: &mut RequestState<T>, payload: &Bytes) -> Result<()> {
113        for middleware in &self.middlewares {
114            middleware.process(state, payload).await?;
115        }
116        Ok(())
117    }
118
119    /// Check if chain is empty
120    pub fn is_empty(&self) -> bool {
121        self.middlewares.is_empty()
122    }
123}
124
125impl<T> Default for MiddlewareChain<T> {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use std::sync::atomic::{AtomicUsize, Ordering};
135
136    fn test_request() -> RpcRequest {
137        RpcRequest {
138            interface_id: 1,
139            method_id: 2,
140            headers: vec![HeaderEntry {
141                key: 100,
142                value: Some(synapse_proto::header_entry::Value::StringValue(
143                    "test".to_string(),
144                )),
145            }],
146            payload: Bytes::new(),
147            sent_at_unix_ms: 12345,
148        }
149    }
150
151    // ========== RequestState ==========
152
153    #[test]
154    fn test_request_state_from_request() {
155        let req = test_request();
156        let state = RequestState::from_request(&req, "custom");
157        assert_eq!(state.interface_id, 1);
158        assert_eq!(state.method_id, 2);
159        assert_eq!(state.sent_at_unix_ms, 12345);
160        assert_eq!(state.custom, "custom");
161        assert_eq!(state.headers.len(), 1);
162    }
163
164    #[test]
165    fn test_request_state_new_unit() {
166        let req = test_request();
167        let state = RequestState::new(&req);
168        assert_eq!(state.custom, ());
169        assert_eq!(state.interface_id, 1);
170    }
171
172    #[test]
173    fn test_extensions_set_and_get() {
174        let req = test_request();
175        let mut state = RequestState::new(&req);
176        state.set_extension("user_id", 42u64);
177        assert_eq!(state.get_extension::<u64>("user_id"), Some(&42));
178    }
179
180    #[test]
181    fn test_extensions_wrong_type() {
182        let req = test_request();
183        let mut state = RequestState::new(&req);
184        state.set_extension("user_id", 42u64);
185        assert!(state.get_extension::<String>("user_id").is_none());
186    }
187
188    #[test]
189    fn test_extensions_missing_key() {
190        let req = test_request();
191        let state = RequestState::new(&req);
192        assert!(state.get_extension::<u64>("nonexistent").is_none());
193    }
194
195    #[test]
196    fn test_extensions_get_mut() {
197        let req = test_request();
198        let mut state = RequestState::new(&req);
199        state.set_extension("counter", 0u32);
200        if let Some(val) = state.get_extension_mut::<u32>("counter") {
201            *val = 5;
202        }
203        assert_eq!(state.get_extension::<u32>("counter"), Some(&5));
204    }
205
206    #[test]
207    fn test_get_header_found() {
208        let req = test_request();
209        let state = RequestState::new(&req);
210        let header = state.get_header(100);
211        assert!(header.is_some());
212        assert_eq!(header.unwrap().key, 100);
213    }
214
215    #[test]
216    fn test_get_header_not_found() {
217        let req = test_request();
218        let state = RequestState::new(&req);
219        assert!(state.get_header(999).is_none());
220    }
221
222    // ========== MiddlewareChain ==========
223
224    struct CountingMiddleware {
225        counter: Arc<AtomicUsize>,
226    }
227
228    #[async_trait]
229    impl Middleware<()> for CountingMiddleware {
230        async fn process(&self, _state: &mut RequestState<()>, _payload: &Bytes) -> Result<()> {
231            self.counter.fetch_add(1, Ordering::SeqCst);
232            Ok(())
233        }
234    }
235
236    struct FailingMiddleware;
237
238    #[async_trait]
239    impl Middleware<()> for FailingMiddleware {
240        async fn process(&self, _state: &mut RequestState<()>, _payload: &Bytes) -> Result<()> {
241            Err(anyhow::anyhow!("middleware failed"))
242        }
243    }
244
245    #[test]
246    fn test_chain_new_is_empty() {
247        let chain = MiddlewareChain::<()>::new();
248        assert!(chain.is_empty());
249    }
250
251    #[test]
252    fn test_chain_default_is_empty() {
253        let chain = MiddlewareChain::<()>::default();
254        assert!(chain.is_empty());
255    }
256
257    #[test]
258    fn test_chain_not_empty_after_add() {
259        let mut chain = MiddlewareChain::<()>::new();
260        let counter = Arc::new(AtomicUsize::new(0));
261        chain.add(Arc::new(CountingMiddleware {
262            counter: counter.clone(),
263        }));
264        assert!(!chain.is_empty());
265    }
266
267    #[tokio::test]
268    async fn test_chain_executes_middleware() {
269        let mut chain = MiddlewareChain::<()>::new();
270        let counter = Arc::new(AtomicUsize::new(0));
271        chain.add(Arc::new(CountingMiddleware {
272            counter: counter.clone(),
273        }));
274
275        let req = test_request();
276        let mut state = RequestState::new(&req);
277        chain.process(&mut state, &Bytes::new()).await.unwrap();
278        assert_eq!(counter.load(Ordering::SeqCst), 1);
279    }
280
281    #[tokio::test]
282    async fn test_chain_executes_in_order() {
283        let mut chain = MiddlewareChain::<()>::new();
284        let c1 = Arc::new(AtomicUsize::new(0));
285        let c2 = Arc::new(AtomicUsize::new(0));
286        chain.add(Arc::new(CountingMiddleware {
287            counter: c1.clone(),
288        }));
289        chain.add(Arc::new(CountingMiddleware {
290            counter: c2.clone(),
291        }));
292
293        let req = test_request();
294        let mut state = RequestState::new(&req);
295        chain.process(&mut state, &Bytes::new()).await.unwrap();
296        assert_eq!(c1.load(Ordering::SeqCst), 1);
297        assert_eq!(c2.load(Ordering::SeqCst), 1);
298    }
299
300    #[tokio::test]
301    async fn test_chain_short_circuits_on_error() {
302        let mut chain = MiddlewareChain::<()>::new();
303        let counter = Arc::new(AtomicUsize::new(0));
304        chain.add(Arc::new(FailingMiddleware));
305        chain.add(Arc::new(CountingMiddleware {
306            counter: counter.clone(),
307        }));
308
309        let req = test_request();
310        let mut state = RequestState::new(&req);
311        let result = chain.process(&mut state, &Bytes::new()).await;
312        assert!(result.is_err());
313        // Second middleware should NOT have been called
314        assert_eq!(counter.load(Ordering::SeqCst), 0);
315    }
316}