use klieo_core::{ServerOutbound, ServerOutboundError};
use std::collections::HashMap;
use std::sync::atomic::{AtomicI64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncWrite;
use tokio::sync::{oneshot, Mutex};
use crate::outbound_sink::OutboundFrameSink;
const JSONRPC_INTERNAL_ERROR: i64 = -32603;
pub(crate) type SharedWriter = Arc<Mutex<dyn AsyncWrite + Send + Unpin>>;
#[derive(Debug)]
struct PeerError {
code: i64,
message: String,
}
pub(crate) struct OutboundRequests {
pending: Mutex<HashMap<i64, oneshot::Sender<Result<serde_json::Value, PeerError>>>>,
next_id: AtomicI64,
sink: Arc<dyn OutboundFrameSink>,
}
impl OutboundRequests {
pub fn new(sink: Arc<dyn OutboundFrameSink>) -> Self {
Self {
pending: Mutex::new(HashMap::new()),
next_id: AtomicI64::new(1),
sink,
}
}
pub async fn complete_pending(&self, id: i64, message: serde_json::Value) {
let sender = self.pending.lock().await.remove(&id);
let Some(sender) = sender else {
tracing::warn!(rpc_id = id, "unknown outbound response id; dropping");
return;
};
let payload = response_payload(&message);
if sender.send(payload).is_err() {
tracing::warn!(
rpc_id = id,
"outbound caller dropped receiver before response arrived"
);
}
}
async fn drop_pending(&self, id: i64) {
self.pending.lock().await.remove(&id);
}
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub(crate) async fn send_notification_frame(
&self,
method: &str,
payload_bytes: usize,
) -> Result<(), crate::outbound_sink::OutboundSinkError> {
let frame = std::sync::Arc::new(serde_json::json!({
"jsonrpc": "2.0",
"method": method,
"params": {
"padding": "x".repeat(payload_bytes),
},
}));
self.sink.send_frame(frame).await
}
#[cfg(feature = "http")]
pub(crate) async fn drain_pending_as_closed(&self) {
let drained: Vec<i64> = {
let mut pending = self.pending.lock().await;
pending.drain().map(|(id, _sender)| id).collect()
};
for id in drained {
tracing::warn!(
target: "mcp.outbound",
rpc_id = id,
"transport closed; outbound request abandoned",
);
}
}
}
fn response_payload(message: &serde_json::Value) -> Result<serde_json::Value, PeerError> {
if let Some(err) = message.get("error") {
let code = err
.get("code")
.and_then(|c| c.as_i64())
.unwrap_or(JSONRPC_INTERNAL_ERROR);
let message = err
.get("message")
.and_then(|m| m.as_str())
.unwrap_or("")
.to_string();
return Err(PeerError { code, message });
}
Ok(message
.get("result")
.cloned()
.unwrap_or(serde_json::Value::Null))
}
fn frame_request(id: i64, method: &str, params: &serde_json::Value) -> serde_json::Value {
serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"method": method,
"params": params,
})
}
#[async_trait::async_trait]
impl ServerOutbound for OutboundRequests {
async fn outbound_request(
&self,
method: &str,
params: serde_json::Value,
timeout: Duration,
) -> Result<serde_json::Value, ServerOutboundError> {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = oneshot::channel();
self.pending.lock().await.insert(id, tx);
let frame = std::sync::Arc::new(frame_request(id, method, ¶ms));
match self.sink.send_frame(frame).await {
Ok(()) => {}
Err(crate::outbound_sink::OutboundSinkError::Serialisation(err)) => {
self.drop_pending(id).await;
return Err(ServerOutboundError::Serialisation(err));
}
Err(_) => {
self.drop_pending(id).await;
return Err(ServerOutboundError::TransportClosed);
}
}
match tokio::time::timeout(timeout, rx).await {
Ok(Ok(Ok(value))) => Ok(value),
Ok(Ok(Err(peer))) => Err(ServerOutboundError::PeerError {
code: peer.code,
message: peer.message,
}),
Ok(Err(_recv_dropped)) => Err(ServerOutboundError::TransportClosed),
Err(_elapsed) => {
self.drop_pending(id).await;
Err(ServerOutboundError::Timeout)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::outbound_sink::OutboundSinkError;
use async_trait::async_trait;
use std::time::Duration;
const OUTBOUND_SPIN_WAIT: Duration = Duration::from_millis(2);
struct CapturingSink {
frames: Mutex<Vec<serde_json::Value>>,
}
impl CapturingSink {
fn new() -> Self {
Self {
frames: Mutex::new(Vec::new()),
}
}
}
#[async_trait]
impl OutboundFrameSink for CapturingSink {
async fn send_frame(
&self,
frame: std::sync::Arc<serde_json::Value>,
) -> Result<(), OutboundSinkError> {
self.frames.lock().await.push((*frame).clone());
Ok(())
}
}
fn sink() -> Arc<dyn OutboundFrameSink> {
Arc::new(CapturingSink::new())
}
async fn pending_len(reqs: &OutboundRequests) -> usize {
reqs.pending.lock().await.len()
}
#[tokio::test]
async fn happy_path_routes_result_to_caller() {
let reqs = Arc::new(OutboundRequests::new(sink()));
let call = {
let reqs = reqs.clone();
tokio::spawn(async move {
reqs.outbound_request(
"sampling/createMessage",
serde_json::json!({"messages": []}),
Duration::from_secs(1),
)
.await
})
};
let expected_id: i64 = 1;
for _ in 0..100 {
if pending_len(&reqs).await == 1 {
break;
}
tokio::time::sleep(OUTBOUND_SPIN_WAIT).await;
}
assert_eq!(pending_len(&reqs).await, 1);
reqs.complete_pending(
expected_id,
serde_json::json!({
"jsonrpc": "2.0",
"id": expected_id,
"result": {"role": "assistant", "model": "stub"}
}),
)
.await;
let result = call.await.expect("task panicked").expect("outbound err");
assert_eq!(result["role"], "assistant");
assert_eq!(result["model"], "stub");
assert_eq!(pending_len(&reqs).await, 0);
}
#[tokio::test]
async fn timeout_returns_timeout_error_and_clears_pending() {
let reqs = Arc::new(OutboundRequests::new(sink()));
let outcome = reqs
.outbound_request(
"roots/list",
serde_json::Value::Null,
Duration::from_millis(25),
)
.await;
assert!(matches!(outcome, Err(ServerOutboundError::Timeout)));
assert_eq!(
pending_len(&reqs).await,
0,
"timeout path must purge pending entry"
);
}
#[tokio::test]
async fn unknown_id_complete_does_not_panic() {
let reqs = OutboundRequests::new(sink());
reqs.complete_pending(
9_999,
serde_json::json!({"jsonrpc":"2.0","id":9999,"result":{}}),
)
.await;
assert_eq!(pending_len(&reqs).await, 0);
}
#[cfg(feature = "http")]
#[tokio::test]
async fn drain_pending_as_closed_clears_table_and_wakes_callers() {
let reqs = Arc::new(OutboundRequests::new(sink()));
let mut handles = Vec::with_capacity(3);
for _ in 0..3 {
let reqs = reqs.clone();
handles.push(tokio::spawn(async move {
reqs.outbound_request(
"sampling/createMessage",
serde_json::json!({"messages": []}),
Duration::from_secs(5),
)
.await
}));
}
for _ in 0..100 {
if pending_len(&reqs).await == 3 {
break;
}
tokio::time::sleep(OUTBOUND_SPIN_WAIT).await;
}
assert_eq!(pending_len(&reqs).await, 3);
reqs.drain_pending_as_closed().await;
assert_eq!(pending_len(&reqs).await, 0, "drain must empty the table");
for handle in handles {
let outcome = handle.await.expect("task panicked");
assert!(
matches!(outcome, Err(ServerOutboundError::TransportClosed)),
"dropped sender must wake caller as TransportClosed; got {outcome:?}"
);
}
}
#[tokio::test]
async fn peer_error_response_maps_to_peer_error() {
let reqs = Arc::new(OutboundRequests::new(sink()));
let call = {
let reqs = reqs.clone();
tokio::spawn(async move {
reqs.outbound_request(
"sampling/createMessage",
serde_json::json!({"messages": []}),
Duration::from_secs(1),
)
.await
})
};
let expected_id: i64 = 1;
for _ in 0..100 {
if pending_len(&reqs).await == 1 {
break;
}
tokio::time::sleep(OUTBOUND_SPIN_WAIT).await;
}
reqs.complete_pending(
expected_id,
serde_json::json!({
"jsonrpc": "2.0",
"id": expected_id,
"error": {"code": -32601, "message": "Method not found"}
}),
)
.await;
let outcome = call.await.expect("task panicked");
match outcome {
Err(ServerOutboundError::PeerError { code, message }) => {
assert_eq!(code, -32601);
assert_eq!(message, "Method not found");
}
other => panic!("expected PeerError, got {other:?}"),
}
assert_eq!(pending_len(&reqs).await, 0);
}
struct SerializationFailingSink;
#[async_trait]
impl OutboundFrameSink for SerializationFailingSink {
async fn send_frame(
&self,
_frame: std::sync::Arc<serde_json::Value>,
) -> Result<(), OutboundSinkError> {
let err = serde_json::from_str::<serde_json::Value>("{invalid}").unwrap_err();
Err(OutboundSinkError::Serialisation(err))
}
}
#[tokio::test]
async fn serialisation_sink_error_maps_to_server_serialisation_and_clears_pending() {
let reqs = Arc::new(OutboundRequests::new(Arc::new(SerializationFailingSink)));
let outcome = reqs
.outbound_request(
"sampling/createMessage",
serde_json::json!({"messages": []}),
Duration::from_secs(1),
)
.await;
assert!(
matches!(outcome, Err(ServerOutboundError::Serialisation(_))),
"sink Serialisation must propagate as ServerOutboundError::Serialisation; got {outcome:?}"
);
assert_eq!(
pending_len(&reqs).await,
0,
"serialisation failure must drop the pending entry"
);
}
}