use crate::{
network_client::{NetworkRetry, WebSocketRequest},
transfer::CancelReason,
Error, NetworkChangeEvent, Result, WireEncodeDecode,
};
use futures::{
stream::{Map, SplitStream},
Future, FutureExt, StreamExt,
};
use prost::bytes::Bytes;
use sos_core::{AccountId, Origin};
use sos_signer::ed25519::BoxedEd25519Signer;
use std::pin::Pin;
use tokio::{net::TcpStream, sync::watch, time::Duration};
use tokio_tungstenite::{
connect_async,
tungstenite::{
self,
protocol::{
frame::{coding::CloseCode, Utf8Bytes},
CloseFrame, Message,
},
},
MaybeTlsStream, WebSocketStream,
};
use super::{bearer_prefix, encode_device_signature};
#[derive(Clone)]
pub struct ListenOptions {
pub(crate) connection_id: String,
pub(crate) retry: NetworkRetry,
}
impl ListenOptions {
pub fn new(connection_id: String) -> Result<Self> {
Ok(Self {
connection_id,
retry: NetworkRetry::new(16, 1000),
})
}
pub fn new_retry(
connection_id: String,
retry: NetworkRetry,
) -> Result<Self> {
Ok(Self {
connection_id,
retry,
})
}
}
async fn request_bearer(
request: &mut WebSocketRequest,
device: &BoxedEd25519Signer,
connection_id: &str,
) -> Result<String> {
let sign_url = request.uri.path();
let device_signature =
encode_device_signature(device.sign(sign_url.as_bytes()).await?)
.await?;
let auth = bearer_prefix(&device_signature);
request
.uri
.query_pairs_mut()
.append_pair("connection_id", connection_id);
Ok(auth)
}
pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
pub async fn connect(
account_id: AccountId,
origin: Origin,
device: BoxedEd25519Signer,
connection_id: String,
) -> Result<WsStream> {
let mut request = WebSocketRequest::new(
account_id,
origin.url(),
"api/v1/sync/changes",
)?;
let bearer =
request_bearer(&mut request, &device, &connection_id).await?;
request.set_bearer(bearer);
tracing::debug!(uri = %request.uri, "ws_client::connect");
let (ws_stream, _) = connect_async(request).await?;
Ok(ws_stream)
}
pub fn changes(
stream: WsStream,
) -> Map<
SplitStream<WsStream>,
impl FnMut(
std::result::Result<Message, tungstenite::Error>,
) -> Result<
Pin<Box<dyn Future<Output = Result<NetworkChangeEvent>> + Send>>,
>,
> {
let (_, read) = stream.split();
read.map(
move |message| -> Result<
Pin<Box<dyn Future<Output = Result<NetworkChangeEvent>> + Send>>,
> {
match message {
Ok(message) => Ok(Box::pin(async move {
Ok(decode_notification(message).await?)
})),
Err(e) => Ok(Box::pin(async move { Err(e.into()) })),
}
},
)
}
async fn decode_notification(message: Message) -> Result<NetworkChangeEvent> {
match message {
Message::Binary(buffer) => {
let buf: Bytes = buffer.into();
let notification = NetworkChangeEvent::decode(buf).await?;
Ok(notification)
}
_ => Err(Error::NotBinaryWebsocketMessageType),
}
}
#[derive(Clone)]
pub struct WebSocketHandle {
notify: watch::Sender<()>,
cancel_retry: watch::Sender<CancelReason>,
}
impl WebSocketHandle {
pub async fn close(&self) {
tracing::debug!(
receivers = %self.notify.receiver_count(),
"ws_client::close");
if let Err(error) = self.notify.send(()) {
tracing::error!(error = ?error);
}
if let Err(error) = self.cancel_retry.send(CancelReason::Closed) {
tracing::error!(error = ?error);
}
}
}
pub struct WebSocketChangeListener {
account_id: AccountId,
origin: Origin,
device: BoxedEd25519Signer,
options: ListenOptions,
shutdown: watch::Sender<()>,
cancel_retry: watch::Sender<CancelReason>,
}
impl WebSocketChangeListener {
pub fn new(
account_id: AccountId,
origin: Origin,
device: BoxedEd25519Signer,
options: ListenOptions,
) -> Self {
let (shutdown, _) = watch::channel(());
let (cancel_retry, _) = watch::channel(Default::default());
Self {
account_id,
origin,
device,
options,
shutdown,
cancel_retry,
}
}
pub fn spawn<F>(
self,
handler: impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static,
) -> WebSocketHandle
where
F: Future<Output = ()> + Send + 'static,
{
let notify = self.shutdown.clone();
let cancel_retry = self.cancel_retry.clone();
tokio::task::spawn(async move {
let _ = self.connect_loop(&handler).await;
});
WebSocketHandle {
notify,
cancel_retry,
}
}
async fn listen<F>(
&self,
mut stream: WsStream,
handler: &(impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static),
) -> Result<()>
where
F: Future<Output = ()> + Send + 'static,
{
tracing::debug!("ws_client::connected");
let mut shutdown_rx = self.shutdown.subscribe();
loop {
futures::select! {
_ = shutdown_rx.changed().fuse() => {
tracing::debug!("ws_client::shutting_down");
if let Err(error) = stream.close(Some(CloseFrame {
code: CloseCode::Normal,
reason: Utf8Bytes::from_static("closed"),
})).await {
tracing::warn!(
error = ?error,
"ws_client::websocket::close_error",
);
}
tracing::debug!("ws_client::shutdown");
return Ok(());
}
message = stream.next().fuse() => {
if let Some(message) = message {
match message {
Ok(message) => {
let notification = decode_notification(
message).await?;
let future = handler(notification);
future.await;
}
Err(e) => {
tracing::error!(error = ?e);
break;
}
}
} else {
break;
}
}
}
}
tracing::debug!("ws_client::disconnected");
Ok(())
}
async fn stream(&self) -> Result<WsStream> {
connect(
self.account_id.clone(),
self.origin.clone(),
self.device.clone(),
self.options.connection_id.clone(),
)
.await
}
async fn connect_loop<F>(
&self,
handler: &(impl Fn(NetworkChangeEvent) -> F + Send + Sync + 'static),
) -> Result<()>
where
F: Future<Output = ()> + Send + 'static,
{
let mut cancel_retry_rx = self.cancel_retry.subscribe();
loop {
tokio::select! {
_ = cancel_retry_rx.changed() => {
tracing::debug!("ws_client::retry_canceled");
return Ok(());
}
result = self.stream() => {
match result {
Ok(stream) => {
self.options.retry.reset();
if let Err(e) = self.listen(stream, handler).await {
tracing::error!(
error = ?e,
"ws_client::listen_error");
}
}
Err(e) => {
tracing::error!(
error = ?e,
"ws_client::connect_error");
let retries = self.options.retry.retries();
if self.options.retry.is_exhausted(retries) {
tracing::debug!(
maximum_retries = %self.options.retry.maximum_retries,
"wsclient::retry_attempts_exhausted");
return Ok(());
}
}
}
}
}
let retries = self.options.retry.retries();
let delay = self.options.retry.delay(retries)?;
let maximum = self.options.retry.maximum();
tracing::debug!(
retries = %retries,
delay = %delay,
maximum_retries = %maximum,
"ws_client::retry");
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(delay)) => {
self.options.retry.increment();
}
_ = cancel_retry_rx.changed() => {
tracing::debug!("ws_client::retry_canceled");
return Ok(());
}
}
}
}
}