use async_trait::async_trait;
use std::collections::VecDeque;
use std::io;
use std::sync::{Arc, Mutex};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{ChildStdin, ChildStdout};
use tokio::sync::mpsc;
use tracing::{error, trace};
const READ_BUFFER_SIZE: usize = 4096;
const UTF8_ACCUMULATION_BUFFER_CAPACITY: usize = 8192;
#[async_trait]
pub trait Transport: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn send(&mut self, message: &str) -> Result<(), Self::Error>;
async fn receive(&mut self) -> Result<String, Self::Error>;
async fn close(&mut self) -> Result<(), Self::Error>;
fn is_connected(&self) -> bool;
}
#[derive(Debug, thiserror::Error)]
pub enum StdioTransportError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("Transport is disconnected")]
Disconnected,
#[error("Channel error: {0}")]
Channel(String),
}
#[derive(Debug)]
pub struct StdioTransport {
stdin_sender: Option<mpsc::UnboundedSender<String>>,
stdout_receiver: Option<mpsc::UnboundedReceiver<String>>,
connected: bool,
}
struct StdoutReaderState {
byte_buffer: Vec<u8>,
buffer_capacity: usize,
}
impl StdoutReaderState {
fn new() -> Self {
Self {
byte_buffer: Vec::with_capacity(UTF8_ACCUMULATION_BUFFER_CAPACITY),
buffer_capacity: UTF8_ACCUMULATION_BUFFER_CAPACITY,
}
}
fn add_bytes(&mut self, bytes: &[u8]) {
self.byte_buffer.extend_from_slice(bytes);
}
fn extract_valid_utf8(&mut self) -> Option<Vec<u8>> {
if self.byte_buffer.is_empty() {
return None;
}
match std::str::from_utf8(&self.byte_buffer) {
Ok(_) => {
Some(self.byte_buffer.drain(..).collect())
}
Err(e) => {
let valid_end = e.valid_up_to();
if valid_end == 0 {
None
} else {
Some(self.byte_buffer.drain(..valid_end).collect())
}
}
}
}
fn should_compact(&self) -> bool {
self.byte_buffer.capacity() > self.buffer_capacity * 2
}
fn compact(&mut self) {
if self.should_compact() {
self.byte_buffer.shrink_to(self.buffer_capacity);
}
}
}
impl StdioTransport {
pub fn new(stdin: ChildStdin, stdout: ChildStdout) -> Self {
let (stdin_sender, stdin_receiver) = mpsc::unbounded_channel();
let (stdout_sender, stdout_receiver) = mpsc::unbounded_channel();
tokio::spawn(Self::stdin_writer_task(stdin, stdin_receiver));
tokio::spawn(Self::stdout_reader_task(stdout, stdout_sender));
Self {
stdin_sender: Some(stdin_sender),
stdout_receiver: Some(stdout_receiver),
connected: true,
}
}
async fn stdin_writer_task(
mut stdin: ChildStdin,
mut receiver: mpsc::UnboundedReceiver<String>,
) {
while let Some(message) = receiver.recv().await {
trace!(
"StdioTransport: Writing message (length: {})",
message.len()
);
if let Err(e) = stdin.write_all(message.as_bytes()).await {
error!("Failed to write to stdin: {}", e);
break;
}
if let Err(e) = stdin.flush().await {
error!("Failed to flush stdin: {}", e);
break;
}
}
trace!("StdioTransport: stdin writer task finished");
}
async fn stdout_reader_task(stdout: ChildStdout, sender: mpsc::UnboundedSender<String>) {
let mut reader = BufReader::new(stdout);
let mut state = StdoutReaderState::new();
let mut read_buffer = Box::new([0u8; READ_BUFFER_SIZE]);
loop {
match reader.read(read_buffer.as_mut()).await {
Ok(0) => {
Self::handle_eof(&mut state, &sender);
break;
}
Ok(n) => {
state.add_bytes(&read_buffer[..n]);
while let Some(valid_bytes) = state.extract_valid_utf8() {
match String::from_utf8(valid_bytes) {
Ok(data) => {
if sender.send(data).is_err() {
trace!(
"StdioTransport: stdout receiver dropped, stopping reader"
);
return;
}
}
Err(e) => {
error!(
"StdioTransport: Failed to convert validated UTF-8 bytes: {}",
e
);
break;
}
}
}
state.compact();
}
Err(e) => {
error!("Failed to read from stdout: {}", e);
break;
}
}
}
trace!("StdioTransport: stdout reader task finished");
}
fn handle_eof(state: &mut StdoutReaderState, sender: &mpsc::UnboundedSender<String>) {
trace!("StdioTransport: stdout reader reached EOF");
if let Some(final_bytes) = state.extract_valid_utf8() {
match String::from_utf8(final_bytes) {
Ok(final_string) => {
if !final_string.is_empty() && sender.send(final_string).is_err() {
trace!("StdioTransport: stdout receiver dropped during EOF processing");
}
}
Err(e) => {
error!("StdioTransport: Invalid UTF-8 in final bytes: {}", e);
}
}
}
if !state.byte_buffer.is_empty() {
error!(
"StdioTransport: {} incomplete bytes remaining at EOF: {:?}",
state.byte_buffer.len(),
state.byte_buffer
);
}
}
}
#[async_trait]
impl Transport for StdioTransport {
type Error = StdioTransportError;
async fn send(&mut self, message: &str) -> Result<(), Self::Error> {
if !self.connected {
return Err(StdioTransportError::Disconnected);
}
let sender = self
.stdin_sender
.as_ref()
.ok_or(StdioTransportError::Disconnected)?;
sender
.send(message.to_string())
.map_err(|e| StdioTransportError::Channel(e.to_string()))?;
Ok(())
}
async fn receive(&mut self) -> Result<String, Self::Error> {
if !self.connected {
return Err(StdioTransportError::Disconnected);
}
let receiver = self
.stdout_receiver
.as_mut()
.ok_or(StdioTransportError::Disconnected)?;
receiver
.recv()
.await
.ok_or(StdioTransportError::Disconnected)
}
async fn close(&mut self) -> Result<(), Self::Error> {
self.connected = false;
self.stdin_sender.take();
self.stdout_receiver.take();
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
}
#[derive(Debug, thiserror::Error)]
#[allow(dead_code)]
pub enum MockTransportError {
#[error("Transport is disconnected")]
Disconnected,
#[error("No more responses available")]
NoMoreResponses,
}
#[allow(dead_code)]
pub struct MockTransport {
sent_messages: Arc<Mutex<Vec<String>>>,
responses: Arc<Mutex<VecDeque<String>>>,
connected: bool,
}
#[allow(dead_code)]
impl MockTransport {
pub fn new() -> Self {
Self {
sent_messages: Arc::new(Mutex::new(Vec::new())),
responses: Arc::new(Mutex::new(VecDeque::new())),
connected: true,
}
}
pub fn with_responses(responses: Vec<String>) -> Self {
let mut transport = Self::new();
transport.add_responses(responses);
transport
}
fn add_responses(&mut self, responses: Vec<String>) {
let mut response_queue = self.responses.lock().unwrap();
response_queue.extend(responses);
}
pub fn add_response(&mut self, response: String) {
let mut responses = self.responses.lock().unwrap();
responses.push_back(response);
}
pub fn sent_messages(&self) -> Vec<String> {
self.sent_messages.lock().unwrap().clone()
}
pub fn clear_sent_messages(&mut self) {
self.sent_messages.lock().unwrap().clear();
}
pub fn has_responses(&self) -> bool {
!self.responses.lock().unwrap().is_empty()
}
}
impl Default for MockTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Transport for MockTransport {
type Error = MockTransportError;
async fn send(&mut self, message: &str) -> Result<(), Self::Error> {
if !self.connected {
return Err(MockTransportError::Disconnected);
}
self.sent_messages.lock().unwrap().push(message.to_string());
Ok(())
}
async fn receive(&mut self) -> Result<String, Self::Error> {
if !self.connected {
return Err(MockTransportError::Disconnected);
}
let mut responses = self.responses.lock().unwrap();
responses
.pop_front()
.ok_or(MockTransportError::NoMoreResponses)
}
async fn close(&mut self) -> Result<(), Self::Error> {
self.connected = false;
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::process::Stdio;
use tokio::process::Command;
#[tokio::test]
async fn test_stdio_transport_echo() {
let mut child = Command::new("echo")
.arg("hello world")
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.spawn()
.expect("Failed to spawn echo command");
let stdin = child.stdin.take().unwrap();
let stdout = child.stdout.take().unwrap();
let mut transport = StdioTransport::new(stdin, stdout);
let output = transport.receive().await.unwrap();
assert_eq!(output.trim(), "hello world");
assert!(transport.is_connected());
transport.close().await.unwrap();
let _ = child.wait().await;
}
#[tokio::test]
async fn test_mock_transport_send_receive() {
let mut transport =
MockTransport::with_responses(vec!["response1".to_string(), "response2".to_string()]);
transport.send("message1").await.unwrap();
transport.send("message2").await.unwrap();
let response1 = transport.receive().await.unwrap();
assert_eq!(response1, "response1");
let response2 = transport.receive().await.unwrap();
assert_eq!(response2, "response2");
let sent = transport.sent_messages();
assert_eq!(sent, vec!["message1", "message2"]);
assert!(transport.receive().await.is_err());
}
#[tokio::test]
async fn test_mock_transport_disconnect() {
let mut transport = MockTransport::new();
assert!(transport.is_connected());
transport.close().await.unwrap();
assert!(!transport.is_connected());
assert!(transport.send("test").await.is_err());
assert!(transport.receive().await.is_err());
}
#[tokio::test]
async fn test_stdout_reader_state_accumulation() {
let mut state = StdoutReaderState::new();
let partial_utf8 = &[0xE4, 0xB8]; state.add_bytes(partial_utf8);
assert!(state.extract_valid_utf8().is_none());
let completion = &[0x96]; state.add_bytes(completion);
let extracted = state
.extract_valid_utf8()
.expect("Should extract complete UTF-8");
let result = String::from_utf8(extracted).expect("Should be valid UTF-8");
assert_eq!(result, "世");
assert!(state.extract_valid_utf8().is_none());
assert!(state.byte_buffer.is_empty());
}
#[tokio::test]
async fn test_stdout_reader_mixed_boundaries() {
let mut state = StdoutReaderState::new();
let data1 = "Hello ".as_bytes(); let data2 = &[0xE4, 0xB8]; let data3 = &[0x96, 0xE7, 0x95]; let data4 = &[0x8C, 0x20, 0xF0, 0x9F]; let data5 = &[0x8C, 0x8D];
state.add_bytes(data1);
let result1 = state.extract_valid_utf8().expect("Should extract 'Hello '");
assert_eq!(String::from_utf8(result1).unwrap(), "Hello ");
state.add_bytes(data2);
assert!(state.extract_valid_utf8().is_none());
state.add_bytes(data3);
let result2 = state.extract_valid_utf8().expect("Should extract '世'");
assert_eq!(String::from_utf8(result2).unwrap(), "世");
state.add_bytes(data4);
let result3 = state.extract_valid_utf8().expect("Should extract '界 '");
assert_eq!(String::from_utf8(result3).unwrap(), "界 ");
state.add_bytes(data5);
let result4 = state.extract_valid_utf8().expect("Should extract '🌍'");
assert_eq!(String::from_utf8(result4).unwrap(), "🌍");
assert!(state.byte_buffer.is_empty());
}
}