Skip to main content

cc_audit/proxy/
server.rs

1//! Async TCP proxy server for MCP message interception.
2
3use super::{InterceptAction, MessageInterceptor, ProxyConfig, ProxyLogger};
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{TcpListener, TcpStream};
7
8/// TCP proxy server for MCP message interception.
9pub struct ProxyServer {
10    config: ProxyConfig,
11    interceptor: Arc<MessageInterceptor>,
12    logger: Arc<ProxyLogger>,
13}
14
15impl ProxyServer {
16    /// Create a new proxy server with the given configuration.
17    pub fn new(config: ProxyConfig) -> std::io::Result<Self> {
18        let interceptor = Arc::new(MessageInterceptor::new(
19            config.block_mode,
20            config.min_block_severity,
21        ));
22
23        let logger = Arc::new(ProxyLogger::new(
24            config.log_file.as_deref(),
25            config.verbose,
26        )?);
27
28        Ok(Self {
29            config,
30            interceptor,
31            logger,
32        })
33    }
34
35    /// Run the proxy server.
36    pub async fn run(&self) -> std::io::Result<()> {
37        let listener = TcpListener::bind(self.config.listen_addr).await?;
38
39        eprintln!(
40            "Proxy listening on {} -> {}",
41            self.config.listen_addr, self.config.target_addr
42        );
43
44        if self.config.block_mode {
45            eprintln!(
46                "Block mode enabled (min severity: {:?})",
47                self.config.min_block_severity
48            );
49        } else {
50            eprintln!("Log-only mode (no blocking)");
51        }
52
53        loop {
54            let (client_stream, client_addr) = listener.accept().await?;
55
56            let target_addr = self.config.target_addr;
57            let interceptor = Arc::clone(&self.interceptor);
58            let logger = Arc::clone(&self.logger);
59            let block_mode = self.config.block_mode;
60
61            tokio::spawn(async move {
62                if let Err(e) = handle_connection(
63                    client_stream,
64                    target_addr,
65                    interceptor,
66                    logger,
67                    block_mode,
68                    client_addr.to_string(),
69                )
70                .await
71                {
72                    eprintln!("Connection error: {}", e);
73                }
74            });
75        }
76    }
77}
78
79/// Handle a single client connection.
80async fn handle_connection(
81    client: TcpStream,
82    target_addr: std::net::SocketAddr,
83    interceptor: Arc<MessageInterceptor>,
84    logger: Arc<ProxyLogger>,
85    block_mode: bool,
86    client_addr: String,
87) -> std::io::Result<()> {
88    // Connect to target
89    let target = TcpStream::connect(target_addr).await?;
90
91    // Split into owned halves
92    let (client_read, client_write) = client.into_split();
93    let (target_read, target_write) = target.into_split();
94
95    let interceptor_req = Arc::clone(&interceptor);
96    let interceptor_resp = Arc::clone(&interceptor);
97    let logger_req = Arc::clone(&logger);
98    let logger_resp = Arc::clone(&logger);
99    let client_addr_req = client_addr.clone();
100    let client_addr_resp = client_addr;
101
102    // Wrap writes in Arc<Mutex> for shared access
103    let client_write = Arc::new(tokio::sync::Mutex::new(client_write));
104    let target_write = Arc::new(tokio::sync::Mutex::new(target_write));
105
106    let client_write_clone = Arc::clone(&client_write);
107
108    // Forward client -> target
109    let client_to_target = async move {
110        let mut client_read = client_read;
111        let mut buf = vec![0u8; 65536];
112        loop {
113            let n = client_read.read(&mut buf).await?;
114            if n == 0 {
115                break;
116            }
117
118            let data = &buf[..n];
119
120            // Intercept and analyze
121            let action = interceptor_req.intercept(data);
122            let method = extract_method(data);
123
124            match &action {
125                InterceptAction::Allow => {
126                    target_write.lock().await.write_all(data).await?;
127                }
128                InterceptAction::Log(findings) => {
129                    logger_req.log_request(
130                        method.as_deref(),
131                        findings,
132                        "logged",
133                        Some(&client_addr_req),
134                        n,
135                    );
136                    target_write.lock().await.write_all(data).await?;
137                }
138                InterceptAction::Block(findings) => {
139                    logger_req.log_request(
140                        method.as_deref(),
141                        findings,
142                        "blocked",
143                        Some(&client_addr_req),
144                        n,
145                    );
146
147                    if block_mode {
148                        // Send error response to client
149                        let error_response = create_error_response(findings);
150                        client_write
151                            .lock()
152                            .await
153                            .write_all(error_response.as_bytes())
154                            .await?;
155                        break;
156                    } else {
157                        target_write.lock().await.write_all(data).await?;
158                    }
159                }
160            }
161        }
162        Ok::<_, std::io::Error>(())
163    };
164
165    // Forward target -> client
166    let target_to_client = async move {
167        let mut target_read = target_read;
168        let mut buf = vec![0u8; 65536];
169        loop {
170            let n = target_read.read(&mut buf).await?;
171            if n == 0 {
172                break;
173            }
174
175            let data = &buf[..n];
176
177            // Intercept and analyze response
178            let action = interceptor_resp.intercept(data);
179            let method = extract_method(data);
180
181            match &action {
182                InterceptAction::Allow => {
183                    client_write_clone.lock().await.write_all(data).await?;
184                }
185                InterceptAction::Log(findings) => {
186                    logger_resp.log_response(
187                        method.as_deref(),
188                        findings,
189                        "logged",
190                        Some(&client_addr_resp),
191                        n,
192                    );
193                    client_write_clone.lock().await.write_all(data).await?;
194                }
195                InterceptAction::Block(findings) => {
196                    logger_resp.log_response(
197                        method.as_deref(),
198                        findings,
199                        "blocked",
200                        Some(&client_addr_resp),
201                        n,
202                    );
203
204                    if block_mode {
205                        // Don't forward blocked response
206                        let error_response = create_error_response(findings);
207                        client_write_clone
208                            .lock()
209                            .await
210                            .write_all(error_response.as_bytes())
211                            .await?;
212                        break;
213                    } else {
214                        client_write_clone.lock().await.write_all(data).await?;
215                    }
216                }
217            }
218        }
219        Ok::<_, std::io::Error>(())
220    };
221
222    // Run both directions concurrently
223    tokio::select! {
224        result = client_to_target => result?,
225        result = target_to_client => result?,
226    }
227
228    Ok(())
229}
230
231/// Extract the JSON-RPC method from a message.
232fn extract_method(data: &[u8]) -> Option<String> {
233    let json: serde_json::Value = serde_json::from_slice(data).ok()?;
234    json.get("method")
235        .and_then(|m| m.as_str())
236        .map(|s| s.to_string())
237}
238
239/// Create a JSON-RPC error response for blocked messages.
240fn create_error_response(findings: &[crate::rules::Finding]) -> String {
241    let messages: Vec<String> = findings.iter().map(|f| f.message.clone()).collect();
242    let error_msg = if messages.is_empty() {
243        "Request blocked by security policy".to_string()
244    } else {
245        format!("Request blocked: {}", messages.join("; "))
246    };
247
248    serde_json::json!({
249        "jsonrpc": "2.0",
250        "error": {
251            "code": -32600,
252            "message": error_msg
253        },
254        "id": null
255    })
256    .to_string()
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::proxy::ProxyConfig;
263    use crate::test_utils::fixtures::create_finding;
264
265    #[test]
266    fn test_extract_method() {
267        let data = br#"{"jsonrpc":"2.0","method":"tools/call","id":1}"#;
268        let method = extract_method(data);
269        assert_eq!(method, Some("tools/call".to_string()));
270    }
271
272    #[test]
273    fn test_extract_method_no_method() {
274        let data = br#"{"jsonrpc":"2.0","result":{},"id":1}"#;
275        let method = extract_method(data);
276        assert!(method.is_none());
277    }
278
279    #[test]
280    fn test_extract_method_invalid_json() {
281        let data = b"not valid json";
282        let method = extract_method(data);
283        assert!(method.is_none());
284    }
285
286    #[test]
287    fn test_extract_method_method_not_string() {
288        let data = br#"{"jsonrpc":"2.0","method":123,"id":1}"#;
289        let method = extract_method(data);
290        assert!(method.is_none());
291    }
292
293    #[test]
294    fn test_create_error_response() {
295        let findings = vec![];
296        let response = create_error_response(&findings);
297
298        assert!(response.contains("blocked by security policy"));
299        assert!(response.contains("-32600"));
300    }
301
302    #[test]
303    fn test_create_error_response_with_findings() {
304        use crate::rules::{Category, Severity};
305
306        let findings = vec![
307            create_finding(
308                "EX-001",
309                Severity::High,
310                Category::Exfiltration,
311                "test",
312                "test.md",
313                1,
314            ),
315            create_finding(
316                "PI-001",
317                Severity::Medium,
318                Category::PromptInjection,
319                "test2",
320                "test.md",
321                2,
322            ),
323        ];
324
325        let response = create_error_response(&findings);
326
327        assert!(response.contains("Request blocked:"));
328        assert!(response.contains("test message"));
329        assert!(response.contains("-32600"));
330    }
331
332    #[test]
333    fn test_proxy_server_new() {
334        let config = ProxyConfig::default();
335        let server = ProxyServer::new(config);
336
337        assert!(server.is_ok());
338    }
339
340    #[test]
341    fn test_proxy_server_new_with_verbose() {
342        let config = ProxyConfig::default().with_verbose();
343        let server = ProxyServer::new(config);
344
345        assert!(server.is_ok());
346    }
347
348    #[test]
349    fn test_proxy_server_new_with_log_file() {
350        use tempfile::TempDir;
351
352        let temp_dir = TempDir::new().unwrap();
353        let log_path = temp_dir.path().join("proxy.log");
354
355        let config = ProxyConfig::default().with_log_file(log_path);
356        let server = ProxyServer::new(config);
357
358        assert!(server.is_ok());
359    }
360
361    #[test]
362    fn test_proxy_server_new_with_block_mode() {
363        use crate::Severity;
364
365        let config = ProxyConfig::default().with_block_mode(Severity::High);
366        let server = ProxyServer::new(config);
367
368        assert!(server.is_ok());
369    }
370}