use std::collections::HashMap;
use std::sync::Arc;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::{Mutex, mpsc, oneshot};
use tokio_tungstenite::tungstenite::Message as WsMessage;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use crate::error::MooError;
use crate::message::{MooBody, MooMessage, MooVerb};
use crate::{parse, serialize};
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WsSink = SplitSink<WsStream, WsMessage>;
type WsSource = SplitStream<WsStream>;
enum RequestSlot {
OneShot(oneshot::Sender<MooMessage>),
Subscription(mpsc::Sender<MooMessage>),
}
pub type ServiceHandler = Arc<dyn Fn(MooMessage, ResponseSender) + Send + Sync>;
#[derive(Clone)]
pub struct ResponseSender {
sink: mpsc::Sender<WsMessage>,
request_id: u32,
}
impl ResponseSender {
pub async fn send_complete(
&self,
status: &str,
body: Option<serde_json::Value>,
) -> Result<(), MooError> {
let msg = MooMessage {
verb: MooVerb::Complete,
name: status.to_string(),
request_id: self.request_id,
headers: HashMap::new(),
body: body.map(MooBody::Json),
};
let raw = serialize(&msg);
self.sink
.send(WsMessage::Binary(raw.into()))
.await
.map_err(|_| MooError::ConnectionClosed)
}
pub async fn send_continue(
&self,
status: &str,
body: Option<serde_json::Value>,
) -> Result<(), MooError> {
let msg = MooMessage {
verb: MooVerb::Continue,
name: status.to_string(),
request_id: self.request_id,
headers: HashMap::new(),
body: body.map(MooBody::Json),
};
let raw = serialize(&msg);
self.sink
.send(WsMessage::Binary(raw.into()))
.await
.map_err(|_| MooError::ConnectionClosed)
}
}
pub struct MooConnection {
ws_tx: mpsc::Sender<WsMessage>,
next_request_id: Arc<Mutex<u32>>,
pending: Arc<Mutex<HashMap<u32, RequestSlot>>>,
task_handle: tokio::task::JoinHandle<()>,
}
impl std::fmt::Debug for MooConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MooConnection")
.field("alive", &!self.task_handle.is_finished())
.finish()
}
}
impl MooConnection {
pub async fn connect(
url: &str,
service_handlers: HashMap<String, ServiceHandler>,
) -> Result<Self, MooError> {
let (ws_stream, _) = tokio_tungstenite::connect_async(url)
.await
.map_err(|e| MooError::WebSocket(e.to_string()))?;
let (ws_sink, ws_source) = ws_stream.split();
let pending: Arc<Mutex<HashMap<u32, RequestSlot>>> = Arc::new(Mutex::new(HashMap::new()));
let (ws_tx, ws_rx) = mpsc::channel::<WsMessage>(64);
let task_handle = tokio::spawn(dispatch_loop(
ws_sink,
ws_source,
ws_rx,
pending.clone(),
ws_tx.clone(),
service_handlers,
));
Ok(MooConnection {
ws_tx,
next_request_id: Arc::new(Mutex::new(0)),
pending,
task_handle,
})
}
pub async fn send_request(
&self,
name: &str,
body: Option<serde_json::Value>,
) -> Result<MooMessage, MooError> {
let request_id = {
let mut id = self.next_request_id.lock().await;
let current = *id;
*id += 1;
current
};
let (tx, rx) = oneshot::channel();
{
let mut pending = self.pending.lock().await;
pending.insert(request_id, RequestSlot::OneShot(tx));
}
let msg = MooMessage {
verb: MooVerb::Request,
name: name.to_string(),
request_id,
headers: HashMap::new(),
body: body.map(MooBody::Json),
};
let raw = serialize(&msg);
self.ws_tx
.send(WsMessage::Binary(raw.into()))
.await
.map_err(|_| MooError::ConnectionClosed)?;
rx.await.map_err(|_| MooError::ConnectionClosed)
}
pub async fn subscribe(
&self,
name: &str,
body: serde_json::Value,
) -> Result<mpsc::Receiver<MooMessage>, MooError> {
let request_id = {
let mut id = self.next_request_id.lock().await;
let current = *id;
*id += 1;
current
};
let (tx, rx) = mpsc::channel(32);
{
let mut pending = self.pending.lock().await;
pending.insert(request_id, RequestSlot::Subscription(tx));
}
let msg = MooMessage {
verb: MooVerb::Request,
name: name.to_string(),
request_id,
headers: HashMap::new(),
body: Some(MooBody::Json(body)),
};
let raw = serialize(&msg);
self.ws_tx
.send(WsMessage::Binary(raw.into()))
.await
.map_err(|_| MooError::ConnectionClosed)?;
Ok(rx)
}
pub async fn close(self) {
let _ = self.ws_tx.send(WsMessage::Close(None)).await;
self.task_handle.abort();
let mut pending = self.pending.lock().await;
pending.clear();
}
pub fn is_alive(&self) -> bool {
!self.task_handle.is_finished()
}
}
async fn dispatch_loop(
mut ws_sink: WsSink,
mut ws_source: WsSource,
mut outgoing_rx: mpsc::Receiver<WsMessage>,
pending: Arc<Mutex<HashMap<u32, RequestSlot>>>,
ws_tx: mpsc::Sender<WsMessage>,
service_handlers: HashMap<String, ServiceHandler>,
) {
let mut ping_interval = tokio::time::interval(std::time::Duration::from_secs(10));
let mut is_alive = true;
loop {
tokio::select! {
_ = ping_interval.tick() => {
if !is_alive {
tracing::warn!("MOO heartbeat timeout: no pong received");
break;
}
is_alive = false;
if ws_sink.send(WsMessage::Ping(vec![].into())).await.is_err() {
break;
}
}
Some(msg) = outgoing_rx.recv() => {
if ws_sink.send(msg).await.is_err() {
break;
}
}
Some(result) = ws_source.next() => {
match result {
Ok(WsMessage::Binary(data)) => {
match parse(&data) {
Ok(msg) => {
handle_incoming(
msg,
&pending,
&ws_tx,
&service_handlers,
).await;
}
Err(e) => {
tracing::warn!("Failed to parse MOO message: {}", e);
}
}
}
Ok(WsMessage::Pong(_)) => {
is_alive = true;
}
Ok(WsMessage::Close(_)) => {
break;
}
Err(e) => {
tracing::warn!("WebSocket error: {}", e);
break;
}
_ => {
}
}
}
else => break,
}
}
let mut pending = pending.lock().await;
pending.clear();
}
async fn handle_incoming(
msg: MooMessage,
pending: &Arc<Mutex<HashMap<u32, RequestSlot>>>,
ws_tx: &mpsc::Sender<WsMessage>,
service_handlers: &HashMap<String, ServiceHandler>,
) {
match msg.verb {
MooVerb::Request => {
let service = msg.service().unwrap_or("").to_string();
let response_sender = ResponseSender {
sink: ws_tx.clone(),
request_id: msg.request_id,
};
if let Some(handler) = service_handlers.get(&service) {
handler(msg, response_sender);
} else {
let _ = response_sender
.send_complete(
"InvalidRequest",
Some(serde_json::json!({"error": format!("unknown service: {}", service)})),
)
.await;
}
}
MooVerb::Continue => {
let mut pending = pending.lock().await;
match pending.get(&msg.request_id) {
Some(RequestSlot::Subscription(tx)) => {
let request_id = msg.request_id;
if tx.send(msg).await.is_err() {
pending.remove(&request_id);
}
}
Some(RequestSlot::OneShot(_)) => {
if let Some(RequestSlot::OneShot(tx)) = pending.remove(&msg.request_id) {
let _ = tx.send(msg);
}
}
None => {
tracing::warn!(
"CONTINUE for unknown request_id {}: closing connection",
msg.request_id
);
}
}
}
MooVerb::Complete => {
let mut pending = pending.lock().await;
match pending.remove(&msg.request_id) {
Some(RequestSlot::OneShot(tx)) => {
let _ = tx.send(msg);
}
Some(RequestSlot::Subscription(tx)) => {
let _ = tx.send(msg).await;
}
None => {
tracing::warn!(
"COMPLETE for unknown request_id {}: closing connection",
msg.request_id
);
}
}
}
}
}