use std::sync::Arc;
use crate::{Error, RequestId, Result, WsOut};
use futures::{SinkExt as _, StreamExt as _};
use serde::Serialize;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_tungstenite::tungstenite::{self, Message};
type InternalRequest = (RequestId, String);
type RequestSender = mpsc::UnboundedSender<InternalRequest>;
type RequestReceiver = mpsc::UnboundedReceiver<InternalRequest>;
type InterruptSignal = oneshot::Receiver<Option<Error>>;
pub type Interrupter = oneshot::Sender<Option<Error>>;
pub fn init(ws_out: WsOut, interrupt: InterruptSignal) -> Transmitter {
let (sender, receiver) = mpsc::unbounded_channel::<InternalRequest>();
tokio::spawn(transmission_task(ws_out, receiver, interrupt));
Transmitter { sender }
}
#[derive(Clone)]
pub struct Transmitter {
sender: RequestSender,
}
impl Transmitter {
pub fn make_request(&self, corr_id: RequestId, cmd: String) -> Result {
self.sender
.send((corr_id, cmd))
.map_err(|_| Arc::new(tungstenite::Error::AlreadyClosed))
}
}
async fn transmission_task(
mut ws_out: WsOut,
receiver: RequestReceiver,
interrupt: InterruptSignal,
) {
let mut request_stream = UnboundedReceiverStream::new(receiver);
tokio::select! {
result = try_send_all(&mut ws_out, &mut request_stream) => {
match result {
Ok(_) => {
log::debug!("All requests were sent successfully");
}
Err(e) => {
error_handler(&mut request_stream, &e).await;
}
}
}
e = interrupt => {
if let Ok(Some(err)) = e {
error_handler(&mut request_stream, &err).await;
} else {
request_stream.close();
if let Err(err) = try_send_all(&mut ws_out, &mut request_stream).await {
error_handler(&mut request_stream, &err).await;
}
}
}
}
log::debug!("Transmission task finished");
}
async fn try_send_all(
ws_out: &mut WsOut,
requests: &mut UnboundedReceiverStream<InternalRequest>,
) -> tungstenite::Result<()> {
let mut message_stream = requests.map(|(id, req)| Ok(into_message(id, req)));
ws_out.send_all(&mut message_stream).await
}
async fn error_handler(
request_stream: &mut UnboundedReceiverStream<InternalRequest>,
err: impl std::fmt::Display,
) {
request_stream.close();
while let Some((id, _)) = request_stream.next().await {
log::warn!("Dropping request `({id})` due to error: {err}");
}
}
fn into_message(id: RequestId, req: String) -> Message {
Message::text(serde_json::to_string(&Request::new(id, req)).unwrap())
}
#[derive(Serialize)]
struct Request {
#[serde(rename = "corrId")]
corr_id: String,
cmd: String,
}
impl Request {
fn new(id: RequestId, cmd: String) -> Self {
Self {
corr_id: id.to_string(),
cmd,
}
}
}