1use super::audit::{AuditEvent, AuditLog};
2use super::jsonrpc::JsonRpcRequest;
3use super::policy::{make_deny_response, McpPolicy, PolicyDecision, PolicyState};
4use std::{
5 collections::HashMap,
6 io::{self, BufRead, BufReader, Write},
7 process::{Child, Command, Stdio},
8 sync::{Arc, Mutex},
9 thread,
10};
11
12#[derive(Clone, Debug, Default)]
13pub struct ProxyConfig {
14 pub dry_run: bool,
15 pub verbose: bool,
16 pub audit_log_path: Option<std::path::PathBuf>,
17 pub server_id: String,
18}
19
20pub struct McpProxy {
21 child: Child,
22 policy: McpPolicy,
23 config: ProxyConfig,
24 identity_cache: Arc<Mutex<HashMap<String, super::identity::ToolIdentity>>>,
26}
27
28impl Drop for McpProxy {
29 fn drop(&mut self) {
30 let _ = self.child.kill();
32 }
33}
34
35impl McpProxy {
36 pub fn spawn(
37 command: &str,
38 args: &[String],
39 policy: McpPolicy,
40 config: ProxyConfig,
41 ) -> io::Result<Self> {
42 let child = Command::new(command)
43 .args(args)
44 .stdin(Stdio::piped())
45 .stdout(Stdio::piped())
46 .stderr(Stdio::inherit()) .spawn()?;
48
49 Ok(Self {
50 child,
51 policy,
52 config,
53 identity_cache: Arc::new(Mutex::new(HashMap::new())),
54 })
55 }
56
57 pub fn run(mut self) -> io::Result<i32> {
58 let mut child_stdin = self.child.stdin.take().expect("child stdin");
59 let child_stdout = self.child.stdout.take().expect("child stdout");
60
61 let stdout = Arc::new(Mutex::new(io::stdout()));
62 let policy = self.policy.clone();
63 let config = self.config.clone();
64 let identity_cache_a = self.identity_cache.clone();
65 let identity_cache_b = self.identity_cache.clone();
66
67 let stdout_a = stdout.clone();
69 let t_server_to_client = thread::spawn(move || -> io::Result<()> {
70 let mut reader = BufReader::new(child_stdout);
71 let mut line = String::new();
72
73 while reader.read_line(&mut line)? > 0 {
74 let mut processed_line = line.clone();
75
76 if let Ok(mut v) = serde_json::from_str::<serde_json::Value>(&line) {
78 if let Some(result) = v.get_mut("result") {
79 if let Some(tools) = result.get_mut("tools").and_then(|t| t.as_array_mut())
80 {
81 for tool in tools {
82 let name = tool
83 .get("name")
84 .and_then(|n| n.as_str())
85 .unwrap_or("unknown");
86 let description = tool
87 .get("description")
88 .and_then(|d| d.as_str())
89 .map(|s| s.to_string());
90 let input_schema = tool
91 .get("inputSchema")
92 .or_else(|| tool.get("input_schema"))
93 .cloned();
94
95 let identity = super::identity::ToolIdentity::new(
96 &config.server_id,
97 name,
98 &input_schema,
99 &description,
100 );
101
102 let mut cache = identity_cache_a.lock().unwrap();
104 cache.insert(name.to_string(), identity.clone());
105
106 tool.as_object_mut().and_then(|m| {
108 m.insert(
109 "tool_identity".to_string(),
110 serde_json::to_value(&identity).unwrap(),
111 )
112 });
113 }
114 processed_line =
115 serde_json::to_string(&v).unwrap_or(line.clone()) + "\n";
116 }
117 }
118 }
119
120 let mut out = stdout_a
121 .lock()
122 .map_err(|e| io::Error::other(e.to_string()))?;
123 out.write_all(processed_line.as_bytes())?;
124 out.flush()?;
125 line.clear();
126 }
127 Ok(())
128 });
129
130 let stdout_b = stdout.clone();
132 let t_client_to_server = thread::spawn(move || -> io::Result<()> {
133 let stdin = io::stdin();
134 let mut reader = stdin.lock();
135 let mut line = String::new();
136
137 let mut state = PolicyState::default();
138 let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
139
140 while reader.read_line(&mut line)? > 0 {
141 match serde_json::from_str::<JsonRpcRequest>(&line) {
143 Ok(req) => {
144 let runtime_id = if req.is_tool_call() {
146 let name = req.tool_params().map(|p| p.name).unwrap_or_default();
147 let cache = identity_cache_b.lock().unwrap();
148 cache.get(&name).cloned()
149 } else {
150 None
151 };
152
153 match policy.evaluate(
154 &req.tool_params().map(|p| p.name).unwrap_or_default(),
155 &req.tool_params()
156 .map(|p| p.arguments)
157 .unwrap_or(serde_json::Value::Null),
158 &mut state,
159 runtime_id.as_ref(),
160 ) {
161 PolicyDecision::Allow => {
162 Self::handle_allow(&req, &mut audit_log, config.verbose);
163 }
164 PolicyDecision::AllowWithWarning { tool, code, reason } => {
165 if config.verbose {
167 eprintln!(
168 "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
169 tool,
170 code,
171 reason
172 );
173 }
174 audit_log.log(&AuditEvent {
175 timestamp: chrono::Utc::now().to_rfc3339(),
176 decision: "allow_with_warning".to_string(),
177 tool: Some(tool.clone()),
178 reason: Some(reason.clone()),
179 request_id: req.id.clone(),
180 agentic: None,
181 });
182 Self::handle_allow(&req, &mut audit_log, false);
184 }
186 PolicyDecision::Deny {
187 tool,
188 code: _,
189 reason,
190 contract,
191 } => {
192 let decision_str =
194 if config.dry_run { "would_deny" } else { "deny" };
195
196 if config.verbose {
197 eprintln!(
198 "[assay] {} {} (reason: {})",
199 decision_str.to_uppercase(),
200 tool,
201 reason
202 );
203 }
204
205 audit_log.log(&AuditEvent {
206 timestamp: chrono::Utc::now().to_rfc3339(),
207 decision: decision_str.to_string(),
208 tool: Some(tool.clone()),
209 reason: Some(reason.clone()),
210 request_id: req.id.clone(),
211 agentic: Some(contract.clone()),
212 });
213
214 if config.dry_run {
215 } else {
218 let id = req.id.unwrap_or(serde_json::Value::Null);
220 let response_json = make_deny_response(
221 id,
222 "Content blocked by policy",
223 contract,
224 );
225
226 let mut out = stdout_b
227 .lock()
228 .map_err(|e| io::Error::other(e.to_string()))?;
229 out.write_all(response_json.as_bytes())?;
230 out.flush()?;
231
232 line.clear();
233 continue; }
235 }
236 }
237 }
238 Err(_) => {
239 let trimmed = line.trim();
241 if trimmed.starts_with('{')
242 && (trimmed.contains("\"method\"")
243 || trimmed.contains("\"params\"")
244 || trimmed.contains("\"tool\""))
245 {
246 eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
247 }
248 }
249 }
250
251 child_stdin.write_all(line.as_bytes())?;
253 child_stdin.flush()?;
254 line.clear();
255 }
256 Ok(())
257 });
258
259 t_client_to_server
261 .join()
262 .map_err(|_| io::Error::other("client->server thread panicked"))??;
263
264 let _ = t_server_to_client.join();
266
267 let status = self.child.wait()?;
269 Ok(status.code().unwrap_or(1))
270 }
271
272 fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
273 if verbose && req.is_tool_call() {
274 let tool = req
275 .tool_params()
276 .map(|p| p.name)
277 .unwrap_or_else(|| "unknown".to_string());
278 eprintln!("[assay] ALLOW {}", tool);
279 }
280
281 if req.is_tool_call() {
282 let tool = req.tool_params().map(|p| p.name);
283 audit_log.log(&AuditEvent {
284 timestamp: chrono::Utc::now().to_rfc3339(),
285 decision: "allow".to_string(),
286 tool,
287 reason: None,
288 request_id: req.id.clone(),
289 agentic: None,
290 });
291 }
292 }
293}