assay_core/mcp/
proxy.rs

1use super::audit::{AuditEvent, AuditLog};
2use super::jsonrpc::JsonRpcRequest;
3use super::policy::{make_deny_response, McpPolicy, PolicyDecision, PolicyState};
4use std::{
5    io::{self, BufRead, BufReader, Write},
6    process::{Child, Command, Stdio},
7    sync::{Arc, Mutex},
8    thread,
9};
10
11#[derive(Clone, Debug, Default)]
12pub struct ProxyConfig {
13    pub dry_run: bool,
14    pub verbose: bool,
15    pub audit_log_path: Option<std::path::PathBuf>,
16}
17
18pub struct McpProxy {
19    child: Child,
20    policy: McpPolicy,
21    config: ProxyConfig,
22}
23
24impl Drop for McpProxy {
25    fn drop(&mut self) {
26        // Best-effort cleanup
27        let _ = self.child.kill();
28    }
29}
30
31impl McpProxy {
32    pub fn spawn(
33        command: &str,
34        args: &[String],
35        policy: McpPolicy,
36        config: ProxyConfig,
37    ) -> io::Result<Self> {
38        let child = Command::new(command)
39            .args(args)
40            .stdin(Stdio::piped())
41            .stdout(Stdio::piped())
42            .stderr(Stdio::inherit()) // protocol blijft op stdout
43            .spawn()?;
44
45        Ok(Self {
46            child,
47            policy,
48            config,
49        })
50    }
51
52    pub fn run(mut self) -> io::Result<i32> {
53        let mut child_stdin = self.child.stdin.take().expect("child stdin");
54        let child_stdout = self.child.stdout.take().expect("child stdout");
55
56        let stdout = Arc::new(Mutex::new(io::stdout()));
57        let policy = self.policy.clone();
58        let config = self.config.clone();
59
60        // Thread A: server -> client passthrough
61        let stdout_a = stdout.clone();
62        let t_server_to_client = thread::spawn(move || -> io::Result<()> {
63            let mut reader = BufReader::new(child_stdout);
64            let mut line = String::new();
65
66            while reader.read_line(&mut line)? > 0 {
67                let mut out = stdout_a
68                    .lock()
69                    .map_err(|e| io::Error::other(e.to_string()))?;
70                out.write_all(line.as_bytes())?;
71                out.flush()?;
72                line.clear();
73            }
74            Ok(())
75        });
76
77        // Thread B: client -> server passthrough with Policy Check
78        let stdout_b = stdout.clone();
79        let t_client_to_server = thread::spawn(move || -> io::Result<()> {
80            let stdin = io::stdin();
81            let mut reader = stdin.lock();
82            let mut line = String::new();
83
84            let mut state = PolicyState::default();
85            let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
86
87            while reader.read_line(&mut line)? > 0 {
88                // 1. Try Parse as MCP Request
89                match serde_json::from_str::<JsonRpcRequest>(&line) {
90                    Ok(req) => {
91                        // 2. Check Policy
92                        match policy.check(&req, &mut state) {
93                            PolicyDecision::Allow => {
94                                Self::handle_allow(&req, &mut audit_log, config.verbose);
95                            }
96                            PolicyDecision::AllowWithWarning { tool, code, reason } => {
97                                // Log warning about allowing a tool invocation with issues
98                                if config.verbose {
99                                    eprintln!(
100                                        "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
101                                        tool,
102                                        code,
103                                        reason
104                                    );
105                                }
106                                audit_log.log(&AuditEvent {
107                                    timestamp: chrono::Utc::now().to_rfc3339(),
108                                    decision: "allow_with_warning".to_string(),
109                                    tool: Some(tool.clone()),
110                                    reason: Some(reason.clone()),
111                                    request_id: req.id.clone(),
112                                    agentic: None,
113                                });
114                                // Then proceed as a normal allow
115                                Self::handle_allow(&req, &mut audit_log, false);
116                                // false = don't double log ALLOW
117                            }
118                            PolicyDecision::Deny {
119                                tool,
120                                code: _,
121                                reason,
122                                contract,
123                            } => {
124                                // Log Decision
125                                let decision_str =
126                                    if config.dry_run { "would_deny" } else { "deny" };
127
128                                if config.verbose {
129                                    eprintln!(
130                                        "[assay] {} {} (reason: {})",
131                                        decision_str.to_uppercase(),
132                                        tool,
133                                        reason
134                                    );
135                                }
136
137                                audit_log.log(&AuditEvent {
138                                    timestamp: chrono::Utc::now().to_rfc3339(),
139                                    decision: decision_str.to_string(),
140                                    tool: Some(tool.clone()),
141                                    reason: Some(reason.clone()),
142                                    request_id: req.id.clone(),
143                                    agentic: Some(contract.clone()),
144                                });
145
146                                if config.dry_run {
147                                    // DRY RUN: Forward anyway
148                                    // Fallthrough to forward logic below
149                                } else {
150                                    // BLOCK: Send error response
151                                    let id = req.id.unwrap_or(serde_json::Value::Null);
152                                    let response_json = make_deny_response(
153                                        id,
154                                        "Content blocked by policy",
155                                        contract,
156                                    );
157
158                                    let mut out = stdout_b
159                                        .lock()
160                                        .map_err(|e| io::Error::other(e.to_string()))?;
161                                    out.write_all(response_json.as_bytes())?;
162                                    out.flush()?;
163
164                                    line.clear();
165                                    continue; // Skip forwarding
166                                }
167                            }
168                        }
169                    }
170                    Err(_) => {
171                        // Hardening: Suspicious Unparsable JSON
172                        let trimmed = line.trim();
173                        if trimmed.starts_with('{')
174                            && (trimmed.contains("\"method\"")
175                                || trimmed.contains("\"params\"")
176                                || trimmed.contains("\"tool\""))
177                        {
178                            eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
179                        }
180                    }
181                }
182
183                // 3. Forward
184                child_stdin.write_all(line.as_bytes())?;
185                child_stdin.flush()?;
186                line.clear();
187            }
188            Ok(())
189        });
190
191        // Wacht tot client->server eindigt (stdin closed)
192        t_client_to_server
193            .join()
194            .map_err(|_| io::Error::other("client->server thread panicked"))??;
195
196        // Server->client thread kan nog even lopen; join best-effort
197        let _ = t_server_to_client.join();
198
199        // Wacht op child exit
200        let status = self.child.wait()?;
201        Ok(status.code().unwrap_or(1))
202    }
203
204    fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
205        if verbose && req.is_tool_call() {
206            let tool = req
207                .tool_params()
208                .map(|p| p.name)
209                .unwrap_or_else(|| "unknown".to_string());
210            eprintln!("[assay] ALLOW {}", tool);
211        }
212
213        if req.is_tool_call() {
214            let tool = req.tool_params().map(|p| p.name);
215            audit_log.log(&AuditEvent {
216                timestamp: chrono::Utc::now().to_rfc3339(),
217                decision: "allow".to_string(),
218                tool,
219                reason: None,
220                request_id: req.id.clone(),
221                agentic: None,
222            });
223        }
224    }
225}