use std::time::Duration;
use futures::stream::{SplitSink, SplitStream};
use futures::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::{connect_async, MaybeTlsStream};
use url::Url;
use crate::auth;
use crate::error::{Error, ErrorCode};
use crate::options::ClientOptions;
use crate::protocol::{ProtocolMessage, PROTOCOL_VERSION};
use crate::rest::Format;
use crate::Result;
type WsStream = tokio_tungstenite::WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
#[derive(Debug)]
pub(crate) enum TransportEvent {
Message(ProtocolMessage),
Disconnected(Option<Error>),
}
pub(crate) struct Transport {
write_tx: mpsc::Sender<ProtocolMessage>,
read_handle: JoinHandle<()>,
write_handle: JoinHandle<()>,
}
impl Transport {
pub async fn connect(
opts: &ClientOptions,
resume_key: Option<&str>,
event_tx: mpsc::Sender<TransportEvent>,
) -> Result<Self> {
let url = build_ws_url(opts, resume_key)?;
let format = opts.format;
let (ws_stream, _response) = connect_async(url.as_str())
.await
.map_err(|e| Error::with_cause(ErrorCode::ConnectionFailed, e, "WebSocket connect failed"))?;
let (sink, stream) = ws_stream.split();
let (write_tx, write_rx) = mpsc::channel::<ProtocolMessage>(64);
let read_handle = {
let event_tx = event_tx.clone();
let fmt = format;
tokio::spawn(async move {
read_loop(stream, event_tx, fmt).await;
})
};
let write_handle = {
let fmt = format;
tokio::spawn(async move {
write_loop(sink, write_rx, fmt).await;
})
};
Ok(Self {
write_tx,
read_handle,
write_handle,
})
}
pub async fn send(&self, msg: ProtocolMessage) -> Result<()> {
self.write_tx
.send(msg)
.await
.map_err(|_| Error::new(ErrorCode::ConnectionFailed, "Transport write channel closed"))
}
pub fn close(self) {
drop(self.write_tx);
let read_handle = self.read_handle;
let write_handle = self.write_handle;
tokio::spawn(async move {
let timeout = tokio::time::timeout(Duration::from_secs(5), async {
let _ = write_handle.await;
});
if timeout.await.is_err() {
}
read_handle.abort();
});
}
pub fn abort(self) {
drop(self.write_tx);
self.read_handle.abort();
self.write_handle.abort();
}
}
async fn read_loop(
mut stream: SplitStream<WsStream>,
event_tx: mpsc::Sender<TransportEvent>,
format: Format,
) {
loop {
match stream.next().await {
Some(Ok(msg)) => {
match deserialize_frame(msg, format) {
Some(Ok(pm)) => {
if event_tx.send(TransportEvent::Message(pm)).await.is_err() {
return;
}
}
Some(Err(e)) => {
let _ = event_tx
.send(TransportEvent::Disconnected(Some(e)))
.await;
return;
}
None => {
}
}
}
Some(Err(e)) => {
let err = Error::with_cause(
ErrorCode::Disconnected,
e,
"WebSocket read error",
);
let _ = event_tx.send(TransportEvent::Disconnected(Some(err))).await;
return;
}
None => {
let _ = event_tx.send(TransportEvent::Disconnected(None)).await;
return;
}
}
}
}
async fn write_loop(
mut sink: SplitSink<WsStream, tungstenite::Message>,
mut write_rx: mpsc::Receiver<ProtocolMessage>,
format: Format,
) {
while let Some(pm) = write_rx.recv().await {
let frame = match serialize_frame(&pm, format) {
Ok(f) => f,
Err(_) => continue, };
if sink.send(frame).await.is_err() {
return;
}
}
let _ = sink.send(tungstenite::Message::Close(None)).await;
}
fn serialize_frame(
pm: &ProtocolMessage,
format: Format,
) -> Result<tungstenite::Message> {
match format {
Format::JSON => {
let json = serde_json::to_string(pm)?;
Ok(tungstenite::Message::Text(json.into()))
}
Format::MessagePack => {
let bytes = rmp_serde::to_vec_named(pm)?;
Ok(tungstenite::Message::Binary(bytes.into()))
}
}
}
fn deserialize_frame(
msg: tungstenite::Message,
_format: Format,
) -> Option<Result<ProtocolMessage>> {
match msg {
tungstenite::Message::Text(text) => {
Some(serde_json::from_str::<ProtocolMessage>(&text).map_err(Into::into))
}
tungstenite::Message::Binary(bytes) => {
Some(
rmp_serde::from_read::<_, ProtocolMessage>(&bytes[..])
.map_err(Into::into),
)
}
tungstenite::Message::Ping(_) | tungstenite::Message::Pong(_) => None,
tungstenite::Message::Close(_) => None,
tungstenite::Message::Frame(_) => None,
}
}
fn build_ws_url(opts: &ClientOptions, resume_key: Option<&str>) -> Result<Url> {
let scheme = if opts.tls { "wss" } else { "ws" };
let port = if opts.tls { opts.tls_port } else { opts.port };
let host = match &opts.environment {
Some(env) => format!("{}-realtime.ably.io", env),
None => opts.realtime_host.clone(),
};
let base = format!("{}://{}:{}/", scheme, host, port);
let mut url = Url::parse(&base)?;
{
let mut q = url.query_pairs_mut();
q.append_pair("v", &PROTOCOL_VERSION.to_string());
let format_str = match opts.format {
Format::JSON => "json",
Format::MessagePack => "msgpack",
};
q.append_pair("format", format_str);
q.append_pair("heartbeats", "true");
q.append_pair("echo", "true");
if let auth::Credential::Key(ref key) = opts.credential {
q.append_pair("key", &format!("{}:{}", key.name, key.value));
}
if let Some(client_id) = &opts.client_id {
q.append_pair("clientId", client_id);
}
if let Some(key) = resume_key {
q.append_pair("resume", key);
}
}
Ok(url)
}