1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use hyper::service::{make_service_fn, service_fn};
7use hyper::{Body, Client, Method, Request, Response, Server, StatusCode};
8use hyper_tls::HttpsConnector;
9use serde_json::Value;
10use tokio::sync::{Mutex, RwLock};
11use tracing::{info, warn};
12
13use crate::config::ConfigManager;
14use crate::redaction::{RedactionEngine, RedactionService};
15use crate::{Error, Result};
16
17#[derive(Debug, Clone)]
18struct SseAccumulator {
19 buffer: String,
20}
21
22pub struct ProxyServer {
23 port: u16,
24 request_count: Arc<Mutex<u64>>,
25 redaction_engine: RedactionEngine,
26 redaction_service: RedactionService,
27 config_manager: Arc<RwLock<ConfigManager>>,
28 client: Client<HttpsConnector<hyper::client::HttpConnector>>,
29 sse_content_accumulators: Arc<RwLock<HashMap<u64, SseAccumulator>>>,
30}
31
32impl ProxyServer {
33 pub async fn new(port: u16, config_path: Option<String>, redaction_api_url: Option<String>) -> Result<Self> {
34 let https = HttpsConnector::new();
35 let client = Client::builder().build::<_, hyper::Body>(https);
36
37 let mut config_manager = ConfigManager::new_with_path(config_path);
38 if let Err(e) = config_manager.load_config().await {
39 warn!("Warning: Could not load config file, using defaults: {}", e);
40 }
41
42 Ok(Self {
43 port,
44 request_count: Arc::new(Mutex::new(0)),
45 redaction_engine: RedactionEngine::new(),
46 redaction_service: RedactionService::new(redaction_api_url),
47 config_manager: Arc::new(RwLock::new(config_manager)),
48 client,
49 sse_content_accumulators: Arc::new(RwLock::new(HashMap::new())),
50 })
51 }
52
53 pub async fn start(&self) -> Result<()> {
54 let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
55
56 let server_clone = Arc::new(self.clone());
57 let make_svc = make_service_fn(move |_conn| {
58 let server = Arc::clone(&server_clone);
59 async move {
60 Ok::<_, Infallible>(service_fn(move |req| {
61 let server = Arc::clone(&server);
62 async move { server.handle_request(req).await }
63 }))
64 }
65 });
66
67 let server = Server::bind(&addr).serve(make_svc);
68 info!("Proxy server listening on http://{}", addr);
69
70 if let Err(e) = server.await {
71 return Err(Error::Server(e.to_string()));
72 }
73
74 Ok(())
75 }
76
77 pub async fn stop(&self) {
78 info!("Stopping proxy server...");
79 }
80
81 async fn handle_request(&self, req: Request<Body>) -> std::result::Result<Response<Body>, Infallible> {
82 let result = self.handle_request_inner(req).await;
83 Ok(result.unwrap_or_else(|e| {
84 Response::builder()
85 .status(StatusCode::INTERNAL_SERVER_ERROR)
86 .body(Body::from(format!("Internal Server Error: {}", e)))
87 .unwrap()
88 }))
89 }
90
91 async fn redact_content(&self, content: &mut Value) -> Result<Value> {
92 match content {
93 Value::String(text) => {
94 let redacted = self.redaction_service.redact_user_prompt(text).await?;
95 Ok(Value::String(redacted))
96 }
97 Value::Array(content_blocks) => {
98 let mut redacted_blocks = Vec::new();
99 for block in content_blocks.iter() {
100 if let Some(block_obj) = block.as_object() {
101 if let Some(block_type) = block_obj.get("type") {
102 if block_type == "text" {
103 if let Some(text) = block_obj.get("text").and_then(|t| t.as_str()) {
104 let redacted_text = self.redaction_service.redact_user_prompt(text).await?;
105 let mut new_block = block_obj.clone();
106 new_block.insert("text".to_string(), Value::String(redacted_text));
107 redacted_blocks.push(Value::Object(new_block));
108 } else {
109 redacted_blocks.push(block.clone());
110 }
111 } else {
112 redacted_blocks.push(block.clone());
114 }
115 } else {
116 redacted_blocks.push(block.clone());
117 }
118 } else {
119 redacted_blocks.push(block.clone());
120 }
121 }
122 Ok(Value::Array(redacted_blocks))
123 }
124 _ => Ok(content.clone()),
125 }
126 }
127
128 async fn handle_request_inner(&self, req: Request<Body>) -> Result<Response<Body>> {
129 if req.uri().path() == "/health" && req.method() == Method::GET {
131 let request_count = self.request_count.lock().await;
132 let count = *request_count;
133 drop(request_count);
134
135 let health_data = serde_json::json!({
136 "status": "healthy",
137 "uptime": std::time::SystemTime::now()
138 .duration_since(std::time::UNIX_EPOCH)
139 .unwrap()
140 .as_secs(),
141 "timestamp": chrono::Utc::now(),
142 "requestCount": count
143 });
144
145 return Ok(Response::builder()
146 .status(StatusCode::OK)
147 .header("content-type", "application/json")
148 .body(Body::from(health_data.to_string()))?);
149 }
150
151 let mut request_count = self.request_count.lock().await;
152 *request_count += 1;
153 let request_id = *request_count;
154 drop(request_count);
155
156 let (parts, body) = req.into_parts();
158 let body_bytes = hyper::body::to_bytes(body).await?;
159 let request_body = String::from_utf8_lossy(&body_bytes);
160
161
162 self.process_request_with_config(request_body.to_string(), parts, request_id).await
163 }
164
165 async fn process_request_with_config(
166 &self,
167 request_body: String,
168 parts: http::request::Parts,
169 request_id: u64,
170 ) -> Result<Response<Body>> {
171 let mut model_name = None;
173 let processed_request_body = if !request_body.is_empty() {
174 if let Ok(mut parsed_body) = serde_json::from_str::<Value>(&request_body) {
175 if let Some(model) = parsed_body.get("model") {
176 if let Some(model_str) = model.as_str() {
177 model_name = Some(model_str.to_string());
178 }
179 }
180
181 if let Some(messages) = parsed_body.get_mut("messages") {
183 if let Some(messages_array) = messages.as_array_mut() {
184 for message in messages_array.iter_mut() {
185 if let Some(message_obj) = message.as_object_mut() {
186 if let Some(role) = message_obj.get("role") {
187 if role == "user" {
188 if let Some(content) = message_obj.get_mut("content") {
189 match self.redact_content(content).await {
190 Ok(redacted_content) => {
191 *content = redacted_content;
192 }
193 Err(e) => {
194 warn!("Failed to redact user message: {}", e);
195 }
197 }
198 }
199 }
200 }
201 }
202 }
203 }
204 }
205
206 serde_json::to_string(&parsed_body).unwrap_or(request_body)
207 } else {
208 request_body
209 }
210 } else {
211 request_body
212 };
213
214 let target_url = if parts.uri.path().starts_with('/') {
216 let model = model_name.ok_or_else(|| {
218 Error::Server("Bad Request: Model must be specified in request body".to_string())
219 })?;
220
221 let config_manager = self.config_manager.read().await;
222 let model_config = config_manager.get_model_config(&model);
223 let base_url = if model_config.api_base.ends_with('/') {
224 model_config.api_base
225 } else {
226 format!("{}/", model_config.api_base)
227 };
228
229
230 let path = if parts.uri.path().starts_with('/') {
231 &parts.uri.path()[1..]
232 } else {
233 parts.uri.path()
234 };
235
236 let final_url = format!("{}{}{}", base_url, path,
237 parts.uri.query().map_or(String::new(), |q| format!("?{}", q)));
238
239 final_url
240 } else {
241 parts.uri.to_string()
243 };
244
245 let reqwest_method = match parts.method.as_str() {
247 "GET" => reqwest::Method::GET,
248 "POST" => reqwest::Method::POST,
249 "PUT" => reqwest::Method::PUT,
250 "DELETE" => reqwest::Method::DELETE,
251 "PATCH" => reqwest::Method::PATCH,
252 "HEAD" => reqwest::Method::HEAD,
253 "OPTIONS" => reqwest::Method::OPTIONS,
254 _ => reqwest::Method::GET, };
256
257 let mut req_builder = reqwest::Client::new()
258 .request(reqwest_method, &target_url);
259
260 for (name, value) in parts.headers.iter() {
262 if !matches!(name.as_str(), "host" | "proxy-connection" | "proxy-authorization") {
263 if let Ok(header_value) = value.to_str() {
264 req_builder = req_builder.header(name.as_str(), header_value);
265 }
266 }
267 }
268
269 if !processed_request_body.is_empty() {
271 req_builder = req_builder.body(processed_request_body.clone());
272 }
273
274 let proxy_response = req_builder.send().await?;
276
277
278 let is_sse = proxy_response
280 .headers()
281 .get("content-type")
282 .and_then(|v| v.to_str().ok())
283 .map_or(false, |ct| ct.contains("text/event-stream"));
284
285 if is_sse {
286 self.handle_sse_response(proxy_response, &target_url, request_id).await
287 } else {
288 let status = proxy_response.status();
290 let response_body = proxy_response.text().await?;
291 let redacted_response = self.redaction_engine.redact_sensitive_content(&response_body);
292
293
294 Ok(Response::builder()
295 .status(status)
296 .header("content-type", "application/json")
297 .body(Body::from(redacted_response))?)
298 }
299 }
300
301 async fn handle_sse_response(
302 &self,
303 proxy_response: reqwest::Response,
304 target_url: &str,
305 request_id: u64,
306 ) -> Result<Response<Body>> {
307 use futures::StreamExt;
308
309
310 self.sse_content_accumulators.write().await.insert(
312 request_id,
313 SseAccumulator {
314 buffer: String::new(),
315 },
316 );
317
318 let is_openai_format = target_url.contains("openai.com") || target_url.contains("/responses");
320
321 let (tx, rx) = tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, Box<dyn std::error::Error + Send + Sync>>>(100);
323
324 let redaction_engine = self.redaction_engine.clone();
326 let sse_accumulators = Arc::clone(&self.sse_content_accumulators);
327
328 tokio::spawn(async move {
330 let mut event_buffer = String::new();
331 let mut stream = proxy_response.bytes_stream();
332
333 while let Some(chunk_result) = stream.next().await {
334 match chunk_result {
335 Ok(chunk) => {
336 let chunk_str = String::from_utf8_lossy(&chunk);
337 event_buffer.push_str(&chunk_str);
338
339 let buffer_clone = event_buffer.clone();
341 let events: Vec<&str> = buffer_clone.split("\n\n").collect();
342
343 if events.len() > 1 {
344 event_buffer = events.last().unwrap_or(&"").to_string();
346
347 for event_data in &events[..events.len() - 1] {
349 if !event_data.trim().is_empty() {
350 if let Some(processed_event) = Self::process_sse_event(
351 event_data,
352 request_id,
353 is_openai_format,
354 &redaction_engine,
355 &sse_accumulators,
356 ).await {
357 let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
358 if tx.send(Ok(event_bytes)).await.is_err() {
359 break;
360 }
361 }
362 }
363 }
364 }
365 }
366 Err(e) => {
367 let _ = tx.send(Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)).await;
368 break;
369 }
370 }
371 }
372
373 if !event_buffer.trim().is_empty() {
375 if let Some(processed_event) = Self::process_sse_event(
376 &event_buffer,
377 request_id,
378 is_openai_format,
379 &redaction_engine,
380 &sse_accumulators,
381 ).await {
382 let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
383 let _ = tx.send(Ok(event_bytes)).await;
384 }
385 }
386
387 sse_accumulators.write().await.remove(&request_id);
389 });
390
391 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
393 let body = Body::wrap_stream(stream.map(|item| {
394 item.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
395 }));
396
397 Ok(Response::builder()
398 .status(StatusCode::OK)
399 .header("content-type", "text/event-stream")
400 .header("cache-control", "no-cache")
401 .header("connection", "keep-alive")
402 .body(body)?)
403 }
404
405 async fn process_sse_event(
406 event_data: &str,
407 request_id: u64,
408 is_openai_format: bool,
409 redaction_engine: &RedactionEngine,
410 sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
411 ) -> Option<String> {
412 let lines: Vec<&str> = event_data.split('\n').collect();
413 let mut event_type = String::new();
414 let mut event_data_json: Option<serde_json::Value> = None;
415
416 for line in lines {
418 if let Some(event_value) = line.strip_prefix("event: ") {
419 event_type = event_value.to_string();
420 } else if let Some(data_value) = line.strip_prefix("data: ") {
421 if let Ok(data) = serde_json::from_str::<serde_json::Value>(data_value) {
422 event_data_json = Some(data);
423 } else {
424 if data_value == "[DONE]" {
426 return Some(event_data.to_string());
427 }
428 }
429 }
430 }
431
432 let mut should_redact_and_forward = false;
434 let mut text_content = String::new();
435
436 if let Some(data) = &event_data_json {
437
438 if is_openai_format {
439 if event_type == "response.output_text.delta" {
441 if let Some(delta) = data.get("delta").and_then(|d| d.as_str()) {
442 text_content = delta.to_string();
443 should_redact_and_forward = true;
444 }
445 } else if event_type == "response.output_text.done" {
446 if let Some(text) = data.get("text").and_then(|t| t.as_str()) {
447 text_content = text.to_string();
448 should_redact_and_forward = true;
449 }
450 }
451 } else {
452 if event_type == "content_block_delta" {
454 if let Some(delta) = data.get("delta") {
455 if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
456 text_content = text.to_string();
457 should_redact_and_forward = true;
458 }
459 }
460 }
461 }
462
463 if !should_redact_and_forward {
464 }
465 }
466
467 if should_redact_and_forward && !text_content.is_empty() {
468
469 let redacted_content = if is_openai_format && event_type == "response.output_text.done" {
471 let redacted = redaction_engine.redact_sensitive_content(&text_content);
473 Some(redacted)
474 } else {
475 let result = Self::process_chunk_with_buffer(&text_content, request_id, redaction_engine, sse_accumulators).await;
477 result
478 };
479
480 if let Some(redacted) = redacted_content {
481 if let Some(mut data) = event_data_json {
482 if is_openai_format {
483 if event_type == "response.output_text.delta" {
485 data["delta"] = serde_json::Value::String(redacted);
486 } else if event_type == "response.output_text.done" {
487 data["text"] = serde_json::Value::String(redacted);
488 }
489 } else {
490 if let Some(delta) = data.get_mut("delta") {
492 delta["text"] = serde_json::Value::String(redacted);
493 }
494 }
495
496 return Some(format!("event: {}\ndata: {}", event_type, data));
497 }
498 } else {
499 return None;
501 }
502 } else if event_type == "content_block_stop" && !is_openai_format {
503 if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
505 let final_event_data = serde_json::json!({
506 "type": "content_block_delta",
507 "index": 0,
508 "delta": {
509 "type": "text_delta",
510 "text": final_chunk
511 }
512 });
513 return Some(format!("event: content_block_delta\ndata: {}\n\nevent: content_block_stop\ndata: {}",
514 final_event_data, event_data_json.unwrap_or(serde_json::Value::Null)));
515 }
516 } else if is_openai_format && (event_type == "response.completed" || event_data.contains("[DONE]")) {
517 if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
519 let final_data = serde_json::json!({
520 "type": "response.output_text.delta",
521 "delta": final_chunk
522 });
523 return Some(format!("event: response.output_text.delta\ndata: {}\n\n{}",
524 final_data, event_data));
525 }
526 }
527
528 Some(event_data.to_string())
530 }
531
532 async fn process_chunk_with_buffer(
533 new_chunk: &str,
534 request_id: u64,
535 redaction_engine: &RedactionEngine,
536 sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
537 ) -> Option<String> {
538 let mut accumulators = sse_accumulators.write().await;
539 if let Some(accumulator) = accumulators.get_mut(&request_id) {
540 accumulator.buffer.push_str(new_chunk);
542
543 let buffer_copy = accumulator.buffer.clone();
545 let mut lines: Vec<&str> = buffer_copy.split('\n').collect();
546
547 accumulator.buffer = lines.pop().unwrap_or("").to_string();
549
550 if !lines.is_empty() {
552 let complete_lines = lines.join("\n") + "\n";
553 let redacted = redaction_engine.redact_sensitive_content(&complete_lines);
554 return Some(redacted);
555 }
556
557 return None;
559 }
560 None
561 }
562
563 async fn flush_buffer(
564 request_id: u64,
565 redaction_engine: &RedactionEngine,
566 sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
567 ) -> Option<String> {
568 let mut accumulators = sse_accumulators.write().await;
569 if let Some(accumulator) = accumulators.get_mut(&request_id) {
570 if accumulator.buffer.is_empty() {
571 return None;
572 }
573
574 let remaining = accumulator.buffer.clone();
576 accumulator.buffer.clear();
577
578 let redacted = redaction_engine.redact_sensitive_content(&remaining);
579 Some(redacted)
580 } else {
581 None
582 }
583 }
584}
585
586impl Clone for ProxyServer {
587 fn clone(&self) -> Self {
588 let https = HttpsConnector::new();
589 let client = Client::builder().build::<_, hyper::Body>(https);
590
591 Self {
592 port: self.port,
593 request_count: Arc::clone(&self.request_count),
594 redaction_engine: self.redaction_engine.clone(),
595 redaction_service: self.redaction_service.clone(),
596 config_manager: Arc::clone(&self.config_manager),
597 client,
598 sse_content_accumulators: Arc::clone(&self.sse_content_accumulators),
599 }
600 }
601}
602
603impl Clone for RedactionEngine {
604 fn clone(&self) -> Self {
605 RedactionEngine::new()
606 }
607}