mcp_rs_sdk/agent.rs
1//! The `agent` module provides the main runner for an MCP agent.
2use crate::error::McpError;
3use crate::types::{GetContextRequest, GetContextResponse, Response};
4use std::io::{self, BufRead, Write};
5
6/// Runs the main MCP agent loop for non-streaming responses.
7///
8/// This function listens for incoming JSON requests from stdin, deserializes them into
9/// [`GetContextRequest`] objects, and passes them to the provided handler function.
10/// The handler's response is then serialized to JSON and written to stdout.
11///
12/// The loop continues until stdin is closed or an unrecoverable error occurs.
13///
14/// # Arguments
15///
16/// * `handler` - A function or closure that takes a `GetContextRequest` and
17/// returns a `Result<Response, E>` where E is an error type that
18/// can be converted into a `Box<dyn std::error::Error>`.
19///
20/// # Returns
21///
22/// This function returns `Ok(())` if the loop completes successfully (e.g., stdin is closed).
23/// It returns an `Err(McpError)` if a non-recoverable error occurs, such as an
24/// I/O error or a JSON processing error. Handler errors are wrapped but do not
25/// terminate the loop.
26pub fn run_agent<F, E>(mut handler: F) -> Result<(), McpError>
27where
28 F: FnMut(GetContextRequest) -> Result<Response, E>,
29 E: std::error::Error + Send + Sync + 'static,
30{
31 let stdin = io::stdin();
32 let mut stdout = io::stdout();
33
34 for line in stdin.lock().lines() {
35 let line = line?; // Propagates I/O errors
36 if line.trim().is_empty() {
37 continue;
38 }
39
40 let request: GetContextRequest = match serde_json::from_str(&line) {
41 Ok(req) => req,
42 Err(e) => {
43 return Err(McpError::Json(e));
44 }
45 };
46
47 match handler(request) {
48 Ok(response) => {
49 let final_response = GetContextResponse { response };
50 let response_json = serde_json::to_string(&final_response)?;
51 writeln!(stdout, "{}", response_json)?;
52 stdout.flush()?;
53 }
54 Err(e) => {
55 let handler_error = McpError::Handler(Box::new(e));
56 eprintln!("{}", handler_error);
57 }
58 }
59 }
60 Ok(())
61}
62
63#[cfg(feature = "streaming")]
64use crate::types::PartialResponse;
65#[cfg(feature = "streaming")]
66use futures::Stream;
67
68/// Runs the main MCP agent loop for streaming responses.
69///
70/// This function is only available when the `streaming` feature is enabled.
71///
72/// It listens for incoming requests and calls the provided handler, which must
73/// return a `Stream` of `PartialResponse` items. Each item from the stream is
74/// immediately serialized and written to stdout.
75///
76/// # Arguments
77///
78/// * `handler` - A function that takes a `GetContextRequest` and returns a `Stream`
79/// of `Result<PartialResponse, E>`.
80///
81/// # Returns
82///
83/// Returns `Ok(())` if the loop completes successfully, or an `Err(McpError)`
84/// if a non-recoverable I/O or JSON error occurs.
85#[cfg(feature = "streaming")]
86pub async fn run_streaming_agent<F, S, E>(mut handler: F) -> Result<(), McpError>
87where
88 F: FnMut(GetContextRequest) -> S,
89 S: Stream<Item = Result<PartialResponse, E>>,
90 E: std::error::Error + Send + Sync + 'static,
91{
92 use futures::pin_mut;
93 use futures::StreamExt;
94 use tokio::io::{AsyncBufReadExt, BufReader};
95
96 // Use tokio's async stdin
97 let stdin = tokio::io::stdin();
98 let mut lines = BufReader::new(stdin).lines();
99 let mut stdout = io::stdout();
100
101 while let Some(line) = lines.next_line().await? {
102 if line.trim().is_empty() {
103 continue;
104 }
105
106 let request: GetContextRequest = match serde_json::from_str(&line) {
107 Ok(req) => req,
108 Err(e) => {
109 // Unlike the sync version, we'll report the error and continue,
110 // as async streams might be more resilient to single malformed requests.
111 eprintln!("{}", McpError::Json(e));
112 continue;
113 }
114 };
115
116 let stream = handler(request);
117 pin_mut!(stream);
118
119 while let Some(item) = stream.next().await {
120 match item {
121 Ok(partial_response) => {
122 let response_json = serde_json::to_string(&partial_response)?;
123 writeln!(stdout, "{}", response_json)?;
124 stdout.flush()?;
125 }
126 Err(e) => {
127 let handler_error = McpError::Handler(Box::new(e));
128 eprintln!("{}", handler_error);
129 }
130 }
131 }
132 }
133 Ok(())
134}