#![deny(
missing_docs,
unused_must_use,
unused_mut,
unused_imports,
unused_import_braces
)]
pub use tungstenite;
mod compat;
mod handshake;
#[cfg(any(
feature = "async-tls",
feature = "async-native-tls",
feature = "smol-native-tls",
feature = "futures-rustls-manual-roots",
feature = "futures-rustls-webpki-roots",
feature = "futures-rustls-native-certs",
feature = "futures-rustls-platform-verifier",
feature = "tokio-native-tls",
feature = "tokio-rustls-manual-roots",
feature = "tokio-rustls-native-certs",
feature = "tokio-rustls-platform-verifier",
feature = "tokio-rustls-webpki-roots",
feature = "tokio-openssl",
))]
pub mod stream;
use std::{
io::{Read, Write},
pin::Pin,
sync::{Arc, Mutex, MutexGuard},
task::{ready, Context, Poll},
};
use compat::{cvt, AllowStd, ContextWaker};
use futures_core::stream::{FusedStream, Stream};
use futures_io::{AsyncRead, AsyncWrite};
use log::*;
#[cfg(feature = "handshake")]
use tungstenite::{
client::IntoClientRequest,
handshake::{
client::{ClientHandshake, Response},
server::{Callback, NoCallback},
HandshakeError,
},
};
use tungstenite::{
error::Error as WsError,
protocol::{Message, Role, WebSocket, WebSocketConfig},
};
#[cfg(feature = "async-std-runtime")]
#[deprecated = "async-std is unmaintained upstream. Please use the smol runtime instead."]
pub mod async_std;
#[cfg(feature = "async-tls")]
pub mod async_tls;
#[cfg(feature = "gio-runtime")]
pub mod gio;
#[cfg(feature = "smol-runtime")]
pub mod smol;
#[cfg(feature = "tokio-runtime")]
pub mod tokio;
pub mod bytes;
pub use bytes::ByteReader;
pub use bytes::ByteWriter;
use tungstenite::protocol::CloseFrame;
#[cfg(feature = "handshake")]
pub async fn client_async<'a, R, S>(
request: R,
stream: S,
) -> Result<(WebSocketStream<S>, Response), WsError>
where
R: IntoClientRequest + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
client_async_with_config(request, stream, None).await
}
#[cfg(feature = "handshake")]
pub async fn client_async_with_config<'a, R, S>(
request: R,
stream: S,
config: Option<WebSocketConfig>,
) -> Result<(WebSocketStream<S>, Response), WsError>
where
R: IntoClientRequest + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
let f = handshake::client_handshake(stream, move |allow_std| {
let request = request.into_client_request()?;
let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
cli_handshake.handshake()
});
f.await.map_err(|e| match e {
HandshakeError::Failure(e) => e,
e => WsError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)),
})
}
#[cfg(feature = "handshake")]
pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
accept_hdr_async(stream, NoCallback).await
}
#[cfg(feature = "handshake")]
pub async fn accept_async_with_config<S>(
stream: S,
config: Option<WebSocketConfig>,
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
accept_hdr_async_with_config(stream, NoCallback, config).await
}
#[cfg(feature = "handshake")]
pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
C: Callback + Unpin,
{
accept_hdr_async_with_config(stream, callback, None).await
}
#[cfg(feature = "handshake")]
pub async fn accept_hdr_async_with_config<S, C>(
stream: S,
callback: C,
config: Option<WebSocketConfig>,
) -> Result<WebSocketStream<S>, WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
C: Callback + Unpin,
{
let f = handshake::server_handshake(stream, move |allow_std| {
tungstenite::accept_hdr_with_config(allow_std, callback, config)
});
f.await.map_err(|e| match e {
HandshakeError::Failure(e) => e,
e => WsError::Io(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
)),
})
}
#[derive(Debug)]
pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>,
#[cfg(feature = "futures-03-sink")]
closing: bool,
ended: bool,
ready: bool,
}
impl<S> WebSocketStream<S> {
pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_raw_socket(allow_std, role, config)
})
.await
}
pub async fn from_partially_read(
stream: S,
part: Vec<u8>,
role: Role,
config: Option<WebSocketConfig>,
) -> Self
where
S: AsyncRead + AsyncWrite + Unpin,
{
handshake::without_handshake(stream, move |allow_std| {
WebSocket::from_partially_read(allow_std, part, role, config)
})
.await
}
pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
Self {
inner: ws,
#[cfg(feature = "futures-03-sink")]
closing: false,
ended: false,
ready: true,
}
}
fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
where
F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
AllowStd<S>: Read + Write,
{
#[cfg(feature = "verbose-logging")]
trace!("{}:{} WebSocketStream.with_context", file!(), line!());
if let Some((kind, ctx)) = ctx {
self.inner.get_mut().set_waker(kind, ctx.waker());
}
f(&mut self.inner)
}
pub fn into_inner(self) -> S {
self.inner.into_inner().into_inner()
}
pub fn get_ref(&self) -> &S
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.inner.get_ref().get_ref()
}
pub fn get_mut(&mut self) -> &mut S
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.inner.get_mut().get_mut()
}
pub fn get_config(&self) -> &WebSocketConfig {
self.inner.get_config()
}
pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.send(Message::Close(msg)).await
}
pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
let shared = Arc::new(Shared(Mutex::new(self)));
let sender = WebSocketSender {
shared: shared.clone(),
};
let receiver = WebSocketReceiver { shared };
(sender, receiver)
}
pub fn reunite(
sender: WebSocketSender<S>,
receiver: WebSocketReceiver<S>,
) -> Result<Self, (WebSocketSender<S>, WebSocketReceiver<S>)> {
if sender.is_pair_of(&receiver) {
drop(receiver);
let stream = Arc::try_unwrap(sender.shared)
.ok()
.expect("reunite the stream")
.into_inner();
Ok(stream)
} else {
Err((sender, receiver))
}
}
}
impl<S> WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Message, WsError>>> {
#[cfg(feature = "verbose-logging")]
trace!("{}:{} WebSocketStream.poll_next", file!(), line!());
if self.ended {
return Poll::Ready(None);
}
match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
#[cfg(feature = "verbose-logging")]
trace!(
"{}:{} WebSocketStream.with_context poll_next -> read()",
file!(),
line!()
);
cvt(s.read())
})) {
Ok(v) => Poll::Ready(Some(Ok(v))),
Err(e) => {
self.ended = true;
if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
Poll::Ready(None)
} else {
Poll::Ready(Some(Err(e)))
}
}
}
}
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
if self.ready {
return Poll::Ready(Ok(()));
}
self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ready = true;
r
})
}
fn start_send(&mut self, item: Message) -> Result<(), WsError> {
match self.with_context(None, |s| s.write(item)) {
Ok(()) => {
self.ready = true;
Ok(())
}
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
self.ready = false;
Ok(())
}
Err(e) => {
self.ready = true;
debug!("websocket start_send error: {}", e);
Err(e)
}
}
}
fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
self.with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
.map(|r| {
self.ready = true;
match r {
Err(WsError::ConnectionClosed) => Ok(()),
other => other,
}
})
}
#[cfg(feature = "futures-03-sink")]
fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
self.ready = true;
let res = if self.closing {
self.with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
} else {
self.with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
};
match res {
Ok(()) => Poll::Ready(Ok(())),
Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
trace!("WouldBlock");
self.closing = true;
Poll::Pending
}
Err(err) => {
debug!("websocket close error: {}", err);
Poll::Ready(Err(err))
}
}
}
}
impl<S> Stream for WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<Message, WsError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.get_mut().poll_next(cx)
}
}
impl<S> FusedStream for WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn is_terminated(&self) -> bool {
self.ended
}
}
#[cfg(feature = "futures-03-sink")]
impl<S> futures_util::Sink<Message> for WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Error = WsError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.get_mut().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.get_mut().poll_close(cx)
}
}
#[cfg(not(feature = "futures-03-sink"))]
impl<S> bytes::private::SealedSender for WebSocketStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, WsError>> {
let me = self.get_mut();
ready!(me.poll_ready(cx))?;
let len = buf.len();
me.start_send(Message::binary(buf.to_owned()))?;
Poll::Ready(Ok(len))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
self.get_mut().poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
msg: &mut Option<Message>,
) -> Poll<Result<(), WsError>> {
let me = self.get_mut();
send_helper(me, msg, cx)
}
}
impl<S> WebSocketStream<S> {
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
Send {
ws: self,
msg: Some(msg),
}
.await
}
}
struct Send<W> {
ws: W,
msg: Option<Message>,
}
fn send_helper<S>(
ws: &mut WebSocketStream<S>,
msg: &mut Option<Message>,
cx: &mut Context<'_>,
) -> Poll<Result<(), WsError>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
if msg.is_some() {
ready!(ws.poll_ready(cx))?;
let msg = msg.take().expect("unreachable");
ws.start_send(msg)?;
}
ws.poll_flush(cx)
}
impl<S> std::future::Future for Send<&mut WebSocketStream<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), WsError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.get_mut();
send_helper(me.ws, &mut me.msg, cx)
}
}
impl<S> std::future::Future for Send<&Shared<S>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<(), WsError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.get_mut();
let mut ws = me.ws.lock();
send_helper(&mut ws, &mut me.msg, cx)
}
}
#[derive(Debug)]
pub struct WebSocketSender<S> {
shared: Arc<Shared<S>>,
}
impl<S> WebSocketSender<S> {
pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
Send {
ws: &*self.shared,
msg: Some(msg),
}
.await
}
pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
where
S: AsyncRead + AsyncWrite + Unpin,
{
self.send(Message::Close(msg)).await
}
pub fn is_pair_of(&self, other: &WebSocketReceiver<S>) -> bool {
Arc::ptr_eq(&self.shared, &other.shared)
}
}
#[cfg(feature = "futures-03-sink")]
impl<T> futures_util::Sink<Message> for WebSocketSender<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
type Error = WsError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.shared.lock().poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
self.shared.lock().start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.shared.lock().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.shared.lock().poll_close(cx)
}
}
#[cfg(not(feature = "futures-03-sink"))]
impl<S> bytes::private::SealedSender for WebSocketSender<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, WsError>> {
let me = self.get_mut();
let mut ws = me.shared.lock();
ready!(ws.poll_ready(cx))?;
let len = buf.len();
ws.start_send(Message::binary(buf.to_owned()))?;
Poll::Ready(Ok(len))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
self.shared.lock().poll_flush(cx)
}
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
msg: &mut Option<Message>,
) -> Poll<Result<(), WsError>> {
let me = self.get_mut();
let mut ws = me.shared.lock();
send_helper(&mut ws, msg, cx)
}
}
#[derive(Debug)]
pub struct WebSocketReceiver<S> {
shared: Arc<Shared<S>>,
}
impl<S> WebSocketReceiver<S> {
pub fn is_pair_of(&self, other: &WebSocketSender<S>) -> bool {
Arc::ptr_eq(&self.shared, &other.shared)
}
}
impl<S> Stream for WebSocketReceiver<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<Message, WsError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.shared.lock().poll_next(cx)
}
}
impl<S> FusedStream for WebSocketReceiver<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn is_terminated(&self) -> bool {
self.shared.lock().ended
}
}
#[derive(Debug)]
struct Shared<S>(Mutex<WebSocketStream<S>>);
impl<S> Shared<S> {
fn lock(&self) -> MutexGuard<'_, WebSocketStream<S>> {
self.0.lock().expect("lock shared stream")
}
fn into_inner(self) -> WebSocketStream<S> {
self.0.into_inner().expect("get shared stream")
}
}
#[cfg(any(
feature = "async-tls",
feature = "async-std-runtime",
feature = "smol-runtime",
feature = "tokio-runtime",
feature = "gio-runtime"
))]
#[inline]
pub(crate) fn domain(
request: &tungstenite::handshake::client::Request,
) -> Result<String, tungstenite::Error> {
request
.uri()
.host()
.map(|host| {
let host = if host.starts_with('[') {
&host[1..host.len() - 1]
} else {
host
};
host.to_owned()
})
.ok_or(tungstenite::Error::Url(
tungstenite::error::UrlError::NoHostName,
))
}
#[cfg(any(
feature = "async-std-runtime",
feature = "smol-runtime",
feature = "tokio-runtime",
feature = "gio-runtime"
))]
#[inline]
pub(crate) fn port(
request: &tungstenite::handshake::client::Request,
) -> Result<u16, tungstenite::Error> {
request
.uri()
.port_u16()
.or_else(|| match request.uri().scheme_str() {
Some("wss") => Some(443),
Some("ws") => Some(80),
_ => None,
})
.ok_or(tungstenite::Error::Url(
tungstenite::error::UrlError::UnsupportedUrlScheme,
))
}
#[cfg(test)]
mod tests {
#[cfg(any(
feature = "async-tls",
feature = "async-std-runtime",
feature = "smol-runtime",
feature = "tokio-runtime",
feature = "gio-runtime"
))]
#[test]
fn domain_strips_ipv6_brackets() {
use tungstenite::client::IntoClientRequest;
let request = "ws://[::1]:80".into_client_request().unwrap();
assert_eq!(crate::domain(&request).unwrap(), "::1");
}
#[cfg(feature = "handshake")]
#[test]
fn requests_cannot_contain_invalid_uris() {
use tungstenite::client::IntoClientRequest;
assert!("ws://[".into_client_request().is_err());
assert!("ws://[blabla/bla".into_client_request().is_err());
assert!("ws://[::1/bla".into_client_request().is_err());
}
}