use async_trait::async_trait;
use serde_json::Value;
use tokio::io::AsyncWriteExt;
use crate::outbound::SharedWriter;
#[async_trait]
pub trait OutboundFrameSink: Send + Sync {
async fn send_frame(&self, frame: std::sync::Arc<Value>) -> Result<(), OutboundSinkError>;
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum OutboundSinkError {
#[error("transport closed")]
TransportClosed,
#[error("frame serialisation failed: {0}")]
Serialisation(#[source] serde_json::Error),
}
pub(crate) struct StdioFrameSink {
writer: SharedWriter,
}
impl StdioFrameSink {
pub(crate) fn new(writer: SharedWriter) -> Self {
Self { writer }
}
}
#[async_trait]
impl OutboundFrameSink for StdioFrameSink {
async fn send_frame(&self, frame: std::sync::Arc<Value>) -> Result<(), OutboundSinkError> {
let bytes = match serde_json::to_vec(&*frame) {
Ok(bytes) => bytes,
Err(err) => {
tracing::error!(
target: "klieo::mcp::stdio",
error = %err,
"outbound frame serialisation failed",
);
return Err(OutboundSinkError::Serialisation(err));
}
};
let mut w = self.writer.lock().await;
if w.write_all(&bytes).await.is_err()
|| w.write_all(b"\n").await.is_err()
|| w.flush().await.is_err()
{
return Err(OutboundSinkError::TransportClosed);
}
Ok(())
}
}
#[cfg(all(feature = "http", not(feature = "test-fixtures")))]
pub(crate) const OUTBOUND_QUEUE_CAPACITY: usize = 1024;
#[cfg(all(feature = "http", feature = "test-fixtures"))]
pub const OUTBOUND_QUEUE_CAPACITY: usize = 1024;
#[cfg(feature = "http")]
pub(crate) struct HttpFrameSink {
session: std::sync::Weak<crate::session::Session>,
tx: crate::outbound_ring::RingSender<(u64, std::sync::Arc<Value>)>,
sse_replay_capacity: usize,
}
#[cfg(feature = "http")]
impl HttpFrameSink {
pub(crate) fn new(
session: std::sync::Weak<crate::session::Session>,
tx: crate::outbound_ring::RingSender<(u64, std::sync::Arc<Value>)>,
sse_replay_capacity: usize,
) -> Self {
Self {
session,
tx,
sse_replay_capacity,
}
}
}
#[cfg(feature = "http")]
#[async_trait]
impl OutboundFrameSink for HttpFrameSink {
async fn send_frame(&self, frame: std::sync::Arc<Value>) -> Result<(), OutboundSinkError> {
if self.tx.is_receiver_dropped() {
return Err(OutboundSinkError::TransportClosed);
}
let Some(session) = self.session.upgrade() else {
return Err(OutboundSinkError::TransportClosed);
};
let buffered = (self.sse_replay_capacity > 0).then(|| std::sync::Arc::clone(&frame));
let dropped = {
let mut buffer = session.sse_replay_buffer.lock();
let event_id = session
.next_event_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if let Some(buffered) = buffered {
if buffer.len() >= self.sse_replay_capacity {
buffer.pop_front();
}
buffer.push_back((event_id, buffered));
}
let dropped = self.tx.push((event_id, frame));
drop(buffer);
dropped
};
if dropped > 0 {
tracing::warn!(
target = "klieo::mcp::outbound",
policy = "drop_oldest",
dropped = dropped,
"outbound ring full; dropped oldest frame(s)"
);
metrics::counter!(
"klieo_mcp_outbound_dropped_total",
"policy" => "oldest"
)
.increment(dropped as u64);
}
Ok(())
}
}
#[cfg(feature = "bench")]
struct BenchWriter(std::sync::Arc<tokio::sync::Mutex<Vec<u8>>>);
#[cfg(feature = "bench")]
impl tokio::io::AsyncWrite for BenchWriter {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
if let Ok(mut guard) = self.0.try_lock() {
guard.extend_from_slice(buf);
}
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}
#[cfg(feature = "bench")]
pub fn bench_stdio_sink() -> (
std::sync::Arc<dyn OutboundFrameSink>,
std::sync::Arc<tokio::sync::Mutex<Vec<u8>>>,
) {
let buf = std::sync::Arc::new(tokio::sync::Mutex::new(Vec::new()));
let writer: crate::outbound::SharedWriter =
std::sync::Arc::new(tokio::sync::Mutex::new(BenchWriter(buf.clone())));
(std::sync::Arc::new(StdioFrameSink::new(writer)), buf)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::sync::Mutex;
#[tokio::test]
async fn stdio_sink_writes_newline_delimited_payload() {
let buffer: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
let writer: SharedWriter = Arc::new(Mutex::new(CapturingWriter(buffer.clone())));
let sink = StdioFrameSink::new(writer);
sink.send_frame(std::sync::Arc::new(
serde_json::json!({"jsonrpc": "2.0", "id": 7}),
))
.await
.expect("write succeeds against in-memory writer");
let written = buffer.lock().await.clone();
let text = String::from_utf8(written).expect("frame is utf-8");
assert!(
text.ends_with('\n'),
"stdio framing requires a trailing newline, got {text:?}"
);
let parsed: serde_json::Value =
serde_json::from_str(text.trim_end()).expect("frame is valid json");
assert_eq!(parsed["jsonrpc"], "2.0");
assert_eq!(parsed["id"], 7);
}
#[tokio::test]
async fn stdio_sink_returns_transport_closed_on_write_error() {
let writer: SharedWriter = Arc::new(Mutex::new(BrokenWriter));
let sink = StdioFrameSink::new(writer);
let outcome = sink
.send_frame(std::sync::Arc::new(serde_json::json!({"jsonrpc": "2.0"})))
.await;
assert!(matches!(outcome, Err(OutboundSinkError::TransportClosed)));
}
struct CapturingWriter(Arc<Mutex<Vec<u8>>>);
impl tokio::io::AsyncWrite for CapturingWriter {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
let mut guard = self.0.try_lock().expect("test writer is uncontended");
guard.extend_from_slice(buf);
std::task::Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}
struct BrokenWriter;
impl tokio::io::AsyncWrite for BrokenWriter {
fn poll_write(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
_buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"test broken pipe",
)))
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::task::Poll::Ready(Ok(()))
}
}
#[cfg(feature = "http")]
mod http_sink_tests {
use super::*;
use crate::outbound_ring::bounded_ring;
#[tokio::test]
async fn http_send_delivers_through_ring() {
let session = std::sync::Arc::new(crate::session::Session::new_stdio());
let (tx, mut rx) = bounded_ring::<(u64, std::sync::Arc<Value>)>(4);
let sink = HttpFrameSink::new(std::sync::Arc::downgrade(&session), tx, 8);
sink.send_frame(std::sync::Arc::new(
serde_json::json!({"jsonrpc": "2.0", "id": 1}),
))
.await
.expect("send_frame succeeds against open ring");
let (event_id, received) = rx.recv().await.expect("ring delivers value");
assert_eq!(event_id, 1, "first frame gets event id 1");
assert_eq!(received["id"], 1);
}
#[tokio::test]
async fn http_send_returns_transport_closed_when_rx_dropped() {
let session = std::sync::Arc::new(crate::session::Session::new_stdio());
let (tx, rx) = bounded_ring::<(u64, std::sync::Arc<Value>)>(4);
drop(rx);
let sink = HttpFrameSink::new(std::sync::Arc::downgrade(&session), tx, 8);
let outcome = sink
.send_frame(std::sync::Arc::new(serde_json::json!({"id": 1})))
.await;
assert!(matches!(outcome, Err(OutboundSinkError::TransportClosed)));
}
#[tokio::test]
async fn http_send_returns_transport_closed_when_session_dropped() {
let session = std::sync::Arc::new(crate::session::Session::new_stdio());
let weak = std::sync::Arc::downgrade(&session);
let (tx, _rx) = bounded_ring::<(u64, std::sync::Arc<Value>)>(4);
let sink = HttpFrameSink::new(weak, tx, 8);
drop(session);
let outcome = sink
.send_frame(std::sync::Arc::new(serde_json::json!({"id": 1})))
.await;
assert!(matches!(outcome, Err(OutboundSinkError::TransportClosed)));
}
#[tokio::test]
async fn http_send_drops_oldest_when_ring_full() {
let session = std::sync::Arc::new(crate::session::Session::new_stdio());
let (tx, mut rx) = bounded_ring::<(u64, std::sync::Arc<Value>)>(2);
let sink = HttpFrameSink::new(std::sync::Arc::downgrade(&session), tx.clone(), 8);
for id in 1..=3 {
sink.send_frame(std::sync::Arc::new(serde_json::json!({"id": id})))
.await
.expect("ring accepts each push");
}
assert_eq!(tx.dropped_oldest_count(), 1);
let (_, first) = rx.recv().await.unwrap();
let (_, second) = rx.recv().await.unwrap();
assert_eq!(first["id"], 2);
assert_eq!(second["id"], 3);
}
#[tokio::test]
async fn http_send_writes_to_sse_replay_buffer() {
let session = std::sync::Arc::new(crate::session::Session::new_stdio());
let (tx, _rx) = bounded_ring::<(u64, std::sync::Arc<Value>)>(4);
let sink = HttpFrameSink::new(std::sync::Arc::downgrade(&session), tx, 8);
for i in 1..=3 {
sink.send_frame(std::sync::Arc::new(serde_json::json!({"id": i})))
.await
.expect("send_frame ok");
}
let buffer = session.sse_replay_buffer.lock();
assert_eq!(buffer.len(), 3);
let ids: Vec<u64> = buffer.iter().map(|(id, _)| *id).collect();
assert_eq!(ids, vec![1, 2, 3]);
}
#[tokio::test]
async fn http_send_skips_sse_replay_buffer_when_disabled() {
let session = std::sync::Arc::new(crate::session::Session::new_stdio());
let (tx, _rx) = bounded_ring::<(u64, std::sync::Arc<Value>)>(4);
let sink = HttpFrameSink::new(std::sync::Arc::downgrade(&session), tx, 0);
sink.send_frame(std::sync::Arc::new(serde_json::json!({"id": 1})))
.await
.expect("send_frame ok");
let buffer = session.sse_replay_buffer.lock();
assert!(buffer.is_empty(), "sse_replay_capacity=0 disables writes");
}
#[test]
fn capacity_constant_is_named_and_nonzero() {
const _: () = assert!(OUTBOUND_QUEUE_CAPACITY > 0);
const _: () = assert!(OUTBOUND_QUEUE_CAPACITY >= 256);
}
}
}