Skip to main content

assay_core/mcp/
proxy.rs

1use super::audit::{AuditEvent, AuditLog};
2use super::jsonrpc::JsonRpcRequest;
3use super::policy::{make_deny_response, McpPolicy, PolicyDecision, PolicyState};
4use std::{
5    collections::HashMap,
6    io::{self, BufRead, BufReader, Write},
7    process::{Child, Command, Stdio},
8    sync::{Arc, Mutex},
9    thread,
10};
11
12#[derive(Clone, Debug, Default)]
13pub struct ProxyConfig {
14    pub dry_run: bool,
15    pub verbose: bool,
16    pub audit_log_path: Option<std::path::PathBuf>,
17    pub server_id: String,
18}
19
20pub struct McpProxy {
21    child: Child,
22    policy: McpPolicy,
23    config: ProxyConfig,
24    /// Cache of tool identities discovered during tools/list
25    identity_cache: Arc<Mutex<HashMap<String, super::identity::ToolIdentity>>>,
26}
27
28impl Drop for McpProxy {
29    fn drop(&mut self) {
30        // Best-effort cleanup
31        let _ = self.child.kill();
32    }
33}
34
35impl McpProxy {
36    pub fn spawn(
37        command: &str,
38        args: &[String],
39        policy: McpPolicy,
40        config: ProxyConfig,
41    ) -> io::Result<Self> {
42        let child = Command::new(command)
43            .args(args)
44            .stdin(Stdio::piped())
45            .stdout(Stdio::piped())
46            .stderr(Stdio::inherit()) // protocol blijft op stdout
47            .spawn()?;
48
49        Ok(Self {
50            child,
51            policy,
52            config,
53            identity_cache: Arc::new(Mutex::new(HashMap::new())),
54        })
55    }
56
57    pub fn run(mut self) -> io::Result<i32> {
58        let mut child_stdin = self.child.stdin.take().expect("child stdin");
59        let child_stdout = self.child.stdout.take().expect("child stdout");
60
61        let stdout = Arc::new(Mutex::new(io::stdout()));
62        let policy = self.policy.clone();
63        let config = self.config.clone();
64        let identity_cache_a = self.identity_cache.clone();
65        let identity_cache_b = self.identity_cache.clone();
66
67        // Thread A: server -> client passthrough
68        let stdout_a = stdout.clone();
69        let t_server_to_client = thread::spawn(move || -> io::Result<()> {
70            let mut reader = BufReader::new(child_stdout);
71            let mut line = String::new();
72
73            while reader.read_line(&mut line)? > 0 {
74                let mut processed_line = line.clone();
75
76                // Phase 9: Compute Identities on tools/list response
77                if let Ok(mut v) = serde_json::from_str::<serde_json::Value>(&line) {
78                    if let Some(result) = v.get_mut("result") {
79                        if let Some(tools) = result.get_mut("tools").and_then(|t| t.as_array_mut())
80                        {
81                            for tool in tools {
82                                let name = tool
83                                    .get("name")
84                                    .and_then(|n| n.as_str())
85                                    .unwrap_or("unknown");
86                                let description = tool
87                                    .get("description")
88                                    .and_then(|d| d.as_str())
89                                    .map(|s| s.to_string());
90                                let input_schema = tool
91                                    .get("inputSchema")
92                                    .or_else(|| tool.get("input_schema"))
93                                    .cloned();
94
95                                let identity = super::identity::ToolIdentity::new(
96                                    &config.server_id,
97                                    name,
98                                    &input_schema,
99                                    &description,
100                                );
101
102                                // Cache for runtime verification
103                                let mut cache = identity_cache_a.lock().unwrap();
104                                cache.insert(name.to_string(), identity.clone());
105
106                                // Augment the response with the computed identity for downstream/logging
107                                tool.as_object_mut().and_then(|m| {
108                                    m.insert(
109                                        "tool_identity".to_string(),
110                                        serde_json::to_value(&identity).unwrap(),
111                                    )
112                                });
113                            }
114                            processed_line =
115                                serde_json::to_string(&v).unwrap_or(line.clone()) + "\n";
116                        }
117                    }
118                }
119
120                let mut out = stdout_a
121                    .lock()
122                    .map_err(|e| io::Error::other(e.to_string()))?;
123                out.write_all(processed_line.as_bytes())?;
124                out.flush()?;
125                line.clear();
126            }
127            Ok(())
128        });
129
130        // Thread B: client -> server passthrough with Policy Check
131        let stdout_b = stdout.clone();
132        let t_client_to_server = thread::spawn(move || -> io::Result<()> {
133            let stdin = io::stdin();
134            let mut reader = stdin.lock();
135            let mut line = String::new();
136
137            let mut state = PolicyState::default();
138            let mut audit_log = AuditLog::new(config.audit_log_path.as_deref());
139
140            while reader.read_line(&mut line)? > 0 {
141                // 1. Try Parse as MCP Request
142                match serde_json::from_str::<JsonRpcRequest>(&line) {
143                    Ok(req) => {
144                        // 2. Check Policy with Identity (Phase 9)
145                        let runtime_id = if req.is_tool_call() {
146                            let name = req.tool_params().map(|p| p.name).unwrap_or_default();
147                            let cache = identity_cache_b.lock().unwrap();
148                            cache.get(&name).cloned()
149                        } else {
150                            None
151                        };
152
153                        match policy.evaluate(
154                            &req.tool_params().map(|p| p.name).unwrap_or_default(),
155                            &req.tool_params()
156                                .map(|p| p.arguments)
157                                .unwrap_or(serde_json::Value::Null),
158                            &mut state,
159                            runtime_id.as_ref(),
160                        ) {
161                            PolicyDecision::Allow => {
162                                Self::handle_allow(&req, &mut audit_log, config.verbose);
163                            }
164                            PolicyDecision::AllowWithWarning { tool, code, reason } => {
165                                // Log warning about allowing a tool invocation with issues
166                                if config.verbose {
167                                    eprintln!(
168                                        "[assay] WARNING: Allowing tool '{}' with warning (code: {}, reason: {}).",
169                                        tool,
170                                        code,
171                                        reason
172                                    );
173                                }
174                                audit_log.log(&AuditEvent {
175                                    timestamp: chrono::Utc::now().to_rfc3339(),
176                                    decision: "allow_with_warning".to_string(),
177                                    tool: Some(tool.clone()),
178                                    reason: Some(reason.clone()),
179                                    request_id: req.id.clone(),
180                                    agentic: None,
181                                });
182                                // Then proceed as a normal allow
183                                Self::handle_allow(&req, &mut audit_log, false);
184                                // false = don't double log ALLOW
185                            }
186                            PolicyDecision::Deny {
187                                tool,
188                                code: _,
189                                reason,
190                                contract,
191                            } => {
192                                // Log Decision
193                                let decision_str =
194                                    if config.dry_run { "would_deny" } else { "deny" };
195
196                                if config.verbose {
197                                    eprintln!(
198                                        "[assay] {} {} (reason: {})",
199                                        decision_str.to_uppercase(),
200                                        tool,
201                                        reason
202                                    );
203                                }
204
205                                audit_log.log(&AuditEvent {
206                                    timestamp: chrono::Utc::now().to_rfc3339(),
207                                    decision: decision_str.to_string(),
208                                    tool: Some(tool.clone()),
209                                    reason: Some(reason.clone()),
210                                    request_id: req.id.clone(),
211                                    agentic: Some(contract.clone()),
212                                });
213
214                                if config.dry_run {
215                                    // DRY RUN: Forward anyway
216                                    // Fallthrough to forward logic below
217                                } else {
218                                    // BLOCK: Send error response
219                                    let id = req.id.unwrap_or(serde_json::Value::Null);
220                                    let response_json = make_deny_response(
221                                        id,
222                                        "Content blocked by policy",
223                                        contract,
224                                    );
225
226                                    let mut out = stdout_b
227                                        .lock()
228                                        .map_err(|e| io::Error::other(e.to_string()))?;
229                                    out.write_all(response_json.as_bytes())?;
230                                    out.flush()?;
231
232                                    line.clear();
233                                    continue; // Skip forwarding
234                                }
235                            }
236                        }
237                    }
238                    Err(_) => {
239                        // Hardening: Suspicious Unparsable JSON
240                        let trimmed = line.trim();
241                        if trimmed.starts_with('{')
242                            && (trimmed.contains("\"method\"")
243                                || trimmed.contains("\"params\"")
244                                || trimmed.contains("\"tool\""))
245                        {
246                            eprintln!("[assay] WARNING: Suspicious unparsable JSON, forwarding anyway (potential bypass attempt?): {:.60}...", trimmed);
247                        }
248                    }
249                }
250
251                // 3. Forward
252                child_stdin.write_all(line.as_bytes())?;
253                child_stdin.flush()?;
254                line.clear();
255            }
256            Ok(())
257        });
258
259        // Wacht tot client->server eindigt (stdin closed)
260        t_client_to_server
261            .join()
262            .map_err(|_| io::Error::other("client->server thread panicked"))??;
263
264        // Server->client thread kan nog even lopen; join best-effort
265        let _ = t_server_to_client.join();
266
267        // Wacht op child exit
268        let status = self.child.wait()?;
269        Ok(status.code().unwrap_or(1))
270    }
271
272    fn handle_allow(req: &JsonRpcRequest, audit_log: &mut AuditLog, verbose: bool) {
273        if verbose && req.is_tool_call() {
274            let tool = req
275                .tool_params()
276                .map(|p| p.name)
277                .unwrap_or_else(|| "unknown".to_string());
278            eprintln!("[assay] ALLOW {}", tool);
279        }
280
281        if req.is_tool_call() {
282            let tool = req.tool_params().map(|p| p.name);
283            audit_log.log(&AuditEvent {
284                timestamp: chrono::Utc::now().to_rfc3339(),
285                decision: "allow".to_string(),
286                tool,
287                reason: None,
288                request_id: req.id.clone(),
289                agentic: None,
290            });
291        }
292    }
293}