routerify_websocket/upgrade.rs
1use crate::{WebSocket, WebSocketConfig};
2use futures::future::{ok, Ready};
3use headers::{Connection, Header, SecWebsocketAccept, SecWebsocketKey, Upgrade};
4use hyper::{
5 body::HttpBody,
6 header::{self, HeaderValue},
7 Request, Response, StatusCode,
8};
9use routerify::ext::RequestExt;
10use std::future::Future;
11
12/// Upgrades the http requests to websocket with the provided [config](./struct.WebSocketConfig.html).
13///
14/// # Examples
15///
16/// ```no_run
17/// # use hyper::{Body, Response, Server};
18/// # use routerify::{Router, RouterService};
19/// # // Import websocket types.
20/// use routerify_websocket::{upgrade_ws_with_config, WebSocket, WebSocketConfig};
21/// # use std::{convert::Infallible, net::SocketAddr};
22///
23/// # // A handler for websocket connections.
24/// async fn ws_handler(ws: WebSocket) {
25/// println!("New websocket connection: {}", ws.remote_addr());
26/// // Handle websocket connection.
27/// }
28///
29/// fn router() -> Router<Body, Infallible> {
30/// // Create a router and specify the path and the handler for new websocket connections.
31/// Router::builder()
32/// // Upgrade the http requests at `/ws` path to websocket with the following config.
33/// .any_method("/ws", upgrade_ws_with_config(ws_handler, WebSocketConfig::default()))
34/// .build()
35/// .unwrap()
36/// }
37///
38/// # #[tokio::main]
39/// # async fn main() {
40/// # let router = router();
41/// #
42/// # // Create a Service from the router above to handle incoming requests.
43/// # let service = RouterService::new(router).unwrap();
44/// #
45/// # // The address on which the server will be listening.
46/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
47/// #
48/// # // Create a server by passing the created service to `.serve` method.
49/// # let server = Server::bind(&addr).serve(service);
50/// #
51/// # println!("App is running on: {}", addr);
52/// # if let Err(err) = server.await {
53/// # eprintln!("Server error: {}", err);
54/// # }
55/// # }
56/// ```
57pub fn upgrade_ws_with_config<H, R, B, E>(
58 handler: H,
59 config: WebSocketConfig,
60) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + Sync + 'static
61where
62 H: Fn(WebSocket) -> R + Copy + Send + Sync + 'static,
63 R: Future<Output = ()> + Send + 'static,
64 B: From<&'static str> + HttpBody + Send + 'static,
65 E: std::error::Error + Send + 'static,
66{
67 return move |req: Request<hyper::Body>| {
68 let sec_key = extract_upgradable_key(&req);
69 let remote_addr = req.remote_addr();
70
71 if sec_key.is_none() {
72 return ok(Response::builder()
73 .status(StatusCode::BAD_REQUEST)
74 .body("BAD REQUEST: The request is not websocket".into())
75 .unwrap());
76 }
77
78 tokio::spawn(async move {
79 match hyper::upgrade::on(req).await {
80 Ok(upgraded) => {
81 handler(WebSocket::from_raw_socket(upgraded, remote_addr, config).await).await;
82 }
83 Err(err) => log::error!("{}", crate::WebsocketError::Upgrade(err.into())),
84 }
85 });
86
87 let resp = Response::builder()
88 .status(StatusCode::SWITCHING_PROTOCOLS)
89 .header(header::CONNECTION, encode_header(Connection::upgrade()))
90 .header(header::UPGRADE, encode_header(Upgrade::websocket()))
91 .header(
92 header::SEC_WEBSOCKET_ACCEPT,
93 encode_header(SecWebsocketAccept::from(sec_key.unwrap())),
94 )
95 .body("".into())
96 .unwrap();
97
98 ok(resp)
99 };
100}
101
102/// Upgrades the http requests to websocket.
103///
104/// # Examples
105///
106/// ```no_run
107/// # use hyper::{Body, Response, Server};
108/// # use routerify::{Router, RouterService};
109/// # // Import websocket types.
110/// use routerify_websocket::{upgrade_ws, WebSocket};
111/// # use std::{convert::Infallible, net::SocketAddr};
112///
113/// # // A handler for websocket connections.
114/// async fn ws_handler(ws: WebSocket) {
115/// println!("New websocket connection: {}", ws.remote_addr());
116/// // Handle websocket connection.
117/// }
118///
119/// fn router() -> Router<Body, Infallible> {
120/// // Create a router and specify the path and the handler for new websocket connections.
121/// Router::builder()
122/// // Upgrade the http requests at `/ws` path to websocket.
123/// .any_method("/ws", upgrade_ws(ws_handler))
124/// .build()
125/// .unwrap()
126/// }
127///
128/// # #[tokio::main]
129/// # async fn main() {
130/// # let router = router();
131/// #
132/// # // Create a Service from the router above to handle incoming requests.
133/// # let service = RouterService::new(router).unwrap();
134/// #
135/// # // The address on which the server will be listening.
136/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
137/// #
138/// # // Create a server by passing the created service to `.serve` method.
139/// # let server = Server::bind(&addr).serve(service);
140/// #
141/// # println!("App is running on: {}", addr);
142/// # if let Err(err) = server.await {
143/// # eprintln!("Server error: {}", err);
144/// # }
145/// # }
146/// ```
147pub fn upgrade_ws<H, R, B, E>(
148 handler: H,
149) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + Sync + 'static
150where
151 H: Fn(WebSocket) -> R + Copy + Send + Sync + 'static,
152 R: Future<Output = ()> + Send + 'static,
153 B: From<&'static str> + HttpBody + Send + 'static,
154 E: std::error::Error + Send + 'static,
155{
156 return upgrade_ws_with_config(handler, WebSocketConfig::default());
157}
158
159fn extract_upgradable_key(req: &Request<hyper::Body>) -> Option<SecWebsocketKey> {
160 let hdrs = req.headers();
161
162 hdrs.get(header::CONNECTION)
163 .and_then(|val| decode_header::<Connection>(val))
164 .and_then(|conn| some(conn.contains("upgrade")))
165 .and_then(|_| hdrs.get(header::UPGRADE))
166 .and_then(|val| val.to_str().ok())
167 .and_then(|val| some(val == "websocket"))
168 .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_VERSION))
169 .and_then(|val| val.to_str().ok())
170 .and_then(|val| some(val == "13"))
171 .and_then(|_| hdrs.get(header::SEC_WEBSOCKET_KEY))
172 .and_then(|val| decode_header::<SecWebsocketKey>(val))
173}
174
175fn decode_header<T: Header>(val: &HeaderValue) -> Option<T> {
176 let values = [val];
177 let mut iter = (&values).into_iter().copied();
178 T::decode(&mut iter).ok()
179}
180
181fn encode_header<T: Header>(h: T) -> HeaderValue {
182 let mut val = Vec::with_capacity(1);
183 h.encode(&mut val);
184 val.into_iter().nth(0).unwrap()
185}
186
187fn some(cond: bool) -> Option<()> {
188 if cond {
189 Some(())
190 } else {
191 None
192 }
193}