Skip to main content

assay_core/mcp/
proxy.rs

1use super::audit::{AuditEvent, AuditLog};
2use super::decision::{
3    reason_codes, refresh_contract_projections, Decision, DecisionEmitter, DecisionEvent,
4    FileDecisionEmitter, NullDecisionEmitter,
5};
6use super::jsonrpc::JsonRpcRequest;
7use super::policy::{
8    make_deny_response, McpPolicy, PolicyDecision, PolicyMatchMetadata, PolicyState,
9};
10use std::{
11    collections::HashMap,
12    io::{self, BufRead, BufReader, Write},
13    process::{Child, Command, Stdio},
14    sync::{Arc, Mutex},
15    thread,
16};
17
18/// Validated proxy configuration.
19///
20/// Use `ProxyConfig::try_from_raw()` to create from CLI/config input.
21#[derive(Clone, Debug)]
22pub struct ProxyConfig {
23    pub dry_run: bool,
24    pub verbose: bool,
25    /// NDJSON log for mandate lifecycle events (audit trail)
26    pub audit_log_path: Option<std::path::PathBuf>,
27    pub server_id: String,
28    /// NDJSON log for tool decision events (high volume)
29    pub decision_log_path: Option<std::path::PathBuf>,
30    /// CloudEvents source URI (validated, required when logging enabled)
31    pub event_source: Option<String>,
32}
33
34/// Raw config as provided by CLI/config files before validation.
35#[derive(Clone, Debug, Default)]
36pub struct ProxyConfigRaw {
37    pub dry_run: bool,
38    pub verbose: bool,
39    pub audit_log_path: Option<std::path::PathBuf>,
40    pub server_id: String,
41    pub decision_log_path: Option<std::path::PathBuf>,
42    pub event_source: Option<String>,
43}
44
45impl ProxyConfig {
46    /// Create validated config from raw input.
47    ///
48    /// Fails if:
49    /// - Logging is enabled but event_source is missing
50    /// - event_source is not a valid absolute URI (scheme://...)
51    pub fn try_from_raw(raw: ProxyConfigRaw) -> anyhow::Result<Self> {
52        let logging_enabled = raw.audit_log_path.is_some() || raw.decision_log_path.is_some();
53
54        let event_source = raw
55            .event_source
56            .map(|s| s.trim().to_string())
57            .filter(|s| !s.is_empty());
58
59        if logging_enabled && event_source.is_none() {
60            anyhow::bail!(
61                "event_source is required when logging is enabled (e.g. --event-source assay://org/app)"
62            );
63        }
64
65        if let Some(ref src) = event_source {
66            validate_event_source(src)?;
67        }
68
69        Ok(ProxyConfig {
70            dry_run: raw.dry_run,
71            verbose: raw.verbose,
72            audit_log_path: raw.audit_log_path,
73            server_id: raw.server_id,
74            decision_log_path: raw.decision_log_path,
75            event_source,
76        })
77    }
78}
79
80/// Validate event_source URI (must be absolute with scheme://).
81fn validate_event_source(s: &str) -> anyhow::Result<()> {
82    let s = s.trim();
83    if s.is_empty() {
84        anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
85    }
86    if s.chars().any(|c| c.is_whitespace()) {
87        anyhow::bail!("event_source must not contain whitespace");
88    }
89
90    // Require scheme://...
91    let Some(pos) = s.find("://") else {
92        anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
93    };
94    if pos == 0 {
95        anyhow::bail!("event_source must have scheme before :// (e.g. assay://org/app)");
96    }
97
98    // Validate scheme charset (RFC 3986: ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ))
99    let scheme = &s[..pos];
100    let mut chars = scheme.chars();
101    match chars.next() {
102        Some(c) if c.is_ascii_alphabetic() => {}
103        _ => anyhow::bail!("event_source URI scheme must start with a letter"),
104    }
105    if !chars.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '-' || c == '.') {
106        anyhow::bail!("event_source URI scheme contains invalid characters");
107    }
108
109    Ok(())
110}
111
112pub struct McpProxy {
113    child: Child,
114    policy: McpPolicy,
115    config: ProxyConfig,
116    /// Cache of tool identities discovered during tools/list
117    identity_cache: Arc<Mutex<HashMap<String, super::identity::ToolIdentity>>>,
118}
119
120impl Drop for McpProxy {
121    fn drop(&mut self) {
122        // Best-effort cleanup
123        let _ = self.child.kill();
124    }
125}
126
127impl McpProxy {
128    pub fn spawn(
129        command: &str,
130        args: &[String],
131        policy: McpPolicy,
132        config: ProxyConfig,
133    ) -> io::Result<Self> {
134        let child = Command::new(command)
135            .args(args)
136            .stdin(Stdio::piped())
137            .stdout(Stdio::piped())
138            .stderr(Stdio::inherit()) // protocol blijft op stdout
139            .spawn()?;
140
141        Ok(Self {
142            child,
143            policy,
144            config,
145            identity_cache: Arc::new(Mutex::new(HashMap::new())),
146        })
147    }
148
149    pub fn run(mut self) -> io::Result<i32> {
150        let mut child_stdin = self.child.stdin.take().expect("child stdin");
151        let child_stdout = self.child.stdout.take().expect("child stdout");
152
153        let stdout = Arc::new(Mutex::new(io::stdout()));
154        let policy = self.policy.clone();
155        let config = self.config.clone();
156        let identity_cache_a = self.identity_cache.clone();
157        let identity_cache_b = self.identity_cache.clone();
158
159        // Initialize decision emitter (I1: always emit decision)
160        let decision_emitter: Arc<dyn DecisionEmitter> =
161            if let Some(path) = &config.decision_log_path {
162                Arc::new(FileDecisionEmitter::new(path)?)
163            } else {
164                Arc::new(NullDecisionEmitter)
165            };
166        let event_source = config
167            .event_source
168            .clone()
169            .unwrap_or_else(|| format!("assay://{}", config.server_id));
170
171        // Thread A: server -> client passthrough
172        let stdout_a = stdout.clone();
173        let t_server_to_client = thread::spawn(move || -> io::Result<()> {
174            let mut reader = BufReader::new(child_stdout);
175            let mut line = String::new();
176
177            while reader.read_line(&mut line)? > 0 {
178                let mut processed_line = line.clone();
179
180                // Phase 9: Compute Identities on tools/list response
181                if let Ok(mut v) = serde_json::from_str::<serde_json::Value>(&line) {
182                    if let Some(result) = v.get_mut("result") {
183                        if let Some(tools) = result.get_mut("tools").and_then(|t| t.as_array_mut())
184                        {
185                            for tool in tools {
186                                let name = tool
187                                    .get("name")
188                                    .and_then(|n| n.as_str())
189                                    .unwrap_or("unknown");
190                                let description = tool
191                                    .get("description")
192                                    .and_then(|d| d.as_str())
193                                    .map(|s| s.to_string());
194                                let input_schema = tool
195                                    .get("inputSchema")
196                                    .or_else(|| tool.get("input_schema"))
197                                    .cloned();
198
199                                let identity = super::identity::ToolIdentity::new(
200                                    &config.server_id,
201                                    name,
202                                    &input_schema,
203                                    &description,
204                                );
205
206                                // Cache for runtime verification
207                                let mut cache = identity_cache_a.lock().unwrap();
208                                cache.insert(name.to_string(), identity.clone());
209
210                                // Augment the response with the computed identity for downstream/logging
211                                tool.as_object_mut().and_then(|m| {
212                                    m.insert(
213                                        "tool_identity".to_string(),
214                                        serde_json::to_value(&identity).unwrap(),
215                                    )
216                                });
217                            }
218                            processed_line =
219                                serde_json::to_string(&v).unwrap_or(line.clone()) + "\n";
220                        }
221                    }
222                }
223
224                let mut out = stdout_a
225                    .lock()
226                    .map_err(|e| io::Error::other(e.to_string()))?;
227                out.write_all(processed_line.as_bytes())?;
228                out.flush()?;
229                line.clear();
230            }
231            Ok(())
232        });
233
234        // Thread B: client -> server passthrough with Policy Check
235        let stdout_b = stdout.clone();
236        let emitter_b = decision_emitter.clone();
237        let event_source_b = event_source.clone();
238        let t_client_to_server = thread::spawn(move || -> io::Result<()> {
239            let stdin = io::stdin();
240            let mut reader = stdin.lock();
241            let mut line = String::new();
242
243            let mut state = PolicyState::default();
244            let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
245
246            while reader.read_line(&mut line)? > 0 {
247                // 1. Try Parse as MCP Request
248                match serde_json::from_str::<JsonRpcRequest>(&line) {
249                    Ok(req) => {
250                        // 2. Check Policy with Identity (Phase 9)
251                        let runtime_id = if req.is_tool_call() {
252                            let name = req.tool_params().map(|p| p.name).unwrap_or_default();
253                            let cache = identity_cache_b.lock().unwrap();
254                            cache.get(&name).cloned()
255                        } else {
256                            None
257                        };
258
259                        let tool_name = req.tool_params().map(|p| p.name).unwrap_or_default();
260                        let tool_call_id = Self::extract_tool_call_id(&req);
261
262                        let policy_eval = policy.evaluate_with_metadata(
263                            &tool_name,
264                            &req.tool_params()
265                                .map(|p| p.arguments)
266                                .unwrap_or(serde_json::Value::Null),
267                            &mut state,
268                            runtime_id.as_ref(),
269                        );
270
271                        match policy_eval.decision {
272                            PolicyDecision::Allow => {
273                                Self::handle_allow(&req, &mut audit_log, config.verbose);
274                                // Emit decision event (I1: always emit)
275                                if req.is_tool_call() {
276                                    Self::emit_decision(
277                                        &emitter_b,
278                                        &event_source_b,
279                                        &tool_call_id,
280                                        &tool_name,
281                                        Decision::Allow,
282                                        reason_codes::P_POLICY_ALLOW,
283                                        None,
284                                        req.id.clone(),
285                                        &policy_eval.metadata,
286                                    );
287                                }
288                            }
289                            PolicyDecision::AllowWithWarning { tool, code, reason } => {
290                                // Log warning about allowing a tool invocation with issues
291                                if config.verbose {
292                                    eprintln!(
293                                        "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
294                                        tool,
295                                        code,
296                                        reason
297                                    );
298                                }
299                                audit_log.log(&AuditEvent {
300                                    timestamp: chrono::Utc::now().to_rfc3339(),
301                                    decision: "allow_with_warning".to_string(),
302                                    tool: Some(tool.clone()),
303                                    reason: Some(reason.clone()),
304                                    request_id: req.id.clone(),
305                                    agentic: None,
306                                });
307                                // Emit decision event (I1: always emit)
308                                Self::emit_decision(
309                                    &emitter_b,
310                                    &event_source_b,
311                                    &tool_call_id,
312                                    &tool,
313                                    Decision::Allow,
314                                    &code,
315                                    Some(reason),
316                                    req.id.clone(),
317                                    &policy_eval.metadata,
318                                );
319                                // Then proceed as a normal allow
320                                Self::handle_allow(&req, &mut audit_log, false);
321                                // false = don't double log ALLOW
322                            }
323                            PolicyDecision::Deny {
324                                tool,
325                                code,
326                                reason,
327                                contract,
328                            } => {
329                                // Log Decision
330                                let decision_str =
331                                    if config.dry_run { "would_deny" } else { "deny" };
332
333                                if config.verbose {
334                                    eprintln!(
335                                        "[assay] {} {} (reason: {})",
336                                        decision_str.to_uppercase(),
337                                        tool,
338                                        reason
339                                    );
340                                }
341
342                                audit_log.log(&AuditEvent {
343                                    timestamp: chrono::Utc::now().to_rfc3339(),
344                                    decision: decision_str.to_string(),
345                                    tool: Some(tool.clone()),
346                                    reason: Some(reason.clone()),
347                                    request_id: req.id.clone(),
348                                    agentic: Some(contract.clone()),
349                                });
350
351                                // Emit decision event (I1: always emit)
352                                let reason_code = Self::map_policy_code(&code);
353                                Self::emit_decision(
354                                    &emitter_b,
355                                    &event_source_b,
356                                    &tool_call_id,
357                                    &tool,
358                                    if config.dry_run {
359                                        Decision::Allow
360                                    } else {
361                                        Decision::Deny
362                                    },
363                                    &reason_code,
364                                    Some(reason),
365                                    req.id.clone(),
366                                    &policy_eval.metadata,
367                                );
368
369                                if config.dry_run {
370                                    // DRY RUN: Forward anyway
371                                    // Fallthrough to forward logic below
372                                } else {
373                                    // BLOCK: Send error response
374                                    let id = req.id.unwrap_or(serde_json::Value::Null);
375                                    let response_json = make_deny_response(
376                                        id,
377                                        "Content blocked by policy",
378                                        contract,
379                                    );
380
381                                    let mut out = stdout_b
382                                        .lock()
383                                        .map_err(|e| io::Error::other(e.to_string()))?;
384                                    out.write_all(response_json.as_bytes())?;
385                                    out.flush()?;
386
387                                    line.clear();
388                                    continue; // Skip forwarding
389                                }
390                            }
391                        }
392                    }
393                    Err(_) => {
394                        // Hardening: Suspicious Unparsable JSON
395                        let trimmed = line.trim();
396                        if trimmed.starts_with('{')
397                            && (trimmed.contains("\"method\"")
398                                || trimmed.contains("\"params\"")
399                                || trimmed.contains("\"tool\""))
400                        {
401                            eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
402                        }
403                    }
404                }
405
406                // 3. Forward
407                child_stdin.write_all(line.as_bytes())?;
408                child_stdin.flush()?;
409                line.clear();
410            }
411            Ok(())
412        });
413
414        // Wacht tot client->server eindigt (stdin closed)
415        t_client_to_server
416            .join()
417            .map_err(|_| io::Error::other("client->server thread panicked"))??;
418
419        // Server->client thread kan nog even lopen; join best-effort
420        let _ = t_server_to_client.join();
421
422        // Wacht op child exit
423        let status = self.child.wait()?;
424        Ok(status.code().unwrap_or(1))
425    }
426
427    fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
428        if verbose && req.is_tool_call() {
429            let tool = req
430                .tool_params()
431                .map(|p| p.name)
432                .unwrap_or_else(|| "unknown".to_string());
433            eprintln!("[assay] ALLOW {}", tool);
434        }
435
436        if req.is_tool_call() {
437            let tool = req.tool_params().map(|p| p.name);
438            audit_log.log(&AuditEvent {
439                timestamp: chrono::Utc::now().to_rfc3339(),
440                decision: "allow".to_string(),
441                tool,
442                reason: None,
443                request_id: req.id.clone(),
444                agentic: None,
445            });
446        }
447    }
448
449    /// Extract tool_call_id from request (I4: idempotency key).
450    fn extract_tool_call_id(request: &JsonRpcRequest) -> String {
451        // Try to get from params._meta.tool_call_id (MCP standard)
452        if let Some(params) = request.tool_params() {
453            if let Some(meta) = params.arguments.get("_meta") {
454                if let Some(id) = meta.get("tool_call_id").and_then(|v| v.as_str()) {
455                    return id.to_string();
456                }
457            }
458        }
459
460        // Fall back to request.id if present
461        if let Some(id) = &request.id {
462            if let Some(s) = id.as_str() {
463                return format!("req_{}", s);
464            }
465            if let Some(n) = id.as_i64() {
466                return format!("req_{}", n);
467            }
468        }
469
470        // Generate one if none found
471        format!("gen_{}", uuid::Uuid::new_v4())
472    }
473
474    /// Map policy error code to reason code.
475    fn map_policy_code(code: &str) -> String {
476        match code {
477            "E_TOOL_DENIED" => reason_codes::P_TOOL_DENIED.to_string(),
478            "E_TOOL_NOT_ALLOWED" => reason_codes::P_TOOL_NOT_ALLOWED.to_string(),
479            "E_ARG_SCHEMA" => reason_codes::P_ARG_SCHEMA.to_string(),
480            "E_RATE_LIMIT" => reason_codes::P_RATE_LIMIT.to_string(),
481            "E_TOOL_DRIFT" => reason_codes::P_TOOL_DRIFT.to_string(),
482            _ => reason_codes::P_POLICY_DENY.to_string(),
483        }
484    }
485
486    /// Emit a decision event (I1: always emit).
487    #[allow(clippy::too_many_arguments)]
488    fn emit_decision(
489        emitter: &Arc<dyn DecisionEmitter>,
490        source: &str,
491        tool_call_id: &str,
492        tool: &str,
493        decision: Decision,
494        reason_code: &str,
495        reason: Option<String>,
496        request_id: Option<serde_json::Value>,
497        metadata: &PolicyMatchMetadata,
498    ) {
499        let mut event = DecisionEvent::new(
500            source.to_string(),
501            tool_call_id.to_string(),
502            tool.to_string(),
503        );
504        event.data.decision = decision;
505        event.data.reason_code = reason_code.to_string();
506        event.data.reason = reason;
507        event.data.request_id = request_id;
508        event.data.tool_classes = metadata.tool_classes.clone();
509        event.data.matched_tool_classes = metadata.matched_tool_classes.clone();
510        event.data.match_basis = metadata.match_basis.as_str().map(ToString::to_string);
511        event.data.matched_rule = metadata.matched_rule.clone();
512        event.data.typed_decision = metadata.typed_decision;
513        event.data.policy_version = metadata.policy_version.clone();
514        event.data.policy_digest = metadata.policy_digest.clone();
515        event.data.obligations = metadata.obligations.clone();
516        event.data.obligation_outcomes =
517            super::obligations::execute_log_only(&metadata.obligations, tool);
518        event.data.approval_state = metadata.approval_state.clone();
519        if let Some(artifact) = &metadata.approval_artifact {
520            event.data.approval_id = Some(artifact.approval_id.clone());
521            event.data.approver = Some(artifact.approver.clone());
522            event.data.issued_at = Some(artifact.issued_at.clone());
523            event.data.expires_at = Some(artifact.expires_at.clone());
524            event.data.scope = Some(artifact.scope.clone());
525            event.data.approval_bound_tool = Some(artifact.bound_tool.clone());
526            event.data.approval_bound_resource = Some(artifact.bound_resource.clone());
527        }
528        event.data.approval_freshness = metadata.approval_freshness;
529        event.data.approval_failure_reason = metadata.approval_failure_reason.clone();
530        event.data.scope_type = metadata.scope_type.clone();
531        event.data.scope_value = metadata.scope_value.clone();
532        event.data.scope_match_mode = metadata.scope_match_mode.clone();
533        event.data.scope_evaluation_state = metadata.scope_evaluation_state.clone();
534        event.data.scope_failure_reason = metadata.scope_failure_reason.clone();
535        event.data.restrict_scope_present = metadata.restrict_scope_present;
536        event.data.restrict_scope_target = metadata.restrict_scope_target.clone();
537        event.data.restrict_scope_match = metadata.restrict_scope_match;
538        event.data.restrict_scope_reason = metadata.restrict_scope_reason.clone();
539        event.data.redaction_target = metadata.redaction_target.clone();
540        event.data.redaction_mode = metadata.redaction_mode.clone();
541        event.data.redaction_scope = metadata.redaction_scope.clone();
542        event.data.redaction_applied_state = metadata.redaction_applied_state.clone();
543        event.data.redaction_reason = metadata.redaction_reason.clone();
544        event.data.redaction_failure_reason = metadata.redaction_failure_reason.clone();
545        event.data.redact_args_present = metadata.redact_args_present;
546        event.data.redact_args_target = metadata.redact_args_target.clone();
547        event.data.redact_args_mode = metadata.redact_args_mode.clone();
548        event.data.redact_args_result = metadata.redact_args_result.clone();
549        event.data.redact_args_reason = metadata.redact_args_reason.clone();
550        event.data.fail_closed = metadata.fail_closed.clone();
551        event.data.lane = metadata.lane.clone();
552        event.data.principal = metadata.principal.clone();
553        event.data.auth_context_summary = metadata.auth_context_summary.clone();
554        event.data.auth_scheme = metadata.auth_scheme.clone();
555        event.data.auth_issuer = metadata.auth_issuer.clone();
556        event.data.delegated_from = metadata.delegated_from.clone();
557        event.data.delegation_depth = metadata.delegation_depth;
558        refresh_contract_projections(&mut event.data);
559        emitter.emit(&event);
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566
567    #[test]
568    fn event_source_accepts_assay_uri() {
569        validate_event_source("assay://myorg/myapp").unwrap();
570    }
571
572    #[test]
573    fn event_source_accepts_https_uri() {
574        validate_event_source("https://example.com/agent").unwrap();
575    }
576
577    #[test]
578    fn event_source_rejects_empty() {
579        assert!(validate_event_source("").is_err());
580        assert!(validate_event_source("   ").is_err());
581    }
582
583    #[test]
584    fn event_source_rejects_whitespace() {
585        assert!(validate_event_source("assay://myorg/my app").is_err());
586        assert!(validate_event_source("assay://myorg/\tmyapp").is_err());
587    }
588
589    #[test]
590    fn event_source_rejects_missing_scheme() {
591        assert!(validate_event_source("myorg/myapp").is_err());
592        assert!(validate_event_source("://myorg/myapp").is_err());
593    }
594
595    #[test]
596    fn event_source_rejects_did_and_urn() {
597        // We require scheme:// not just scheme:
598        assert!(validate_event_source("did:example:123").is_err());
599        assert!(validate_event_source("urn:example:foo").is_err());
600    }
601
602    #[test]
603    fn event_source_rejects_scheme_starting_with_non_letter() {
604        assert!(validate_event_source("1assay://myorg/myapp").is_err());
605        assert!(validate_event_source("-assay://myorg/myapp").is_err());
606    }
607
608    #[test]
609    fn event_source_rejects_scheme_with_invalid_chars() {
610        assert!(validate_event_source("as_say://myorg/myapp").is_err());
611        assert!(validate_event_source("as@say://myorg/myapp").is_err());
612    }
613
614    #[test]
615    fn config_requires_event_source_when_logging_enabled() {
616        let raw = ProxyConfigRaw {
617            dry_run: false,
618            verbose: false,
619            audit_log_path: None,
620            decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
621            event_source: None,
622            server_id: "srv".to_string(),
623        };
624
625        let err = ProxyConfig::try_from_raw(raw).unwrap_err();
626        let msg = format!("{err:#}");
627        assert!(msg.contains("event_source is required"));
628    }
629
630    #[test]
631    fn config_allows_no_event_source_when_logging_disabled() {
632        let raw = ProxyConfigRaw {
633            dry_run: false,
634            verbose: false,
635            audit_log_path: None,
636            decision_log_path: None,
637            event_source: None,
638            server_id: "srv".to_string(),
639        };
640
641        ProxyConfig::try_from_raw(raw).unwrap();
642    }
643
644    #[test]
645    fn config_accepts_valid_event_source() {
646        let raw = ProxyConfigRaw {
647            dry_run: false,
648            verbose: false,
649            audit_log_path: None,
650            decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
651            event_source: Some("assay://myorg/myapp".to_string()),
652            server_id: "srv".to_string(),
653        };
654
655        let cfg = ProxyConfig::try_from_raw(raw).unwrap();
656        assert_eq!(cfg.event_source.as_deref(), Some("assay://myorg/myapp"));
657    }
658
659    #[test]
660    fn config_rejects_invalid_event_source_uri() {
661        let raw = ProxyConfigRaw {
662            dry_run: false,
663            verbose: false,
664            audit_log_path: None,
665            decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
666            event_source: Some("not a uri".to_string()),
667            server_id: "srv".to_string(),
668        };
669
670        assert!(ProxyConfig::try_from_raw(raw).is_err());
671    }
672}