use std::{collections::HashMap, io};
use crate::{headers, Request};
use base64::engine::general_purpose::STANDARD as BASE64ENGINE;
use base64::Engine;
use sha1::{Digest, Sha1};
pub(crate) use tungstenite::WebSocket;
fn build_handshake(sec_key: String) -> HashMap<&'static str, String> {
let mut sha1 = Sha1::new();
sha1.update(sec_key.as_bytes());
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
let accept_value = BASE64ENGINE.encode(sha1.finalize());
headers! {
"Upgrade" => "websocket",
"Connection" => "Upgrade",
"Sec-WebSocket-Accept" => accept_value,
}
}
impl Request {
pub fn is_websocket(&self) -> bool {
self.headers
.get("Upgrade")
.map(|value| value == "websocket")
.unwrap_or(false)
&& self.headers.contains_key("Sec-WebSocket-Key")
}
pub fn upgrade<T: io::Write>(&mut self, mut stream: T) -> Option<WebSocket<T>> {
if !self.is_websocket() {
return None;
}
let ws_key = self.headers.get("Sec-WebSocket-Key")?.clone();
let handshake = build_handshake(ws_key);
crate::response!(switching_protocols, [], handshake)
.send_to(&mut stream)
.ok()?;
Some(WebSocket::from_raw_socket(
stream,
tungstenite::protocol::Role::Server,
None,
))
}
}
#[cfg(feature = "websocket")]
pub fn maybe_websocket<Stream: io::Write>(
handler: Option<(&'static str, fn(WebSocket<&mut Stream>))>,
stream: &mut Stream,
req: &mut Request,
) -> bool {
let handler = match handler {
Some((path, f)) if req.url.starts_with(path) => f,
_ => return false,
};
req.upgrade(stream).map(handler);
true
}