use std::future::Future;
use std::io::ErrorKind;
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD;
use compio_io::compat::SyncStream;
use compio_ws::tungstenite;
pub use compio_ws::tungstenite::Message;
use compio_ws::tungstenite::protocol::CloseFrame;
use compio_ws::tungstenite::protocol::Role;
use compio_ws::tungstenite::protocol::WebSocketConfig;
use futures_util::FutureExt;
use http::StatusCode;
use http::header;
use hyper::upgrade::Upgraded;
use sha1::Digest;
use sha1::Sha1;
use crate::body::TakoBody;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
pub struct UpgradedStream {
inner: Upgraded,
}
impl UpgradedStream {
pub fn new(upgraded: Upgraded) -> Self {
Self { inner: upgraded }
}
}
impl compio_io::AsyncRead for UpgradedStream {
async fn read<B: compio_buf::IoBufMut>(&mut self, mut buf: B) -> compio_buf::BufResult<usize, B> {
use std::pin::Pin;
use std::task::Context;
use hyper::rt::Read;
let slice = buf.as_mut_slice();
let len = slice.len();
let mut temp_buf = vec![0u8; len];
let result = std::future::poll_fn(|cx: &mut Context<'_>| {
let mut read_buf = hyper::rt::ReadBuf::new(&mut temp_buf);
match Pin::new(&mut self.inner).poll_read(cx, read_buf.unfilled()) {
std::task::Poll::Ready(Ok(())) => std::task::Poll::Ready(Ok(read_buf.filled().len())),
std::task::Poll::Ready(Err(e)) => std::task::Poll::Ready(Err(e)),
std::task::Poll::Pending => std::task::Poll::Pending,
}
})
.await;
match result {
Ok(filled_len) => {
if filled_len > 0 {
let dest =
unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, filled_len) };
dest.copy_from_slice(&temp_buf[..filled_len]);
}
unsafe { buf.set_buf_init(filled_len) };
(Ok(filled_len), buf).into()
}
Err(e) => (Err(e), buf).into(),
}
}
}
impl compio_io::AsyncWrite for UpgradedStream {
async fn write<T: compio_buf::IoBuf>(&mut self, buf: T) -> compio_buf::BufResult<usize, T> {
use std::pin::Pin;
use std::task::Context;
use hyper::rt::Write;
let slice = buf.as_slice();
let result =
std::future::poll_fn(|cx: &mut Context<'_>| Pin::new(&mut self.inner).poll_write(cx, slice))
.await;
match result {
Ok(n) => (Ok(n), buf).into(),
Err(e) => (Err(e), buf).into(),
}
}
async fn flush(&mut self) -> std::io::Result<()> {
use std::pin::Pin;
use std::task::Context;
use hyper::rt::Write;
std::future::poll_fn(|cx: &mut Context<'_>| Pin::new(&mut self.inner).poll_flush(cx)).await
}
async fn shutdown(&mut self) -> std::io::Result<()> {
use std::pin::Pin;
use std::task::Context;
use hyper::rt::Write;
std::future::poll_fn(|cx: &mut Context<'_>| Pin::new(&mut self.inner).poll_shutdown(cx)).await
}
}
pub struct CompioWebSocket<S> {
inner: tungstenite::WebSocket<SyncStream<S>>,
}
impl<S> CompioWebSocket<S>
where
S: compio_io::AsyncRead + compio_io::AsyncWrite,
{
const DEFAULT_BUF_SIZE: usize = 128 * 1024;
const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
pub fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self {
let sync_stream =
SyncStream::with_limits(Self::DEFAULT_BUF_SIZE, Self::DEFAULT_MAX_BUFFER, stream);
let ws = tungstenite::WebSocket::from_raw_socket(sync_stream, role, config);
Self { inner: ws }
}
pub async fn send(&mut self, message: Message) -> Result<(), tungstenite::Error> {
self.inner.send(message)?;
self.flush().await
}
pub async fn read(&mut self) -> Result<Message, tungstenite::Error> {
loop {
match self.inner.read() {
Ok(msg) => {
let _ = self.flush().await;
return Ok(msg);
}
Err(tungstenite::Error::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
self
.inner
.get_mut()
.fill_read_buf()
.await
.map_err(tungstenite::Error::Io)?;
}
Err(e) => {
let _ = self.flush().await;
return Err(e);
}
}
}
}
pub async fn flush(&mut self) -> Result<(), tungstenite::Error> {
loop {
match self.inner.flush() {
Ok(()) => break,
Err(tungstenite::Error::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
self
.inner
.get_mut()
.flush_write_buf()
.await
.map_err(tungstenite::Error::Io)?;
}
Err(tungstenite::Error::ConnectionClosed) => break,
Err(e) => return Err(e),
}
}
self
.inner
.get_mut()
.flush_write_buf()
.await
.map_err(tungstenite::Error::Io)?;
Ok(())
}
pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), tungstenite::Error> {
loop {
match self.inner.close(close_frame.clone()) {
Ok(()) => break,
Err(tungstenite::Error::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
let sync_stream = self.inner.get_mut();
let flushed = sync_stream
.flush_write_buf()
.await
.map_err(tungstenite::Error::Io)?;
if flushed == 0 {
sync_stream
.fill_read_buf()
.await
.map_err(tungstenite::Error::Io)?;
}
}
Err(tungstenite::Error::ConnectionClosed) => break,
Err(e) => return Err(e),
}
}
self.flush().await
}
pub fn get_ref(&self) -> &S {
self.inner.get_ref().get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.inner.get_mut().get_mut()
}
}
#[doc(alias = "websocket")]
#[doc(alias = "ws")]
pub struct TakoWsCompio<H, Fut>
where
H: FnOnce(CompioWebSocket<UpgradedStream>) -> Fut + 'static,
Fut: Future<Output = ()> + 'static,
{
request: Request,
handler: H,
}
impl<H, Fut> TakoWsCompio<H, Fut>
where
H: FnOnce(CompioWebSocket<UpgradedStream>) -> Fut + 'static,
Fut: Future<Output = ()> + 'static,
{
pub fn new(request: Request, handler: H) -> Self {
Self { request, handler }
}
}
impl<H, Fut> Responder for TakoWsCompio<H, Fut>
where
H: FnOnce(CompioWebSocket<UpgradedStream>) -> Fut + 'static,
Fut: Future<Output = ()> + 'static,
{
fn into_response(self) -> Response {
let (parts, body) = self.request.into_parts();
let req = http::Request::from_parts(parts, body);
let key = match req.headers().get("Sec-WebSocket-Key") {
Some(k) => k,
None => {
return http::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(TakoBody::from("Missing Sec-WebSocket-Key".to_string()))
.expect("valid bad request response");
}
};
let accept = {
let mut sha1 = Sha1::new();
sha1.update(key.as_bytes());
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
STANDARD.encode(sha1.finalize())
};
let response = http::Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::UPGRADE, "websocket")
.header(header::CONNECTION, "Upgrade")
.header("Sec-WebSocket-Accept", accept)
.body(TakoBody::empty())
.expect("valid WebSocket upgrade response");
if let Some(on_upgrade) = req.extensions().get::<hyper::upgrade::OnUpgrade>().cloned() {
let handler = self.handler;
compio::runtime::spawn(async move {
if let Ok(upgraded) = on_upgrade.await {
let stream = UpgradedStream::new(upgraded);
let ws = CompioWebSocket::from_raw_socket(stream, Role::Server, None);
let _ = std::panic::AssertUnwindSafe(handler(ws))
.catch_unwind()
.await;
}
})
.detach();
}
response
}
}