#![cfg(not(target_arch = "wasm32"))]
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde_json::Value;
use tokio::sync::{mpsc, oneshot, RwLock};
use tokio::time::timeout;
use tracing::{debug, warn};
use crate::error::{Error, ErrorCode, Result};
use crate::types::ServerRequest;
pub const DEFAULT_DISPATCH_TIMEOUT: Duration = Duration::from_secs(60);
static DISPATCH_COUNTER: AtomicU64 = AtomicU64::new(1);
pub struct ServerRequestDispatcher {
outbound_tx: mpsc::Sender<(String, ServerRequest)>,
pending: Arc<RwLock<HashMap<String, oneshot::Sender<Value>>>>,
timeout_duration: Duration,
}
impl ServerRequestDispatcher {
pub fn new_with_channel(outbound_tx: mpsc::Sender<(String, ServerRequest)>) -> Self {
Self {
outbound_tx,
pending: Arc::new(RwLock::new(HashMap::new())),
timeout_duration: DEFAULT_DISPATCH_TIMEOUT,
}
}
#[must_use]
pub fn with_timeout(mut self, timeout_duration: Duration) -> Self {
self.timeout_duration = timeout_duration;
self
}
fn next_correlation_id() -> String {
let id = DISPATCH_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("dispatch-{id}")
}
pub async fn dispatch(&self, request: ServerRequest) -> Result<Value> {
if self.outbound_tx.is_closed() {
return Err(Error::protocol(
ErrorCode::INTERNAL_ERROR,
"ServerRequestDispatcher outbound channel closed",
));
}
let (tx, rx) = oneshot::channel::<Value>();
let correlation_id = Self::next_correlation_id();
self.pending
.write()
.await
.insert(correlation_id.clone(), tx);
if let Err(e) = self
.outbound_tx
.send((correlation_id.clone(), request))
.await
{
self.pending.write().await.remove(&correlation_id);
return Err(Error::protocol(
ErrorCode::INTERNAL_ERROR,
format!("Failed to enqueue server request: {e}"),
));
}
debug!("Dispatched server request: {}", correlation_id);
match timeout(self.timeout_duration, rx).await {
Ok(Ok(value)) => Ok(value),
Ok(Err(_)) => {
self.pending.write().await.remove(&correlation_id);
Err(Error::protocol(
ErrorCode::INTERNAL_ERROR,
"Dispatch oneshot channel closed",
))
},
Err(_) => {
self.pending.write().await.remove(&correlation_id);
Err(Error::protocol(
ErrorCode::REQUEST_TIMEOUT,
format!("Server request {correlation_id} timed out"),
))
},
}
}
pub async fn handle_response(&self, correlation_id: &str, response: Value) -> Result<()> {
let mut pending = self.pending.write().await;
if let Some(tx) = pending.remove(correlation_id) {
if tx.send(response).is_err() {
warn!("Dispatch response receiver dropped: {}", correlation_id);
}
Ok(())
} else {
warn!(
"Received response for unknown correlation: {}",
correlation_id
);
Err(Error::protocol(
ErrorCode::INVALID_REQUEST,
format!("Unknown correlation id: {correlation_id}"),
))
}
}
#[allow(dead_code)]
pub async fn pending_count(&self) -> usize {
self.pending.read().await.len()
}
}
impl std::fmt::Debug for ServerRequestDispatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServerRequestDispatcher")
.field("timeout_duration", &self.timeout_duration)
.field("outbound_tx_closed", &self.outbound_tx.is_closed())
.finish()
}
}
pub fn spawn_server_request_drain<T>(
transport: Arc<crate::runtime::RwLock<T>>,
mut outbound_rx: mpsc::Receiver<(String, ServerRequest)>,
) where
T: crate::shared::Transport + 'static,
{
tokio::spawn(async move {
while let Some((correlation_id, server_request)) = outbound_rx.recv().await {
let request = crate::types::Request::Server(Box::new(server_request));
let id = crate::types::RequestId::from(correlation_id.clone());
let mut t = transport.write().await;
if let Err(e) = t
.send(crate::shared::TransportMessage::Request { id, request })
.await
{
warn!(
"Failed to dispatch server request {}: {}",
correlation_id, e
);
}
}
debug!("Server-request drain task exited");
});
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_dispatcher_enqueues_on_outbound_channel() {
let (tx, mut rx) = mpsc::channel::<(String, ServerRequest)>(4);
let dispatcher =
ServerRequestDispatcher::new_with_channel(tx).with_timeout(Duration::from_millis(100));
let dispatch_fut =
tokio::spawn(async move { dispatcher.dispatch(ServerRequest::ListRoots).await });
let (correlation_id, req) = tokio::time::timeout(Duration::from_millis(50), rx.recv())
.await
.expect("recv deadline")
.expect("channel closed unexpectedly");
assert!(
!correlation_id.is_empty(),
"correlation id must be non-empty"
);
assert!(matches!(req, ServerRequest::ListRoots));
let _ = dispatch_fut.await;
}
#[tokio::test]
async fn test_dispatcher_fulfills_on_handle_response() {
let (tx, mut rx) = mpsc::channel::<(String, ServerRequest)>(4);
let dispatcher = Arc::new(
ServerRequestDispatcher::new_with_channel(tx).with_timeout(Duration::from_secs(2)),
);
let dispatch_fut = {
let d = dispatcher.clone();
tokio::spawn(async move { d.dispatch(ServerRequest::ListRoots).await })
};
let (correlation_id, _req) = rx.recv().await.expect("outbound must receive");
let response = serde_json::json!({"roots": []});
dispatcher
.handle_response(&correlation_id, response.clone())
.await
.expect("handle_response must succeed");
let result = dispatch_fut.await.unwrap().expect("dispatch must succeed");
assert_eq!(result, response);
assert_eq!(dispatcher.pending_count().await, 0);
}
#[tokio::test]
async fn test_dispatcher_timeout_cleans_pending() {
let (tx, mut _rx) = mpsc::channel::<(String, ServerRequest)>(4);
let dispatcher =
ServerRequestDispatcher::new_with_channel(tx).with_timeout(Duration::from_millis(40));
let result = dispatcher.dispatch(ServerRequest::ListRoots).await;
assert!(result.is_err(), "dispatch must timeout");
assert_eq!(
dispatcher.pending_count().await,
0,
"timeout must clean pending"
);
}
#[tokio::test]
async fn test_dispatcher_handle_response_unknown_id_returns_err() {
let (tx, _rx) = mpsc::channel::<(String, ServerRequest)>(4);
let dispatcher = ServerRequestDispatcher::new_with_channel(tx);
let result = dispatcher
.handle_response("does-not-exist", serde_json::json!({}))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_dispatcher_debug_does_not_leak_correlation_ids() {
let (tx, _rx) = mpsc::channel::<(String, ServerRequest)>(4);
let dispatcher = Arc::new(
ServerRequestDispatcher::new_with_channel(tx).with_timeout(Duration::from_secs(5)),
);
let d = dispatcher.clone();
let _fut = tokio::spawn(async move { d.dispatch(ServerRequest::ListRoots).await });
tokio::time::sleep(Duration::from_millis(10)).await;
let debug_str = format!("{:?}", dispatcher);
assert!(
!debug_str.contains("dispatch-"),
"debug must not leak correlation id: {debug_str}"
);
assert!(debug_str.contains("ServerRequestDispatcher"));
}
}