use futures::FutureExt;
use futures::StreamExt;
use holochain_serialized_bytes::{SerializedBytes, SerializedBytesError};
use serde::{de::DeserializeOwned, Serialize};
use stream_cancel::Valve;
use websocket::PairShutdown;
use websocket::TxToWebsocket;
use crate::websocket;
use crate::WebsocketError;
use crate::WebsocketResult;
use std::convert::TryFrom;
use std::convert::TryInto;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct WebsocketSender {
tx_to_websocket: TxToWebsocket,
listener_shutdown: Valve,
__pair_shutdown: Arc<PairShutdown>,
}
#[derive(Debug)]
pub(crate) struct RegisterResponse {
respond: tokio::sync::oneshot::Sender<Option<SerializedBytes>>,
}
#[derive(Debug)]
pub(crate) struct StaleRequest(bool, TxToWebsocket, u64);
pub(crate) type TxStaleRequest = tokio::sync::oneshot::Sender<u64>;
pub(crate) type TxRequestsDebug = tokio::sync::oneshot::Sender<(Vec<u64>, u64)>;
impl RegisterResponse {
pub(crate) fn new(respond: tokio::sync::oneshot::Sender<Option<SerializedBytes>>) -> Self {
Self { respond }
}
pub(crate) fn respond(self, msg: Option<SerializedBytes>) -> WebsocketResult<()> {
tracing::trace!(sending_resp = ?msg);
self.respond
.send(msg)
.map_err(|_| WebsocketError::FailedToSendResp)
}
}
#[derive(Debug)]
pub(crate) enum OutgoingMessage {
Close,
Signal(SerializedBytes),
Request(SerializedBytes, RegisterResponse, TxStaleRequest),
Response(Option<SerializedBytes>, u64),
StaleRequest(u64),
Pong(Vec<u8>),
#[allow(dead_code)]
Debug(TxRequestsDebug),
}
impl WebsocketSender {
pub(crate) fn new(
tx_to_websocket: TxToWebsocket,
listener_shutdown: Valve,
pair_shutdown: Arc<PairShutdown>,
) -> Self {
Self {
tx_to_websocket,
listener_shutdown,
__pair_shutdown: pair_shutdown,
}
}
#[tracing::instrument(skip(self))]
pub async fn request_timeout<I, O>(
&mut self,
msg: I,
timeout: std::time::Duration,
) -> WebsocketResult<O>
where
I: std::fmt::Debug,
O: std::fmt::Debug,
WebsocketError: From<SerializedBytesError>,
I: Serialize,
O: DeserializeOwned,
{
match tokio::time::timeout(timeout, self.request(msg)).await {
Ok(r) => r,
Err(_) => Err(WebsocketError::RespTimeout),
}
}
#[tracing::instrument(skip(self))]
pub async fn request<I, O>(&mut self, msg: I) -> WebsocketResult<O>
where
I: std::fmt::Debug,
O: std::fmt::Debug,
WebsocketError: From<SerializedBytesError>,
I: Serialize,
O: DeserializeOwned,
{
use holochain_serialized_bytes as hsb;
tracing::trace!("Sending");
let (tx_resp, rx_resp) = tokio::sync::oneshot::channel();
let (tx_stale_resp, rx_stale_resp) = tokio::sync::oneshot::channel();
let mut rx_resp = self.listener_shutdown.wrap(rx_resp.into_stream());
let resp = RegisterResponse::new(tx_resp);
let msg = OutgoingMessage::Request(
hsb::UnsafeBytes::from(hsb::encode(&msg)?).into(),
resp,
tx_stale_resp,
);
self.tx_to_websocket
.send_timeout(msg, std::time::Duration::from_secs(30))
.await
.map_err(|_| WebsocketError::Shutdown)?;
tracing::trace!("Sent");
let id = rx_stale_resp.await.map_err(|_| WebsocketError::Shutdown)?;
let stale_request_guard = StaleRequest::new(self.tx_to_websocket.clone(), id);
let sb: SerializedBytes = rx_resp
.next()
.await
.ok_or(WebsocketError::Shutdown)?
.map_err(|_| WebsocketError::FailedToRecvResp)?
.ok_or(WebsocketError::FailedToRecvResp)?;
let resp: O = hsb::decode(&Vec::from(hsb::UnsafeBytes::from(sb)))?;
stale_request_guard.response_received();
Ok(resp)
}
#[tracing::instrument(skip(self))]
pub async fn signal<I, E>(&mut self, msg: I) -> WebsocketResult<()>
where
I: std::fmt::Debug,
WebsocketError: From<E>,
SerializedBytes: TryFrom<I, Error = E>,
{
tracing::trace!("Sending");
let msg = OutgoingMessage::Signal(msg.try_into()?);
self.tx_to_websocket
.send_timeout(msg, std::time::Duration::from_secs(30))
.await
.map_err(|_| WebsocketError::Shutdown)?;
tracing::trace!("Sent");
Ok(())
}
#[cfg(test)]
pub(crate) async fn debug(&mut self) -> WebsocketResult<(Vec<u64>, u64)> {
let (tx_resp, rx_resp) = tokio::sync::oneshot::channel();
let msg = OutgoingMessage::Debug(tx_resp);
self.tx_to_websocket
.send(msg)
.await
.map_err(|_| WebsocketError::Shutdown)?;
rx_resp.await.map_err(|_| WebsocketError::Shutdown)
}
}
impl StaleRequest {
pub fn new(send_response: TxToWebsocket, id: u64) -> Self {
Self(true, send_response, id)
}
pub fn response_received(mut self) {
self.0 = false;
}
}
impl Drop for StaleRequest {
fn drop(&mut self) {
if self.0 {
let tx = self.1.clone();
let id = self.2;
tokio::spawn(async move {
if let Err(e) = tx
.send_timeout(
OutgoingMessage::StaleRequest(id),
std::time::Duration::from_secs(30),
)
.await
{
tracing::warn!("Failed to remove stale response on drop {:?}", e);
}
tracing::trace!("Removed stale response on drop");
});
}
}
}