1use super::audit::{AuditEvent, AuditLog};
2use super::jsonrpc::JsonRpcRequest;
3use super::policy::{make_deny_response, McpPolicy, PolicyDecision, PolicyState};
4use std::{
5 io::{self, BufRead, BufReader, Write},
6 process::{Child, Command, Stdio},
7 sync::{Arc, Mutex},
8 thread,
9};
10
11#[derive(Clone, Debug, Default)]
12pub struct ProxyConfig {
13 pub dry_run: bool,
14 pub verbose: bool,
15 pub audit_log_path: Option<std::path::PathBuf>,
16}
17
18pub struct McpProxy {
19 child: Child,
20 policy: McpPolicy,
21 config: ProxyConfig,
22}
23
24impl Drop for McpProxy {
25 fn drop(&mut self) {
26 let _ = self.child.kill();
28 }
29}
30
31impl McpProxy {
32 pub fn spawn(
33 command: &str,
34 args: &[String],
35 policy: McpPolicy,
36 config: ProxyConfig,
37 ) -> io::Result<Self> {
38 let child = Command::new(command)
39 .args(args)
40 .stdin(Stdio::piped())
41 .stdout(Stdio::piped())
42 .stderr(Stdio::inherit()) .spawn()?;
44
45 Ok(Self {
46 child,
47 policy,
48 config,
49 })
50 }
51
52 pub fn run(mut self) -> io::Result<i32> {
53 let mut child_stdin = self.child.stdin.take().expect("child stdin");
54 let child_stdout = self.child.stdout.take().expect("child stdout");
55
56 let stdout = Arc::new(Mutex::new(io::stdout()));
57 let policy = self.policy.clone();
58 let config = self.config.clone();
59
60 let stdout_a = stdout.clone();
62 let t_server_to_client = thread::spawn(move || -> io::Result<()> {
63 let mut reader = BufReader::new(child_stdout);
64 let mut line = String::new();
65
66 while reader.read_line(&mut line)? > 0 {
67 let mut out = stdout_a
68 .lock()
69 .map_err(|e| io::Error::other(e.to_string()))?;
70 out.write_all(line.as_bytes())?;
71 out.flush()?;
72 line.clear();
73 }
74 Ok(())
75 });
76
77 let stdout_b = stdout.clone();
79 let t_client_to_server = thread::spawn(move || -> io::Result<()> {
80 let stdin = io::stdin();
81 let mut reader = stdin.lock();
82 let mut line = String::new();
83
84 let mut state = PolicyState::default();
85 let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
86
87 while reader.read_line(&mut line)? > 0 {
88 match serde_json::from_str::<JsonRpcRequest>(&line) {
90 Ok(req) => {
91 match policy.check(&req, &mut state) {
93 PolicyDecision::Allow => {
94 Self::handle_allow(&req, &mut audit_log, config.verbose);
95 }
96 PolicyDecision::AllowWithWarning { tool, code, reason } => {
97 if config.verbose {
99 eprintln!(
100 "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
101 tool,
102 code,
103 reason
104 );
105 }
106 audit_log.log(&AuditEvent {
107 timestamp: chrono::Utc::now().to_rfc3339(),
108 decision: "allow_with_warning".to_string(),
109 tool: Some(tool.clone()),
110 reason: Some(reason.clone()),
111 request_id: req.id.clone(),
112 agentic: None,
113 });
114 Self::handle_allow(&req, &mut audit_log, false);
116 }
118 PolicyDecision::Deny {
119 tool,
120 code: _,
121 reason,
122 contract,
123 } => {
124 let decision_str =
126 if config.dry_run { "would_deny" } else { "deny" };
127
128 if config.verbose {
129 eprintln!(
130 "[assay] {} {} (reason: {})",
131 decision_str.to_uppercase(),
132 tool,
133 reason
134 );
135 }
136
137 audit_log.log(&AuditEvent {
138 timestamp: chrono::Utc::now().to_rfc3339(),
139 decision: decision_str.to_string(),
140 tool: Some(tool.clone()),
141 reason: Some(reason.clone()),
142 request_id: req.id.clone(),
143 agentic: Some(contract.clone()),
144 });
145
146 if config.dry_run {
147 } else {
150 let id = req.id.unwrap_or(serde_json::Value::Null);
152 let response_json = make_deny_response(
153 id,
154 "Content blocked by policy",
155 contract,
156 );
157
158 let mut out = stdout_b
159 .lock()
160 .map_err(|e| io::Error::other(e.to_string()))?;
161 out.write_all(response_json.as_bytes())?;
162 out.flush()?;
163
164 line.clear();
165 continue; }
167 }
168 }
169 }
170 Err(_) => {
171 let trimmed = line.trim();
173 if trimmed.starts_with('{')
174 && (trimmed.contains("\"method\"")
175 || trimmed.contains("\"params\"")
176 || trimmed.contains("\"tool\""))
177 {
178 eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
179 }
180 }
181 }
182
183 child_stdin.write_all(line.as_bytes())?;
185 child_stdin.flush()?;
186 line.clear();
187 }
188 Ok(())
189 });
190
191 t_client_to_server
193 .join()
194 .map_err(|_| io::Error::other("client->server thread panicked"))??;
195
196 let _ = t_server_to_client.join();
198
199 let status = self.child.wait()?;
201 Ok(status.code().unwrap_or(1))
202 }
203
204 fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
205 if verbose && req.is_tool_call() {
206 let tool = req
207 .tool_params()
208 .map(|p| p.name)
209 .unwrap_or_else(|| "unknown".to_string());
210 eprintln!("[assay] ALLOW {}", tool);
211 }
212
213 if req.is_tool_call() {
214 let tool = req.tool_params().map(|p| p.name);
215 audit_log.log(&AuditEvent {
216 timestamp: chrono::Utc::now().to_rfc3339(),
217 decision: "allow".to_string(),
218 tool,
219 reason: None,
220 request_id: req.id.clone(),
221 agentic: None,
222 });
223 }
224 }
225}