1use super::{InterceptAction, MessageInterceptor, ProxyConfig, ProxyLogger};
4use std::sync::Arc;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{TcpListener, TcpStream};
7
8pub struct ProxyServer {
10 config: ProxyConfig,
11 interceptor: Arc<MessageInterceptor>,
12 logger: Arc<ProxyLogger>,
13}
14
15impl ProxyServer {
16 pub fn new(config: ProxyConfig) -> std::io::Result<Self> {
18 let interceptor = Arc::new(MessageInterceptor::new(
19 config.block_mode,
20 config.min_block_severity,
21 ));
22
23 let logger = Arc::new(ProxyLogger::new(
24 config.log_file.as_deref(),
25 config.verbose,
26 )?);
27
28 Ok(Self {
29 config,
30 interceptor,
31 logger,
32 })
33 }
34
35 pub async fn run(&self) -> std::io::Result<()> {
37 let listener = TcpListener::bind(self.config.listen_addr).await?;
38
39 eprintln!(
40 "Proxy listening on {} -> {}",
41 self.config.listen_addr, self.config.target_addr
42 );
43
44 if self.config.block_mode {
45 eprintln!(
46 "Block mode enabled (min severity: {:?})",
47 self.config.min_block_severity
48 );
49 } else {
50 eprintln!("Log-only mode (no blocking)");
51 }
52
53 loop {
54 let (client_stream, client_addr) = listener.accept().await?;
55
56 let target_addr = self.config.target_addr;
57 let interceptor = Arc::clone(&self.interceptor);
58 let logger = Arc::clone(&self.logger);
59 let block_mode = self.config.block_mode;
60
61 tokio::spawn(async move {
62 if let Err(e) = handle_connection(
63 client_stream,
64 target_addr,
65 interceptor,
66 logger,
67 block_mode,
68 client_addr.to_string(),
69 )
70 .await
71 {
72 eprintln!("Connection error: {}", e);
73 }
74 });
75 }
76 }
77}
78
79async fn handle_connection(
81 client: TcpStream,
82 target_addr: std::net::SocketAddr,
83 interceptor: Arc<MessageInterceptor>,
84 logger: Arc<ProxyLogger>,
85 block_mode: bool,
86 client_addr: String,
87) -> std::io::Result<()> {
88 let target = TcpStream::connect(target_addr).await?;
90
91 let (client_read, client_write) = client.into_split();
93 let (target_read, target_write) = target.into_split();
94
95 let interceptor_req = Arc::clone(&interceptor);
96 let interceptor_resp = Arc::clone(&interceptor);
97 let logger_req = Arc::clone(&logger);
98 let logger_resp = Arc::clone(&logger);
99 let client_addr_req = client_addr.clone();
100 let client_addr_resp = client_addr;
101
102 let client_write = Arc::new(tokio::sync::Mutex::new(client_write));
104 let target_write = Arc::new(tokio::sync::Mutex::new(target_write));
105
106 let client_write_clone = Arc::clone(&client_write);
107
108 let client_to_target = async move {
110 let mut client_read = client_read;
111 let mut buf = vec![0u8; 65536];
112 loop {
113 let n = client_read.read(&mut buf).await?;
114 if n == 0 {
115 break;
116 }
117
118 let data = &buf[..n];
119
120 let action = interceptor_req.intercept(data);
122 let method = extract_method(data);
123
124 match &action {
125 InterceptAction::Allow => {
126 target_write.lock().await.write_all(data).await?;
127 }
128 InterceptAction::Log(findings) => {
129 logger_req.log_request(
130 method.as_deref(),
131 findings,
132 "logged",
133 Some(&client_addr_req),
134 n,
135 );
136 target_write.lock().await.write_all(data).await?;
137 }
138 InterceptAction::Block(findings) => {
139 logger_req.log_request(
140 method.as_deref(),
141 findings,
142 "blocked",
143 Some(&client_addr_req),
144 n,
145 );
146
147 if block_mode {
148 let error_response = create_error_response(findings);
150 client_write
151 .lock()
152 .await
153 .write_all(error_response.as_bytes())
154 .await?;
155 break;
156 } else {
157 target_write.lock().await.write_all(data).await?;
158 }
159 }
160 }
161 }
162 Ok::<_, std::io::Error>(())
163 };
164
165 let target_to_client = async move {
167 let mut target_read = target_read;
168 let mut buf = vec![0u8; 65536];
169 loop {
170 let n = target_read.read(&mut buf).await?;
171 if n == 0 {
172 break;
173 }
174
175 let data = &buf[..n];
176
177 let action = interceptor_resp.intercept(data);
179 let method = extract_method(data);
180
181 match &action {
182 InterceptAction::Allow => {
183 client_write_clone.lock().await.write_all(data).await?;
184 }
185 InterceptAction::Log(findings) => {
186 logger_resp.log_response(
187 method.as_deref(),
188 findings,
189 "logged",
190 Some(&client_addr_resp),
191 n,
192 );
193 client_write_clone.lock().await.write_all(data).await?;
194 }
195 InterceptAction::Block(findings) => {
196 logger_resp.log_response(
197 method.as_deref(),
198 findings,
199 "blocked",
200 Some(&client_addr_resp),
201 n,
202 );
203
204 if block_mode {
205 let error_response = create_error_response(findings);
207 client_write_clone
208 .lock()
209 .await
210 .write_all(error_response.as_bytes())
211 .await?;
212 break;
213 } else {
214 client_write_clone.lock().await.write_all(data).await?;
215 }
216 }
217 }
218 }
219 Ok::<_, std::io::Error>(())
220 };
221
222 tokio::select! {
224 result = client_to_target => result?,
225 result = target_to_client => result?,
226 }
227
228 Ok(())
229}
230
231fn extract_method(data: &[u8]) -> Option<String> {
233 let json: serde_json::Value = serde_json::from_slice(data).ok()?;
234 json.get("method")
235 .and_then(|m| m.as_str())
236 .map(|s| s.to_string())
237}
238
239fn create_error_response(findings: &[crate::rules::Finding]) -> String {
241 let messages: Vec<String> = findings.iter().map(|f| f.message.clone()).collect();
242 let error_msg = if messages.is_empty() {
243 "Request blocked by security policy".to_string()
244 } else {
245 format!("Request blocked: {}", messages.join("; "))
246 };
247
248 serde_json::json!({
249 "jsonrpc": "2.0",
250 "error": {
251 "code": -32600,
252 "message": error_msg
253 },
254 "id": null
255 })
256 .to_string()
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::proxy::ProxyConfig;
263 use crate::test_utils::fixtures::create_finding;
264
265 #[test]
266 fn test_extract_method() {
267 let data = br#"{"jsonrpc":"2.0","method":"tools/call","id":1}"#;
268 let method = extract_method(data);
269 assert_eq!(method, Some("tools/call".to_string()));
270 }
271
272 #[test]
273 fn test_extract_method_no_method() {
274 let data = br#"{"jsonrpc":"2.0","result":{},"id":1}"#;
275 let method = extract_method(data);
276 assert!(method.is_none());
277 }
278
279 #[test]
280 fn test_extract_method_invalid_json() {
281 let data = b"not valid json";
282 let method = extract_method(data);
283 assert!(method.is_none());
284 }
285
286 #[test]
287 fn test_extract_method_method_not_string() {
288 let data = br#"{"jsonrpc":"2.0","method":123,"id":1}"#;
289 let method = extract_method(data);
290 assert!(method.is_none());
291 }
292
293 #[test]
294 fn test_create_error_response() {
295 let findings = vec![];
296 let response = create_error_response(&findings);
297
298 assert!(response.contains("blocked by security policy"));
299 assert!(response.contains("-32600"));
300 }
301
302 #[test]
303 fn test_create_error_response_with_findings() {
304 use crate::rules::{Category, Severity};
305
306 let findings = vec![
307 create_finding(
308 "EX-001",
309 Severity::High,
310 Category::Exfiltration,
311 "test",
312 "test.md",
313 1,
314 ),
315 create_finding(
316 "PI-001",
317 Severity::Medium,
318 Category::PromptInjection,
319 "test2",
320 "test.md",
321 2,
322 ),
323 ];
324
325 let response = create_error_response(&findings);
326
327 assert!(response.contains("Request blocked:"));
328 assert!(response.contains("test message"));
329 assert!(response.contains("-32600"));
330 }
331
332 #[test]
333 fn test_proxy_server_new() {
334 let config = ProxyConfig::default();
335 let server = ProxyServer::new(config);
336
337 assert!(server.is_ok());
338 }
339
340 #[test]
341 fn test_proxy_server_new_with_verbose() {
342 let config = ProxyConfig::default().with_verbose();
343 let server = ProxyServer::new(config);
344
345 assert!(server.is_ok());
346 }
347
348 #[test]
349 fn test_proxy_server_new_with_log_file() {
350 use tempfile::TempDir;
351
352 let temp_dir = TempDir::new().unwrap();
353 let log_path = temp_dir.path().join("proxy.log");
354
355 let config = ProxyConfig::default().with_log_file(log_path);
356 let server = ProxyServer::new(config);
357
358 assert!(server.is_ok());
359 }
360
361 #[test]
362 fn test_proxy_server_new_with_block_mode() {
363 use crate::Severity;
364
365 let config = ProxyConfig::default().with_block_mode(Severity::High);
366 let server = ProxyServer::new(config);
367
368 assert!(server.is_ok());
369 }
370}