braid_core/core/server/
middleware.rs1use super::resource_state::ResourceStateManager;
7use crate::core::protocol_mod as protocol;
8use crate::core::protocol_mod::constants::headers;
9use crate::core::Version;
10use axum::{extract::Request, middleware::Next, response::Response};
11use futures::StreamExt;
12use std::collections::BTreeMap;
13use std::sync::Arc;
14
15pub use crate::core::protocol_mod::BraidState;
16
17#[derive(Clone, Copy, Debug)]
19pub struct IsFirefox(pub bool);
20
21async fn braid_middleware_handler(
23 axum::extract::State(state): axum::extract::State<BraidLayer>,
24 req: Request,
25 next: Next,
26) -> Response {
27 state.handle_middleware(req, next).await
28}
29
30#[derive(Clone)]
32pub struct BraidLayer {
33 config: super::config::ServerConfig,
34 pub resource_manager: Arc<ResourceStateManager>,
35 pub multiplexer_registry: Arc<super::multiplex::MultiplexerRegistry>,
36}
37
38impl BraidLayer {
39 #[must_use]
41 pub fn new() -> Self {
42 Self {
43 config: super::config::ServerConfig::default(),
44 resource_manager: Arc::new(ResourceStateManager::new()),
45 multiplexer_registry: Arc::new(super::multiplex::MultiplexerRegistry::new()),
46 }
47 }
48
49 #[must_use]
51 pub fn with_config(config: super::config::ServerConfig) -> Self {
52 Self {
53 config,
54 resource_manager: Arc::new(ResourceStateManager::new()),
55 multiplexer_registry: Arc::new(super::multiplex::MultiplexerRegistry::new()),
56 }
57 }
58
59 #[inline]
61 #[must_use]
62 pub fn config(&self) -> &super::config::ServerConfig {
63 &self.config
64 }
65
66 #[must_use]
68 pub fn middleware(
69 &self,
70 ) -> impl tower::Layer<
71 axum::routing::Route,
72 Service = impl tower::Service<
73 Request,
74 Response = Response,
75 Error = std::convert::Infallible,
76 Future = impl Send + 'static,
77 > + Clone
78 + Send
79 + Sync
80 + 'static,
81 > + Clone {
82 axum::middleware::from_fn_with_state(self.clone(), braid_middleware_handler)
83 }
84
85 async fn handle_middleware(&self, mut req: Request, next: Next) -> Response {
86 let resource_manager = self.resource_manager.clone();
87 let multiplexer_registry = self.multiplexer_registry.clone();
88
89 if req.method().as_str() == "MULTIPLEX" {
90 let version = req
91 .headers()
92 .get(headers::MULTIPLEX_VERSION)
93 .and_then(|v| v.to_str().ok())
94 .unwrap_or("1.0");
95
96 if version == "1.0" {
97 let (tx, mut rx) = tokio::sync::mpsc::channel(1024);
98 let id = format!("{:x}", rand::random::<u64>());
99
100 multiplexer_registry.add(id.clone(), tx).await;
101
102 let stream = async_stream::stream! {
103 while let Some(data) = rx.recv().await {
104 yield Ok::<_, std::io::Error>(axum::body::Bytes::from(data));
105 }
106 };
107
108 let body = axum::body::Body::from_stream(stream);
109 return Response::builder()
110 .status(200)
111 .header(headers::MULTIPLEX_VERSION, "1.0")
112 .body(body)
113 .unwrap();
114 }
115 }
116
117 let braid_state = BraidState::from_headers(req.headers());
118 let multiplex_through = braid_state.multiplex_through.clone();
119 let m_registry = multiplexer_registry.clone();
120
121 let is_firefox = req
122 .headers()
123 .get("user-agent")
124 .and_then(|v| v.to_str().ok())
125 .map(|ua| ua.to_lowercase().contains("firefox"))
126 .unwrap_or(false);
127
128 req.extensions_mut().insert(Arc::new(braid_state));
129 req.extensions_mut().insert(resource_manager);
130 req.extensions_mut().insert(multiplexer_registry);
131 req.extensions_mut().insert(IsFirefox(is_firefox));
132
133 let mut response = next.run(req).await;
134
135 let headers = response.headers_mut();
136 headers.insert(
137 axum::http::header::HeaderName::from_static("range-request-allow-methods"),
138 axum::http::header::HeaderValue::from_static("PATCH, PUT"),
139 );
140 headers.insert(
141 axum::http::header::HeaderName::from_static("range-request-allow-units"),
142 axum::http::header::HeaderValue::from_static("json"),
143 );
144
145 if let Some(through) = multiplex_through {
146 let parts: Vec<&str> = through.split('/').collect();
147 if parts.len() >= 5 && parts[1] == ".well-known" && parts[2] == "multiplexer" {
148 let m_id = parts[3];
149 let r_id = parts[4];
150
151 if let Some(conn) = m_registry.get(m_id).await {
152 let sender = conn.sender.clone();
153 let r_id = r_id.to_string();
154
155 let mut cors_headers = axum::http::HeaderMap::new();
156 for (k, v) in response.headers() {
157 let k_str = k.as_str();
158 if k_str.starts_with("access-control-") {
159 cors_headers.insert(k.clone(), v.clone());
160 }
161 }
162
163 tokio::spawn(async move {
164 let mut header_block =
165 format!(":status: {}\r\n", response.status().as_u16());
166 for (name, value) in response.headers() {
167 header_block.push_str(&format!(
168 "{}: {}\r\n",
169 name,
170 value.to_str().unwrap_or("")
171 ));
172 }
173 header_block.push_str("\r\n");
174
175 let start_evt = protocol::multiplex::MultiplexEvent::Data(
176 r_id.clone(),
177 header_block.clone().into_bytes(),
178 );
179 let _ = sender.send(start_evt.to_string().into_bytes()).await;
180 let _ = sender.send(header_block.into_bytes()).await;
181
182 let mut body_stream = response.into_body().into_data_stream();
183 while let Some(Ok(chunk)) = body_stream.next().await {
184 let data_evt = protocol::multiplex::MultiplexEvent::Data(
185 r_id.clone(),
186 chunk.to_vec(),
187 );
188 let _ = sender.send(data_evt.to_string().into_bytes()).await;
189 let _ = sender.send(chunk.to_vec()).await;
190 }
191
192 let close_evt = protocol::multiplex::MultiplexEvent::CloseResponse(r_id);
193 let _ = sender.send(close_evt.to_string().into_bytes()).await;
194 });
195
196 let mut builder = Response::builder()
197 .status(293)
198 .header(headers::MULTIPLEX_VERSION, "1.0");
199
200 if let Some(headers) = builder.headers_mut() {
201 headers.extend(cors_headers);
202 }
203
204 return builder.body(axum::body::Body::empty()).unwrap();
205 }
206 }
207 }
208
209 response
210 }
211}
212
213impl Default for BraidLayer {
214 fn default() -> Self {
215 Self::new()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[test]
224 fn test_parse_version_header() {
225 let result = protocol::parse_version_header("\"v1\", \"v2\", \"v3\"");
226 assert!(result.is_ok());
227 assert_eq!(result.unwrap().len(), 3);
228 }
229}