use anyhow::Result;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use super::{error_codes, AnalyzerMcpServer, JsonRpcError, Request, Response};
const DEFAULT_MAX_RESPONSE_BYTES: usize = 2_000_000;
fn response_size_ceiling() -> usize {
std::env::var("TRUSTY_MCP_MAX_RESPONSE_BYTES")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.unwrap_or(DEFAULT_MAX_RESPONSE_BYTES)
}
pub fn guard_response_size(bytes: Vec<u8>, ceiling: usize) -> Vec<u8> {
if bytes.len() <= ceiling {
return bytes;
}
let n = bytes.len();
tracing::warn!(
response_bytes = n,
ceiling,
"MCP response exceeds stdio size ceiling — replacing with truncation notice"
);
let id = serde_json::from_slice::<serde_json::Value>(&bytes)
.ok()
.and_then(|v| v.get("id").cloned())
.unwrap_or(serde_json::Value::Null);
let notice = serde_json::json!({
"result": {
"isError": true,
"content": [{
"type": "text",
"text": format!(
"Response truncated: {n} bytes exceeded limit {ceiling}. \
Use limit/offset pagination to retrieve results in smaller pages."
),
}]
},
"jsonrpc": "2.0",
"id": id,
});
let mut out = serde_json::to_vec(¬ice).unwrap_or_else(|_| b"{}".to_vec());
out.push(b'\n');
out
}
pub async fn run(server: AnalyzerMcpServer) -> Result<()> {
let stdin = tokio::io::stdin();
let mut reader = BufReader::new(stdin).lines();
let mut stdout = tokio::io::stdout();
let ceiling = response_size_ceiling();
while let Some(line) = reader.next_line().await? {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let resp: Response = match serde_json::from_str::<Request>(trimmed) {
Ok(req) => server.dispatch(req).await,
Err(e) => Response {
jsonrpc: "2.0".into(),
id: serde_json::Value::Null,
result: None,
error: Some(JsonRpcError {
code: error_codes::INVALID_REQUEST,
message: format!("parse error: {e}"),
data: None,
}),
suppress: false,
},
};
if resp.suppress {
continue;
}
let mut bytes = serde_json::to_vec(&resp)?;
bytes.push(b'\n');
let bytes = guard_response_size(bytes, ceiling);
stdout.write_all(&bytes).await?;
stdout.flush().await?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stdio_size_guard_passes_within_limit() {
let payload = b"hello world\n".to_vec();
let ceiling = payload.len(); let out = guard_response_size(payload.clone(), ceiling);
assert_eq!(out, payload, "payload within limit must be unchanged");
}
#[test]
fn stdio_size_guard_truncates_oversized_response() {
let large = vec![b'x'; 3_000_000]; let ceiling = 2_000_000usize;
let out = guard_response_size(large, ceiling);
let text = std::str::from_utf8(&out).expect("utf8");
assert!(
text.contains("Response truncated"),
"truncation notice expected, got: {text:.200}"
);
assert!(
text.contains("3000000"),
"original size expected in notice, got: {text:.200}"
);
assert!(out.len() < ceiling, "notice must be smaller than ceiling");
}
#[test]
fn stdio_size_guard_echoes_request_id() {
let large_response = serde_json::json!({
"jsonrpc": "2.0",
"id": 42,
"result": { "data": "x".repeat(3_000_000) },
});
let bytes = serde_json::to_vec(&large_response).unwrap();
assert!(
bytes.len() > 2_000_000,
"pre-condition: bytes must exceed ceiling"
);
let out = guard_response_size(bytes, 2_000_000);
let trimmed = out.trim_ascii_end();
let v: serde_json::Value =
serde_json::from_slice(trimmed).expect("truncation notice must be valid JSON");
assert_eq!(
v["id"],
serde_json::Value::from(42i64),
"truncation notice must echo the request id"
);
assert!(v["result"]["isError"].as_bool().unwrap_or(false));
}
#[test]
fn stdio_size_guard_notice_is_valid_json() {
let large = vec![b'x'; 3_000_000];
let out = guard_response_size(large, 2_000_000);
let trimmed = out.trim_ascii_end();
let v: serde_json::Value =
serde_json::from_slice(trimmed).expect("truncation notice must be valid JSON");
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["result"]["isError"], true);
}
}