1use super::audit::{AuditEvent, AuditLog};
2use super::decision::{
3 reason_codes, refresh_contract_projections, Decision, DecisionEmitter, DecisionEvent,
4 FileDecisionEmitter, NullDecisionEmitter,
5};
6use super::jsonrpc::JsonRpcRequest;
7use super::policy::{
8 make_deny_response, McpPolicy, PolicyDecision, PolicyMatchMetadata, PolicyState,
9};
10use std::{
11 collections::HashMap,
12 io::{self, BufRead, BufReader, Write},
13 process::{Child, Command, Stdio},
14 sync::{Arc, Mutex},
15 thread,
16};
17
18#[derive(Clone, Debug)]
22pub struct ProxyConfig {
23 pub dry_run: bool,
24 pub verbose: bool,
25 pub audit_log_path: Option<std::path::PathBuf>,
27 pub server_id: String,
28 pub decision_log_path: Option<std::path::PathBuf>,
30 pub event_source: Option<String>,
32}
33
34#[derive(Clone, Debug, Default)]
36pub struct ProxyConfigRaw {
37 pub dry_run: bool,
38 pub verbose: bool,
39 pub audit_log_path: Option<std::path::PathBuf>,
40 pub server_id: String,
41 pub decision_log_path: Option<std::path::PathBuf>,
42 pub event_source: Option<String>,
43}
44
45impl ProxyConfig {
46 pub fn try_from_raw(raw: ProxyConfigRaw) -> anyhow::Result<Self> {
52 let logging_enabled = raw.audit_log_path.is_some() || raw.decision_log_path.is_some();
53
54 let event_source = raw
55 .event_source
56 .map(|s| s.trim().to_string())
57 .filter(|s| !s.is_empty());
58
59 if logging_enabled && event_source.is_none() {
60 anyhow::bail!(
61 "event_source is required when logging is enabled (e.g. --event-source assay://org/app)"
62 );
63 }
64
65 if let Some(ref src) = event_source {
66 validate_event_source(src)?;
67 }
68
69 Ok(ProxyConfig {
70 dry_run: raw.dry_run,
71 verbose: raw.verbose,
72 audit_log_path: raw.audit_log_path,
73 server_id: raw.server_id,
74 decision_log_path: raw.decision_log_path,
75 event_source,
76 })
77 }
78}
79
80fn validate_event_source(s: &str) -> anyhow::Result<()> {
82 let s = s.trim();
83 if s.is_empty() {
84 anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
85 }
86 if s.chars().any(|c| c.is_whitespace()) {
87 anyhow::bail!("event_source must not contain whitespace");
88 }
89
90 let Some(pos) = s.find("://") else {
92 anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
93 };
94 if pos == 0 {
95 anyhow::bail!("event_source must have scheme before :// (e.g. assay://org/app)");
96 }
97
98 let scheme = &s[..pos];
100 let mut chars = scheme.chars();
101 match chars.next() {
102 Some(c) if c.is_ascii_alphabetic() => {}
103 _ => anyhow::bail!("event_source URI scheme must start with a letter"),
104 }
105 if !chars.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '-' || c == '.') {
106 anyhow::bail!("event_source URI scheme contains invalid characters");
107 }
108
109 Ok(())
110}
111
112pub struct McpProxy {
113 child: Child,
114 policy: McpPolicy,
115 config: ProxyConfig,
116 identity_cache: Arc<Mutex<HashMap<String, super::identity::ToolIdentity>>>,
118}
119
120impl Drop for McpProxy {
121 fn drop(&mut self) {
122 let _ = self.child.kill();
124 }
125}
126
127impl McpProxy {
128 pub fn spawn(
129 command: &str,
130 args: &[String],
131 policy: McpPolicy,
132 config: ProxyConfig,
133 ) -> io::Result<Self> {
134 let child = Command::new(command)
135 .args(args)
136 .stdin(Stdio::piped())
137 .stdout(Stdio::piped())
138 .stderr(Stdio::inherit()) .spawn()?;
140
141 Ok(Self {
142 child,
143 policy,
144 config,
145 identity_cache: Arc::new(Mutex::new(HashMap::new())),
146 })
147 }
148
149 pub fn run(mut self) -> io::Result<i32> {
150 let mut child_stdin = self.child.stdin.take().expect("child stdin");
151 let child_stdout = self.child.stdout.take().expect("child stdout");
152
153 let stdout = Arc::new(Mutex::new(io::stdout()));
154 let policy = self.policy.clone();
155 let config = self.config.clone();
156 let identity_cache_a = self.identity_cache.clone();
157 let identity_cache_b = self.identity_cache.clone();
158
159 let decision_emitter: Arc<dyn DecisionEmitter> =
161 if let Some(path) = &config.decision_log_path {
162 Arc::new(FileDecisionEmitter::new(path)?)
163 } else {
164 Arc::new(NullDecisionEmitter)
165 };
166 let event_source = config
167 .event_source
168 .clone()
169 .unwrap_or_else(|| format!("assay://{}", config.server_id));
170
171 let stdout_a = stdout.clone();
173 let t_server_to_client = thread::spawn(move || -> io::Result<()> {
174 let mut reader = BufReader::new(child_stdout);
175 let mut line = String::new();
176
177 while reader.read_line(&mut line)? > 0 {
178 let mut processed_line = line.clone();
179
180 if let Ok(mut v) = serde_json::from_str::<serde_json::Value>(&line) {
182 if let Some(result) = v.get_mut("result") {
183 if let Some(tools) = result.get_mut("tools").and_then(|t| t.as_array_mut())
184 {
185 for tool in tools {
186 let name = tool
187 .get("name")
188 .and_then(|n| n.as_str())
189 .unwrap_or("unknown");
190 let description = tool
191 .get("description")
192 .and_then(|d| d.as_str())
193 .map(|s| s.to_string());
194 let input_schema = tool
195 .get("inputSchema")
196 .or_else(|| tool.get("input_schema"))
197 .cloned();
198
199 let identity = super::identity::ToolIdentity::new(
200 &config.server_id,
201 name,
202 &input_schema,
203 &description,
204 );
205
206 let mut cache = identity_cache_a.lock().unwrap();
208 cache.insert(name.to_string(), identity.clone());
209
210 tool.as_object_mut().and_then(|m| {
212 m.insert(
213 "tool_identity".to_string(),
214 serde_json::to_value(&identity).unwrap(),
215 )
216 });
217 }
218 processed_line =
219 serde_json::to_string(&v).unwrap_or(line.clone()) + "\n";
220 }
221 }
222 }
223
224 let mut out = stdout_a
225 .lock()
226 .map_err(|e| io::Error::other(e.to_string()))?;
227 out.write_all(processed_line.as_bytes())?;
228 out.flush()?;
229 line.clear();
230 }
231 Ok(())
232 });
233
234 let stdout_b = stdout.clone();
236 let emitter_b = decision_emitter.clone();
237 let event_source_b = event_source.clone();
238 let t_client_to_server = thread::spawn(move || -> io::Result<()> {
239 let stdin = io::stdin();
240 let mut reader = stdin.lock();
241 let mut line = String::new();
242
243 let mut state = PolicyState::default();
244 let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
245
246 while reader.read_line(&mut line)? > 0 {
247 match serde_json::from_str::<JsonRpcRequest>(&line) {
249 Ok(req) => {
250 let runtime_id = if req.is_tool_call() {
252 let name = req.tool_params().map(|p| p.name).unwrap_or_default();
253 let cache = identity_cache_b.lock().unwrap();
254 cache.get(&name).cloned()
255 } else {
256 None
257 };
258
259 let tool_name = req.tool_params().map(|p| p.name).unwrap_or_default();
260 let tool_call_id = Self::extract_tool_call_id(&req);
261
262 let policy_eval = policy.evaluate_with_metadata(
263 &tool_name,
264 &req.tool_params()
265 .map(|p| p.arguments)
266 .unwrap_or(serde_json::Value::Null),
267 &mut state,
268 runtime_id.as_ref(),
269 );
270
271 match policy_eval.decision {
272 PolicyDecision::Allow => {
273 Self::handle_allow(&req, &mut audit_log, config.verbose);
274 if req.is_tool_call() {
276 Self::emit_decision(
277 &emitter_b,
278 &event_source_b,
279 &tool_call_id,
280 &tool_name,
281 Decision::Allow,
282 reason_codes::P_POLICY_ALLOW,
283 None,
284 req.id.clone(),
285 &policy_eval.metadata,
286 );
287 }
288 }
289 PolicyDecision::AllowWithWarning { tool, code, reason } => {
290 if config.verbose {
292 eprintln!(
293 "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
294 tool,
295 code,
296 reason
297 );
298 }
299 audit_log.log(&AuditEvent {
300 timestamp: chrono::Utc::now().to_rfc3339(),
301 decision: "allow_with_warning".to_string(),
302 tool: Some(tool.clone()),
303 reason: Some(reason.clone()),
304 request_id: req.id.clone(),
305 agentic: None,
306 });
307 Self::emit_decision(
309 &emitter_b,
310 &event_source_b,
311 &tool_call_id,
312 &tool,
313 Decision::Allow,
314 &code,
315 Some(reason),
316 req.id.clone(),
317 &policy_eval.metadata,
318 );
319 Self::handle_allow(&req, &mut audit_log, false);
321 }
323 PolicyDecision::Deny {
324 tool,
325 code,
326 reason,
327 contract,
328 } => {
329 let decision_str =
331 if config.dry_run { "would_deny" } else { "deny" };
332
333 if config.verbose {
334 eprintln!(
335 "[assay] {} {} (reason: {})",
336 decision_str.to_uppercase(),
337 tool,
338 reason
339 );
340 }
341
342 audit_log.log(&AuditEvent {
343 timestamp: chrono::Utc::now().to_rfc3339(),
344 decision: decision_str.to_string(),
345 tool: Some(tool.clone()),
346 reason: Some(reason.clone()),
347 request_id: req.id.clone(),
348 agentic: Some(contract.clone()),
349 });
350
351 let reason_code = Self::map_policy_code(&code);
353 Self::emit_decision(
354 &emitter_b,
355 &event_source_b,
356 &tool_call_id,
357 &tool,
358 if config.dry_run {
359 Decision::Allow
360 } else {
361 Decision::Deny
362 },
363 &reason_code,
364 Some(reason),
365 req.id.clone(),
366 &policy_eval.metadata,
367 );
368
369 if config.dry_run {
370 } else {
373 let id = req.id.unwrap_or(serde_json::Value::Null);
375 let response_json = make_deny_response(
376 id,
377 "Content blocked by policy",
378 contract,
379 );
380
381 let mut out = stdout_b
382 .lock()
383 .map_err(|e| io::Error::other(e.to_string()))?;
384 out.write_all(response_json.as_bytes())?;
385 out.flush()?;
386
387 line.clear();
388 continue; }
390 }
391 }
392 }
393 Err(_) => {
394 let trimmed = line.trim();
396 if trimmed.starts_with('{')
397 && (trimmed.contains("\"method\"")
398 || trimmed.contains("\"params\"")
399 || trimmed.contains("\"tool\""))
400 {
401 eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
402 }
403 }
404 }
405
406 child_stdin.write_all(line.as_bytes())?;
408 child_stdin.flush()?;
409 line.clear();
410 }
411 Ok(())
412 });
413
414 t_client_to_server
416 .join()
417 .map_err(|_| io::Error::other("client->server thread panicked"))??;
418
419 let _ = t_server_to_client.join();
421
422 let status = self.child.wait()?;
424 Ok(status.code().unwrap_or(1))
425 }
426
427 fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
428 if verbose && req.is_tool_call() {
429 let tool = req
430 .tool_params()
431 .map(|p| p.name)
432 .unwrap_or_else(|| "unknown".to_string());
433 eprintln!("[assay] ALLOW {}", tool);
434 }
435
436 if req.is_tool_call() {
437 let tool = req.tool_params().map(|p| p.name);
438 audit_log.log(&AuditEvent {
439 timestamp: chrono::Utc::now().to_rfc3339(),
440 decision: "allow".to_string(),
441 tool,
442 reason: None,
443 request_id: req.id.clone(),
444 agentic: None,
445 });
446 }
447 }
448
449 fn extract_tool_call_id(request: &JsonRpcRequest) -> String {
451 if let Some(params) = request.tool_params() {
453 if let Some(meta) = params.arguments.get("_meta") {
454 if let Some(id) = meta.get("tool_call_id").and_then(|v| v.as_str()) {
455 return id.to_string();
456 }
457 }
458 }
459
460 if let Some(id) = &request.id {
462 if let Some(s) = id.as_str() {
463 return format!("req_{}", s);
464 }
465 if let Some(n) = id.as_i64() {
466 return format!("req_{}", n);
467 }
468 }
469
470 format!("gen_{}", uuid::Uuid::new_v4())
472 }
473
474 fn map_policy_code(code: &str) -> String {
476 match code {
477 "E_TOOL_DENIED" => reason_codes::P_TOOL_DENIED.to_string(),
478 "E_TOOL_NOT_ALLOWED" => reason_codes::P_TOOL_NOT_ALLOWED.to_string(),
479 "E_ARG_SCHEMA" => reason_codes::P_ARG_SCHEMA.to_string(),
480 "E_RATE_LIMIT" => reason_codes::P_RATE_LIMIT.to_string(),
481 "E_TOOL_DRIFT" => reason_codes::P_TOOL_DRIFT.to_string(),
482 _ => reason_codes::P_POLICY_DENY.to_string(),
483 }
484 }
485
486 #[allow(clippy::too_many_arguments)]
488 fn emit_decision(
489 emitter: &Arc<dyn DecisionEmitter>,
490 source: &str,
491 tool_call_id: &str,
492 tool: &str,
493 decision: Decision,
494 reason_code: &str,
495 reason: Option<String>,
496 request_id: Option<serde_json::Value>,
497 metadata: &PolicyMatchMetadata,
498 ) {
499 let mut event = DecisionEvent::new(
500 source.to_string(),
501 tool_call_id.to_string(),
502 tool.to_string(),
503 );
504 event.data.decision = decision;
505 event.data.reason_code = reason_code.to_string();
506 event.data.reason = reason;
507 event.data.request_id = request_id;
508 event.data.tool_classes = metadata.tool_classes.clone();
509 event.data.matched_tool_classes = metadata.matched_tool_classes.clone();
510 event.data.match_basis = metadata.match_basis.as_str().map(ToString::to_string);
511 event.data.matched_rule = metadata.matched_rule.clone();
512 event.data.typed_decision = metadata.typed_decision;
513 event.data.policy_version = metadata.policy_version.clone();
514 event.data.policy_digest = metadata.policy_digest.clone();
515 event.data.obligations = metadata.obligations.clone();
516 event.data.obligation_outcomes =
517 super::obligations::execute_log_only(&metadata.obligations, tool);
518 event.data.approval_state = metadata.approval_state.clone();
519 if let Some(artifact) = &metadata.approval_artifact {
520 event.data.approval_id = Some(artifact.approval_id.clone());
521 event.data.approver = Some(artifact.approver.clone());
522 event.data.issued_at = Some(artifact.issued_at.clone());
523 event.data.expires_at = Some(artifact.expires_at.clone());
524 event.data.scope = Some(artifact.scope.clone());
525 event.data.approval_bound_tool = Some(artifact.bound_tool.clone());
526 event.data.approval_bound_resource = Some(artifact.bound_resource.clone());
527 }
528 event.data.approval_freshness = metadata.approval_freshness;
529 event.data.approval_failure_reason = metadata.approval_failure_reason.clone();
530 event.data.scope_type = metadata.scope_type.clone();
531 event.data.scope_value = metadata.scope_value.clone();
532 event.data.scope_match_mode = metadata.scope_match_mode.clone();
533 event.data.scope_evaluation_state = metadata.scope_evaluation_state.clone();
534 event.data.scope_failure_reason = metadata.scope_failure_reason.clone();
535 event.data.restrict_scope_present = metadata.restrict_scope_present;
536 event.data.restrict_scope_target = metadata.restrict_scope_target.clone();
537 event.data.restrict_scope_match = metadata.restrict_scope_match;
538 event.data.restrict_scope_reason = metadata.restrict_scope_reason.clone();
539 event.data.redaction_target = metadata.redaction_target.clone();
540 event.data.redaction_mode = metadata.redaction_mode.clone();
541 event.data.redaction_scope = metadata.redaction_scope.clone();
542 event.data.redaction_applied_state = metadata.redaction_applied_state.clone();
543 event.data.redaction_reason = metadata.redaction_reason.clone();
544 event.data.redaction_failure_reason = metadata.redaction_failure_reason.clone();
545 event.data.redact_args_present = metadata.redact_args_present;
546 event.data.redact_args_target = metadata.redact_args_target.clone();
547 event.data.redact_args_mode = metadata.redact_args_mode.clone();
548 event.data.redact_args_result = metadata.redact_args_result.clone();
549 event.data.redact_args_reason = metadata.redact_args_reason.clone();
550 event.data.fail_closed = metadata.fail_closed.clone();
551 event.data.lane = metadata.lane.clone();
552 event.data.principal = metadata.principal.clone();
553 event.data.auth_context_summary = metadata.auth_context_summary.clone();
554 event.data.auth_scheme = metadata.auth_scheme.clone();
555 event.data.auth_issuer = metadata.auth_issuer.clone();
556 event.data.delegated_from = metadata.delegated_from.clone();
557 event.data.delegation_depth = metadata.delegation_depth;
558 refresh_contract_projections(&mut event.data);
559 emitter.emit(&event);
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566
567 #[test]
568 fn event_source_accepts_assay_uri() {
569 validate_event_source("assay://myorg/myapp").unwrap();
570 }
571
572 #[test]
573 fn event_source_accepts_https_uri() {
574 validate_event_source("https://example.com/agent").unwrap();
575 }
576
577 #[test]
578 fn event_source_rejects_empty() {
579 assert!(validate_event_source("").is_err());
580 assert!(validate_event_source(" ").is_err());
581 }
582
583 #[test]
584 fn event_source_rejects_whitespace() {
585 assert!(validate_event_source("assay://myorg/my app").is_err());
586 assert!(validate_event_source("assay://myorg/\tmyapp").is_err());
587 }
588
589 #[test]
590 fn event_source_rejects_missing_scheme() {
591 assert!(validate_event_source("myorg/myapp").is_err());
592 assert!(validate_event_source("://myorg/myapp").is_err());
593 }
594
595 #[test]
596 fn event_source_rejects_did_and_urn() {
597 assert!(validate_event_source("did:example:123").is_err());
599 assert!(validate_event_source("urn:example:foo").is_err());
600 }
601
602 #[test]
603 fn event_source_rejects_scheme_starting_with_non_letter() {
604 assert!(validate_event_source("1assay://myorg/myapp").is_err());
605 assert!(validate_event_source("-assay://myorg/myapp").is_err());
606 }
607
608 #[test]
609 fn event_source_rejects_scheme_with_invalid_chars() {
610 assert!(validate_event_source("as_say://myorg/myapp").is_err());
611 assert!(validate_event_source("as@say://myorg/myapp").is_err());
612 }
613
614 #[test]
615 fn config_requires_event_source_when_logging_enabled() {
616 let raw = ProxyConfigRaw {
617 dry_run: false,
618 verbose: false,
619 audit_log_path: None,
620 decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
621 event_source: None,
622 server_id: "srv".to_string(),
623 };
624
625 let err = ProxyConfig::try_from_raw(raw).unwrap_err();
626 let msg = format!("{err:#}");
627 assert!(msg.contains("event_source is required"));
628 }
629
630 #[test]
631 fn config_allows_no_event_source_when_logging_disabled() {
632 let raw = ProxyConfigRaw {
633 dry_run: false,
634 verbose: false,
635 audit_log_path: None,
636 decision_log_path: None,
637 event_source: None,
638 server_id: "srv".to_string(),
639 };
640
641 ProxyConfig::try_from_raw(raw).unwrap();
642 }
643
644 #[test]
645 fn config_accepts_valid_event_source() {
646 let raw = ProxyConfigRaw {
647 dry_run: false,
648 verbose: false,
649 audit_log_path: None,
650 decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
651 event_source: Some("assay://myorg/myapp".to_string()),
652 server_id: "srv".to_string(),
653 };
654
655 let cfg = ProxyConfig::try_from_raw(raw).unwrap();
656 assert_eq!(cfg.event_source.as_deref(), Some("assay://myorg/myapp"));
657 }
658
659 #[test]
660 fn config_rejects_invalid_event_source_uri() {
661 let raw = ProxyConfigRaw {
662 dry_run: false,
663 verbose: false,
664 audit_log_path: None,
665 decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
666 event_source: Some("not a uri".to_string()),
667 server_id: "srv".to_string(),
668 };
669
670 assert!(ProxyConfig::try_from_raw(raw).is_err());
671 }
672}