Skip to main content

braid_core/core/server/
middleware.rs

1//! Axum middleware for Braid protocol support.
2//!
3//! Provides an Axum layer that extracts Braid protocol information from incoming
4//! requests and makes it available to handlers.
5
6use super::resource_state::ResourceStateManager;
7use crate::core::protocol_mod as protocol;
8use crate::core::protocol_mod::constants::headers;
9use crate::core::Version;
10use axum::{extract::Request, middleware::Next, response::Response};
11use futures::StreamExt;
12use std::collections::BTreeMap;
13use std::sync::Arc;
14
15pub use crate::core::protocol_mod::BraidState;
16
17/// Newtype wrapper indicating Firefox browser detection.
18#[derive(Clone, Copy, Debug)]
19pub struct IsFirefox(pub bool);
20
21/// Middleware handler function for use with `from_fn_with_state`.
22async fn braid_middleware_handler(
23    axum::extract::State(state): axum::extract::State<BraidLayer>,
24    req: Request,
25    next: Next,
26) -> Response {
27    state.handle_middleware(req, next).await
28}
29
30/// Axum middleware layer for Braid protocol support.
31#[derive(Clone)]
32pub struct BraidLayer {
33    config: super::config::ServerConfig,
34    pub resource_manager: Arc<ResourceStateManager>,
35    pub multiplexer_registry: Arc<super::multiplex::MultiplexerRegistry>,
36}
37
38impl BraidLayer {
39    /// Create a new Braid layer with default configuration.
40    #[must_use]
41    pub fn new() -> Self {
42        Self {
43            config: super::config::ServerConfig::default(),
44            resource_manager: Arc::new(ResourceStateManager::new()),
45            multiplexer_registry: Arc::new(super::multiplex::MultiplexerRegistry::new()),
46        }
47    }
48
49    /// Create a new Braid layer with custom configuration.
50    #[must_use]
51    pub fn with_config(config: super::config::ServerConfig) -> Self {
52        Self {
53            config,
54            resource_manager: Arc::new(ResourceStateManager::new()),
55            multiplexer_registry: Arc::new(super::multiplex::MultiplexerRegistry::new()),
56        }
57    }
58
59    /// Get a reference to the layer's configuration.
60    #[inline]
61    #[must_use]
62    pub fn config(&self) -> &super::config::ServerConfig {
63        &self.config
64    }
65
66    /// Create the middleware function for use with Axum.
67    #[must_use]
68    pub fn middleware(
69        &self,
70    ) -> impl tower::Layer<
71        axum::routing::Route,
72        Service = impl tower::Service<
73            Request,
74            Response = Response,
75            Error = std::convert::Infallible,
76            Future = impl Send + 'static,
77        > + Clone
78                      + Send
79                      + Sync
80                      + 'static,
81    > + Clone {
82        axum::middleware::from_fn_with_state(self.clone(), braid_middleware_handler)
83    }
84
85    async fn handle_middleware(&self, mut req: Request, next: Next) -> Response {
86        let resource_manager = self.resource_manager.clone();
87        let multiplexer_registry = self.multiplexer_registry.clone();
88
89        if req.method().as_str() == "MULTIPLEX" {
90            let version = req
91                .headers()
92                .get(headers::MULTIPLEX_VERSION)
93                .and_then(|v| v.to_str().ok())
94                .unwrap_or("1.0");
95
96            if version == "1.0" {
97                let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
98                let id = format!("{:x}", rand::random::<u64>());
99
100                multiplexer_registry.add(id.clone(), tx).await;
101
102                let stream = async_stream::stream! {
103                    while let Some(data) = rx.recv().await {
104                        yield Ok::<_, std::io::Error>(axum::body::Bytes::from(data));
105                    }
106                };
107
108                let body = axum::body::Body::from_stream(stream);
109                return Response::builder()
110                    .status(200)
111                    .header(headers::MULTIPLEX_VERSION, "1.0")
112                    .body(body)
113                    .unwrap();
114            }
115        }
116
117        let braid_state = BraidState::from_headers(req.headers());
118        let multiplex_through = braid_state.multiplex_through.clone();
119        let m_registry = multiplexer_registry.clone();
120
121        let is_firefox = req
122            .headers()
123            .get("user-agent")
124            .and_then(|v| v.to_str().ok())
125            .map(|ua| ua.to_lowercase().contains("firefox"))
126            .unwrap_or(false);
127
128        req.extensions_mut().insert(Arc::new(braid_state));
129        req.extensions_mut().insert(resource_manager);
130        req.extensions_mut().insert(multiplexer_registry);
131        req.extensions_mut().insert(IsFirefox(is_firefox));
132
133        let mut response = next.run(req).await;
134
135        let headers = response.headers_mut();
136        headers.insert(
137            axum::http::header::HeaderName::from_static("range-request-allow-methods"),
138            axum::http::header::HeaderValue::from_static("PATCH, PUT"),
139        );
140        headers.insert(
141            axum::http::header::HeaderName::from_static("range-request-allow-units"),
142            axum::http::header::HeaderValue::from_static("json"),
143        );
144
145        if let Some(through) = multiplex_through {
146            let parts: Vec<&str> = through.split('/').collect();
147            if parts.len() >= 5 && parts[1] == ".well-known" && parts[2] == "multiplexer" {
148                let m_id = parts[3];
149                let r_id = parts[4];
150
151                if let Some(conn) = m_registry.get(m_id).await {
152                    let sender = conn.sender.clone();
153                    let r_id = r_id.to_string();
154
155                    let mut cors_headers = axum::http::HeaderMap::new();
156                    for (k, v) in response.headers() {
157                        let k_str = k.as_str();
158                        if k_str.starts_with("access-control-") {
159                            cors_headers.insert(k.clone(), v.clone());
160                        }
161                    }
162
163                    tokio::spawn(async move {
164                        let mut header_block =
165                            format!(":status: {}\r\n", response.status().as_u16());
166                        for (name, value) in response.headers() {
167                            header_block.push_str(&format!(
168                                "{}: {}\r\n",
169                                name,
170                                value.to_str().unwrap_or("")
171                            ));
172                        }
173                        header_block.push_str("\r\n");
174
175                        let start_evt = protocol::multiplex::MultiplexEvent::Data(
176                            r_id.clone(),
177                            header_block.clone().into_bytes(),
178                        );
179                        let _ = sender.send(start_evt.to_string().into_bytes()).await;
180                        let _ = sender.send(header_block.into_bytes()).await;
181
182                        let mut body_stream = response.into_body().into_data_stream();
183                        while let Some(Ok(chunk)) = body_stream.next().await {
184                            let data_evt = protocol::multiplex::MultiplexEvent::Data(
185                                r_id.clone(),
186                                chunk.to_vec(),
187                            );
188                            let _ = sender.send(data_evt.to_string().into_bytes()).await;
189                            let _ = sender.send(chunk.to_vec()).await;
190                        }
191
192                        let close_evt = protocol::multiplex::MultiplexEvent::CloseResponse(r_id);
193                        let _ = sender.send(close_evt.to_string().into_bytes()).await;
194                    });
195
196                    let mut builder = Response::builder()
197                        .status(293)
198                        .header(headers::MULTIPLEX_VERSION, "1.0");
199
200                    if let Some(headers) = builder.headers_mut() {
201                        headers.extend(cors_headers);
202                    }
203
204                    return builder.body(axum::body::Body::empty()).unwrap();
205                }
206            }
207        }
208
209        response
210    }
211}
212
213impl Default for BraidLayer {
214    fn default() -> Self {
215        Self::new()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222
223    #[test]
224    fn test_parse_version_header() {
225        let result = protocol::parse_version_header("\"v1\", \"v2\", \"v3\"");
226        assert!(result.is_ok());
227        assert_eq!(result.unwrap().len(), 3);
228    }
229}