1use super::audit::{AuditEvent, AuditLog};
2use super::decision::{
3 reason_codes, Decision, DecisionEmitter, DecisionEvent, FileDecisionEmitter,
4 NullDecisionEmitter,
5};
6use super::jsonrpc::JsonRpcRequest;
7use super::policy::{make_deny_response, McpPolicy, PolicyDecision, PolicyState};
8use std::{
9 collections::HashMap,
10 io::{self, BufRead, BufReader, Write},
11 process::{Child, Command, Stdio},
12 sync::{Arc, Mutex},
13 thread,
14};
15
16#[derive(Clone, Debug)]
20pub struct ProxyConfig {
21 pub dry_run: bool,
22 pub verbose: bool,
23 pub audit_log_path: Option<std::path::PathBuf>,
25 pub server_id: String,
26 pub decision_log_path: Option<std::path::PathBuf>,
28 pub event_source: Option<String>,
30}
31
32#[derive(Clone, Debug, Default)]
34pub struct ProxyConfigRaw {
35 pub dry_run: bool,
36 pub verbose: bool,
37 pub audit_log_path: Option<std::path::PathBuf>,
38 pub server_id: String,
39 pub decision_log_path: Option<std::path::PathBuf>,
40 pub event_source: Option<String>,
41}
42
43impl ProxyConfig {
44 pub fn try_from_raw(raw: ProxyConfigRaw) -> anyhow::Result<Self> {
50 let logging_enabled = raw.audit_log_path.is_some() || raw.decision_log_path.is_some();
51
52 let event_source = raw
53 .event_source
54 .map(|s| s.trim().to_string())
55 .filter(|s| !s.is_empty());
56
57 if logging_enabled && event_source.is_none() {
58 anyhow::bail!(
59 "event_source is required when logging is enabled (e.g. --event-source assay://org/app)"
60 );
61 }
62
63 if let Some(ref src) = event_source {
64 validate_event_source(src)?;
65 }
66
67 Ok(ProxyConfig {
68 dry_run: raw.dry_run,
69 verbose: raw.verbose,
70 audit_log_path: raw.audit_log_path,
71 server_id: raw.server_id,
72 decision_log_path: raw.decision_log_path,
73 event_source,
74 })
75 }
76}
77
78fn validate_event_source(s: &str) -> anyhow::Result<()> {
80 let s = s.trim();
81 if s.is_empty() {
82 anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
83 }
84 if s.chars().any(|c| c.is_whitespace()) {
85 anyhow::bail!("event_source must not contain whitespace");
86 }
87
88 let Some(pos) = s.find("://") else {
90 anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
91 };
92 if pos == 0 {
93 anyhow::bail!("event_source must have scheme before :// (e.g. assay://org/app)");
94 }
95
96 let scheme = &s[..pos];
98 let mut chars = scheme.chars();
99 match chars.next() {
100 Some(c) if c.is_ascii_alphabetic() => {}
101 _ => anyhow::bail!("event_source URI scheme must start with a letter"),
102 }
103 if !chars.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '-' || c == '.') {
104 anyhow::bail!("event_source URI scheme contains invalid characters");
105 }
106
107 Ok(())
108}
109
110pub struct McpProxy {
111 child: Child,
112 policy: McpPolicy,
113 config: ProxyConfig,
114 identity_cache: Arc<Mutex<HashMap<String, super::identity::ToolIdentity>>>,
116}
117
118impl Drop for McpProxy {
119 fn drop(&mut self) {
120 let _ = self.child.kill();
122 }
123}
124
125impl McpProxy {
126 pub fn spawn(
127 command: &str,
128 args: &[String],
129 policy: McpPolicy,
130 config: ProxyConfig,
131 ) -> io::Result<Self> {
132 let child = Command::new(command)
133 .args(args)
134 .stdin(Stdio::piped())
135 .stdout(Stdio::piped())
136 .stderr(Stdio::inherit()) .spawn()?;
138
139 Ok(Self {
140 child,
141 policy,
142 config,
143 identity_cache: Arc::new(Mutex::new(HashMap::new())),
144 })
145 }
146
147 pub fn run(mut self) -> io::Result<i32> {
148 let mut child_stdin = self.child.stdin.take().expect("child stdin");
149 let child_stdout = self.child.stdout.take().expect("child stdout");
150
151 let stdout = Arc::new(Mutex::new(io::stdout()));
152 let policy = self.policy.clone();
153 let config = self.config.clone();
154 let identity_cache_a = self.identity_cache.clone();
155 let identity_cache_b = self.identity_cache.clone();
156
157 let decision_emitter: Arc<dyn DecisionEmitter> =
159 if let Some(path) = &config.decision_log_path {
160 Arc::new(FileDecisionEmitter::new(path)?)
161 } else {
162 Arc::new(NullDecisionEmitter)
163 };
164 let event_source = config
165 .event_source
166 .clone()
167 .unwrap_or_else(|| format!("assay://{}", config.server_id));
168
169 let stdout_a = stdout.clone();
171 let t_server_to_client = thread::spawn(move || -> io::Result<()> {
172 let mut reader = BufReader::new(child_stdout);
173 let mut line = String::new();
174
175 while reader.read_line(&mut line)? > 0 {
176 let mut processed_line = line.clone();
177
178 if let Ok(mut v) = serde_json::from_str::<serde_json::Value>(&line) {
180 if let Some(result) = v.get_mut("result") {
181 if let Some(tools) = result.get_mut("tools").and_then(|t| t.as_array_mut())
182 {
183 for tool in tools {
184 let name = tool
185 .get("name")
186 .and_then(|n| n.as_str())
187 .unwrap_or("unknown");
188 let description = tool
189 .get("description")
190 .and_then(|d| d.as_str())
191 .map(|s| s.to_string());
192 let input_schema = tool
193 .get("inputSchema")
194 .or_else(|| tool.get("input_schema"))
195 .cloned();
196
197 let identity = super::identity::ToolIdentity::new(
198 &config.server_id,
199 name,
200 &input_schema,
201 &description,
202 );
203
204 let mut cache = identity_cache_a.lock().unwrap();
206 cache.insert(name.to_string(), identity.clone());
207
208 tool.as_object_mut().and_then(|m| {
210 m.insert(
211 "tool_identity".to_string(),
212 serde_json::to_value(&identity).unwrap(),
213 )
214 });
215 }
216 processed_line =
217 serde_json::to_string(&v).unwrap_or(line.clone()) + "\n";
218 }
219 }
220 }
221
222 let mut out = stdout_a
223 .lock()
224 .map_err(|e| io::Error::other(e.to_string()))?;
225 out.write_all(processed_line.as_bytes())?;
226 out.flush()?;
227 line.clear();
228 }
229 Ok(())
230 });
231
232 let stdout_b = stdout.clone();
234 let emitter_b = decision_emitter.clone();
235 let event_source_b = event_source.clone();
236 let t_client_to_server = thread::spawn(move || -> io::Result<()> {
237 let stdin = io::stdin();
238 let mut reader = stdin.lock();
239 let mut line = String::new();
240
241 let mut state = PolicyState::default();
242 let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
243
244 while reader.read_line(&mut line)? > 0 {
245 match serde_json::from_str::<JsonRpcRequest>(&line) {
247 Ok(req) => {
248 let runtime_id = if req.is_tool_call() {
250 let name = req.tool_params().map(|p| p.name).unwrap_or_default();
251 let cache = identity_cache_b.lock().unwrap();
252 cache.get(&name).cloned()
253 } else {
254 None
255 };
256
257 let tool_name = req.tool_params().map(|p| p.name).unwrap_or_default();
258 let tool_call_id = Self::extract_tool_call_id(&req);
259
260 match policy.evaluate(
261 &tool_name,
262 &req.tool_params()
263 .map(|p| p.arguments)
264 .unwrap_or(serde_json::Value::Null),
265 &mut state,
266 runtime_id.as_ref(),
267 ) {
268 PolicyDecision::Allow => {
269 Self::handle_allow(&req, &mut audit_log, config.verbose);
270 if req.is_tool_call() {
272 Self::emit_decision(
273 &emitter_b,
274 &event_source_b,
275 &tool_call_id,
276 &tool_name,
277 Decision::Allow,
278 reason_codes::P_POLICY_DENY, None,
280 req.id.clone(),
281 );
282 }
283 }
284 PolicyDecision::AllowWithWarning { tool, code, reason } => {
285 if config.verbose {
287 eprintln!(
288 "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
289 tool,
290 code,
291 reason
292 );
293 }
294 audit_log.log(&AuditEvent {
295 timestamp: chrono::Utc::now().to_rfc3339(),
296 decision: "allow_with_warning".to_string(),
297 tool: Some(tool.clone()),
298 reason: Some(reason.clone()),
299 request_id: req.id.clone(),
300 agentic: None,
301 });
302 Self::emit_decision(
304 &emitter_b,
305 &event_source_b,
306 &tool_call_id,
307 &tool,
308 Decision::Allow,
309 &code,
310 Some(reason),
311 req.id.clone(),
312 );
313 Self::handle_allow(&req, &mut audit_log, false);
315 }
317 PolicyDecision::Deny {
318 tool,
319 code,
320 reason,
321 contract,
322 } => {
323 let decision_str =
325 if config.dry_run { "would_deny" } else { "deny" };
326
327 if config.verbose {
328 eprintln!(
329 "[assay] {} {} (reason: {})",
330 decision_str.to_uppercase(),
331 tool,
332 reason
333 );
334 }
335
336 audit_log.log(&AuditEvent {
337 timestamp: chrono::Utc::now().to_rfc3339(),
338 decision: decision_str.to_string(),
339 tool: Some(tool.clone()),
340 reason: Some(reason.clone()),
341 request_id: req.id.clone(),
342 agentic: Some(contract.clone()),
343 });
344
345 let reason_code = Self::map_policy_code(&code);
347 Self::emit_decision(
348 &emitter_b,
349 &event_source_b,
350 &tool_call_id,
351 &tool,
352 if config.dry_run {
353 Decision::Allow
354 } else {
355 Decision::Deny
356 },
357 &reason_code,
358 Some(reason),
359 req.id.clone(),
360 );
361
362 if config.dry_run {
363 } else {
366 let id = req.id.unwrap_or(serde_json::Value::Null);
368 let response_json = make_deny_response(
369 id,
370 "Content blocked by policy",
371 contract,
372 );
373
374 let mut out = stdout_b
375 .lock()
376 .map_err(|e| io::Error::other(e.to_string()))?;
377 out.write_all(response_json.as_bytes())?;
378 out.flush()?;
379
380 line.clear();
381 continue; }
383 }
384 }
385 }
386 Err(_) => {
387 let trimmed = line.trim();
389 if trimmed.starts_with('{')
390 && (trimmed.contains("\"method\"")
391 || trimmed.contains("\"params\"")
392 || trimmed.contains("\"tool\""))
393 {
394 eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
395 }
396 }
397 }
398
399 child_stdin.write_all(line.as_bytes())?;
401 child_stdin.flush()?;
402 line.clear();
403 }
404 Ok(())
405 });
406
407 t_client_to_server
409 .join()
410 .map_err(|_| io::Error::other("client->server thread panicked"))??;
411
412 let _ = t_server_to_client.join();
414
415 let status = self.child.wait()?;
417 Ok(status.code().unwrap_or(1))
418 }
419
420 fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
421 if verbose && req.is_tool_call() {
422 let tool = req
423 .tool_params()
424 .map(|p| p.name)
425 .unwrap_or_else(|| "unknown".to_string());
426 eprintln!("[assay] ALLOW {}", tool);
427 }
428
429 if req.is_tool_call() {
430 let tool = req.tool_params().map(|p| p.name);
431 audit_log.log(&AuditEvent {
432 timestamp: chrono::Utc::now().to_rfc3339(),
433 decision: "allow".to_string(),
434 tool,
435 reason: None,
436 request_id: req.id.clone(),
437 agentic: None,
438 });
439 }
440 }
441
442 fn extract_tool_call_id(request: &JsonRpcRequest) -> String {
444 if let Some(params) = request.tool_params() {
446 if let Some(meta) = params.arguments.get("_meta") {
447 if let Some(id) = meta.get("tool_call_id").and_then(|v| v.as_str()) {
448 return id.to_string();
449 }
450 }
451 }
452
453 if let Some(id) = &request.id {
455 if let Some(s) = id.as_str() {
456 return format!("req_{}", s);
457 }
458 if let Some(n) = id.as_i64() {
459 return format!("req_{}", n);
460 }
461 }
462
463 format!("gen_{}", uuid::Uuid::new_v4())
465 }
466
467 fn map_policy_code(code: &str) -> String {
469 match code {
470 "E_TOOL_DENIED" => reason_codes::P_TOOL_DENIED.to_string(),
471 "E_TOOL_NOT_ALLOWED" => reason_codes::P_TOOL_NOT_ALLOWED.to_string(),
472 "E_ARG_SCHEMA" => reason_codes::P_ARG_SCHEMA.to_string(),
473 "E_RATE_LIMIT" => reason_codes::P_RATE_LIMIT.to_string(),
474 "E_TOOL_DRIFT" => reason_codes::P_TOOL_DRIFT.to_string(),
475 _ => reason_codes::P_POLICY_DENY.to_string(),
476 }
477 }
478
479 #[allow(clippy::too_many_arguments)]
481 fn emit_decision(
482 emitter: &Arc<dyn DecisionEmitter>,
483 source: &str,
484 tool_call_id: &str,
485 tool: &str,
486 decision: Decision,
487 reason_code: &str,
488 reason: Option<String>,
489 request_id: Option<serde_json::Value>,
490 ) {
491 let mut event = DecisionEvent::new(
492 source.to_string(),
493 tool_call_id.to_string(),
494 tool.to_string(),
495 );
496 event.data.decision = decision;
497 event.data.reason_code = reason_code.to_string();
498 event.data.reason = reason;
499 event.data.request_id = request_id;
500 emitter.emit(&event);
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 #[test]
509 fn event_source_accepts_assay_uri() {
510 validate_event_source("assay://myorg/myapp").unwrap();
511 }
512
513 #[test]
514 fn event_source_accepts_https_uri() {
515 validate_event_source("https://example.com/agent").unwrap();
516 }
517
518 #[test]
519 fn event_source_rejects_empty() {
520 assert!(validate_event_source("").is_err());
521 assert!(validate_event_source(" ").is_err());
522 }
523
524 #[test]
525 fn event_source_rejects_whitespace() {
526 assert!(validate_event_source("assay://myorg/my app").is_err());
527 assert!(validate_event_source("assay://myorg/\tmyapp").is_err());
528 }
529
530 #[test]
531 fn event_source_rejects_missing_scheme() {
532 assert!(validate_event_source("myorg/myapp").is_err());
533 assert!(validate_event_source("://myorg/myapp").is_err());
534 }
535
536 #[test]
537 fn event_source_rejects_did_and_urn() {
538 assert!(validate_event_source("did:example:123").is_err());
540 assert!(validate_event_source("urn:example:foo").is_err());
541 }
542
543 #[test]
544 fn event_source_rejects_scheme_starting_with_non_letter() {
545 assert!(validate_event_source("1assay://myorg/myapp").is_err());
546 assert!(validate_event_source("-assay://myorg/myapp").is_err());
547 }
548
549 #[test]
550 fn event_source_rejects_scheme_with_invalid_chars() {
551 assert!(validate_event_source("as_say://myorg/myapp").is_err());
552 assert!(validate_event_source("as@say://myorg/myapp").is_err());
553 }
554
555 #[test]
556 fn config_requires_event_source_when_logging_enabled() {
557 let raw = ProxyConfigRaw {
558 dry_run: false,
559 verbose: false,
560 audit_log_path: None,
561 decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
562 event_source: None,
563 server_id: "srv".to_string(),
564 };
565
566 let err = ProxyConfig::try_from_raw(raw).unwrap_err();
567 let msg = format!("{err:#}");
568 assert!(msg.contains("event_source is required"));
569 }
570
571 #[test]
572 fn config_allows_no_event_source_when_logging_disabled() {
573 let raw = ProxyConfigRaw {
574 dry_run: false,
575 verbose: false,
576 audit_log_path: None,
577 decision_log_path: None,
578 event_source: None,
579 server_id: "srv".to_string(),
580 };
581
582 ProxyConfig::try_from_raw(raw).unwrap();
583 }
584
585 #[test]
586 fn config_accepts_valid_event_source() {
587 let raw = ProxyConfigRaw {
588 dry_run: false,
589 verbose: false,
590 audit_log_path: None,
591 decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
592 event_source: Some("assay://myorg/myapp".to_string()),
593 server_id: "srv".to_string(),
594 };
595
596 let cfg = ProxyConfig::try_from_raw(raw).unwrap();
597 assert_eq!(cfg.event_source.as_deref(), Some("assay://myorg/myapp"));
598 }
599
600 #[test]
601 fn config_rejects_invalid_event_source_uri() {
602 let raw = ProxyConfigRaw {
603 dry_run: false,
604 verbose: false,
605 audit_log_path: None,
606 decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
607 event_source: Some("not a uri".to_string()),
608 server_id: "srv".to_string(),
609 };
610
611 assert!(ProxyConfig::try_from_raw(raw).is_err());
612 }
613}