1use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::{future::Future, str::FromStr, sync::Arc};
5
6use http::{Method, Request, Response};
7
8use engineioxide_core::Sid;
9
10use crate::{
11 body::ResponseBody,
12 config::EngineIoConfig,
13 engine::EngineIo,
14 handler::EngineIoHandler,
15 service::futures::ResponseFuture,
16 transport::{polling, ws},
17};
18
19pub fn dispatch_req<F, H, ReqBody, ResBody>(
21 req: Request<ReqBody>,
22 engine: Arc<EngineIo<H>>,
23) -> ResponseFuture<F, ResBody>
24where
25 ReqBody: http_body::Body + Send + Unpin + 'static,
26 ReqBody::Data: Send,
27 ReqBody::Error: std::fmt::Debug,
28 ResBody: Send + 'static,
29 H: EngineIoHandler,
30 F: Future,
31{
32 match RequestInfo::parse(&req, &engine.config) {
33 Ok(RequestInfo {
34 protocol,
35 sid: None,
36 transport: TransportType::Polling,
37 method: Method::GET,
38 #[cfg(feature = "v3")]
39 b64,
40 }) => ResponseFuture::ready(polling::open_req(
41 engine,
42 protocol,
43 req,
44 #[cfg(feature = "v3")]
45 !b64,
46 )),
47 Ok(RequestInfo {
48 protocol,
49 sid: Some(sid),
50 transport: TransportType::Polling,
51 method: Method::GET,
52 ..
53 }) => ResponseFuture::async_response(Box::pin(polling::polling_req(engine, protocol, sid))),
54 Ok(RequestInfo {
55 protocol,
56 sid: Some(sid),
57 transport: TransportType::Polling,
58 method: Method::POST,
59 ..
60 }) => {
61 ResponseFuture::async_response(Box::pin(polling::post_req(engine, protocol, sid, req)))
62 }
63 Ok(RequestInfo {
64 protocol,
65 sid,
66 transport: TransportType::Websocket,
67 method: Method::GET,
68 ..
69 }) => ResponseFuture::ready(ws::new_req(engine, protocol, sid, req)),
70 Err(e) => {
71 #[cfg(feature = "tracing")]
72 tracing::debug!("error parsing request: {:?}", e);
73 ResponseFuture::ready(Ok(e.into()))
74 }
75 _req => {
76 #[cfg(feature = "tracing")]
77 tracing::debug!("invalid request: {:?}", _req);
78 ResponseFuture::empty_response(400)
79 }
80 }
81}
82
83#[derive(thiserror::Error, Debug)]
84pub enum ParseError {
85 #[error("transport unknown")]
86 UnknownTransport,
87 #[error("bad handshake method")]
88 BadHandshakeMethod,
89 #[error("transport mismatch")]
90 TransportMismatch,
91 #[error("unsupported protocol version")]
92 UnsupportedProtocolVersion,
93}
94
95impl<B> From<ParseError> for Response<ResponseBody<B>> {
99 fn from(err: ParseError) -> Self {
100 use ParseError::*;
101 let conn_err_resp = |message: &'static str| {
102 Response::builder()
103 .status(400)
104 .header("Content-Type", "application/json")
105 .body(ResponseBody::custom_response(message.into()))
106 .unwrap()
107 };
108 match err {
109 UnknownTransport => conn_err_resp("{\"code\":\"0\",\"message\":\"Transport unknown\"}"),
110 BadHandshakeMethod => {
111 conn_err_resp("{\"code\":\"2\",\"message\":\"Bad handshake method\"}")
112 }
113 TransportMismatch => conn_err_resp("{\"code\":\"3\",\"message\":\"Bad request\"}"),
114 UnsupportedProtocolVersion => {
115 conn_err_resp("{\"code\":\"5\",\"message\":\"Unsupported protocol version\"}")
116 }
117 }
118 }
119}
120
121#[derive(Debug, Copy, Clone, PartialEq)]
123pub enum ProtocolVersion {
124 V3 = 3,
126 V4 = 4,
128}
129
130impl FromStr for ProtocolVersion {
131 type Err = ParseError;
132
133 #[cfg(feature = "v3")]
134 fn from_str(s: &str) -> Result<Self, Self::Err> {
135 match s {
136 "3" => Ok(ProtocolVersion::V3),
137 "4" => Ok(ProtocolVersion::V4),
138 _ => Err(ParseError::UnsupportedProtocolVersion),
139 }
140 }
141
142 #[cfg(not(feature = "v3"))]
143 fn from_str(s: &str) -> Result<Self, Self::Err> {
144 match s {
145 "4" => Ok(ProtocolVersion::V4),
146 _ => Err(ParseError::UnsupportedProtocolVersion),
147 }
148 }
149}
150
151#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
153pub enum TransportType {
154 Polling = 0x01,
156 Websocket = 0x02,
158}
159
160impl Serialize for TransportType {
161 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
162 where
163 S: Serializer,
164 {
165 serializer.serialize_str((*self).into())
166 }
167}
168
169impl<'de> Deserialize<'de> for TransportType {
170 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
171 where
172 D: Deserializer<'de>,
173 {
174 let s = String::deserialize(deserializer)?;
175 Self::from_str(&s).map_err(serde::de::Error::custom)
176 }
177}
178
179impl From<u8> for TransportType {
180 fn from(t: u8) -> Self {
181 match t {
182 0x01 => TransportType::Polling,
183 0x02 => TransportType::Websocket,
184 _ => panic!("unknown transport type"),
185 }
186 }
187}
188
189impl FromStr for TransportType {
190 type Err = ParseError;
191
192 fn from_str(s: &str) -> Result<Self, Self::Err> {
193 match s {
194 "websocket" => Ok(TransportType::Websocket),
195 "polling" => Ok(TransportType::Polling),
196 _ => Err(ParseError::UnknownTransport),
197 }
198 }
199}
200impl From<TransportType> for &'static str {
201 fn from(t: TransportType) -> Self {
202 match t {
203 TransportType::Polling => "polling",
204 TransportType::Websocket => "websocket",
205 }
206 }
207}
208impl From<TransportType> for String {
209 fn from(t: TransportType) -> Self {
210 match t {
211 TransportType::Polling => "polling".into(),
212 TransportType::Websocket => "websocket".into(),
213 }
214 }
215}
216
217#[derive(Debug)]
219pub struct RequestInfo {
220 pub protocol: ProtocolVersion,
222 pub sid: Option<Sid>,
224 pub transport: TransportType,
226 pub method: Method,
228 #[cfg(feature = "v3")]
230 pub b64: bool,
231}
232
233impl RequestInfo {
234 fn parse<B>(req: &Request<B>, config: &EngineIoConfig) -> Result<Self, ParseError> {
236 use ParseError::*;
237 let query = req.uri().query().ok_or(UnknownTransport)?;
238
239 let protocol: ProtocolVersion = query
240 .split('&')
241 .find(|s| s.starts_with("EIO="))
242 .and_then(|s| s.split('=').nth(1))
243 .ok_or(UnsupportedProtocolVersion)
244 .and_then(|t| t.parse())?;
245
246 let sid = query
247 .split('&')
248 .find(|s| s.starts_with("sid="))
249 .and_then(|s| s.split('=').nth(1).map(|s1| s1.parse().ok()))
250 .flatten();
251
252 let transport: TransportType = query
253 .split('&')
254 .find(|s| s.starts_with("transport="))
255 .and_then(|s| s.split('=').nth(1))
256 .ok_or(UnknownTransport)
257 .and_then(|t| t.parse())?;
258
259 if !config.allowed_transport(transport) {
260 return Err(TransportMismatch);
261 }
262
263 #[cfg(feature = "v3")]
264 let b64: bool = query
265 .split('&')
266 .find(|s| s.starts_with("b64="))
267 .map(|_| true)
268 .unwrap_or_default();
269
270 let method = req.method().clone();
271 if !matches!(method, Method::GET) && sid.is_none() {
272 Err(BadHandshakeMethod)
273 } else {
274 Ok(RequestInfo {
275 protocol,
276 sid,
277 transport,
278 method,
279 #[cfg(feature = "v3")]
280 b64,
281 })
282 }
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 fn build_request(path: &str) -> Request<()> {
291 Request::get(path).body(()).unwrap()
292 }
293
294 #[test]
295 fn request_info_polling() {
296 let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling");
297 let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
298 assert_eq!(info.sid, None);
299 assert_eq!(info.transport, TransportType::Polling);
300 assert_eq!(info.protocol, ProtocolVersion::V4);
301 assert_eq!(info.method, Method::GET);
302 }
303
304 #[test]
305 fn request_info_websocket() {
306 let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket");
307 let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
308 assert_eq!(info.sid, None);
309 assert_eq!(info.transport, TransportType::Websocket);
310 assert_eq!(info.protocol, ProtocolVersion::V4);
311 assert_eq!(info.method, Method::GET);
312 }
313
314 #[test]
315 #[cfg(feature = "v3")]
316 fn request_info_polling_with_sid() {
317 let req = build_request(
318 "http://localhost:3000/socket.io/?EIO=3&transport=polling&sid=AAAAAAAAAAAAAAHs",
319 );
320 let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
321 assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap()));
322 assert_eq!(info.transport, TransportType::Polling);
323 assert_eq!(info.protocol, ProtocolVersion::V3);
324 assert_eq!(info.method, Method::GET);
325 }
326
327 #[test]
328 fn request_info_websocket_with_sid() {
329 let req = build_request(
330 "http://localhost:3000/socket.io/?EIO=4&transport=websocket&sid=AAAAAAAAAAAAAAHs",
331 );
332 let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
333 assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap()));
334 assert_eq!(info.transport, TransportType::Websocket);
335 assert_eq!(info.protocol, ProtocolVersion::V4);
336 assert_eq!(info.method, Method::GET);
337 }
338
339 #[test]
340 #[cfg(feature = "v3")]
341 fn request_info_polling_with_bin_by_default() {
342 let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling");
343 let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
344 assert!(!req.b64);
345 }
346
347 #[test]
348 #[cfg(feature = "v3")]
349 fn request_info_polling_withb64() {
350 assert!(cfg!(feature = "v3"));
351
352 let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling&b64=1");
353 let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
354 assert!(req.b64);
355 }
356
357 #[test]
358 fn transport_unknown_err() {
359 let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=grpc");
360 let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err();
361 assert!(matches!(err, ParseError::UnknownTransport));
362 }
363 #[test]
364 fn unsupported_protocol_version() {
365 let req = build_request("http://localhost:3000/socket.io/?EIO=2&transport=polling");
366 let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err();
367 assert!(matches!(err, ParseError::UnsupportedProtocolVersion));
368 }
369 #[test]
370 fn bad_handshake_method() {
371 let req = Request::post("http://localhost:3000/socket.io/?EIO=4&transport=polling")
372 .body(())
373 .unwrap();
374 let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err();
375 assert!(matches!(err, ParseError::BadHandshakeMethod));
376 }
377
378 #[test]
379 fn unsupported_transport() {
380 let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling");
381 let err = RequestInfo::parse(
382 &req,
383 &EngineIoConfig::builder()
384 .transports([TransportType::Websocket])
385 .build(),
386 )
387 .unwrap_err();
388
389 assert!(matches!(err, ParseError::TransportMismatch));
390
391 let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket");
392 let err = RequestInfo::parse(
393 &req,
394 &EngineIoConfig::builder()
395 .transports([TransportType::Polling])
396 .build(),
397 )
398 .unwrap_err();
399
400 assert!(matches!(err, ParseError::TransportMismatch))
401 }
402}