fastmcp 0.0.0

A Rust framework for building Model Context Protocol (MCP) services
Documentation
use std::fmt;

use async_trait::async_trait;
use serde_json::Value;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, stdin, stdout};
use tokio::sync::mpsc;
use tracing::{debug, error, info};

use super::{RequestHandler, Transport, TransportMessage};
use crate::error::{Error, Result};
use crate::protocol::{ErrorMessage, Request, Response};

/// Transport implementation for stdin/stdout
pub struct StdioTransport {
    /// Request handler
    request_handler: Option<RequestHandler>,
}

// Manual Debug implementation since RequestHandler doesn't implement Debug
impl fmt::Debug for StdioTransport {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("StdioTransport")
            .field(
                "request_handler",
                &format!("<handler: {}>", self.request_handler.is_some()),
            )
            .finish()
    }
}

impl Default for StdioTransport {
    fn default() -> Self {
        Self::new()
    }
}

impl StdioTransport {
    /// Create a new StdioTransport
    pub fn new() -> Self {
        Self {
            request_handler: None,
        }
    }

    /// Process an incoming message
    async fn process_message(&self, message_text: String) -> Result<()> {
        // Parse the message
        let value: Value = serde_json::from_str(&message_text)
            .map_err(|e| Error::Protocol(format!("Failed to parse message: {e}")))?;

        // Check if it's a request
        if let Ok(request) = serde_json::from_value::<Request>(value.clone()) {
            debug!("Received request: {}", request.tool);

            // Process the request
            if let Some(handler) = &self.request_handler {
                let tx = handler(request.clone());

                // Spawn a task to process this request
                tokio::spawn(async move {
                    // Create a response channel
                    let (resp_tx, mut resp_rx) = mpsc::channel(10);

                    // Send the request to be processed, along with the response channel
                    if let Err(e) = tx.send((request, resp_tx)).await {
                        error!("Failed to send request: {}", e);
                        return;
                    }

                    // Wait for and process responses
                    while let Some(message) = resp_rx.recv().await {
                        match message {
                            TransportMessage::Response(response) => {
                                if let Err(e) = Self::send_response(&response).await {
                                    error!("Failed to send response: {}", e);
                                }
                            }
                            TransportMessage::Error(error) => {
                                if let Err(e) = Self::send_error(&error).await {
                                    error!("Failed to send error: {}", e);
                                }
                            }
                            TransportMessage::Notification(notification) => {
                                if let Err(e) = Self::send_json(&notification).await {
                                    error!("Failed to send notification: {}", e);
                                }
                            }
                        }
                    }
                });
            } else {
                error!("No request handler set");
                return Err(Error::Transport("No request handler set".to_string()));
            }
        } else {
            // Not a request - we'll ignore other message types for now
            error!("Unsupported message type");
            return Err(Error::Protocol("Unsupported message type".to_string()));
        }

        Ok(())
    }

    /// Send a response back to the client
    async fn send_response(response: &Response) -> Result<()> {
        let json = serde_json::to_string(response).map_err(Error::Json)?;

        Self::send_raw(&json).await
    }

    /// Send an error back to the client
    async fn send_error(error: &ErrorMessage) -> Result<()> {
        let json = serde_json::to_string(error).map_err(Error::Json)?;

        Self::send_raw(&json).await
    }

    /// Send any serializable object as JSON
    async fn send_json<T: serde::Serialize>(value: &T) -> Result<()> {
        let json = serde_json::to_string(value).map_err(Error::Json)?;

        Self::send_raw(&json).await
    }

    /// Send raw text to stdout
    async fn send_raw(text: &str) -> Result<()> {
        let mut stdout = stdout();
        stdout.write_all(text.as_bytes()).await?;
        stdout.write_all(b"\n").await?;
        stdout.flush().await?;

        Ok(())
    }
}

#[async_trait]
impl Transport for StdioTransport {
    /// Start the transport
    async fn start(&self) -> Result<()> {
        info!("Starting stdio transport");

        // Ensure we have a request handler
        if self.request_handler.is_none() {
            return Err(Error::Transport("No request handler set".to_string()));
        }

        // Read from stdin
        let mut reader = tokio::io::BufReader::new(stdin()).lines();

        while let Some(line) = reader.next_line().await? {
            if let Err(e) = self.process_message(line).await {
                error!("Error processing message: {}", e);
            }
        }

        Ok(())
    }

    /// Set the request handler
    fn set_request_handler(&mut self, handler: RequestHandler) {
        self.request_handler = Some(handler);
    }
}