1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
use crate::{WebSocket, WebSocketConfig};
use futures::future::{ok, Ready};
use headers::{Connection, Header, SecWebsocketAccept, SecWebsocketKey, Upgrade};
use hyper::{
body::HttpBody,
header::{self, HeaderValue},
Request, Response, StatusCode,
};
use routerify::ext::RequestExt;
use std::future::Future;
/// Upgrades the http requests to websocket with the provided [config](./struct.WebSocketConfig.html).
///
/// # Examples
///
/// ```no_run
/// # use hyper::{Body, Response, Server};
/// # use routerify::{Router, RouterService};
/// # // Import websocket types.
/// use routerify_websocket::{upgrade_ws_with_config, WebSocket, WebSocketConfig};
/// # use std::{convert::Infallible, net::SocketAddr};
///
/// # // A handler for websocket connections.
/// async fn ws_handler(ws: WebSocket) {
/// println!("New websocket connection: {}", ws.remote_addr());
/// // Handle websocket connection.
/// }
///
/// fn router() -> Router<Body, Infallible> {
/// // Create a router and specify the path and the handler for new websocket connections.
/// Router::builder()
/// // Upgrade the http requests at `/ws` path to websocket with the following config.
/// .any_method("/ws", upgrade_ws_with_config(ws_handler, WebSocketConfig::default()))
/// .build()
/// .unwrap()
/// }
///
/// # #[tokio::main]
/// # async fn main() {
/// # let router = router();
/// #
/// # // Create a Service from the router above to handle incoming requests.
/// # let service = RouterService::new(router).unwrap();
/// #
/// # // The address on which the server will be listening.
/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
/// #
/// # // Create a server by passing the created service to `.serve` method.
/// # let server = Server::bind(&addr).serve(service);
/// #
/// # println!("App is running on: {}", addr);
/// # if let Err(err) = server.await {
/// # eprintln!("Server error: {}", err);
/// # }
/// # }
/// ```
pub fn upgrade_ws_with_config<H, R, B, E>(
handler: H,
config: WebSocketConfig,
) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + Sync + 'static
where
H: Fn(WebSocket) -> R + Copy + Send + Sync + 'static,
R: Future<Output = ()> + Send + 'static,
B: From<&'static str> + HttpBody + Send + 'static,
E: std::error::Error + Send + 'static,
{
return move |req: Request<hyper::Body>| {
let sec_key = extract_upgradable_key(&req);
let remote_addr = req.remote_addr();
if sec_key.is_none() {
return ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("BAD REQUEST: The request is not websocket".into())
.unwrap());
}
tokio::spawn(async move {
match hyper::upgrade::on(req).await {
Ok(upgraded) => {
handler(WebSocket::from_raw_socket(upgraded, remote_addr, config).await).await;
}
Err(err) => log::error!("{}", crate::WebsocketError::Upgrade(err.into())),
}
});
let resp = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, encode_header(Connection::upgrade()))
.header(header::UPGRADE, encode_header(Upgrade::websocket()))
.header(
header::SEC_WEBSOCKET_ACCEPT,
encode_header(SecWebsocketAccept::from(sec_key.unwrap())),
)
.body("".into())
.unwrap();
ok(resp)
};
}
/// Upgrades the http requests to websocket.
///
/// # Examples
///
/// ```no_run
/// # use hyper::{Body, Response, Server};
/// # use routerify::{Router, RouterService};
/// # // Import websocket types.
/// use routerify_websocket::{upgrade_ws, WebSocket};
/// # use std::{convert::Infallible, net::SocketAddr};
///
/// # // A handler for websocket connections.
/// async fn ws_handler(ws: WebSocket) {
/// println!("New websocket connection: {}", ws.remote_addr());
/// // Handle websocket connection.
/// }
///
/// fn router() -> Router<Body, Infallible> {
/// // Create a router and specify the path and the handler for new websocket connections.
/// Router::builder()
/// // Upgrade the http requests at `/ws` path to websocket.
/// .any_method("/ws", upgrade_ws(ws_handler))
/// .build()
/// .unwrap()
/// }
///
/// # #[tokio::main]
/// # async fn main() {
/// # let router = router();
/// #
/// # // Create a Service from the router above to handle incoming requests.
/// # let service = RouterService::new(router).unwrap();
/// #
/// # // The address on which the server will be listening.
/// # let addr = SocketAddr::from(([127, 0, 0, 1], 3001));
/// #
/// # // Create a server by passing the created service to `.serve` method.
/// # let server = Server::bind(&addr).serve(service);
/// #
/// # println!("App is running on: {}", addr);
/// # if let Err(err) = server.await {
/// # eprintln!("Server error: {}", err);
/// # }
/// # }
/// ```
pub fn upgrade_ws<H, R, B, E>(
handler: H,
) -> impl Fn(Request<hyper::Body>) -> Ready<Result<Response<B>, E>> + Send + Sync + 'static
where
H: Fn(WebSocket) -> R + Copy + Send + Sync + 'static,
R: Future<Output = ()> + Send + 'static,
B: From<&'static str> + HttpBody + Send + 'static,
E: std::error::Error + Send + 'static,
{
return upgrade_ws_with_config(handler, WebSocketConfig::default());
}
fn extract_upgradable_key(req: &Request<hyper::Body>) -> Option<SecWebsocketKey> {
let hdrs = req.headers();
hdrs.get(header::CONNECTION)
.and_then(|val| decode_header::<Connection>(val))
.and_then(|conn| some(conn.contains("upgrade")))
.and_then(|_| hdrs.get(header::UPGRADE))
.and_then(|val| val.to_str().ok())
.and_then(|val| some(val == "websocket"))
.and_then(|_| hdrs.get(header::SEC_WEBSOCKET_VERSION))
.and_then(|val| val.to_str().ok())
.and_then(|val| some(val == "13"))
.and_then(|_| hdrs.get(header::SEC_WEBSOCKET_KEY))
.and_then(|val| decode_header::<SecWebsocketKey>(val))
}
fn decode_header<T: Header>(val: &HeaderValue) -> Option<T> {
let values = [val];
let mut iter = (&values).into_iter().copied();
T::decode(&mut iter).ok()
}
fn encode_header<T: Header>(h: T) -> HeaderValue {
let mut val = Vec::with_capacity(1);
h.encode(&mut val);
val.into_iter().nth(0).unwrap()
}
fn some(cond: bool) -> Option<()> {
if cond {
Some(())
} else {
None
}
}