use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use serde_json::Value;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::Message;
use crate::error::{Result, SlopError};
use crate::server::{Connection, SlopServer};
enum ConnMessage {
Send(Value),
Close,
}
struct ChannelConnection {
tx: mpsc::UnboundedSender<ConnMessage>,
}
impl Connection for ChannelConnection {
fn send(&self, message: &Value) -> Result<()> {
self.tx
.send(ConnMessage::Send(message.clone()))
.map_err(|_| SlopError::Transport("connection closed".into()))
}
fn close(&self) -> Result<()> {
let _ = self.tx.send(ConnMessage::Close);
Ok(())
}
}
pub async fn serve(slop: &SlopServer, addr: &str) -> Result<JoinHandle<()>> {
let listener = TcpListener::bind(addr)
.await
.map_err(|e| SlopError::Transport(e.to_string()))?;
let slop = slop.clone();
let handle = tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let slop = slop.clone();
tokio::spawn(async move {
let ws_stream = match tokio_tungstenite::accept_async(stream).await {
Ok(ws) => ws,
Err(_) => return,
};
let (mut sender, mut receiver) = ws_stream.split();
let (tx, mut rx) = mpsc::unbounded_channel::<ConnMessage>();
let conn: Arc<dyn Connection> = Arc::new(ChannelConnection { tx });
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
match msg {
ConnMessage::Send(val) => {
let json = serde_json::to_string(&val).unwrap_or_default();
if sender.send(Message::Text(json.into())).await.is_err() {
break;
}
}
ConnMessage::Close => {
let _ = sender.send(Message::Close(None)).await;
break;
}
}
}
});
slop.handle_connection(conn.clone());
while let Some(Ok(msg)) = receiver.next().await {
if let Message::Text(text) = msg {
if let Ok(parsed) = serde_json::from_str::<Value>(&text) {
slop.handle_message(&conn, &parsed);
}
}
}
slop.handle_disconnect(&conn);
});
}
});
Ok(handle)
}