engineioxide/service/
parser.rs

1//! A Parser module to parse any `EngineIo` query
2
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use std::{future::Future, str::FromStr, sync::Arc};
5
6use http::{Method, Request, Response};
7
8use engineioxide_core::Sid;
9
10use crate::{
11    body::ResponseBody,
12    config::EngineIoConfig,
13    engine::EngineIo,
14    handler::EngineIoHandler,
15    service::futures::ResponseFuture,
16    transport::{polling, ws},
17};
18
19/// Dispatch a request according to the [`RequestInfo`] to the appropriate [`transport`](crate::transport).
20pub fn dispatch_req<F, H, ReqBody, ResBody>(
21    req: Request<ReqBody>,
22    engine: Arc<EngineIo<H>>,
23) -> ResponseFuture<F, ResBody>
24where
25    ReqBody: http_body::Body + Send + Unpin + 'static,
26    ReqBody::Data: Send,
27    ReqBody::Error: std::fmt::Debug,
28    ResBody: Send + 'static,
29    H: EngineIoHandler,
30    F: Future,
31{
32    match RequestInfo::parse(&req, &engine.config) {
33        Ok(RequestInfo {
34            protocol,
35            sid: None,
36            transport: TransportType::Polling,
37            method: Method::GET,
38            #[cfg(feature = "v3")]
39            b64,
40        }) => ResponseFuture::ready(polling::open_req(
41            engine,
42            protocol,
43            req,
44            #[cfg(feature = "v3")]
45            !b64,
46        )),
47        Ok(RequestInfo {
48            protocol,
49            sid: Some(sid),
50            transport: TransportType::Polling,
51            method: Method::GET,
52            ..
53        }) => ResponseFuture::async_response(Box::pin(polling::polling_req(engine, protocol, sid))),
54        Ok(RequestInfo {
55            protocol,
56            sid: Some(sid),
57            transport: TransportType::Polling,
58            method: Method::POST,
59            ..
60        }) => {
61            ResponseFuture::async_response(Box::pin(polling::post_req(engine, protocol, sid, req)))
62        }
63        Ok(RequestInfo {
64            protocol,
65            sid,
66            transport: TransportType::Websocket,
67            method: Method::GET,
68            ..
69        }) => ResponseFuture::ready(ws::new_req(engine, protocol, sid, req)),
70        Err(e) => {
71            #[cfg(feature = "tracing")]
72            tracing::debug!("error parsing request: {:?}", e);
73            ResponseFuture::ready(Ok(e.into()))
74        }
75        _req => {
76            #[cfg(feature = "tracing")]
77            tracing::debug!("invalid request: {:?}", _req);
78            ResponseFuture::empty_response(400)
79        }
80    }
81}
82
83#[derive(thiserror::Error, Debug)]
84pub enum ParseError {
85    #[error("transport unknown")]
86    UnknownTransport,
87    #[error("bad handshake method")]
88    BadHandshakeMethod,
89    #[error("transport mismatch")]
90    TransportMismatch,
91    #[error("unsupported protocol version")]
92    UnsupportedProtocolVersion,
93}
94
95/// Convert an error into an http response
96/// If it is a known error, return the appropriate http status code
97/// Otherwise, return a 500
98impl<B> From<ParseError> for Response<ResponseBody<B>> {
99    fn from(err: ParseError) -> Self {
100        use ParseError::*;
101        let conn_err_resp = |message: &'static str| {
102            Response::builder()
103                .status(400)
104                .header("Content-Type", "application/json")
105                .body(ResponseBody::custom_response(message.into()))
106                .unwrap()
107        };
108        match err {
109            UnknownTransport => conn_err_resp("{\"code\":\"0\",\"message\":\"Transport unknown\"}"),
110            BadHandshakeMethod => {
111                conn_err_resp("{\"code\":\"2\",\"message\":\"Bad handshake method\"}")
112            }
113            TransportMismatch => conn_err_resp("{\"code\":\"3\",\"message\":\"Bad request\"}"),
114            UnsupportedProtocolVersion => {
115                conn_err_resp("{\"code\":\"5\",\"message\":\"Unsupported protocol version\"}")
116            }
117        }
118    }
119}
120
121/// The engine.io protocol version
122#[derive(Debug, Copy, Clone, PartialEq)]
123pub enum ProtocolVersion {
124    /// The protocol version 3
125    V3 = 3,
126    /// The protocol version 4
127    V4 = 4,
128}
129
130impl FromStr for ProtocolVersion {
131    type Err = ParseError;
132
133    #[cfg(feature = "v3")]
134    fn from_str(s: &str) -> Result<Self, Self::Err> {
135        match s {
136            "3" => Ok(ProtocolVersion::V3),
137            "4" => Ok(ProtocolVersion::V4),
138            _ => Err(ParseError::UnsupportedProtocolVersion),
139        }
140    }
141
142    #[cfg(not(feature = "v3"))]
143    fn from_str(s: &str) -> Result<Self, Self::Err> {
144        match s {
145            "4" => Ok(ProtocolVersion::V4),
146            _ => Err(ParseError::UnsupportedProtocolVersion),
147        }
148    }
149}
150
151/// The type of `transport` used by the client.
152#[derive(Debug, Copy, Clone, PartialEq, PartialOrd)]
153pub enum TransportType {
154    /// Polling transport
155    Polling = 0x01,
156    /// Websocket transport
157    Websocket = 0x02,
158}
159
160impl Serialize for TransportType {
161    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
162    where
163        S: Serializer,
164    {
165        serializer.serialize_str((*self).into())
166    }
167}
168
169impl<'de> Deserialize<'de> for TransportType {
170    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
171    where
172        D: Deserializer<'de>,
173    {
174        let s = String::deserialize(deserializer)?;
175        Self::from_str(&s).map_err(serde::de::Error::custom)
176    }
177}
178
179impl From<u8> for TransportType {
180    fn from(t: u8) -> Self {
181        match t {
182            0x01 => TransportType::Polling,
183            0x02 => TransportType::Websocket,
184            _ => panic!("unknown transport type"),
185        }
186    }
187}
188
189impl FromStr for TransportType {
190    type Err = ParseError;
191
192    fn from_str(s: &str) -> Result<Self, Self::Err> {
193        match s {
194            "websocket" => Ok(TransportType::Websocket),
195            "polling" => Ok(TransportType::Polling),
196            _ => Err(ParseError::UnknownTransport),
197        }
198    }
199}
200impl From<TransportType> for &'static str {
201    fn from(t: TransportType) -> Self {
202        match t {
203            TransportType::Polling => "polling",
204            TransportType::Websocket => "websocket",
205        }
206    }
207}
208impl From<TransportType> for String {
209    fn from(t: TransportType) -> Self {
210        match t {
211            TransportType::Polling => "polling".into(),
212            TransportType::Websocket => "websocket".into(),
213        }
214    }
215}
216
217/// The request information extracted from the request URI.
218#[derive(Debug)]
219pub struct RequestInfo {
220    /// The protocol version used by the client.
221    pub protocol: ProtocolVersion,
222    /// The socket id if present in the request.
223    pub sid: Option<Sid>,
224    /// The transport type used by the client.
225    pub transport: TransportType,
226    /// The request method.
227    pub method: Method,
228    /// If the client asked for base64 encoding only.
229    #[cfg(feature = "v3")]
230    pub b64: bool,
231}
232
233impl RequestInfo {
234    /// Parse the request URI to extract the [`TransportType`](crate::service::TransportType) and the socket id.
235    fn parse<B>(req: &Request<B>, config: &EngineIoConfig) -> Result<Self, ParseError> {
236        use ParseError::*;
237        let query = req.uri().query().ok_or(UnknownTransport)?;
238
239        let protocol: ProtocolVersion = query
240            .split('&')
241            .find(|s| s.starts_with("EIO="))
242            .and_then(|s| s.split('=').nth(1))
243            .ok_or(UnsupportedProtocolVersion)
244            .and_then(|t| t.parse())?;
245
246        let sid = query
247            .split('&')
248            .find(|s| s.starts_with("sid="))
249            .and_then(|s| s.split('=').nth(1).map(|s1| s1.parse().ok()))
250            .flatten();
251
252        let transport: TransportType = query
253            .split('&')
254            .find(|s| s.starts_with("transport="))
255            .and_then(|s| s.split('=').nth(1))
256            .ok_or(UnknownTransport)
257            .and_then(|t| t.parse())?;
258
259        if !config.allowed_transport(transport) {
260            return Err(TransportMismatch);
261        }
262
263        #[cfg(feature = "v3")]
264        let b64: bool = query
265            .split('&')
266            .find(|s| s.starts_with("b64="))
267            .map(|_| true)
268            .unwrap_or_default();
269
270        let method = req.method().clone();
271        if !matches!(method, Method::GET) && sid.is_none() {
272            Err(BadHandshakeMethod)
273        } else {
274            Ok(RequestInfo {
275                protocol,
276                sid,
277                transport,
278                method,
279                #[cfg(feature = "v3")]
280                b64,
281            })
282        }
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    fn build_request(path: &str) -> Request<()> {
291        Request::get(path).body(()).unwrap()
292    }
293
294    #[test]
295    fn request_info_polling() {
296        let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling");
297        let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
298        assert_eq!(info.sid, None);
299        assert_eq!(info.transport, TransportType::Polling);
300        assert_eq!(info.protocol, ProtocolVersion::V4);
301        assert_eq!(info.method, Method::GET);
302    }
303
304    #[test]
305    fn request_info_websocket() {
306        let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket");
307        let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
308        assert_eq!(info.sid, None);
309        assert_eq!(info.transport, TransportType::Websocket);
310        assert_eq!(info.protocol, ProtocolVersion::V4);
311        assert_eq!(info.method, Method::GET);
312    }
313
314    #[test]
315    #[cfg(feature = "v3")]
316    fn request_info_polling_with_sid() {
317        let req = build_request(
318            "http://localhost:3000/socket.io/?EIO=3&transport=polling&sid=AAAAAAAAAAAAAAHs",
319        );
320        let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
321        assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap()));
322        assert_eq!(info.transport, TransportType::Polling);
323        assert_eq!(info.protocol, ProtocolVersion::V3);
324        assert_eq!(info.method, Method::GET);
325    }
326
327    #[test]
328    fn request_info_websocket_with_sid() {
329        let req = build_request(
330            "http://localhost:3000/socket.io/?EIO=4&transport=websocket&sid=AAAAAAAAAAAAAAHs",
331        );
332        let info = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
333        assert_eq!(info.sid, Some("AAAAAAAAAAAAAAHs".parse().unwrap()));
334        assert_eq!(info.transport, TransportType::Websocket);
335        assert_eq!(info.protocol, ProtocolVersion::V4);
336        assert_eq!(info.method, Method::GET);
337    }
338
339    #[test]
340    #[cfg(feature = "v3")]
341    fn request_info_polling_with_bin_by_default() {
342        let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling");
343        let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
344        assert!(!req.b64);
345    }
346
347    #[test]
348    #[cfg(feature = "v3")]
349    fn request_info_polling_withb64() {
350        assert!(cfg!(feature = "v3"));
351
352        let req = build_request("http://localhost:3000/socket.io/?EIO=3&transport=polling&b64=1");
353        let req = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap();
354        assert!(req.b64);
355    }
356
357    #[test]
358    fn transport_unknown_err() {
359        let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=grpc");
360        let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err();
361        assert!(matches!(err, ParseError::UnknownTransport));
362    }
363    #[test]
364    fn unsupported_protocol_version() {
365        let req = build_request("http://localhost:3000/socket.io/?EIO=2&transport=polling");
366        let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err();
367        assert!(matches!(err, ParseError::UnsupportedProtocolVersion));
368    }
369    #[test]
370    fn bad_handshake_method() {
371        let req = Request::post("http://localhost:3000/socket.io/?EIO=4&transport=polling")
372            .body(())
373            .unwrap();
374        let err = RequestInfo::parse(&req, &EngineIoConfig::default()).unwrap_err();
375        assert!(matches!(err, ParseError::BadHandshakeMethod));
376    }
377
378    #[test]
379    fn unsupported_transport() {
380        let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=polling");
381        let err = RequestInfo::parse(
382            &req,
383            &EngineIoConfig::builder()
384                .transports([TransportType::Websocket])
385                .build(),
386        )
387        .unwrap_err();
388
389        assert!(matches!(err, ParseError::TransportMismatch));
390
391        let req = build_request("http://localhost:3000/socket.io/?EIO=4&transport=websocket");
392        let err = RequestInfo::parse(
393            &req,
394            &EngineIoConfig::builder()
395                .transports([TransportType::Polling])
396                .build(),
397        )
398        .unwrap_err();
399
400        assert!(matches!(err, ParseError::TransportMismatch))
401    }
402}