use crate::error::WebsocketError;
use crate::stream::WebsocketStream;
use crate::util::base64::Base64Encode;
use crate::util::sha1::SHA1Hash;
use crate::MAGIC_STRING;
use humphrey::http::headers::HeaderType;
use humphrey::http::{Request, Response, StatusCode};
use humphrey::stream::Stream;
use std::io::Write;
use std::sync::mpsc::Sender;
use std::sync::{Arc, Mutex};
pub trait WebsocketHandler<S>: Fn(WebsocketStream, Arc<S>) + Send + Sync {}
impl<T, S> WebsocketHandler<S> for T where T: Fn(WebsocketStream, Arc<S>) + Send + Sync {}
pub fn websocket_handler<T, S>(handler: T) -> impl Fn(Request, Stream, Arc<S>)
where
T: WebsocketHandler<S>,
{
move |request: Request, mut stream: Stream, state: Arc<S>| {
if handshake(request, &mut stream).is_ok() {
handler(WebsocketStream::new(stream), state);
}
}
}
pub fn async_websocket_handler<S>(
hook: Arc<Mutex<Sender<WebsocketStream>>>,
) -> impl Fn(Request, Stream, Arc<S>) {
move |request: Request, mut stream: Stream, _: Arc<S>| {
if handshake(request, &mut stream).is_ok() {
hook.lock().unwrap().send(WebsocketStream::new(stream)).ok();
}
}
}
fn handshake(request: Request, stream: &mut Stream) -> Result<(), WebsocketError> {
let handshake_key = request
.headers
.get("Sec-WebSocket-Key")
.ok_or(WebsocketError::HandshakeError)?;
let sec_websocket_accept = format!("{}{}", handshake_key, MAGIC_STRING).hash().encode();
let response = Response::empty(StatusCode::SwitchingProtocols)
.with_header(HeaderType::Upgrade, "websocket")
.with_header(HeaderType::Connection, "Upgrade")
.with_header("Sec-WebSocket-Accept", sec_websocket_accept);
let response_bytes: Vec<u8> = response.into();
stream
.write_all(&response_bytes)
.map_err(|_| WebsocketError::WriteError)?;
Ok(())
}