use std::fmt;
use async_trait::async_trait;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, stdin, stdout};
use tokio::sync::mpsc;
use tracing::{debug, error, info};
use super::{RequestHandler, Transport, TransportMessage};
use crate::error::{Error, Result};
use crate::protocol::{ErrorMessage, Request, Response};
pub struct StdioTransport {
request_handler: Option<RequestHandler>,
}
impl fmt::Debug for StdioTransport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StdioTransport")
.field(
"request_handler",
&format!("<handler: {}>", self.request_handler.is_some()),
)
.finish()
}
}
impl Default for StdioTransport {
fn default() -> Self {
Self::new()
}
}
impl StdioTransport {
pub fn new() -> Self {
Self {
request_handler: None,
}
}
async fn process_message(&self, message_text: String) -> Result<()> {
let value: Value = serde_json::from_str(&message_text)
.map_err(|e| Error::Protocol(format!("Failed to parse message: {e}")))?;
if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
debug!("Received request: {}", request.tool);
if let Some(handler) = &self.request_handler {
let tx = handler(request.clone());
tokio::spawn(async move {
let (resp_tx, mut resp_rx) = mpsc::channel(10);
if let Err(e) = tx.send((request, resp_tx)).await {
error!("Failed to send request: {}", e);
return;
}
while let Some(message) = resp_rx.recv().await {
match message {
TransportMessage::Response(response) => {
if let Err(e) = Self::send_response(&response).await {
error!("Failed to send response: {}", e);
}
}
TransportMessage::Error(error) => {
if let Err(e) = Self::send_error(&error).await {
error!("Failed to send error: {}", e);
}
}
TransportMessage::Notification(notification) => {
if let Err(e) = Self::send_json(¬ification).await {
error!("Failed to send notification: {}", e);
}
}
}
}
});
} else {
error!("No request handler set");
return Err(Error::Transport("No request handler set".to_string()));
}
} else {
error!("Unsupported message type");
return Err(Error::Protocol("Unsupported message type".to_string()));
}
Ok(())
}
async fn send_response(response: &Response) -> Result<()> {
let json = serde_json::to_string(response).map_err(Error::Json)?;
Self::send_raw(&json).await
}
async fn send_error(error: &ErrorMessage) -> Result<()> {
let json = serde_json::to_string(error).map_err(Error::Json)?;
Self::send_raw(&json).await
}
async fn send_json<T: serde::Serialize>(value: &T) -> Result<()> {
let json = serde_json::to_string(value).map_err(Error::Json)?;
Self::send_raw(&json).await
}
async fn send_raw(text: &str) -> Result<()> {
let mut stdout = stdout();
stdout.write_all(text.as_bytes()).await?;
stdout.write_all(b"\n").await?;
stdout.flush().await?;
Ok(())
}
}
#[async_trait]
impl Transport for StdioTransport {
async fn start(&self) -> Result<()> {
info!("Starting stdio transport");
if self.request_handler.is_none() {
return Err(Error::Transport("No request handler set".to_string()));
}
let mut reader = tokio::io::BufReader::new(stdin()).lines();
while let Some(line) = reader.next_line().await? {
if let Err(e) = self.process_message(line).await {
error!("Error processing message: {}", e);
}
}
Ok(())
}
fn set_request_handler(&mut self, handler: RequestHandler) {
self.request_handler = Some(handler);
}
}