use std::net::SocketAddr;
use base64::{engine::general_purpose::STANDARD, Engine};
use sha1::{Digest, Sha1};
use tokio::sync::mpsc::Sender;
use webparse::{
ws::{OwnedMessage, WsError},
Response, WebError,
};
use crate::{Body, ProtError, ProtResult, RecvRequest, RecvResponse};
static MAGIC_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
pub struct WsHandshake {
pub sender: Sender<OwnedMessage>,
pub request: Option<RecvRequest>,
pub response: RecvResponse,
pub peer_addr: Option<SocketAddr>,
pub local_addr: Option<SocketAddr>,
}
impl WsHandshake {
pub fn new(
sender: Sender<OwnedMessage>,
request: Option<RecvRequest>,
response: RecvResponse,
peer_addr: Option<SocketAddr>,
) -> Self {
Self {
sender,
request,
response,
peer_addr,
local_addr: None,
}
}
pub fn build_accept(key: &str) -> ProtResult<String> {
match STANDARD.decode(key) {
Ok(vec) => {
if vec.len() != 16 {
return Err(ProtError::from(WebError::Ws(WsError::ProtocolError(
"Sec-WebSocket-Key must be 16 bytes",
))));
}
let mut array = [0u8; 16];
array[..16].clone_from_slice(&vec[..16]);
let mut concat_key = String::with_capacity(array.len() + 36);
concat_key.push_str(&key[..]);
concat_key.push_str(MAGIC_GUID);
let hash = Sha1::digest(concat_key.as_bytes());
let key: [u8; 20] = hash.into();
Ok(STANDARD.encode(key))
}
Err(_) => {
return Err(ProtError::from(WebError::Ws(WsError::ProtocolError(
"Invalid Sec-WebSocket-Accept",
))))
}
}
}
pub fn build_request(req: &RecvRequest) -> ProtResult<RecvResponse> {
let key = req.headers().get_str_value(&"Sec-WebSocket-Key");
let protocol = req.headers().get_str_value(&"Sec-WebSocket-Protocol");
let version = req.headers().get_str_value(&"Sec-WebSocket-Version");
if key.is_none() || version.as_ref().map(|s| &**s) != Some("13") {
return Ok(Response::builder()
.status(400)
.body("invalid websocket version")
.unwrap()
.into_type());
}
let (key, protocol) = (key.unwrap(), protocol.unwrap_or("chat".to_string()));
let accept = Self::build_accept(&key)?;
let protocols: Vec<&str> = protocol
.split(|c| c == ',' || c == ' ')
.filter(|s| !s.is_empty())
.collect();
return Ok(Response::builder()
.status(101)
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Accept", accept)
.header("Sec-WebSocket-Protocol", protocols[0].to_string())
.body(Body::empty())
.unwrap());
}
}