mcp_rs_sdk/
agent.rs

1//! The `agent` module provides the main runner for an MCP agent.
2use crate::error::McpError;
3use crate::types::{GetContextRequest, GetContextResponse, Response};
4use std::io::{self, BufRead, Write};
5
6/// Runs the main MCP agent loop for non-streaming responses.
7///
8/// This function listens for incoming JSON requests from stdin, deserializes them into
9/// [`GetContextRequest`] objects, and passes them to the provided handler function.
10/// The handler's response is then serialized to JSON and written to stdout.
11///
12/// The loop continues until stdin is closed or an unrecoverable error occurs.
13///
14/// # Arguments
15///
16/// * `handler` - A function or closure that takes a `GetContextRequest` and
17///   returns a `Result<Response, E>` where E is an error type that
18///   can be converted into a `Box<dyn std::error::Error>`.
19///
20/// # Returns
21///
22/// This function returns `Ok(())` if the loop completes successfully (e.g., stdin is closed).
23/// It returns an `Err(McpError)` if a non-recoverable error occurs, such as an
24/// I/O error or a JSON processing error. Handler errors are wrapped but do not
25/// terminate the loop.
26pub fn run_agent<F, E>(mut handler: F) -> Result<(), McpError>
27where
28    F: FnMut(GetContextRequest) -> Result<Response, E>,
29    E: std::error::Error + Send + Sync + 'static,
30{
31    let stdin = io::stdin();
32    let mut stdout = io::stdout();
33
34    for line in stdin.lock().lines() {
35        let line = line?; // Propagates I/O errors
36        if line.trim().is_empty() {
37            continue;
38        }
39
40        let request: GetContextRequest = match serde_json::from_str(&line) {
41            Ok(req) => req,
42            Err(e) => {
43                return Err(McpError::Json(e));
44            }
45        };
46
47        match handler(request) {
48            Ok(response) => {
49                let final_response = GetContextResponse { response };
50                let response_json = serde_json::to_string(&final_response)?;
51                writeln!(stdout, "{}", response_json)?;
52                stdout.flush()?;
53            }
54            Err(e) => {
55                let handler_error = McpError::Handler(Box::new(e));
56                eprintln!("{}", handler_error);
57            }
58        }
59    }
60    Ok(())
61}
62
63#[cfg(feature = "streaming")]
64use crate::types::PartialResponse;
65#[cfg(feature = "streaming")]
66use futures::Stream;
67
68/// Runs the main MCP agent loop for streaming responses.
69///
70/// This function is only available when the `streaming` feature is enabled.
71///
72/// It listens for incoming requests and calls the provided handler, which must
73/// return a `Stream` of `PartialResponse` items. Each item from the stream is
74/// immediately serialized and written to stdout.
75///
76/// # Arguments
77///
78/// * `handler` - A function that takes a `GetContextRequest` and returns a `Stream`
79///   of `Result<PartialResponse, E>`.
80///
81/// # Returns
82///
83/// Returns `Ok(())` if the loop completes successfully, or an `Err(McpError)`
84/// if a non-recoverable I/O or JSON error occurs.
85#[cfg(feature = "streaming")]
86pub async fn run_streaming_agent<F, S, E>(mut handler: F) -> Result<(), McpError>
87where
88    F: FnMut(GetContextRequest) -> S,
89    S: Stream<Item = Result<PartialResponse, E>>,
90    E: std::error::Error + Send + Sync + 'static,
91{
92    use futures::pin_mut;
93    use futures::StreamExt;
94    use tokio::io::{AsyncBufReadExt, BufReader};
95
96    // Use tokio's async stdin
97    let stdin = tokio::io::stdin();
98    let mut lines = BufReader::new(stdin).lines();
99    let mut stdout = io::stdout();
100
101    while let Some(line) = lines.next_line().await? {
102        if line.trim().is_empty() {
103            continue;
104        }
105
106        let request: GetContextRequest = match serde_json::from_str(&line) {
107            Ok(req) => req,
108            Err(e) => {
109                // Unlike the sync version, we'll report the error and continue,
110                // as async streams might be more resilient to single malformed requests.
111                eprintln!("{}", McpError::Json(e));
112                continue;
113            }
114        };
115
116        let stream = handler(request);
117        pin_mut!(stream);
118
119        while let Some(item) = stream.next().await {
120            match item {
121                Ok(partial_response) => {
122                    let response_json = serde_json::to_string(&partial_response)?;
123                    writeln!(stdout, "{}", response_json)?;
124                    stdout.flush()?;
125                }
126                Err(e) => {
127                    let handler_error = McpError::Handler(Box::new(e));
128                    eprintln!("{}", handler_error);
129                }
130            }
131        }
132    }
133    Ok(())
134}