#![cfg_attr(docsrs, feature(doc_cfg))]
#![warn(missing_docs)]
#![deny(rustdoc::broken_intra_doc_links)]
#![doc(
html_logo_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
)]
#![doc(
html_favicon_url = "https://github.com/compio-rs/compio-logo/raw/refs/heads/master/generated/colored-bold.svg"
)]
use std::io::ErrorKind;
use compio_buf::IntoInner;
use compio_io::{AsyncRead, AsyncWrite, compat::SyncStream};
use tungstenite::{
Error as WsError, HandshakeError, Message, WebSocket,
client::IntoClientRequest,
handshake::server::{Callback, NoCallback},
protocol::{CloseFrame, Role, WebSocketConfig},
};
mod tls;
pub use tls::*;
pub use tungstenite;
pub struct Config {
websocket: Option<WebSocketConfig>,
buffer_size_base: usize,
buffer_size_limit: usize,
disable_nagle: bool,
}
impl Config {
const DEFAULT_BUF_SIZE: usize = 128 * 1024;
const DEFAULT_MAX_BUFFER: usize = 64 * 1024 * 1024;
pub fn new() -> Self {
Self {
websocket: None,
buffer_size_base: Self::DEFAULT_BUF_SIZE,
buffer_size_limit: Self::DEFAULT_MAX_BUFFER,
disable_nagle: false,
}
}
pub fn websocket_config(&self) -> Option<&WebSocketConfig> {
self.websocket.as_ref()
}
pub fn buffer_size_base(&self) -> usize {
self.buffer_size_base
}
pub fn buffer_size_limit(&self) -> usize {
self.buffer_size_limit
}
pub fn with_buffer_size_base(mut self, size: usize) -> Self {
self.buffer_size_base = size;
self
}
pub fn with_buffer_size_limit(mut self, size: usize) -> Self {
self.buffer_size_limit = size;
self
}
pub fn with_buffer_sizes(mut self, base: usize, limit: usize) -> Self {
self.buffer_size_base = base;
self.buffer_size_limit = limit;
self
}
pub fn disable_nagle(mut self, disable: bool) -> Self {
self.disable_nagle = disable;
self
}
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
impl From<WebSocketConfig> for Config {
fn from(config: WebSocketConfig) -> Self {
Self {
websocket: Some(config),
..Default::default()
}
}
}
impl From<Option<WebSocketConfig>> for Config {
fn from(config: Option<WebSocketConfig>) -> Self {
Self {
websocket: config,
..Default::default()
}
}
}
#[derive(Debug)]
pub struct WebSocketStream<S> {
inner: WebSocket<SyncStream<S>>,
}
impl<S> WebSocketStream<S>
where
S: AsyncRead + AsyncWrite,
{
pub async fn from_raw_socket(stream: S, role: Role, config: impl Into<Config>) -> Self {
let config = config.into();
let sync_stream =
SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
WebSocketStream {
inner: WebSocket::from_raw_socket(sync_stream, role, config.websocket),
}
}
pub async fn from_partially_read(
stream: S,
part: Vec<u8>,
role: Role,
config: impl Into<Config>,
) -> Self {
let config = config.into();
let sync_stream =
SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
WebSocketStream {
inner: WebSocket::from_partially_read(sync_stream, part, role, config.websocket),
}
}
pub async fn send(&mut self, message: Message) -> Result<(), WsError> {
self.inner.send(message)?;
self.flush().await
}
pub async fn read(&mut self) -> Result<Message, WsError> {
loop {
match self.inner.read() {
Ok(msg) => {
self.flush().await?;
return Ok(msg);
}
Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
self.inner
.get_mut()
.fill_read_buf()
.await
.map_err(WsError::Io)?;
}
Err(e) => {
let _ = self.flush().await;
return Err(e);
}
}
}
}
pub async fn flush(&mut self) -> Result<(), WsError> {
loop {
match self.inner.flush() {
Ok(()) => break,
Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
self.inner
.get_mut()
.flush_write_buf()
.await
.map_err(WsError::Io)?;
}
Err(WsError::ConnectionClosed) => break,
Err(e) => return Err(e),
}
}
self.inner
.get_mut()
.flush_write_buf()
.await
.map_err(WsError::Io)?;
Ok(())
}
pub async fn close(&mut self, close_frame: Option<CloseFrame>) -> Result<(), WsError> {
loop {
match self.inner.close(close_frame.clone()) {
Ok(()) => break,
Err(WsError::Io(ref e)) if e.kind() == ErrorKind::WouldBlock => {
let sync_stream = self.inner.get_mut();
let flushed = sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
if flushed == 0 {
sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
}
}
Err(WsError::ConnectionClosed) => break,
Err(e) => return Err(e),
}
}
self.flush().await
}
pub fn get_ref(&self) -> &S {
self.inner.get_ref().get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.inner.get_mut().get_mut()
}
}
impl<S> IntoInner for WebSocketStream<S> {
type Inner = WebSocket<SyncStream<S>>;
fn into_inner(self) -> Self::Inner {
self.inner
}
}
pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
{
accept_hdr_async(stream, NoCallback).await
}
pub async fn accept_async_with_config<S>(
stream: S,
config: impl Into<Config>,
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
{
accept_hdr_with_config_async(stream, NoCallback, config).await
}
pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
C: Callback,
{
accept_hdr_with_config_async(stream, callback, None).await
}
pub async fn accept_hdr_with_config_async<S, C>(
stream: S,
callback: C,
config: impl Into<Config>,
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite,
C: Callback,
{
let config = config.into();
let sync_stream =
SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
let mut handshake_result =
tungstenite::accept_hdr_with_config(sync_stream, callback, config.websocket);
loop {
match handshake_result {
Ok(mut websocket) => {
websocket
.get_mut()
.flush_write_buf()
.await
.map_err(WsError::Io)?;
return Ok(WebSocketStream { inner: websocket });
}
Err(HandshakeError::Interrupted(mut mid_handshake)) => {
let sync_stream = mid_handshake.get_mut().get_mut();
sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
handshake_result = mid_handshake.handshake();
}
Err(HandshakeError::Failure(error)) => {
return Err(error);
}
}
}
}
pub async fn client_async<R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
where
R: IntoClientRequest,
S: AsyncRead + AsyncWrite,
{
client_async_with_config(request, stream, None).await
}
pub async fn client_async_with_config<R, S>(
request: R,
stream: S,
config: impl Into<Config>,
) -> Result<(WebSocketStream<S>, tungstenite::handshake::client::Response), WsError>
where
R: IntoClientRequest,
S: AsyncRead + AsyncWrite,
{
let config = config.into();
let sync_stream =
SyncStream::with_limits(config.buffer_size_base, config.buffer_size_limit, stream);
let mut handshake_result =
tungstenite::client::client_with_config(request, sync_stream, config.websocket);
loop {
match handshake_result {
Ok((mut websocket, response)) => {
websocket
.get_mut()
.flush_write_buf()
.await
.map_err(WsError::Io)?;
return Ok((WebSocketStream { inner: websocket }, response));
}
Err(HandshakeError::Interrupted(mut mid_handshake)) => {
let sync_stream = mid_handshake.get_mut().get_mut();
sync_stream.flush_write_buf().await.map_err(WsError::Io)?;
sync_stream.fill_read_buf().await.map_err(WsError::Io)?;
handshake_result = mid_handshake.handshake();
}
Err(HandshakeError::Failure(error)) => {
return Err(error);
}
}
}
}