#![cfg(feature = "std")]
use futures_util::{AsyncRead, AsyncWrite, Future as _, future};
use core::{
cmp, mem,
pin::Pin,
task::{Context, Poll},
};
use std::io;
pub struct Config<'a, T> {
pub tcp_socket: T,
pub host: &'a str,
pub url: &'a str,
}
pub async fn websocket_client_handshake<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>(
config: Config<'_, T>,
) -> Result<Connection<T>, io::Error> {
let mut client = soketto::handshake::Client::new(config.tcp_socket, config.host, config.url);
let (sender, receiver) = match client.handshake().await {
Ok(soketto::handshake::ServerResponse::Accepted { .. }) => client.into_builder().finish(),
Ok(soketto::handshake::ServerResponse::Redirect { .. }) => {
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
"Redirections not implemented",
));
}
Ok(soketto::handshake::ServerResponse::Rejected { status_code }) => {
return Err(io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("Status code {status_code}"),
));
}
Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err)),
};
Ok(Connection {
sender: Write::Idle(sender),
receiver: Read::Idle(receiver, Vec::with_capacity(1024), 0),
})
}
pub struct Connection<T> {
sender: Write<T>,
receiver: Read<T>,
}
enum Read<T> {
Idle(soketto::connection::Receiver<T>, Vec<u8>, usize),
Error(soketto::connection::Error),
InProgress(future::BoxFuture<'static, Result<ReadOutcome<T>, soketto::connection::Error>>),
Poisoned,
}
struct ReadOutcome<T> {
socket: soketto::connection::Receiver<T>,
buffer: Vec<u8>,
}
enum Write<T> {
Idle(soketto::connection::Sender<T>),
Writing(
future::BoxFuture<
'static,
Result<soketto::connection::Sender<T>, soketto::connection::Error>,
>,
),
Flushing(
future::BoxFuture<
'static,
Result<soketto::connection::Sender<T>, soketto::connection::Error>,
>,
),
Closing(future::BoxFuture<'static, Result<(), soketto::connection::Error>>),
Closed,
Error(soketto::connection::Error),
Poisoned,
}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncRead for Connection<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
out_buf: &mut [u8],
) -> Poll<io::Result<usize>> {
assert_ne!(out_buf.len(), 0);
loop {
match mem::replace(&mut self.receiver, Read::Poisoned) {
Read::Idle(socket, pending, pending_pos) if pending_pos < pending.len() => {
let to_copy = cmp::min(out_buf.len(), pending.len() - pending_pos);
debug_assert_ne!(to_copy, 0);
out_buf[..to_copy].copy_from_slice(&pending[pending_pos..][..to_copy]);
self.receiver = Read::Idle(socket, pending, pending_pos + to_copy);
return Poll::Ready(Ok(to_copy));
}
Read::Idle(mut socket, mut buffer, _) => {
buffer.clear();
self.receiver = Read::InProgress(Box::pin(async move {
socket.receive_data(&mut buffer).await?;
Ok(ReadOutcome { socket, buffer })
}));
}
Read::InProgress(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.receiver = Read::InProgress(future);
return Poll::Pending;
}
Poll::Ready(Ok(ReadOutcome { socket, buffer })) => {
self.receiver = Read::Idle(socket, buffer, 0);
}
Poll::Ready(Err(err)) => {
self.receiver = Read::Error(err);
}
},
Read::Error(err) => {
let out_err = convert_err(&err);
self.receiver = Read::Error(err);
return Poll::Ready(Err(out_err));
}
Read::Poisoned => unreachable!(),
}
}
}
}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> AsyncWrite for Connection<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
match mem::replace(&mut self.sender, Write::Poisoned) {
Write::Idle(mut socket) => {
let len = buf.len();
let buf = buf.to_vec();
self.sender = Write::Writing(Box::pin(async move {
socket.send_binary_mut(buf).await?;
Ok(socket)
}));
return Poll::Ready(Ok(len));
}
Write::Flushing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Flushing(future);
return Poll::Pending;
}
Poll::Ready(Ok(socket)) => {
self.sender = Write::Idle(socket);
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Writing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Writing(future);
return Poll::Pending;
}
Poll::Ready(Ok(socket)) => {
self.sender = Write::Idle(socket);
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Closing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Closing(future);
return Poll::Pending;
}
Poll::Ready(Ok(())) => {
self.sender = Write::Closed;
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Closed => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"called poll_write after poll_close has succeeded",
)));
}
Write::Error(err) => {
let out_err = convert_err(&err);
self.sender = Write::Error(err);
return Poll::Ready(Err(out_err));
}
Write::Poisoned => unreachable!(),
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
loop {
match mem::replace(&mut self.sender, Write::Poisoned) {
Write::Idle(mut socket) => {
self.sender = Write::Flushing(Box::pin(async move {
socket.flush().await?;
Ok(socket)
}));
}
Write::Flushing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Flushing(future);
return Poll::Pending;
}
Poll::Ready(Ok(socket)) => {
self.sender = Write::Idle(socket);
return Poll::Ready(Ok(()));
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Writing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Writing(future);
return Poll::Pending;
}
Poll::Ready(Ok(socket)) => {
self.sender = Write::Idle(socket);
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Closing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Closing(future);
return Poll::Pending;
}
Poll::Ready(Ok(())) => {
self.sender = Write::Closed;
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Closed => return Poll::Ready(Ok(())),
Write::Error(err) => {
let out_err = convert_err(&err);
self.sender = Write::Error(err);
return Poll::Ready(Err(out_err));
}
Write::Poisoned => unreachable!(),
}
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
loop {
match mem::replace(&mut self.sender, Write::Poisoned) {
Write::Idle(mut socket) => {
self.sender = Write::Closing(Box::pin(async move {
socket.close().await?;
Ok(())
}));
}
Write::Flushing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Flushing(future);
return Poll::Pending;
}
Poll::Ready(Ok(socket)) => {
self.sender = Write::Idle(socket);
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Writing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Writing(future);
return Poll::Pending;
}
Poll::Ready(Ok(socket)) => {
self.sender = Write::Idle(socket);
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Closing(mut future) => match Pin::new(&mut future).poll(cx) {
Poll::Pending => {
self.sender = Write::Closing(future);
return Poll::Pending;
}
Poll::Ready(Ok(())) => {
self.sender = Write::Closed;
return Poll::Ready(Ok(()));
}
Poll::Ready(Err(err)) => {
self.sender = Write::Error(err);
}
},
Write::Closed => return Poll::Ready(Ok(())),
Write::Error(err) => {
let out_err = convert_err(&err);
self.sender = Write::Error(err);
return Poll::Ready(Err(out_err));
}
Write::Poisoned => unreachable!(),
}
}
}
}
fn convert_err(err: &soketto::connection::Error) -> io::Error {
match err {
soketto::connection::Error::Io(err) => io::Error::new(err.kind(), err.to_string()),
soketto::connection::Error::Codec(err) => {
io::Error::new(io::ErrorKind::InvalidData, err.to_string())
}
soketto::connection::Error::Extension(err) => {
io::Error::new(io::ErrorKind::InvalidData, err.to_string())
}
soketto::connection::Error::UnexpectedOpCode(err) => {
io::Error::new(io::ErrorKind::InvalidData, err.to_string())
}
soketto::connection::Error::Utf8(err) => {
io::Error::new(io::ErrorKind::InvalidData, err.to_string())
}
soketto::connection::Error::MessageTooLarge { .. } => {
io::Error::from(io::ErrorKind::InvalidData)
}
soketto::connection::Error::Closed => io::Error::from(io::ErrorKind::ConnectionAborted),
_ => io::Error::from(io::ErrorKind::Other),
}
}
#[cfg(test)]
mod tests {
use futures_util::{AsyncRead, AsyncWrite};
#[test]
fn is_send() {
fn req_send<T: Send>() {}
#[allow(unused)]
fn trait_bounds<T: AsyncRead + AsyncWrite + Send + Unpin + 'static>() {
req_send::<super::Connection<T>>()
}
}
}