use crate::server::protocol::{ErrorData, ResponseEnvelope};
use crate::server::sink::{WsSink, WsSinkError};
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug)]
pub enum WsOpSinkError {
Sink(WsSinkError),
TerminalAlreadySent,
MissingOpId,
}
impl std::fmt::Display for WsOpSinkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WsOpSinkError::Sink(e) => write!(f, "{e}"),
WsOpSinkError::TerminalAlreadySent => write!(f, "terminal message already sent"),
WsOpSinkError::MissingOpId => write!(f, "missing op_id for streaming message"),
}
}
}
impl std::error::Error for WsOpSinkError {}
impl From<WsSinkError> for WsOpSinkError {
fn from(e: WsSinkError) -> Self {
WsOpSinkError::Sink(e)
}
}
#[derive(Clone)]
pub struct WsOpSink {
sink: WsSink,
id: String,
op_id: Option<String>,
terminal_sent: Arc<Mutex<bool>>,
}
impl WsOpSink {
pub fn new(sink: WsSink, id: String, op_id: Option<String>) -> Self {
Self {
sink,
id,
op_id,
terminal_sent: Arc::new(Mutex::new(false)),
}
}
pub fn id(&self) -> &str {
&self.id
}
pub fn op_id(&self) -> Option<&str> {
self.op_id.as_deref()
}
pub async fn send_progress<T: Serialize>(&self, data: T) -> Result<(), WsOpSinkError> {
self.ensure_not_terminal().await?;
let op_id = self.op_id.as_ref().ok_or(WsOpSinkError::MissingOpId)?;
Ok(self
.sink
.send_envelope(ResponseEnvelope::progress(
self.id.clone(),
op_id.clone(),
data,
))
.await?)
}
pub async fn send_stream<T: Serialize>(&self, data: T) -> Result<(), WsOpSinkError> {
self.ensure_not_terminal().await?;
let op_id = self.op_id.as_ref().ok_or(WsOpSinkError::MissingOpId)?;
Ok(self
.sink
.send_envelope(ResponseEnvelope::stream(
self.id.clone(),
op_id.clone(),
data,
))
.await?)
}
pub async fn send_result<T: Serialize>(&self, data: T) -> Result<(), WsOpSinkError> {
self.mark_terminal().await?;
match &self.op_id {
Some(op_id) => Ok(self
.sink
.send_envelope(ResponseEnvelope::result_with_op(
self.id.clone(),
op_id.clone(),
data,
))
.await?),
None => Ok(self
.sink
.send_envelope(ResponseEnvelope::result(self.id.clone(), data))
.await?),
}
}
pub async fn send_error(&self, error: ErrorData) -> Result<(), WsOpSinkError> {
self.mark_terminal().await?;
Ok(self
.sink
.send_envelope(ResponseEnvelope::error(
self.id.clone(),
self.op_id.clone(),
error,
))
.await?)
}
pub fn inner(&self) -> &WsSink {
&self.sink
}
async fn ensure_not_terminal(&self) -> Result<(), WsOpSinkError> {
let sent = *self.terminal_sent.lock().await;
if sent {
Err(WsOpSinkError::TerminalAlreadySent)
} else {
Ok(())
}
}
async fn mark_terminal(&self) -> Result<(), WsOpSinkError> {
let mut sent = self.terminal_sent.lock().await;
if *sent {
return Err(WsOpSinkError::TerminalAlreadySent);
}
*sent = true;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn terminal_guard_allows_one_terminal() {
struct Guard(Arc<Mutex<bool>>);
impl Guard {
async fn ensure_not_terminal(&self) -> Result<(), WsOpSinkError> {
let sent = *self.0.lock().await;
if sent {
Err(WsOpSinkError::TerminalAlreadySent)
} else {
Ok(())
}
}
async fn mark_terminal(&self) -> Result<(), WsOpSinkError> {
let mut sent = self.0.lock().await;
if *sent {
return Err(WsOpSinkError::TerminalAlreadySent);
}
*sent = true;
Ok(())
}
}
let g = Guard(Arc::new(Mutex::new(false)));
g.ensure_not_terminal().await.unwrap();
g.mark_terminal().await.unwrap();
assert!(matches!(
g.mark_terminal().await,
Err(WsOpSinkError::TerminalAlreadySent)
));
assert!(matches!(
g.ensure_not_terminal().await,
Err(WsOpSinkError::TerminalAlreadySent)
));
}
}