use crate::buffer::{BufferPoolRef, default_buffer_pool_ref};
use crate::service::select::Selectable;
use crate::stream::tcp::TcpStream;
#[cfg(any(feature = "rustls", feature = "openssl"))]
use crate::stream::tls::{IntoTlsStream, TlsReadyStream, TlsStream};
use crate::stream::{BindAndConnect, ConnectionInfoProvider};
use crate::util::NoBlock;
use crate::ws::Error::{Closed, ReceivedCloseFrame};
use crate::ws::decoder::Decoder;
pub use crate::ws::error::Error;
use crate::ws::handshake::Handshaker;
#[cfg(feature = "mio")]
use mio::{Interest, Registry, Token, event::Source};
use std::fmt::Debug;
use std::io;
use std::io::ErrorKind::WouldBlock;
use std::io::{Read, Write};
use thiserror::Error;
use url::Url;
mod decoder;
pub mod ds;
mod encoder;
mod error;
mod handshake;
mod protocol;
pub mod util;
pub enum WebsocketFrame {
Ping(&'static [u8]),
Pong(&'static [u8]),
Text(bool, &'static [u8]),
Binary(bool, &'static [u8]),
Continuation(bool, &'static [u8]),
Close(&'static [u8]),
}
#[derive(Debug)]
pub struct Websocket<S> {
stream: S,
closed: bool,
state: State,
}
impl<S> Websocket<S> {
pub fn new(stream: S, endpoint: &str) -> Websocket<S>
where
S: ConnectionInfoProvider,
{
let connection_info = stream.connection_info().clone();
let server_name = connection_info.host();
Self {
stream,
closed: false,
state: State::handshake(server_name, endpoint, default_buffer_pool_ref()),
}
}
pub fn new_with_handshake_complete(stream: S) -> Websocket<S> {
Self {
stream,
closed: false,
state: State::connection(default_buffer_pool_ref()),
}
}
pub const fn closed(&self) -> bool {
self.closed
}
#[inline]
pub const fn handshake_complete(&self) -> bool {
match self.state {
State::Handshake(_, _) => false,
State::Connection(_) => true,
}
}
}
impl<S: Read + Write> Websocket<S> {
#[inline]
pub fn read_batch(&mut self) -> Result<Batch<'_, S>, Error> {
match self.state.read(&mut self.stream).no_block() {
Ok(()) => Ok(Batch { websocket: self }),
Err(err) => {
self.closed = true;
Err(err)?
}
}
}
#[inline]
pub fn receive_next(&mut self) -> Option<Result<WebsocketFrame, Error>> {
match self.read_batch() {
Ok(mut batch) => batch.receive_next(),
Err(err) => Some(Err(err)),
}
}
#[inline]
pub fn send_text(&mut self, fin: bool, body: Option<&[u8]>) -> Result<(), Error> {
self.send(fin, protocol::op::TEXT_FRAME, body)
}
#[inline]
pub fn send_binary(&mut self, fin: bool, body: Option<&[u8]>) -> Result<(), Error> {
self.send(fin, protocol::op::BINARY_FRAME, body)
}
#[inline]
pub fn send_pong(&mut self, body: Option<&[u8]>) -> Result<(), Error> {
self.send(true, protocol::op::PONG, body)
}
#[inline]
pub fn send_ping(&mut self, body: Option<&[u8]>) -> Result<(), Error> {
self.send(true, protocol::op::PING, body)
}
#[inline]
pub fn send_close(&mut self) -> Result<(), Error> {
self.send(true, protocol::op::CONNECTION_CLOSE, None)?;
self.closed = true;
Ok(())
}
#[inline]
fn next(&mut self) -> Result<Option<WebsocketFrame>, Error> {
self.ensure_not_closed()?;
match self.state.next(&mut self.stream) {
Ok(frame) => Ok(frame),
Err(err) => {
self.closed = true;
Err(err)?
}
}
}
#[inline]
fn send(&mut self, fin: bool, op_code: u8, body: Option<&[u8]>) -> Result<(), Error> {
self.ensure_not_closed()?;
match self.state.send(&mut self.stream, fin, op_code, body) {
Ok(()) => Ok(()),
Err(err) => {
self.closed = true;
Err(err)?
}
}
}
#[inline]
const fn ensure_not_closed(&self) -> Result<(), Error> {
if self.closed {
return Err(Closed);
}
Ok(())
}
}
#[cfg(feature = "mio")]
impl<S: Source> Source for Websocket<S> {
fn register(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> {
registry.register(&mut self.stream, token, interests)
}
fn reregister(&mut self, registry: &Registry, token: Token, interests: Interest) -> io::Result<()> {
registry.reregister(&mut self.stream, token, interests)
}
fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
registry.deregister(&mut self.stream)
}
}
impl<S: Selectable> Selectable for Websocket<S> {
fn connected(&mut self) -> io::Result<bool> {
self.stream.connected()
}
fn make_writable(&mut self) -> io::Result<()> {
self.stream.make_writable()
}
fn make_readable(&mut self) -> io::Result<()> {
self.stream.make_readable()
}
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum State {
Handshake(Handshaker, BufferPoolRef),
Connection(Decoder),
}
impl State {
pub fn handshake(server_name: &str, endpoint: &str, mut pool: BufferPoolRef) -> Self {
Self::Handshake(Handshaker::new(server_name, endpoint, &mut pool), pool)
}
pub fn connection(mut pool: BufferPoolRef) -> Self {
Self::Connection(Decoder::new(&mut pool))
}
}
impl State {
#[inline]
fn read<S: Read>(&mut self, stream: &mut S) -> io::Result<()> {
match self {
State::Handshake(handshake, _) => handshake.read(stream),
State::Connection(decoder) => decoder.read(stream),
}
}
#[inline]
fn next<S: Read + Write>(&mut self, stream: &mut S) -> Result<Option<WebsocketFrame>, Error> {
match self {
State::Handshake(handshake, pool) => match handshake.perform_handshake(stream) {
Ok(()) => {
handshake.drain_pending_message_buffer(stream, encoder::send)?;
*self = State::connection(pool.clone());
Ok(None)
}
Err(err) if err.kind() == WouldBlock => Ok(None),
Err(err) => Err(err)?,
},
State::Connection(decoder) => match decoder.decode_next() {
Ok(Some(WebsocketFrame::Ping(payload))) => {
self.send(stream, true, protocol::op::PONG, Some(payload))?;
Ok(None)
}
Ok(Some(WebsocketFrame::Close(payload))) => {
let _ = self.send(stream, true, protocol::op::CONNECTION_CLOSE, Some(payload));
let (status_code, body) = payload.split_at(std::mem::size_of::<u16>());
let status_code = u16::from_be_bytes(status_code.try_into()?);
let body = String::from_utf8_lossy(body).to_string();
Err(ReceivedCloseFrame(status_code, body))
}
Ok(frame) => Ok(frame),
Err(err) => Err(err)?,
},
}
}
#[inline]
fn send<S: Write>(&mut self, stream: &mut S, fin: bool, op_code: u8, body: Option<&[u8]>) -> Result<(), Error> {
match self {
State::Handshake(handshake, _) => {
handshake.buffer_message(fin, op_code, body);
Ok(())
}
State::Connection(_) => {
encoder::send(stream, fin, op_code, body)?;
Ok(())
}
}
}
}
pub struct Batch<'a, S> {
websocket: &'a mut Websocket<S>,
}
impl<'a, S: Read + Write> IntoIterator for Batch<'a, S> {
type Item = Result<WebsocketFrame, Error>;
type IntoIter = BatchIter<'a, S>;
fn into_iter(self) -> Self::IntoIter {
BatchIter { batch: self }
}
}
impl<S: Read + Write> Batch<'_, S> {
pub fn receive_next(&mut self) -> Option<Result<WebsocketFrame, Error>> {
self.websocket.next().transpose()
}
}
pub struct BatchIter<'a, S> {
batch: Batch<'a, S>,
}
impl<S: Read + Write> Iterator for BatchIter<'_, S> {
type Item = Result<WebsocketFrame, Error>;
fn next(&mut self) -> Option<Self::Item> {
self.batch.receive_next()
}
}
pub trait IntoWebsocket {
fn into_websocket(self, endpoint: &str) -> Websocket<Self>
where
Self: Sized;
}
impl<T> IntoWebsocket for T
where
T: Read + Write + ConnectionInfoProvider,
{
fn into_websocket(self, endpoint: &str) -> Websocket<Self>
where
Self: Sized,
{
Websocket::new(self, endpoint)
}
}
#[cfg(any(feature = "rustls", feature = "openssl"))]
pub trait IntoTlsWebsocket {
fn into_tls_websocket(self, endpoint: &str) -> io::Result<Websocket<TlsStream<Self>>>
where
Self: Sized;
}
#[cfg(any(feature = "rustls", feature = "openssl"))]
impl<T> IntoTlsWebsocket for T
where
T: Read + Write + Debug + ConnectionInfoProvider,
{
fn into_tls_websocket(self, endpoint: &str) -> io::Result<Websocket<TlsStream<Self>>>
where
Self: Sized,
{
Ok(self.into_tls_stream()?.into_websocket(endpoint))
}
}
#[cfg(any(feature = "rustls", feature = "openssl"))]
pub trait TryIntoTlsReadyWebsocket {
fn try_into_tls_ready_websocket(self) -> io::Result<Websocket<TlsReadyStream<TcpStream>>>
where
Self: Sized;
}
#[cfg(any(feature = "rustls", feature = "openssl"))]
impl<T> TryIntoTlsReadyWebsocket for T
where
T: AsRef<str>,
{
fn try_into_tls_ready_websocket(self) -> io::Result<Websocket<TlsReadyStream<TcpStream>>>
where
Self: Sized,
{
let url = Url::parse(self.as_ref()).map_err(io::Error::other)?;
let addr = url.socket_addrs(|| match url.scheme() {
"ws" => Some(80),
"wss" => Some(443),
_ => None,
})?;
let endpoint = match url.query() {
Some(query) => format!("{}?{}", url.path(), query),
None => url.path().to_string(),
};
let stream = std::net::TcpStream::bind_and_connect(addr[0], None, None)?;
let stream = TcpStream::new(stream, url.clone().try_into()?);
let tls_ready_stream = match url.scheme() {
"ws" => Ok(TlsReadyStream::Plain(stream)),
"wss" => Ok(TlsReadyStream::Tls(TlsStream::new(stream, url.host_str().unwrap()).unwrap())),
scheme => Err(io::Error::other(format!("unrecognised url scheme: {scheme}"))),
}?;
Ok(Websocket::new(tls_ready_stream, &endpoint))
}
}