routerify_ws/
upgrade.rs

1use crate::{WebSocket, WebSocketConfig};
2use futures::future::{ok, Ready};
3use headers::{Connection, Header, SecWebsocketAccept, SecWebsocketKey, Upgrade};
4use hyper::{
5    body::HttpBody,
6    header::{self, HeaderValue},
7    Body, Request, Response, StatusCode,
8};
9use routerify::ext::RequestExt;
10use std::future::Future;
11
12/// Upgrades the http requests to websocket with the provided [config](./struct.WebSocketConfig.html).
13///
14/// # Examples
15///
16/// ```no_run
17/// # use hyper::{Body, Response, Server};
18/// # use routerify::{Router, RouterService};
19/// # // Import websocket types.
20/// use routerify_ws::{upgrade_ws_with_config, WebSocket, WebSocketConfig};
21/// # use std::{convert::Infallible, net::SocketAddr};
22///
23/// # // A handler for websocket connections.
24/// async fn ws_handler(ws: WebSocket) {
25///     println!("New websocket connection: {}", ws.remote_addr());
26///     // Handle websocket connection.
27/// }
28///
29/// fn router() -> Router<Body, Infallible> {
30///     // Create a router and specify the path and the handler for new websocket connections.
31///     Router::builder()
32///         // Upgrade the http requests at `/ws` path to websocket with the following config.
33///         .any_method("/ws", upgrade_ws_with_config(ws_handler, WebSocketConfig::default()))
34///         .build()
35///         .unwrap()
36/// }
37///
38/// # #[tokio::main]
39/// # async fn main() {
40/// #     let router = router();
41/// #
42/// #     // Create a Service from the router above to handle incoming requests.
43/// #     let service = RouterService::new(router).unwrap();
44/// #
45/// #     // The address on which the server will be listening.
46/// #     let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
47/// #
48/// #     // Create a server by passing the created service to `.serve` method.
49/// #      let server = Server::bind(&addr).serve(service);
50/// #
51/// #     println!("App is running on: {}", addr);
52/// #     if let Err(err) = server.await {
53/// #         eprintln!("Server error: {}", err);
54/// #     }
55/// # }
56/// ```
57pub fn upgrade_ws_with_config<H, R, B, E>(
58    handler: H,
59    config: WebSocketConfig,
60) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
61where
62    H: Fn(WebSocket) -> R + Clone + Send + Sync + 'static,
63    R: Future<Output = ()> + Send + 'static,
64    B: From<&'static str> + HttpBody + Send + 'static,
65    E: std::error::Error + Send + 'static,
66{
67    return upgrade_ws_with_config_and_req(move |_, w| handler(w), config);
68}
69/// Upgrades the http requests to websocket with the provided [config](./struct.WebSocketConfig.html) and adds request to the handler to be able to use extensions.
70///
71/// # Examples
72///
73/// ```no_run
74/// # use hyper::{Body, Response, Server,Request};
75/// # use routerify::{Router, RouterService};
76/// # // Import websocket types.
77/// use routerify_ws::{upgrade_ws_with_config_and_req, WebSocket, WebSocketConfig};
78/// # use std::{convert::Infallible, net::SocketAddr};
79///
80/// # // A handler for websocket connections.
81/// async fn ws_handler(req:Request<Body>,ws: WebSocket) {
82///
83///     println!("New websocket connection: {} {:?}", ws.remote_addr(),req.headers());
84///     // Handle websocket connection.
85/// }
86///
87/// fn router() -> Router<Body, Infallible> {
88///     // Create a router and specify the path and the handler for new websocket connections.
89///     Router::builder()
90///         // Upgrade the http requests at `/ws` path to websocket with the following config.
91///         .any_method("/ws", upgrade_ws_with_config_and_req(ws_handler, WebSocketConfig::default()))
92///         .build()
93///         .unwrap()
94/// }
95///
96/// # #[tokio::main]
97/// # async fn main() {
98/// #     let router = router();
99/// #
100/// #     // Create a Service from the router above to handle incoming requests.
101/// #     let service = RouterService::new(router).unwrap();
102/// #
103/// #     // The address on which the server will be listening.
104/// #     let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
105/// #
106/// #     // Create a server by passing the created service to `.serve` method.
107/// #      let server = Server::bind(&addr).serve(service);
108/// #
109/// #     println!("App is running on: {}", addr);
110/// #     if let Err(err) = server.await {
111/// #         eprintln!("Server error: {}", err);
112/// #     }
113/// # }
114/// ```
115pub fn upgrade_ws_with_config_and_req<H, R, B, E>(
116    handler: H,
117    config: WebSocketConfig,
118) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
119where
120    H: Fn(Request<Body>, WebSocket) -> R + Clone + Send + Sync + 'static,
121    R: Future<Output = ()> + Send + 'static,
122    B: From<&'static str> + HttpBody + Send + 'static,
123    E: std::error::Error + Send + 'static,
124{
125    return move |mut req: Request<hyper::Body>| {
126        let sec_key = extract_upgradable_key(&req);
127        let remote_addr = req.remote_addr();
128
129        if sec_key.is_none() {
130            return ok(Response::builder()
131                .status(StatusCode::BAD_REQUEST)
132                .body("BAD REQUEST: The request is not websocket".into())
133                .unwrap());
134        }
135        let handler = handler.clone();
136        tokio::spawn(async move {
137            match hyper::upgrade::on(&mut req).await {
138                Ok(upgraded) => {
139                    handler(req, WebSocket::from_raw_socket(upgraded, remote_addr, config).await).await;
140                }
141                Err(err) => log::error!("{}", crate::WebsocketError::Upgrade(err.into())),
142            }
143        });
144
145        let resp = Response::builder()
146            .status(StatusCode::SWITCHING_PROTOCOLS)
147            .header(header::CONNECTION, encode_header(Connection::upgrade()))
148            .header(header::UPGRADE, encode_header(Upgrade::websocket()))
149            .header(
150                header::SEC_WEBSOCKET_ACCEPT,
151                encode_header(SecWebsocketAccept::from(sec_key.unwrap())),
152            )
153            .body("".into())
154            .unwrap();
155
156        ok(resp)
157    };
158}
159
160/// Upgrades the http requests to websocket.
161///
162/// # Examples
163///
164/// ```no_run
165/// # use hyper::{Body, Response, Server};
166/// # use routerify::{Router, RouterService};
167/// # // Import websocket types.
168/// use routerify_ws::{upgrade_ws, WebSocket};
169/// # use std::{convert::Infallible, net::SocketAddr};
170///
171/// # // A handler for websocket connections.
172/// async fn ws_handler(ws: WebSocket) {
173///     println!("New websocket connection: {}", ws.remote_addr());
174///     // Handle websocket connection.
175/// }
176///
177/// fn router() -> Router<Body, Infallible> {
178///     // Create a router and specify the path and the handler for new websocket connections.
179///     Router::builder()
180///         // Upgrade the http requests at `/ws` path to websocket.
181///         .any_method("/ws", upgrade_ws(ws_handler))
182///         .build()
183///         .unwrap()
184/// }
185///
186/// # #[tokio::main]
187/// # async fn main() {
188/// #     let router = router();
189/// #
190/// #     // Create a Service from the router above to handle incoming requests.
191/// #     let service = RouterService::new(router).unwrap();
192/// #
193/// #     // The address on which the server will be listening.
194/// #     let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
195/// #
196/// #     // Create a server by passing the created service to `.serve` method.
197/// #      let server = Server::bind(&addr).serve(service);
198/// #
199/// #     println!("App is running on: {}", addr);
200/// #     if let Err(err) = server.await {
201/// #         eprintln!("Server error: {}", err);
202/// #     }
203/// # }
204/// ```
205pub fn upgrade_ws<H, R, B, E>(
206    handler: H,
207) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
208where
209    H: Fn(WebSocket) -> R + Clone + Send + Sync + 'static,
210    R: Future<Output = ()> + Send + 'static,
211    B: From<&'static str> + HttpBody + Send + 'static,
212    E: std::error::Error + Send + 'static,
213{
214    return upgrade_ws_with_req(move |_, w| handler(w));
215}
216/// Upgrades the http requests to websocket while still providing the request for accesing things
217/// like headers or extensions.
218///
219/// # Examples
220///
221/// ```no_run
222/// # use hyper::{Body, Response,Request, Server};
223/// # use routerify::{Router, RouterService};
224/// # // Import websocket types.
225/// use routerify_ws::{upgrade_ws_with_req, WebSocket};
226/// # use std::{convert::Infallible, net::SocketAddr};
227///
228/// # // A handler for websocket connections.
229/// async fn ws_handler(req:Request<Body>,ws: WebSocket) {
230///     println!("New websocket connection: {} {:?}", ws.remote_addr(),req.headers());
231///     // Handle websocket connection.
232/// }
233///
234/// fn router() -> Router<Body, Infallible> {
235///     // Create a router and specify the path and the handler for new websocket connections.
236///     Router::builder()
237///         // Upgrade the http requests at `/ws` path to websocket.
238///         .any_method("/ws", upgrade_ws_with_req(ws_handler))
239///         .build()
240///         .unwrap()
241/// }
242///
243/// # #[tokio::main]
244/// # async fn main() {
245/// #     let router = router();
246/// #
247/// #     // Create a Service from the router above to handle incoming requests.
248/// #     let service = RouterService::new(router).unwrap();
249/// #
250/// #     // The address on which the server will be listening.
251/// #     let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
252/// #
253/// #     // Create a server by passing the created service to `.serve` method.
254/// #      let server = Server::bind(&addr).serve(service);
255/// #
256/// #     println!("App is running on: {}", addr);
257/// #     if let Err(err) = server.await {
258/// #         eprintln!("Server error: {}", err);
259/// #     }
260/// # }
261/// ```
262pub fn upgrade_ws_with_req<H, R, B, E>(
263    handler: H,
264) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
265where
266    H: Fn(Request<Body>, WebSocket) -> R + Clone + Send + Sync + 'static,
267    R: Future<Output = ()> + Send + 'static,
268    B: From<&'static str> + HttpBody + Send + 'static,
269    E: std::error::Error + Send + 'static,
270{
271    return upgrade_ws_with_config_and_req(handler, WebSocketConfig::default());
272}
273
274fn extract_upgradable_key(req: &Request<hyper::Body>) -> Option<SecWebsocketKey> {
275    let hdrs = req.headers();
276
277    hdrs.get(header::CONNECTION)
278        .and_then(|val| decode_header::<Connection>(val))
279        .and_then(|conn| some(conn.contains("upgrade")))
280        .and_then(|_| hdrs.get(header::UPGRADE))
281        .and_then(|val| val.to_str().ok())
282        .and_then(|val| some(val == "websocket"))
283        .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_VERSION))
284        .and_then(|val| val.to_str().ok())
285        .and_then(|val| some(val == "13"))
286        .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_KEY))
287        .and_then(|val| decode_header::<SecWebsocketKey>(val))
288}
289
290fn decode_header<T: Header>(val: &HeaderValue) -> Option<T> {
291    let values = [val];
292    let mut iter = (&values).into_iter().copied();
293    T::decode(&mut iter).ok()
294}
295
296fn encode_header<T: Header>(h: T) -> HeaderValue {
297    let mut val = Vec::with_capacity(1);
298    h.encode(&mut val);
299    val.into_iter().nth(0).unwrap()
300}
301
302fn some(cond: bool) -> Option<()> {
303    if cond {
304        Some(())
305    } else {
306        None
307    }
308}