use crate::error::WebsocketError;
use crate::stream::WebsocketStream;
use crate::util::base64::Base64Encode;
use crate::util::sha1::SHA1Hash;
use crate::MAGIC_STRING;
use humphrey::http::headers::{RequestHeader, ResponseHeader};
use humphrey::http::{Request, Response, StatusCode};
use humphrey::stream::Stream;
use std::io::Write;
use std::sync::mpsc::Sender;
use std::sync::{Arc, Mutex};
pub trait WebsocketHandler<S>: Fn(WebsocketStream<Stream>, Arc<S>) + Send + Sync {}
impl<T, S> WebsocketHandler<S> for T where T: Fn(WebsocketStream<Stream>, Arc<S>) + Send + Sync {}
pub fn websocket_handler<T, S>(handler: T) -> impl Fn(Request, Stream, Arc<S>)
where
T: WebsocketHandler<S>,
{
move |request: Request, mut stream: Stream, state: Arc<S>| {
if handshake(request, &mut stream).is_ok() {
handler(WebsocketStream::new(stream), state);
}
}
}
pub fn async_websocket_handler<S>(
hook: Arc<Mutex<Sender<WebsocketStream<Stream>>>>,
) -> impl Fn(Request, Stream, Arc<S>) {
move |request: Request, mut stream: Stream, _: Arc<S>| {
if handshake(request, &mut stream).is_ok() {
hook.lock()
.unwrap()
.send(WebsocketStream::new(stream))
.unwrap();
}
}
}
fn handshake(request: Request, stream: &mut Stream) -> Result<(), WebsocketError> {
let handshake_key = request
.headers
.get(&RequestHeader::Custom {
name: "sec-websocket-key".into(),
})
.ok_or(WebsocketError::HandshakeError)?;
let sec_websocket_accept = format!("{}{}", handshake_key, MAGIC_STRING).hash().encode();
let response = Response::empty(StatusCode::SwitchingProtocols)
.with_header(ResponseHeader::Upgrade, "websocket".into())
.with_header(ResponseHeader::Connection, "Upgrade".into())
.with_header(
ResponseHeader::Custom {
name: "Sec-WebSocket-Accept".into(),
},
sec_websocket_accept,
);
let response_bytes: Vec<u8> = response.into();
stream
.write_all(&response_bytes)
.map_err(|_| WebsocketError::WriteError)?;
Ok(())
}