use crate::server::protocol::ResponseEnvelope;
use axum::extract::ws::{Message, WebSocket};
use futures::stream::SplitSink;
use futures::SinkExt;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Clone)]
pub struct WsSink {
pub(crate) inner: Arc<Mutex<SplitSink<WebSocket, Message>>>,
}
impl WsSink {
pub fn new(sender: SplitSink<WebSocket, Message>) -> Self {
Self {
inner: Arc::new(Mutex::new(sender)),
}
}
pub async fn send_message_raw(&self, msg: Message) -> Result<(), WsSinkError> {
let mut sender = self.inner.lock().await;
sender
.send(msg)
.await
.map_err(|e| WsSinkError::SendError(e.to_string()))
}
pub async fn send_envelope(&self, envelope: ResponseEnvelope) -> Result<(), WsSinkError> {
let json = serde_json::to_string(&envelope)
.map_err(|e| WsSinkError::SerializationError(e.to_string()))?;
self.send_message_raw(Message::Text(json)).await
}
}
#[derive(Debug, Clone)]
pub enum WsSinkError {
SerializationError(String),
SendError(String),
}
impl std::fmt::Display for WsSinkError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WsSinkError::SerializationError(e) => write!(f, "Serialization error: {}", e),
WsSinkError::SendError(e) => write!(f, "Send error: {}", e),
}
}
}
impl std::error::Error for WsSinkError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_sink_error_display() {
let err = WsSinkError::SerializationError("test".to_string());
assert!(err.to_string().contains("Serialization error"));
let err = WsSinkError::SendError("connection closed".to_string());
assert!(err.to_string().contains("Send error"));
}
}