use rustls::ClientConfig;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::{Connector, WebSocketStream, tungstenite};
use tungstenite::client::IntoClientRequest;
use websock_proto::{ConnectOptions, Error, Message, Result};
#[derive(Debug, Clone, Copy)]
pub struct ConnectionInfo {
pub peer: std::net::SocketAddr,
pub local: std::net::SocketAddr,
pub is_tls: bool,
}
pub async fn connect(url: &str, opts: ConnectOptions) -> Result<Connection> {
connect_with_tls(url, opts, None).await
}
pub async fn connect_with_tls(
url: &str,
opts: ConnectOptions,
tls: Option<Arc<ClientConfig>>,
) -> Result<Connection> {
let mut req = url
.into_client_request()
.map_err(|e| Error::InvalidUrl(e.to_string()))?;
{
let headers = req.headers_mut();
for (k, v) in opts.headers {
let name = tungstenite::http::header::HeaderName::from_bytes(k.as_bytes())
.map_err(|e| Error::Protocol(format!("invalid header name: {e}")))?;
let value = tungstenite::http::header::HeaderValue::from_str(&v)
.map_err(|e| Error::Protocol(format!("invalid header value: {e}")))?;
headers.append(name, value);
}
if !opts.protocols.is_empty() {
let joined = opts.protocols.join(",");
let value = tungstenite::http::header::HeaderValue::from_str(&joined)
.map_err(|e| Error::Protocol(format!("invalid protocol value: {e}")))?;
headers.insert(tungstenite::http::header::SEC_WEBSOCKET_PROTOCOL, value);
}
}
let connector = tls.map(Connector::Rustls);
let (ws, _resp) = tokio_tungstenite::connect_async_tls_with_config(req, None, false, connector)
.await
.map_err(map_tungstenite_err)?;
let info = ConnectionInfo {
peer: ws
.get_ref()
.get_ref()
.peer_addr()
.map_err(|e| Error::Io(e.to_string()))?,
local: ws
.get_ref()
.get_ref()
.local_addr()
.map_err(|e| Error::Io(e.to_string()))?,
is_tls: matches!(ws.get_ref(), tokio_tungstenite::MaybeTlsStream::Rustls(_)),
};
Ok(Connection { ws, info })
}
pub struct Connection<S = tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>> {
pub(crate) ws: WebSocketStream<S>,
pub(crate) info: ConnectionInfo,
}
impl<S> Connection<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub async fn send(&mut self, msg: Message) -> Result<()> {
use futures_util::SinkExt;
let tmsg = match msg {
Message::Text(s) => tungstenite::Message::Text(s.into()),
Message::Binary(b) => tungstenite::Message::Binary(b),
};
self.ws.send(tmsg).await.map_err(map_tungstenite_err)?;
Ok(())
}
pub async fn recv(&mut self) -> Result<Message> {
use futures_util::{SinkExt, StreamExt};
loop {
let item = self.ws.next().await.ok_or(Error::Closed)?;
let msg = item.map_err(map_tungstenite_err)?;
match msg {
tungstenite::Message::Ping(p) => {
self.ws
.send(tungstenite::Message::Pong(p))
.await
.map_err(map_tungstenite_err)?;
continue;
}
tungstenite::Message::Pong(_) => continue,
tungstenite::Message::Text(s) => return Ok(Message::Text(s.to_string())),
tungstenite::Message::Binary(b) => return Ok(Message::Binary(b)),
tungstenite::Message::Close(_) => {
let _ = self.ws.close(None).await;
return Err(Error::Closed);
}
_ => return Err(Error::Protocol("unsupported ws message".into())),
}
}
}
pub async fn close(&mut self) -> Result<()> {
self.ws.close(None).await.map_err(map_tungstenite_err)?;
Ok(())
}
pub fn get_ref(&self) -> &S {
self.ws.get_ref()
}
pub fn get_mut(&mut self) -> &mut S {
self.ws.get_mut()
}
}
impl<S> Connection<S> {
pub fn peer_addr(&self) -> SocketAddr {
self.info.peer
}
pub fn local_addr(&self) -> SocketAddr {
self.info.local
}
pub fn is_tls(&self) -> bool {
self.info.is_tls
}
pub fn info(&self) -> ConnectionInfo {
self.info
}
}
pub(crate) fn map_tungstenite_err(e: tungstenite::Error) -> Error {
use tungstenite::Error as E;
match e {
E::ConnectionClosed | E::AlreadyClosed => Error::Closed,
E::Io(io) => Error::Io(io.to_string()),
E::Tls(tls) => Error::Tls(tls.to_string()),
E::Url(url) => Error::InvalidUrl(url.to_string()),
E::Protocol(err) => Error::Protocol(err.to_string()),
E::Utf8(err) => Error::Protocol(err),
E::Capacity(err) => Error::Protocol(err.to_string()),
E::HttpFormat(err) => Error::Protocol(err.to_string()),
other => Error::Other(other.to_string()),
}
}