#[cfg(unix)]
use std::os::unix::io::AsFd;
#[cfg(windows)]
use std::os::windows::io::AsSocket;
use std::{collections::VecDeque, io, net::TcpStream as SyncTcpStream, sync::Arc};
use log::*;
use native_tls::{HandshakeError as TlsHandshakeError, TlsConnector, TlsStream};
use serde::de::DeserializeOwned;
use smol::{Async, net::TcpStream as AsyncTcpStream};
use tungstenite::HandshakeError as WsHandshakeError;
use tungstenite::client::IntoClientRequest;
use tungstenite::error::{TlsError, UrlError};
use tungstenite::handshake::client::ClientHandshake;
use tungstenite::stream::{MaybeTlsStream, Mode};
use tungstenite::{Message, WebSocket};
use crate::error::{Error, ProtocolError};
use crate::protocol::{ClientMessage, ServerMessage};
const DEFAULT_PORT: u16 = 38281;
pub(crate) struct Socket<S: DeserializeOwned> {
async_stream: Arc<Async<SyncTcpStream>>,
inner: WebSocket<MaybeTlsStream<SyncTcpStream>>,
messages: VecDeque<ServerMessage<S>>,
}
impl<S: DeserializeOwned> Socket<S> {
pub(crate) async fn connect(request: impl IntoClientRequest) -> Result<Self, Error> {
let request = request.into_client_request()?;
let domain = request
.uri()
.host()
.map(|h| h.to_string())
.ok_or(tungstenite::Error::Url(UrlError::NoHostName))?;
let port = request.uri().port_u16().unwrap_or(DEFAULT_PORT);
let (stream, mut async_stream) = Self::connect_tcp(&domain, port).await?;
async_stream.writable().await?;
let maybe_tls_stream = match tungstenite::client::uri_mode(request.uri())? {
Mode::Plain => {
debug!("Upgrading to WebSocket...");
MaybeTlsStream::Plain(stream)
}
Mode::Tls => {
debug!("Upgrading to TLS...");
match Self::try_tls(async_stream.as_ref(), stream, domain.as_str()).await {
Ok(stream) => {
debug!("Upgrading to WebSocket...");
MaybeTlsStream::NativeTls(stream)
}
Err(Error::WebSocket(tungstenite::Error::Tls(err))) => {
debug!(
"Upgrading to TLS failed ({err}), attempting plain TCP connection..."
);
let (stream, async_stream_) = Self::connect_tcp(&domain, port).await?;
async_stream = async_stream_;
debug!("Upgrading to WebSocket...");
MaybeTlsStream::Plain(stream)
}
Err(err) => return Err(err),
}
}
};
let mut handshake = ClientHandshake::start(maybe_tls_stream, request, None)?;
loop {
match handshake.handshake() {
Ok((inner, response)) => {
debug!(
"WebSocket response: {}\n{:?}",
response.status(),
response.headers()
);
return Ok(Socket {
async_stream,
inner,
messages: Default::default(),
});
}
Err(WsHandshakeError::Interrupted(new_handshake)) => {
handshake = new_handshake;
async_stream.readable().await?;
}
Err(WsHandshakeError::Failure(err)) => return Err(err.into()),
}
}
}
async fn connect_tcp(
domain: impl AsRef<str>,
port: u16,
) -> Result<(SyncTcpStream, Arc<Async<SyncTcpStream>>), Error> {
let domain = domain.as_ref();
debug!("Establishing TCP connection to {domain}:{port}...");
let stream = match AsyncTcpStream::connect(format!("{domain}:{port}")).await {
Ok(stream) => stream,
Err(err) => {
return Err(if let Some(os_err) = err.raw_os_error() {
tungstenite::Error::Io(io::Error::from_raw_os_error(os_err)).into()
} else {
err.into()
});
}
};
let async_stream = Arc::<Async<SyncTcpStream>>::from(stream.clone());
#[cfg(unix)]
let stream = SyncTcpStream::from(stream.as_fd().try_clone_to_owned()?);
#[cfg(windows)]
let stream = SyncTcpStream::from(stream.as_socket().try_clone_to_owned()?);
Ok((stream, async_stream))
}
async fn try_tls(
async_stream: &Async<SyncTcpStream>,
stream: SyncTcpStream,
domain: impl AsRef<str>,
) -> Result<TlsStream<SyncTcpStream>, Error> {
let connector = TlsConnector::new()
.map_err(|e| tungstenite::Error::Tls(TlsError::Native(Box::new(e))))?;
let mut handshake = connector.connect(domain.as_ref(), stream);
loop {
match handshake {
Ok(socket) => {
debug!("Upgrading to WebSocket...");
return Ok(socket);
}
Err(TlsHandshakeError::Failure(err)) => {
return Err(tungstenite::Error::Tls(TlsError::Native(Box::new(err))).into());
}
Err(TlsHandshakeError::WouldBlock(new_handshake)) => {
async_stream.readable().await?;
handshake = new_handshake.handshake();
}
}
}
}
pub(crate) fn try_recv(&mut self) -> Result<Option<ServerMessage<S>>, Error> {
if let Some(message) = self.messages.pop_front() {
return Ok(Some(message));
}
loop {
match self.inner.read() {
Ok(Message::Text(bytes)) => {
debug!("--> {bytes}");
self.messages = serde_json::from_str(&bytes).map_err(|error| {
ProtocolError::Deserialize {
json: bytes.to_string(),
error,
}
})?;
return self.try_recv();
}
Ok(Message::Binary(bytes)) => {
return Err(ProtocolError::BinaryMessage(bytes.to_vec()).into());
}
Ok(_) => {}
Err(tungstenite::Error::Io(err)) if err.kind() == io::ErrorKind::WouldBlock => {
self.inner.flush()?;
return Ok(None);
}
Err(err) => return Err(err.into()),
}
}
}
pub(crate) async fn recv_async(&mut self) -> Result<ServerMessage<S>, Error> {
loop {
match self.try_recv()? {
Some(message) => return Ok(message),
None => self.async_stream.readable().await?,
}
}
}
pub(crate) fn send(&mut self, message: ClientMessage) -> Result<(), Error> {
self.inner
.send(Message::Text(match serde_json::to_string(&[&message]) {
Ok(text) => {
debug!("<-- {text}");
text.into()
}
Err(error) => return Err(Error::Serialize(error)),
}))?;
self.inner.flush()?;
Ok(())
}
}