use base64;
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use hyper::Body;
use hyper::Request;
use hyper::Response;
use pin_project::pin_project;
use sha1::Digest;
use sha1::Sha1;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use crate::Role;
use crate::WebSocket;
fn sec_websocket_protocol(key: &[u8]) -> String {
let mut sha1 = Sha1::new();
sha1.update(key);
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); let result = sha1.finalize();
STANDARD.encode(&result[..])
}
type Error = Box<dyn std::error::Error + Send + Sync>;
#[pin_project]
#[derive(Debug)]
pub struct UpgradeFut {
#[pin]
inner: hyper::upgrade::OnUpgrade,
}
pub fn upgrade<B>(
mut request: impl std::borrow::BorrowMut<Request<B>>,
) -> Result<(Response<Body>, UpgradeFut), Error> {
let request = request.borrow_mut();
let key = request
.headers()
.get("Sec-WebSocket-Key")
.ok_or("Sec-WebSocket-Key header is missing")?;
if request
.headers()
.get("Sec-WebSocket-Version")
.map(|v| v.as_bytes())
!= Some(b"13")
{
return Err("Sec-WebSocket-Version must be 13".into());
}
let response = Response::builder()
.status(hyper::StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "upgrade")
.header(hyper::header::UPGRADE, "websocket")
.header(
"Sec-WebSocket-Accept",
&sec_websocket_protocol(key.as_bytes()),
)
.body(Body::from("switching to websocket protocol"))
.expect("bug: failed to build response");
let stream = UpgradeFut {
inner: hyper::upgrade::on(request),
};
Ok((response, stream))
}
pub fn is_upgrade_request<B>(request: &hyper::Request<B>) -> bool {
header_contains_value(request.headers(), hyper::header::CONNECTION, "Upgrade")
&& header_contains_value(
request.headers(),
hyper::header::UPGRADE,
"websocket",
)
}
fn header_contains_value(
headers: &hyper::HeaderMap,
header: impl hyper::header::AsHeaderName,
value: impl AsRef<[u8]>,
) -> bool {
let value = value.as_ref();
for header in headers.get_all(header) {
if header
.as_bytes()
.split(|&c| c == b',')
.any(|x| trim(x).eq_ignore_ascii_case(value))
{
return true;
}
}
false
}
fn trim(data: &[u8]) -> &[u8] {
trim_end(trim_start(data))
}
fn trim_start(data: &[u8]) -> &[u8] {
if let Some(start) = data.iter().position(|x| !x.is_ascii_whitespace()) {
&data[start..]
} else {
b""
}
}
fn trim_end(data: &[u8]) -> &[u8] {
if let Some(last) = data.iter().rposition(|x| !x.is_ascii_whitespace()) {
&data[..last + 1]
} else {
b""
}
}
impl std::future::Future for UpgradeFut {
type Output = Result<WebSocket<hyper::upgrade::Upgraded>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
let upgraded = match this.inner.poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(x) => x,
};
Poll::Ready(Ok(WebSocket::after_handshake(upgraded?, Role::Server)))
}
}