#[cfg(unix)]
use std::os::unix::io::AsFd;
#[cfg(windows)]
use std::os::windows::io::AsSocket;
use std::{collections::VecDeque, io, mem, net::TcpStream as SyncTcpStream, sync::Arc};
use log::*;
#[cfg(feature = "native-tls")]
use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector, TlsStream};
#[cfg(feature = "rustls")]
use rustls::{
ClientConfig, ClientConnection, KeyLogFile, OtherError, RootCertStore, pki_types::ServerName,
};
use serde::de::DeserializeOwned;
use smol::{Async, net::TcpStream as AsyncTcpStream};
use tungstenite::HandshakeError as WsHandshakeError;
use tungstenite::client::IntoClientRequest;
#[cfg(any(feature = "rustls", feature = "native-tls"))]
use tungstenite::error::TlsError;
use tungstenite::error::UrlError;
use tungstenite::handshake::client::ClientHandshake;
use tungstenite::stream::{MaybeTlsStream, Mode};
use tungstenite::{Message, WebSocket};
use crate::error::{Error, ProtocolError};
use crate::protocol::{ClientMessage, ServerMessage};
const DEFAULT_PORT: u16 = 38281;
pub(crate) struct Socket<S: DeserializeOwned + 'static> {
async_stream: Arc<Async<SyncTcpStream>>,
inner: WebSocket<MaybeTlsStream<SyncTcpStream>>,
messages: VecDeque<Result<ServerMessage<S>, Error>>,
}
impl<S: DeserializeOwned + 'static> Socket<S> {
pub(crate) async fn connect(
request: impl IntoClientRequest,
#[cfg(feature = "rustls")] rustls_config: Option<Arc<ClientConfig>>,
) -> Result<Self, Error> {
let request = request.into_client_request()?;
let domain = request
.uri()
.host()
.map(|h| h.to_string())
.ok_or(tungstenite::Error::Url(UrlError::NoHostName))?;
let port = request.uri().port_u16().unwrap_or(DEFAULT_PORT);
#[cfg(any(feature = "rustls", feature = "native-tls"))]
let (stream, mut async_stream) = Self::connect_tcp(&domain, port).await?;
#[cfg(not(any(feature = "rustls", feature = "native-tls")))]
let (stream, async_stream) = Self::connect_tcp(&domain, port).await?;
async_stream.writable().await?;
let maybe_tls_stream = match tungstenite::client::uri_mode(request.uri())? {
Mode::Plain => {
debug!("Upgrading to WebSocket...");
MaybeTlsStream::Plain(stream)
}
#[cfg(any(feature = "rustls", feature = "native-tls"))]
Mode::Tls => match Self::try_tls(
async_stream.as_ref(),
stream,
&domain,
#[cfg(feature = "rustls")]
rustls_config,
)
.await
{
Ok(stream) => {
debug!("Upgrading to WebSocket...");
stream
}
Err(Error::WebSocket(tungstenite::Error::Tls(err))) => {
debug!(
"Upgrading to TLS failed: {err}\n\
Attempting plain TCP connection..."
);
let (stream, async_stream_) = Self::connect_tcp(&domain, port).await?;
async_stream = async_stream_;
debug!("Upgrading to WebSocket...");
MaybeTlsStream::Plain(stream)
}
Err(err) => return Err(err),
},
#[cfg(not(any(feature = "rustls", feature = "native-tls")))]
Mode::Tls => {
debug!("No TLS features are enabled, upgrading to unencrypted WebSocket...");
MaybeTlsStream::Plain(stream)
}
};
let mut handshake = ClientHandshake::start(maybe_tls_stream, request, None)?;
loop {
match handshake.handshake() {
Ok((inner, response)) => {
debug!(
"WebSocket response: {}\n{:?}",
response.status(),
response.headers()
);
return Ok(Socket {
async_stream,
inner,
messages: Default::default(),
});
}
Err(WsHandshakeError::Interrupted(new_handshake)) => {
handshake = new_handshake;
async_stream.readable().await?;
}
Err(WsHandshakeError::Failure(err)) => return Err(err.into()),
}
}
}
async fn connect_tcp(
domain: &str,
port: u16,
) -> Result<(SyncTcpStream, Arc<Async<SyncTcpStream>>), Error> {
debug!("Establishing TCP connection to {domain}:{port}...");
let stream = match AsyncTcpStream::connect(format!("{domain}:{port}")).await {
Ok(stream) => stream,
Err(err) => {
return Err(if let Some(os_err) = err.raw_os_error() {
tungstenite::Error::Io(io::Error::from_raw_os_error(os_err)).into()
} else {
err.into()
});
}
};
let async_stream = Arc::<Async<SyncTcpStream>>::from(stream.clone());
#[cfg(unix)]
let stream = SyncTcpStream::from(stream.as_fd().try_clone_to_owned()?);
#[cfg(windows)]
let stream = SyncTcpStream::from(stream.as_socket().try_clone_to_owned()?);
Ok((stream, async_stream))
}
#[cfg(any(feature = "rustls", feature = "native-tls"))]
async fn try_tls(
async_stream: &Async<SyncTcpStream>,
#[cfg(feature = "rustls")] mut stream: SyncTcpStream,
#[cfg(not(feature = "rustls"))] stream: SyncTcpStream,
domain: &str,
#[cfg(feature = "rustls")] rustls_config: Option<Arc<ClientConfig>>,
) -> Result<MaybeTlsStream<SyncTcpStream>, Error> {
async_stream.writable().await?;
#[cfg(feature = "rustls")]
{
debug!("Upgrading to TLS using rustls...");
match Self::try_rust_tls(async_stream, &mut stream, domain, rustls_config).await {
Ok(conn) => {
return Ok(MaybeTlsStream::Rustls(rustls::StreamOwned::new(
conn, stream,
)));
}
#[cfg(feature = "native-tls")]
Err(Error::WebSocket(tungstenite::Error::Tls(err))) => {
debug!(
"Upgrading to TLS using rustls failed: {err}\n\
Falling back to native-tls..."
);
}
Err(err) => return Err(err),
}
}
#[cfg(feature = "native-tls")]
{
debug!("Upgrading to TLS using native-tls...");
Self::try_native_tls(async_stream, stream, domain)
.await
.map(MaybeTlsStream::NativeTls)
}
}
#[cfg(feature = "rustls")]
async fn try_rust_tls(
async_stream: &Async<SyncTcpStream>,
stream: &mut SyncTcpStream,
domain: &str,
config: Option<Arc<ClientConfig>>,
) -> Result<ClientConnection, Error> {
let mut conn = ClientConnection::new(
config.unwrap_or_else(|| {
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
config.key_log = Arc::new(KeyLogFile::new());
Arc::new(config)
}),
ServerName::try_from(domain.to_string())
.map_err(|_| tungstenite::Error::Tls(TlsError::InvalidDnsName))?,
)
.map_err(|e| tungstenite::Error::Tls(TlsError::Rustls(Box::new(e))))?;
while conn.is_handshaking() {
match conn.complete_io(stream) {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
async_stream.readable().await?;
}
Err(e) => {
return Err(tungstenite::Error::Tls(TlsError::Rustls(Box::new(
rustls::Error::Other(OtherError(Arc::new(e))),
)))
.into());
}
}
}
Ok(conn)
}
#[cfg(feature = "native-tls")]
async fn try_native_tls(
async_stream: &Async<SyncTcpStream>,
stream: SyncTcpStream,
domain: &str,
) -> Result<TlsStream<SyncTcpStream>, Error> {
info!("Using nativels!");
let connector = TlsConnector::new()
.map_err(|e| tungstenite::Error::Tls(TlsError::Native(Box::new(e))))?;
let mut handshake = connector.connect(domain, stream);
loop {
match handshake {
Ok(socket) => {
debug!("Upgrading to WebSocket...");
return Ok(socket);
}
Err(TlsHandshakeError::Failure(err)) => {
return Err(tungstenite::Error::Tls(TlsError::Native(Box::new(err))).into());
}
Err(TlsHandshakeError::WouldBlock(new_handshake)) => {
async_stream.readable().await?;
handshake = new_handshake.handshake();
}
}
}
}
pub(crate) fn recv_all(
&mut self,
) -> impl IntoIterator<Item = Result<ServerMessage<S>, Error>> + use<S> {
self.read_inner();
mem::take(&mut self.messages)
}
pub(crate) fn try_recv(&mut self) -> Option<Result<ServerMessage<S>, Error>> {
self.read_inner();
self.messages.pop_front()
}
pub(crate) async fn recv_async(&mut self) -> Result<ServerMessage<S>, Error> {
loop {
match self.try_recv() {
Some(result) => return result,
None => self.async_stream.readable().await?,
}
}
}
fn read_inner(&mut self) {
while self.inner.can_read() {
match self.inner.read() {
Ok(Message::Text(bytes)) => {
debug!("--> {bytes}");
match serde_json::from_str::<Vec<ServerMessage<S>>>(&bytes) {
Ok(messages) => self.messages.extend(messages.into_iter().map(Ok)),
Err(err) => self.messages.push_back(Err(ProtocolError::Deserialize {
json: bytes.to_string(),
error: err,
}
.into())),
}
}
Ok(Message::Binary(bytes)) => {
self.messages
.push_back(Err(ProtocolError::BinaryMessage(bytes.to_vec()).into()));
}
Ok(Message::Close(_)) => {
debug!("--> [closed]");
let err = Error::WebSocket(tungstenite::Error::ConnectionClosed);
self.messages.push_back(Err(err));
break;
}
Ok(_) => {}
Err(tungstenite::Error::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
break;
}
Err(err) => {
debug!("--> [error] {err:?}");
let err = Error::from(err);
let fatal = err.is_fatal();
self.messages.push_back(Err(err));
if fatal {
break;
}
}
}
}
}
pub(crate) fn send(&mut self, message: ClientMessage) -> Result<(), Error> {
self.inner
.send(Message::Text(match serde_json::to_string(&[&message]) {
Ok(text) => {
debug!("<-- {text}");
text.into()
}
Err(error) => return Err(Error::Serialize(error)),
}))?;
self.inner.flush()?;
Ok(())
}
}