Skip to main content

assay_core/mcp/
proxy.rs

1use super::audit::{AuditEvent, AuditLog};
2use super::decision::{
3    reason_codes, Decision, DecisionEmitter, DecisionEvent, FileDecisionEmitter,
4    NullDecisionEmitter,
5};
6use super::jsonrpc::JsonRpcRequest;
7use super::policy::{make_deny_response, McpPolicy, PolicyDecision, PolicyState};
8use std::{
9    collections::HashMap,
10    io::{self, BufRead, BufReader, Write},
11    process::{Child, Command, Stdio},
12    sync::{Arc, Mutex},
13    thread,
14};
15
16/// Validated proxy configuration.
17///
18/// Use `ProxyConfig::try_from_raw()` to create from CLI/config input.
19#[derive(Clone, Debug)]
20pub struct ProxyConfig {
21    pub dry_run: bool,
22    pub verbose: bool,
23    /// NDJSON log for mandate lifecycle events (audit trail)
24    pub audit_log_path: Option<std::path::PathBuf>,
25    pub server_id: String,
26    /// NDJSON log for tool decision events (high volume)
27    pub decision_log_path: Option<std::path::PathBuf>,
28    /// CloudEvents source URI (validated, required when logging enabled)
29    pub event_source: Option<String>,
30}
31
32/// Raw config as provided by CLI/config files before validation.
33#[derive(Clone, Debug, Default)]
34pub struct ProxyConfigRaw {
35    pub dry_run: bool,
36    pub verbose: bool,
37    pub audit_log_path: Option<std::path::PathBuf>,
38    pub server_id: String,
39    pub decision_log_path: Option<std::path::PathBuf>,
40    pub event_source: Option<String>,
41}
42
43impl ProxyConfig {
44    /// Create validated config from raw input.
45    ///
46    /// Fails if:
47    /// - Logging is enabled but event_source is missing
48    /// - event_source is not a valid absolute URI (scheme://...)
49    pub fn try_from_raw(raw: ProxyConfigRaw) -> anyhow::Result<Self> {
50        let logging_enabled = raw.audit_log_path.is_some() || raw.decision_log_path.is_some();
51
52        let event_source = raw
53            .event_source
54            .map(|s| s.trim().to_string())
55            .filter(|s| !s.is_empty());
56
57        if logging_enabled && event_source.is_none() {
58            anyhow::bail!(
59                "event_source is required when logging is enabled (e.g. --event-source assay://org/app)"
60            );
61        }
62
63        if let Some(ref src) = event_source {
64            validate_event_source(src)?;
65        }
66
67        Ok(ProxyConfig {
68            dry_run: raw.dry_run,
69            verbose: raw.verbose,
70            audit_log_path: raw.audit_log_path,
71            server_id: raw.server_id,
72            decision_log_path: raw.decision_log_path,
73            event_source,
74        })
75    }
76}
77
78/// Validate event_source URI (must be absolute with scheme://).
79fn validate_event_source(s: &str) -> anyhow::Result<()> {
80    let s = s.trim();
81    if s.is_empty() {
82        anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
83    }
84    if s.chars().any(|c| c.is_whitespace()) {
85        anyhow::bail!("event_source must not contain whitespace");
86    }
87
88    // Require scheme://...
89    let Some(pos) = s.find("://") else {
90        anyhow::bail!("event_source must be absolute URI with scheme (e.g. assay://org/app)");
91    };
92    if pos == 0 {
93        anyhow::bail!("event_source must have scheme before :// (e.g. assay://org/app)");
94    }
95
96    // Validate scheme charset (RFC 3986: ALPHA *( ALPHA / DIGIT / "+" / "-" / "." ))
97    let scheme = &s[..pos];
98    let mut chars = scheme.chars();
99    match chars.next() {
100        Some(c) if c.is_ascii_alphabetic() => {}
101        _ => anyhow::bail!("event_source URI scheme must start with a letter"),
102    }
103    if !chars.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '-' || c == '.') {
104        anyhow::bail!("event_source URI scheme contains invalid characters");
105    }
106
107    Ok(())
108}
109
110pub struct McpProxy {
111    child: Child,
112    policy: McpPolicy,
113    config: ProxyConfig,
114    /// Cache of tool identities discovered during tools/list
115    identity_cache: Arc<Mutex<HashMap<String, super::identity::ToolIdentity>>>,
116}
117
118impl Drop for McpProxy {
119    fn drop(&mut self) {
120        // Best-effort cleanup
121        let _ = self.child.kill();
122    }
123}
124
125impl McpProxy {
126    pub fn spawn(
127        command: &str,
128        args: &[String],
129        policy: McpPolicy,
130        config: ProxyConfig,
131    ) -> io::Result<Self> {
132        let child = Command::new(command)
133            .args(args)
134            .stdin(Stdio::piped())
135            .stdout(Stdio::piped())
136            .stderr(Stdio::inherit()) // protocol blijft op stdout
137            .spawn()?;
138
139        Ok(Self {
140            child,
141            policy,
142            config,
143            identity_cache: Arc::new(Mutex::new(HashMap::new())),
144        })
145    }
146
147    pub fn run(mut self) -> io::Result<i32> {
148        let mut child_stdin = self.child.stdin.take().expect("child stdin");
149        let child_stdout = self.child.stdout.take().expect("child stdout");
150
151        let stdout = Arc::new(Mutex::new(io::stdout()));
152        let policy = self.policy.clone();
153        let config = self.config.clone();
154        let identity_cache_a = self.identity_cache.clone();
155        let identity_cache_b = self.identity_cache.clone();
156
157        // Initialize decision emitter (I1: always emit decision)
158        let decision_emitter: Arc<dyn DecisionEmitter> =
159            if let Some(path) = &config.decision_log_path {
160                Arc::new(FileDecisionEmitter::new(path)?)
161            } else {
162                Arc::new(NullDecisionEmitter)
163            };
164        let event_source = config
165            .event_source
166            .clone()
167            .unwrap_or_else(|| format!("assay://{}", config.server_id));
168
169        // Thread A: server -> client passthrough
170        let stdout_a = stdout.clone();
171        let t_server_to_client = thread::spawn(move || -> io::Result<()> {
172            let mut reader = BufReader::new(child_stdout);
173            let mut line = String::new();
174
175            while reader.read_line(&mut line)? > 0 {
176                let mut processed_line = line.clone();
177
178                // Phase 9: Compute Identities on tools/list response
179                if let Ok(mut v) = serde_json::from_str::<serde_json::Value>(&line) {
180                    if let Some(result) = v.get_mut("result") {
181                        if let Some(tools) = result.get_mut("tools").and_then(|t| t.as_array_mut())
182                        {
183                            for tool in tools {
184                                let name = tool
185                                    .get("name")
186                                    .and_then(|n| n.as_str())
187                                    .unwrap_or("unknown");
188                                let description = tool
189                                    .get("description")
190                                    .and_then(|d| d.as_str())
191                                    .map(|s| s.to_string());
192                                let input_schema = tool
193                                    .get("inputSchema")
194                                    .or_else(|| tool.get("input_schema"))
195                                    .cloned();
196
197                                let identity = super::identity::ToolIdentity::new(
198                                    &config.server_id,
199                                    name,
200                                    &input_schema,
201                                    &description,
202                                );
203
204                                // Cache for runtime verification
205                                let mut cache = identity_cache_a.lock().unwrap();
206                                cache.insert(name.to_string(), identity.clone());
207
208                                // Augment the response with the computed identity for downstream/logging
209                                tool.as_object_mut().and_then(|m| {
210                                    m.insert(
211                                        "tool_identity".to_string(),
212                                        serde_json::to_value(&identity).unwrap(),
213                                    )
214                                });
215                            }
216                            processed_line =
217                                serde_json::to_string(&v).unwrap_or(line.clone()) + "\n";
218                        }
219                    }
220                }
221
222                let mut out = stdout_a
223                    .lock()
224                    .map_err(|e| io::Error::other(e.to_string()))?;
225                out.write_all(processed_line.as_bytes())?;
226                out.flush()?;
227                line.clear();
228            }
229            Ok(())
230        });
231
232        // Thread B: client -> server passthrough with Policy Check
233        let stdout_b = stdout.clone();
234        let emitter_b = decision_emitter.clone();
235        let event_source_b = event_source.clone();
236        let t_client_to_server = thread::spawn(move || -> io::Result<()> {
237            let stdin = io::stdin();
238            let mut reader = stdin.lock();
239            let mut line = String::new();
240
241            let mut state = PolicyState::default();
242            let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
243
244            while reader.read_line(&mut line)? > 0 {
245                // 1. Try Parse as MCP Request
246                match serde_json::from_str::<JsonRpcRequest>(&line) {
247                    Ok(req) => {
248                        // 2. Check Policy with Identity (Phase 9)
249                        let runtime_id = if req.is_tool_call() {
250                            let name = req.tool_params().map(|p| p.name).unwrap_or_default();
251                            let cache = identity_cache_b.lock().unwrap();
252                            cache.get(&name).cloned()
253                        } else {
254                            None
255                        };
256
257                        let tool_name = req.tool_params().map(|p| p.name).unwrap_or_default();
258                        let tool_call_id = Self::extract_tool_call_id(&req);
259
260                        match policy.evaluate(
261                            &tool_name,
262                            &req.tool_params()
263                                .map(|p| p.arguments)
264                                .unwrap_or(serde_json::Value::Null),
265                            &mut state,
266                            runtime_id.as_ref(),
267                        ) {
268                            PolicyDecision::Allow => {
269                                Self::handle_allow(&req, &mut audit_log, config.verbose);
270                                // Emit decision event (I1: always emit)
271                                if req.is_tool_call() {
272                                    Self::emit_decision(
273                                        &emitter_b,
274                                        &event_source_b,
275                                        &tool_call_id,
276                                        &tool_name,
277                                        Decision::Allow,
278                                        reason_codes::P_POLICY_DENY, // TODO: Better code
279                                        None,
280                                        req.id.clone(),
281                                    );
282                                }
283                            }
284                            PolicyDecision::AllowWithWarning { tool, code, reason } => {
285                                // Log warning about allowing a tool invocation with issues
286                                if config.verbose {
287                                    eprintln!(
288                                        "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
289                                        tool,
290                                        code,
291                                        reason
292                                    );
293                                }
294                                audit_log.log(&AuditEvent {
295                                    timestamp: chrono::Utc::now().to_rfc3339(),
296                                    decision: "allow_with_warning".to_string(),
297                                    tool: Some(tool.clone()),
298                                    reason: Some(reason.clone()),
299                                    request_id: req.id.clone(),
300                                    agentic: None,
301                                });
302                                // Emit decision event (I1: always emit)
303                                Self::emit_decision(
304                                    &emitter_b,
305                                    &event_source_b,
306                                    &tool_call_id,
307                                    &tool,
308                                    Decision::Allow,
309                                    &code,
310                                    Some(reason),
311                                    req.id.clone(),
312                                );
313                                // Then proceed as a normal allow
314                                Self::handle_allow(&req, &mut audit_log, false);
315                                // false = don't double log ALLOW
316                            }
317                            PolicyDecision::Deny {
318                                tool,
319                                code,
320                                reason,
321                                contract,
322                            } => {
323                                // Log Decision
324                                let decision_str =
325                                    if config.dry_run { "would_deny" } else { "deny" };
326
327                                if config.verbose {
328                                    eprintln!(
329                                        "[assay] {} {} (reason: {})",
330                                        decision_str.to_uppercase(),
331                                        tool,
332                                        reason
333                                    );
334                                }
335
336                                audit_log.log(&AuditEvent {
337                                    timestamp: chrono::Utc::now().to_rfc3339(),
338                                    decision: decision_str.to_string(),
339                                    tool: Some(tool.clone()),
340                                    reason: Some(reason.clone()),
341                                    request_id: req.id.clone(),
342                                    agentic: Some(contract.clone()),
343                                });
344
345                                // Emit decision event (I1: always emit)
346                                let reason_code = Self::map_policy_code(&code);
347                                Self::emit_decision(
348                                    &emitter_b,
349                                    &event_source_b,
350                                    &tool_call_id,
351                                    &tool,
352                                    if config.dry_run {
353                                        Decision::Allow
354                                    } else {
355                                        Decision::Deny
356                                    },
357                                    &reason_code,
358                                    Some(reason),
359                                    req.id.clone(),
360                                );
361
362                                if config.dry_run {
363                                    // DRY RUN: Forward anyway
364                                    // Fallthrough to forward logic below
365                                } else {
366                                    // BLOCK: Send error response
367                                    let id = req.id.unwrap_or(serde_json::Value::Null);
368                                    let response_json = make_deny_response(
369                                        id,
370                                        "Content blocked by policy",
371                                        contract,
372                                    );
373
374                                    let mut out = stdout_b
375                                        .lock()
376                                        .map_err(|e| io::Error::other(e.to_string()))?;
377                                    out.write_all(response_json.as_bytes())?;
378                                    out.flush()?;
379
380                                    line.clear();
381                                    continue; // Skip forwarding
382                                }
383                            }
384                        }
385                    }
386                    Err(_) => {
387                        // Hardening: Suspicious Unparsable JSON
388                        let trimmed = line.trim();
389                        if trimmed.starts_with('{')
390                            && (trimmed.contains("\"method\"")
391                                || trimmed.contains("\"params\"")
392                                || trimmed.contains("\"tool\""))
393                        {
394                            eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
395                        }
396                    }
397                }
398
399                // 3. Forward
400                child_stdin.write_all(line.as_bytes())?;
401                child_stdin.flush()?;
402                line.clear();
403            }
404            Ok(())
405        });
406
407        // Wacht tot client->server eindigt (stdin closed)
408        t_client_to_server
409            .join()
410            .map_err(|_| io::Error::other("client->server thread panicked"))??;
411
412        // Server->client thread kan nog even lopen; join best-effort
413        let _ = t_server_to_client.join();
414
415        // Wacht op child exit
416        let status = self.child.wait()?;
417        Ok(status.code().unwrap_or(1))
418    }
419
420    fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
421        if verbose && req.is_tool_call() {
422            let tool = req
423                .tool_params()
424                .map(|p| p.name)
425                .unwrap_or_else(|| "unknown".to_string());
426            eprintln!("[assay] ALLOW {}", tool);
427        }
428
429        if req.is_tool_call() {
430            let tool = req.tool_params().map(|p| p.name);
431            audit_log.log(&AuditEvent {
432                timestamp: chrono::Utc::now().to_rfc3339(),
433                decision: "allow".to_string(),
434                tool,
435                reason: None,
436                request_id: req.id.clone(),
437                agentic: None,
438            });
439        }
440    }
441
442    /// Extract tool_call_id from request (I4: idempotency key).
443    fn extract_tool_call_id(request: &JsonRpcRequest) -> String {
444        // Try to get from params._meta.tool_call_id (MCP standard)
445        if let Some(params) = request.tool_params() {
446            if let Some(meta) = params.arguments.get("_meta") {
447                if let Some(id) = meta.get("tool_call_id").and_then(|v| v.as_str()) {
448                    return id.to_string();
449                }
450            }
451        }
452
453        // Fall back to request.id if present
454        if let Some(id) = &request.id {
455            if let Some(s) = id.as_str() {
456                return format!("req_{}", s);
457            }
458            if let Some(n) = id.as_i64() {
459                return format!("req_{}", n);
460            }
461        }
462
463        // Generate one if none found
464        format!("gen_{}", uuid::Uuid::new_v4())
465    }
466
467    /// Map policy error code to reason code.
468    fn map_policy_code(code: &str) -> String {
469        match code {
470            "E_TOOL_DENIED" => reason_codes::P_TOOL_DENIED.to_string(),
471            "E_TOOL_NOT_ALLOWED" => reason_codes::P_TOOL_NOT_ALLOWED.to_string(),
472            "E_ARG_SCHEMA" => reason_codes::P_ARG_SCHEMA.to_string(),
473            "E_RATE_LIMIT" => reason_codes::P_RATE_LIMIT.to_string(),
474            "E_TOOL_DRIFT" => reason_codes::P_TOOL_DRIFT.to_string(),
475            _ => reason_codes::P_POLICY_DENY.to_string(),
476        }
477    }
478
479    /// Emit a decision event (I1: always emit).
480    #[allow(clippy::too_many_arguments)]
481    fn emit_decision(
482        emitter: &Arc<dyn DecisionEmitter>,
483        source: &str,
484        tool_call_id: &str,
485        tool: &str,
486        decision: Decision,
487        reason_code: &str,
488        reason: Option<String>,
489        request_id: Option<serde_json::Value>,
490    ) {
491        let mut event = DecisionEvent::new(
492            source.to_string(),
493            tool_call_id.to_string(),
494            tool.to_string(),
495        );
496        event.data.decision = decision;
497        event.data.reason_code = reason_code.to_string();
498        event.data.reason = reason;
499        event.data.request_id = request_id;
500        emitter.emit(&event);
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    #[test]
509    fn event_source_accepts_assay_uri() {
510        validate_event_source("assay://myorg/myapp").unwrap();
511    }
512
513    #[test]
514    fn event_source_accepts_https_uri() {
515        validate_event_source("https://example.com/agent").unwrap();
516    }
517
518    #[test]
519    fn event_source_rejects_empty() {
520        assert!(validate_event_source("").is_err());
521        assert!(validate_event_source("   ").is_err());
522    }
523
524    #[test]
525    fn event_source_rejects_whitespace() {
526        assert!(validate_event_source("assay://myorg/my app").is_err());
527        assert!(validate_event_source("assay://myorg/\tmyapp").is_err());
528    }
529
530    #[test]
531    fn event_source_rejects_missing_scheme() {
532        assert!(validate_event_source("myorg/myapp").is_err());
533        assert!(validate_event_source("://myorg/myapp").is_err());
534    }
535
536    #[test]
537    fn event_source_rejects_did_and_urn() {
538        // We require scheme:// not just scheme:
539        assert!(validate_event_source("did:example:123").is_err());
540        assert!(validate_event_source("urn:example:foo").is_err());
541    }
542
543    #[test]
544    fn event_source_rejects_scheme_starting_with_non_letter() {
545        assert!(validate_event_source("1assay://myorg/myapp").is_err());
546        assert!(validate_event_source("-assay://myorg/myapp").is_err());
547    }
548
549    #[test]
550    fn event_source_rejects_scheme_with_invalid_chars() {
551        assert!(validate_event_source("as_say://myorg/myapp").is_err());
552        assert!(validate_event_source("as@say://myorg/myapp").is_err());
553    }
554
555    #[test]
556    fn config_requires_event_source_when_logging_enabled() {
557        let raw = ProxyConfigRaw {
558            dry_run: false,
559            verbose: false,
560            audit_log_path: None,
561            decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
562            event_source: None,
563            server_id: "srv".to_string(),
564        };
565
566        let err = ProxyConfig::try_from_raw(raw).unwrap_err();
567        let msg = format!("{err:#}");
568        assert!(msg.contains("event_source is required"));
569    }
570
571    #[test]
572    fn config_allows_no_event_source_when_logging_disabled() {
573        let raw = ProxyConfigRaw {
574            dry_run: false,
575            verbose: false,
576            audit_log_path: None,
577            decision_log_path: None,
578            event_source: None,
579            server_id: "srv".to_string(),
580        };
581
582        ProxyConfig::try_from_raw(raw).unwrap();
583    }
584
585    #[test]
586    fn config_accepts_valid_event_source() {
587        let raw = ProxyConfigRaw {
588            dry_run: false,
589            verbose: false,
590            audit_log_path: None,
591            decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
592            event_source: Some("assay://myorg/myapp".to_string()),
593            server_id: "srv".to_string(),
594        };
595
596        let cfg = ProxyConfig::try_from_raw(raw).unwrap();
597        assert_eq!(cfg.event_source.as_deref(), Some("assay://myorg/myapp"));
598    }
599
600    #[test]
601    fn config_rejects_invalid_event_source_uri() {
602        let raw = ProxyConfigRaw {
603            dry_run: false,
604            verbose: false,
605            audit_log_path: None,
606            decision_log_path: Some(std::path::PathBuf::from("decisions.ndjson")),
607            event_source: Some("not a uri".to_string()),
608            server_id: "srv".to_string(),
609        };
610
611        assert!(ProxyConfig::try_from_raw(raw).is_err());
612    }
613}