routerify_websocket/
upgrade.rs

1use crate::{WebSocket, WebSocketConfig};
2use futures::future::{ok, Ready};
3use headers::{Connection, Header, SecWebsocketAccept, SecWebsocketKey, Upgrade};
4use hyper::{
5    body::HttpBody,
6    header::{self, HeaderValue},
7    Request, Response, StatusCode,
8};
9use routerify::ext::RequestExt;
10use std::future::Future;
11
12/// Upgrades the http requests to websocket with the provided [config](./struct.WebSocketConfig.html).
13///
14/// # Examples
15///
16/// ```no_run
17/// # use hyper::{Body, Response, Server};
18/// # use routerify::{Router, RouterService};
19/// # // Import websocket types.
20/// use routerify_websocket::{upgrade_ws_with_config, WebSocket, WebSocketConfig};
21/// # use std::{convert::Infallible, net::SocketAddr};
22///
23/// # // A handler for websocket connections.
24/// async fn ws_handler(ws: WebSocket) {
25///     println!("New websocket connection: {}", ws.remote_addr());
26///     // Handle websocket connection.
27/// }
28///
29/// fn router() -> Router<Body, Infallible> {
30///     // Create a router and specify the path and the handler for new websocket connections.
31///     Router::builder()
32///         // Upgrade the http requests at `/ws` path to websocket with the following config.
33///         .any_method("/ws", upgrade_ws_with_config(ws_handler, WebSocketConfig::default()))
34///         .build()
35///         .unwrap()
36/// }
37///
38/// # #[tokio::main]
39/// # async fn main() {
40/// #     let router = router();
41/// #
42/// #     // Create a Service from the router above to handle incoming requests.
43/// #     let service = RouterService::new(router).unwrap();
44/// #
45/// #     // The address on which the server will be listening.
46/// #     let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
47/// #
48/// #     // Create a server by passing the created service to `.serve` method.
49/// #      let server = Server::bind(&addr).serve(service);
50/// #
51/// #     println!("App is running on: {}", addr);
52/// #     if let Err(err) = server.await {
53/// #         eprintln!("Server error: {}", err);
54/// #     }
55/// # }
56/// ```
57pub fn upgrade_ws_with_config<H, R, B, E>(
58    handler: H,
59    config: WebSocketConfig,
60) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + Sync + 'static
61where
62    H: Fn(WebSocket) -> R + Copy + Send + Sync + 'static,
63    R: Future<Output = ()> + Send + 'static,
64    B: From<&'static str> + HttpBody + Send + 'static,
65    E: std::error::Error + Send + 'static,
66{
67    return move |req: Request<hyper::Body>| {
68        let sec_key = extract_upgradable_key(&req);
69        let remote_addr = req.remote_addr();
70
71        if sec_key.is_none() {
72            return ok(Response::builder()
73                .status(StatusCode::BAD_REQUEST)
74                .body("BAD REQUEST: The request is not websocket".into())
75                .unwrap());
76        }
77
78        tokio::spawn(async move {
79            match hyper::upgrade::on(req).await {
80                Ok(upgraded) => {
81                    handler(WebSocket::from_raw_socket(upgraded, remote_addr, config).await).await;
82                }
83                Err(err) => log::error!("{}", crate::WebsocketError::Upgrade(err.into())),
84            }
85        });
86
87        let resp = Response::builder()
88            .status(StatusCode::SWITCHING_PROTOCOLS)
89            .header(header::CONNECTION, encode_header(Connection::upgrade()))
90            .header(header::UPGRADE, encode_header(Upgrade::websocket()))
91            .header(
92                header::SEC_WEBSOCKET_ACCEPT,
93                encode_header(SecWebsocketAccept::from(sec_key.unwrap())),
94            )
95            .body("".into())
96            .unwrap();
97
98        ok(resp)
99    };
100}
101
102/// Upgrades the http requests to websocket.
103///
104/// # Examples
105///
106/// ```no_run
107/// # use hyper::{Body, Response, Server};
108/// # use routerify::{Router, RouterService};
109/// # // Import websocket types.
110/// use routerify_websocket::{upgrade_ws, WebSocket};
111/// # use std::{convert::Infallible, net::SocketAddr};
112///
113/// # // A handler for websocket connections.
114/// async fn ws_handler(ws: WebSocket) {
115///     println!("New websocket connection: {}", ws.remote_addr());
116///     // Handle websocket connection.
117/// }
118///
119/// fn router() -> Router<Body, Infallible> {
120///     // Create a router and specify the path and the handler for new websocket connections.
121///     Router::builder()
122///         // Upgrade the http requests at `/ws` path to websocket.
123///         .any_method("/ws", upgrade_ws(ws_handler))
124///         .build()
125///         .unwrap()
126/// }
127///
128/// # #[tokio::main]
129/// # async fn main() {
130/// #     let router = router();
131/// #
132/// #     // Create a Service from the router above to handle incoming requests.
133/// #     let service = RouterService::new(router).unwrap();
134/// #
135/// #     // The address on which the server will be listening.
136/// #     let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
137/// #
138/// #     // Create a server by passing the created service to `.serve` method.
139/// #      let server = Server::bind(&addr).serve(service);
140/// #
141/// #     println!("App is running on: {}", addr);
142/// #     if let Err(err) = server.await {
143/// #         eprintln!("Server error: {}", err);
144/// #     }
145/// # }
146/// ```
147pub fn upgrade_ws<H, R, B, E>(
148    handler: H,
149) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + Sync + 'static
150where
151    H: Fn(WebSocket) -> R + Copy + Send + Sync + 'static,
152    R: Future<Output = ()> + Send + 'static,
153    B: From<&'static str> + HttpBody + Send + 'static,
154    E: std::error::Error + Send + 'static,
155{
156    return upgrade_ws_with_config(handler, WebSocketConfig::default());
157}
158
159fn extract_upgradable_key(req: &Request<hyper::Body>) -> Option<SecWebsocketKey> {
160    let hdrs = req.headers();
161
162    hdrs.get(header::CONNECTION)
163        .and_then(|val| decode_header::<Connection>(val))
164        .and_then(|conn| some(conn.contains("upgrade")))
165        .and_then(|_| hdrs.get(header::UPGRADE))
166        .and_then(|val| val.to_str().ok())
167        .and_then(|val| some(val == "websocket"))
168        .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_VERSION))
169        .and_then(|val| val.to_str().ok())
170        .and_then(|val| some(val == "13"))
171        .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_KEY))
172        .and_then(|val| decode_header::<SecWebsocketKey>(val))
173}
174
175fn decode_header<T: Header>(val: &HeaderValue) -> Option<T> {
176    let values = [val];
177    let mut iter = (&values).into_iter().copied();
178    T::decode(&mut iter).ok()
179}
180
181fn encode_header<T: Header>(h: T) -> HeaderValue {
182    let mut val = Vec::with_capacity(1);
183    h.encode(&mut val);
184    val.into_iter().nth(0).unwrap()
185}
186
187fn some(cond: bool) -> Option<()> {
188    if cond {
189        Some(())
190    } else {
191        None
192    }
193}