routerify_ws/upgrade.rs
1use crate::{WebSocket, WebSocketConfig};
2use futures::future::{ok, Ready};
3use headers::{Connection, Header, SecWebsocketAccept, SecWebsocketKey, Upgrade};
4use hyper::{
5 body::HttpBody,
6 header::{self, HeaderValue},
7 Body, Request, Response, StatusCode,
8};
9use routerify::ext::RequestExt;
10use std::future::Future;
11
12/// Upgrades the http requests to websocket with the provided [config](./struct.WebSocketConfig.html).
13///
14/// # Examples
15///
16/// ```no_run
17/// # use hyper::{Body, Response, Server};
18/// # use routerify::{Router, RouterService};
19/// # // Import websocket types.
20/// use routerify_ws::{upgrade_ws_with_config, WebSocket, WebSocketConfig};
21/// # use std::{convert::Infallible, net::SocketAddr};
22///
23/// # // A handler for websocket connections.
24/// async fn ws_handler(ws: WebSocket) {
25/// println!("New websocket connection: {}", ws.remote_addr());
26/// // Handle websocket connection.
27/// }
28///
29/// fn router() -> Router<Body, Infallible> {
30/// // Create a router and specify the path and the handler for new websocket connections.
31/// Router::builder()
32/// // Upgrade the http requests at `/ws` path to websocket with the following config.
33/// .any_method("/ws", upgrade_ws_with_config(ws_handler, WebSocketConfig::default()))
34/// .build()
35/// .unwrap()
36/// }
37///
38/// # #[tokio::main]
39/// # async fn main() {
40/// # let router = router();
41/// #
42/// # // Create a Service from the router above to handle incoming requests.
43/// # let service = RouterService::new(router).unwrap();
44/// #
45/// # // The address on which the server will be listening.
46/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
47/// #
48/// # // Create a server by passing the created service to `.serve` method.
49/// # let server = Server::bind(&addr).serve(service);
50/// #
51/// # println!("App is running on: {}", addr);
52/// # if let Err(err) = server.await {
53/// # eprintln!("Server error: {}", err);
54/// # }
55/// # }
56/// ```
57pub fn upgrade_ws_with_config<H, R, B, E>(
58 handler: H,
59 config: WebSocketConfig,
60) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
61where
62 H: Fn(WebSocket) -> R + Clone + Send + Sync + 'static,
63 R: Future<Output = ()> + Send + 'static,
64 B: From<&'static str> + HttpBody + Send + 'static,
65 E: std::error::Error + Send + 'static,
66{
67 return upgrade_ws_with_config_and_req(move |_, w| handler(w), config);
68}
69/// Upgrades the http requests to websocket with the provided [config](./struct.WebSocketConfig.html) and adds request to the handler to be able to use extensions.
70///
71/// # Examples
72///
73/// ```no_run
74/// # use hyper::{Body, Response, Server,Request};
75/// # use routerify::{Router, RouterService};
76/// # // Import websocket types.
77/// use routerify_ws::{upgrade_ws_with_config_and_req, WebSocket, WebSocketConfig};
78/// # use std::{convert::Infallible, net::SocketAddr};
79///
80/// # // A handler for websocket connections.
81/// async fn ws_handler(req:Request<Body>,ws: WebSocket) {
82///
83/// println!("New websocket connection: {} {:?}", ws.remote_addr(),req.headers());
84/// // Handle websocket connection.
85/// }
86///
87/// fn router() -> Router<Body, Infallible> {
88/// // Create a router and specify the path and the handler for new websocket connections.
89/// Router::builder()
90/// // Upgrade the http requests at `/ws` path to websocket with the following config.
91/// .any_method("/ws", upgrade_ws_with_config_and_req(ws_handler, WebSocketConfig::default()))
92/// .build()
93/// .unwrap()
94/// }
95///
96/// # #[tokio::main]
97/// # async fn main() {
98/// # let router = router();
99/// #
100/// # // Create a Service from the router above to handle incoming requests.
101/// # let service = RouterService::new(router).unwrap();
102/// #
103/// # // The address on which the server will be listening.
104/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
105/// #
106/// # // Create a server by passing the created service to `.serve` method.
107/// # let server = Server::bind(&addr).serve(service);
108/// #
109/// # println!("App is running on: {}", addr);
110/// # if let Err(err) = server.await {
111/// # eprintln!("Server error: {}", err);
112/// # }
113/// # }
114/// ```
115pub fn upgrade_ws_with_config_and_req<H, R, B, E>(
116 handler: H,
117 config: WebSocketConfig,
118) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
119where
120 H: Fn(Request<Body>, WebSocket) -> R + Clone + Send + Sync + 'static,
121 R: Future<Output = ()> + Send + 'static,
122 B: From<&'static str> + HttpBody + Send + 'static,
123 E: std::error::Error + Send + 'static,
124{
125 return move |mut req: Request<hyper::Body>| {
126 let sec_key = extract_upgradable_key(&req);
127 let remote_addr = req.remote_addr();
128
129 if sec_key.is_none() {
130 return ok(Response::builder()
131 .status(StatusCode::BAD_REQUEST)
132 .body("BAD REQUEST: The request is not websocket".into())
133 .unwrap());
134 }
135 let handler = handler.clone();
136 tokio::spawn(async move {
137 match hyper::upgrade::on(&mut req).await {
138 Ok(upgraded) => {
139 handler(req, WebSocket::from_raw_socket(upgraded, remote_addr, config).await).await;
140 }
141 Err(err) => log::error!("{}", crate::WebsocketError::Upgrade(err.into())),
142 }
143 });
144
145 let resp = Response::builder()
146 .status(StatusCode::SWITCHING_PROTOCOLS)
147 .header(header::CONNECTION, encode_header(Connection::upgrade()))
148 .header(header::UPGRADE, encode_header(Upgrade::websocket()))
149 .header(
150 header::SEC_WEBSOCKET_ACCEPT,
151 encode_header(SecWebsocketAccept::from(sec_key.unwrap())),
152 )
153 .body("".into())
154 .unwrap();
155
156 ok(resp)
157 };
158}
159
160/// Upgrades the http requests to websocket.
161///
162/// # Examples
163///
164/// ```no_run
165/// # use hyper::{Body, Response, Server};
166/// # use routerify::{Router, RouterService};
167/// # // Import websocket types.
168/// use routerify_ws::{upgrade_ws, WebSocket};
169/// # use std::{convert::Infallible, net::SocketAddr};
170///
171/// # // A handler for websocket connections.
172/// async fn ws_handler(ws: WebSocket) {
173/// println!("New websocket connection: {}", ws.remote_addr());
174/// // Handle websocket connection.
175/// }
176///
177/// fn router() -> Router<Body, Infallible> {
178/// // Create a router and specify the path and the handler for new websocket connections.
179/// Router::builder()
180/// // Upgrade the http requests at `/ws` path to websocket.
181/// .any_method("/ws", upgrade_ws(ws_handler))
182/// .build()
183/// .unwrap()
184/// }
185///
186/// # #[tokio::main]
187/// # async fn main() {
188/// # let router = router();
189/// #
190/// # // Create a Service from the router above to handle incoming requests.
191/// # let service = RouterService::new(router).unwrap();
192/// #
193/// # // The address on which the server will be listening.
194/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
195/// #
196/// # // Create a server by passing the created service to `.serve` method.
197/// # let server = Server::bind(&addr).serve(service);
198/// #
199/// # println!("App is running on: {}", addr);
200/// # if let Err(err) = server.await {
201/// # eprintln!("Server error: {}", err);
202/// # }
203/// # }
204/// ```
205pub fn upgrade_ws<H, R, B, E>(
206 handler: H,
207) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
208where
209 H: Fn(WebSocket) -> R + Clone + Send + Sync + 'static,
210 R: Future<Output = ()> + Send + 'static,
211 B: From<&'static str> + HttpBody + Send + 'static,
212 E: std::error::Error + Send + 'static,
213{
214 return upgrade_ws_with_req(move |_, w| handler(w));
215}
216/// Upgrades the http requests to websocket while still providing the request for accesing things
217/// like headers or extensions.
218///
219/// # Examples
220///
221/// ```no_run
222/// # use hyper::{Body, Response,Request, Server};
223/// # use routerify::{Router, RouterService};
224/// # // Import websocket types.
225/// use routerify_ws::{upgrade_ws_with_req, WebSocket};
226/// # use std::{convert::Infallible, net::SocketAddr};
227///
228/// # // A handler for websocket connections.
229/// async fn ws_handler(req:Request<Body>,ws: WebSocket) {
230/// println!("New websocket connection: {} {:?}", ws.remote_addr(),req.headers());
231/// // Handle websocket connection.
232/// }
233///
234/// fn router() -> Router<Body, Infallible> {
235/// // Create a router and specify the path and the handler for new websocket connections.
236/// Router::builder()
237/// // Upgrade the http requests at `/ws` path to websocket.
238/// .any_method("/ws", upgrade_ws_with_req(ws_handler))
239/// .build()
240/// .unwrap()
241/// }
242///
243/// # #[tokio::main]
244/// # async fn main() {
245/// # let router = router();
246/// #
247/// # // Create a Service from the router above to handle incoming requests.
248/// # let service = RouterService::new(router).unwrap();
249/// #
250/// # // The address on which the server will be listening.
251/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
252/// #
253/// # // Create a server by passing the created service to `.serve` method.
254/// # let server = Server::bind(&addr).serve(service);
255/// #
256/// # println!("App is running on: {}", addr);
257/// # if let Err(err) = server.await {
258/// # eprintln!("Server error: {}", err);
259/// # }
260/// # }
261/// ```
262pub fn upgrade_ws_with_req<H, R, B, E>(
263 handler: H,
264) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + 'static
265where
266 H: Fn(Request<Body>, WebSocket) -> R + Clone + Send + Sync + 'static,
267 R: Future<Output = ()> + Send + 'static,
268 B: From<&'static str> + HttpBody + Send + 'static,
269 E: std::error::Error + Send + 'static,
270{
271 return upgrade_ws_with_config_and_req(handler, WebSocketConfig::default());
272}
273
274fn extract_upgradable_key(req: &Request<hyper::Body>) -> Option<SecWebsocketKey> {
275 let hdrs = req.headers();
276
277 hdrs.get(header::CONNECTION)
278 .and_then(|val| decode_header::<Connection>(val))
279 .and_then(|conn| some(conn.contains("upgrade")))
280 .and_then(|_| hdrs.get(header::UPGRADE))
281 .and_then(|val| val.to_str().ok())
282 .and_then(|val| some(val == "websocket"))
283 .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_VERSION))
284 .and_then(|val| val.to_str().ok())
285 .and_then(|val| some(val == "13"))
286 .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_KEY))
287 .and_then(|val| decode_header::<SecWebsocketKey>(val))
288}
289
290fn decode_header<T: Header>(val: &HeaderValue) -> Option<T> {
291 let values = [val];
292 let mut iter = (&values).into_iter().copied();
293 T::decode(&mut iter).ok()
294}
295
296fn encode_header<T: Header>(h: T) -> HeaderValue {
297 let mut val = Vec::with_capacity(1);
298 h.encode(&mut val);
299 val.into_iter().nth(0).unwrap()
300}
301
302fn some(cond: bool) -> Option<()> {
303 if cond {
304 Some(())
305 } else {
306 None
307 }
308}