1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Instant;
6
7use hyper::service::{make_service_fn, service_fn};
8use hyper::{Body, Method, Request, Response, Server, StatusCode};
9use serde_json::Value;
10use tokio::sync::{Mutex, RwLock};
11use tracing::{info, warn};
12use uuid::Uuid;
13use redis::{AsyncCommands, aio::ConnectionManager};
14
15use crate::config::ConfigManager;
16use crate::redaction::{RedactionEngine, RedactionService};
17use crate::{Error, Result};
18
19#[derive(Debug, Clone)]
20struct SseAccumulator {
21 buffer: String,
22}
23
24pub struct ProxyServer {
25 port: u16,
26 request_count: Arc<Mutex<u64>>,
27 redaction_engine: RedactionEngine,
28 redaction_service: RedactionService,
29 config_manager: Arc<RwLock<ConfigManager>>,
30 sse_content_accumulators: Arc<RwLock<HashMap<u64, SseAccumulator>>>,
31 redis_manager: Option<ConnectionManager>,
32 telemetry_webhook_config: Arc<RwLock<Option<crate::config::TelemetryWebhookConfig>>>,
33}
34
35impl ProxyServer {
36 pub async fn new(port: u16, config_path: Option<String>, redaction_api_url: Option<String>) -> Result<Self> {
37
38 let mut config_manager = ConfigManager::new_with_path(config_path.clone());
39 if let Err(e) = config_manager.load_config().await {
40 warn!(
41 event_type = "config_load_failed",
42 config_path = config_path.as_deref().unwrap_or("default"),
43 error = %e,
44 "Could not load config file, using defaults"
45 );
46 }
47
48 let redis_manager = if std::env::var("MULTITENANT").unwrap_or_default() == "true" {
50 let redis_url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379".to_string());
51 match redis::Client::open(redis_url.as_str()) {
52 Ok(client) => {
53 match ConnectionManager::new(client).await {
54 Ok(manager) => {
55 println!("[REDIS] Connection manager initialized successfully");
56 Some(manager)
57 },
58 Err(e) => {
59 eprintln!("[REDIS] Failed to create connection manager: {}", e);
60 None
61 }
62 }
63 },
64 Err(e) => {
65 eprintln!("[REDIS] Failed to initialize Redis client: {}", e);
66 None
67 }
68 }
69 } else {
70 None
71 };
72
73 Ok(Self {
74 port,
75 request_count: Arc::new(Mutex::new(0)),
76 redaction_engine: RedactionEngine::new(),
77 redaction_service: RedactionService::new(redaction_api_url),
78 config_manager: Arc::new(RwLock::new(config_manager)),
79 sse_content_accumulators: Arc::new(RwLock::new(HashMap::new())),
80 redis_manager,
81 telemetry_webhook_config: Arc::new(RwLock::new(None)),
82 })
83 }
84
85 pub async fn start(&self) -> Result<()> {
86 let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
87
88 info!(
89 event_type = "server_start",
90 service = "ai-firewall-rust",
91 version = "0.0.1",
92 port = self.port,
93 address = %addr,
94 "Proxy server starting"
95 );
96
97 let server_clone = Arc::new(self.clone());
98 let make_svc = make_service_fn(move |_conn| {
99 let server = Arc::clone(&server_clone);
100 async move {
101 Ok::<_, Infallible>(service_fn(move |req| {
102 let server = Arc::clone(&server);
103 async move { server.handle_request(req).await }
104 }))
105 }
106 });
107
108 let server = Server::bind(&addr).serve(make_svc);
109 info!(
110 event_type = "server_listening",
111 address = %addr,
112 "Proxy server listening"
113 );
114
115 if let Err(e) = server.await {
116 return Err(Error::Server(e.to_string()));
117 }
118
119 Ok(())
120 }
121
122 pub async fn stop(&self) {
123 info!(
124 event_type = "server_stop",
125 "Stopping proxy server"
126 );
127 }
128
129 async fn handle_request(&self, req: Request<Body>) -> std::result::Result<Response<Body>, Infallible> {
130 let result = self.handle_request_inner(req).await;
131 Ok(result.unwrap_or_else(|e| {
132 Response::builder()
133 .status(StatusCode::INTERNAL_SERVER_ERROR)
134 .body(Body::from(format!("Internal Server Error: {}", e)))
135 .unwrap()
136 }))
137 }
138
139 async fn get_config_for_multitenant(&self, config_id: &str, model_name: Option<&str>) -> Result<crate::config::ResolvedModelConfig> {
140 println!("[MULTITENANT] Config ID: {}, Model: {:?}", config_id, model_name);
141
142 let cache_key = format!("config:{}", config_id);
143 let cache_ttl: u64 = std::env::var("CONFIG_CACHE_TTL")
144 .unwrap_or_else(|_| "3600".to_string())
145 .parse()
146 .unwrap_or(3600);
147
148 let mut all_configs: Option<Vec<serde_json::Value>> = None;
149
150 if let Some(redis_manager) = &self.redis_manager {
152 let mut conn = redis_manager.clone();
153 match conn.get::<_, String>(&cache_key).await {
154 Ok(cached_data) => {
155 println!("[MULTITENANT] Cache HIT for config ID: {}", config_id);
156 if let Ok(cached_config) = serde_json::from_str::<serde_json::Value>(&cached_data) {
157 if let Some(models) = cached_config.get("models").and_then(|m| m.as_array()) {
159 all_configs = Some(models.clone());
160 }
161
162 if let Some(telemetry_webhook) = cached_config.get("telemetry_webhook") {
164 if let Ok(webhook_config) = serde_json::from_value::<crate::config::TelemetryWebhookConfig>(telemetry_webhook.clone()) {
165 *self.telemetry_webhook_config.write().await = Some(webhook_config);
166 }
167 }
168 }
169 },
170 Err(_) => {
171 println!("[MULTITENANT] Cache MISS for config ID: {}", config_id);
172 }
173 }
174 }
175
176 if all_configs.is_none() {
178 let config_api_url = std::env::var("MULTITENANT_CONFIG_API_URL")
179 .map_err(|_| Error::Server("MULTITENANT_CONFIG_API_URL environment variable is required when MULTITENANT=true".to_string()))?;
180
181 let full_url = format!("{}{}", config_api_url, config_id);
182 println!("[MULTITENANT] Fetching config from: {}", full_url);
183
184 let response = reqwest::Client::new().get(&full_url).send().await
185 .map_err(|e| Error::Server(format!("Failed to fetch config for {}: {}", config_id, e)))?;
186
187 if !response.status().is_success() {
188 return Err(Error::Server(format!("Config API returned {}: {}", response.status(), response.status().canonical_reason().unwrap_or("Unknown"))));
189 }
190
191 let config_response = response.json::<serde_json::Value>().await
192 .map_err(|e| Error::Server(format!("Failed to parse config response for {}: {}", config_id, e)))?;
193
194 println!("[MULTITENANT] Received config response: {}", serde_json::to_string_pretty(&config_response).unwrap_or_default());
195
196 let configs_array = config_response.get("models")
198 .and_then(|m| m.as_array())
199 .ok_or_else(|| Error::Server("Config API returned invalid models array".to_string()))?
200 .clone();
201
202 if let Some(telemetry_webhook) = config_response.get("telemetry_webhook") {
204 if let Ok(webhook_config) = serde_json::from_value::<crate::config::TelemetryWebhookConfig>(telemetry_webhook.clone()) {
205 *self.telemetry_webhook_config.write().await = Some(webhook_config);
206 println!("[MULTITENANT] Stored telemetry webhook config");
207 } else {
208 eprintln!("[MULTITENANT] Failed to parse telemetry webhook config");
209 }
210 }
211
212 if configs_array.is_empty() {
213 return Err(Error::Server("Config API returned empty models array".to_string()));
214 }
215
216 all_configs = Some(configs_array);
217
218 if let Some(redis_manager) = &self.redis_manager {
220 let mut conn = redis_manager.clone();
221 if let Ok(config_json) = serde_json::to_string(&config_response) {
222 match conn.set_ex::<_, _, ()>(&cache_key, &config_json, cache_ttl as usize).await {
223 Ok(_) => println!("[MULTITENANT] Cached complete config for {} with TTL {} seconds", config_id, cache_ttl),
224 Err(e) => eprintln!("[REDIS] Failed to cache config: {}", e),
225 }
226 }
227 }
228 }
229
230 if let (Some(model_name), Some(all_configs)) = (model_name, &all_configs) {
232 for config in all_configs {
233 if let Some(config_model_name) = config.get("model_name").and_then(|v| v.as_str()) {
234 if config_model_name == model_name {
235 println!("[MULTITENANT] Found specific config for model: {}", model_name);
236 return Ok(crate::config::ResolvedModelConfig {
237 provider: config.get("provider").and_then(|v| v.as_str()).unwrap_or("openai").to_string(),
238 api_base: config.get("api_base").and_then(|v| v.as_str()).unwrap_or("https://api.openai.com/").to_string(),
239 model_name: config.get("model_name").and_then(|v| v.as_str()).unwrap_or(model_name).to_string(),
240 });
241 }
242 }
243 }
244
245 return Err(Error::Server(format!("Model '{}' not found in configuration for config_id: {}", model_name, config_id)));
247 }
248
249 Err(Error::Server("Model must be specified in request body for multitenant configuration".to_string()))
251 }
252
253 fn send_to_telemetry_webhook(&self, request_data: serde_json::Value) {
255 let webhook_config = Arc::clone(&self.telemetry_webhook_config);
256 let config_manager = Arc::clone(&self.config_manager);
257
258 tokio::spawn(async move {
259 let telemetry_config = {
261 let is_multitenant = std::env::var("MULTITENANT").unwrap_or_default() == "true";
262
263 if is_multitenant {
264 let config = webhook_config.read().await;
266 config.clone()
267 } else {
268 let config_manager = config_manager.read().await;
270 config_manager.get_telemetry_webhook_config().cloned()
271 }
272 };
273
274 if let Some(config) = telemetry_config {
275 let payload = serde_json::json!({
276 "source": "superagent-proxy",
277 "event": request_data
278 });
279
280 let mut req_builder = reqwest::Client::new()
282 .post(&config.url)
283 .header("content-type", "application/json");
284
285 for (key, value) in &config.headers {
286 req_builder = req_builder.header(key, value);
287 }
288
289 match req_builder.json(&payload).send().await {
290 Ok(_) => {}, Err(e) => eprintln!("[TELEMETRY] Failed to send webhook: {}", e),
292 }
293 }
294 });
295 }
296
297 async fn handle_config_cache_update(&self, req: Request<Body>, config_id: &str) -> Result<Response<Body>> {
298 if config_id.is_empty() {
299 return Ok(Response::builder()
300 .status(StatusCode::BAD_REQUEST)
301 .header("content-type", "application/json")
302 .body(Body::from(r#"{"error": "Config ID is required"}"#))?);
303 }
304
305 let body_bytes = hyper::body::to_bytes(req.into_body()).await?;
307 let request_body = String::from_utf8_lossy(&body_bytes);
308
309 if request_body.is_empty() {
310 return Ok(Response::builder()
311 .status(StatusCode::BAD_REQUEST)
312 .header("content-type", "application/json")
313 .body(Body::from(r#"{"error": "Request body with config array is required"}"#))?);
314 }
315
316 let new_config: serde_json::Value = match serde_json::from_str(&request_body) {
318 Ok(config) => config,
319 Err(e) => {
320 eprintln!("[CONFIG-CACHE] Error parsing config for {}: {}", config_id, e);
321 return Ok(Response::builder()
322 .status(StatusCode::BAD_REQUEST)
323 .header("content-type", "application/json")
324 .body(Body::from(format!(r#"{{"error": "Invalid JSON", "details": "{}"}}"#, e)))?);
325 }
326 };
327
328 if !new_config.is_array() {
330 return Ok(Response::builder()
331 .status(StatusCode::BAD_REQUEST)
332 .header("content-type", "application/json")
333 .body(Body::from(r#"{"error": "Config must be an array"}"#))?);
334 }
335
336 let cache_key = format!("config:{}", config_id);
337 let cache_ttl: u64 = std::env::var("CONFIG_CACHE_TTL")
338 .unwrap_or_else(|_| "3600".to_string())
339 .parse()
340 .unwrap_or(3600);
341
342 if let Some(redis_manager) = &self.redis_manager {
344 let mut conn = redis_manager.clone();
345 let config_json = serde_json::to_string(&new_config).unwrap();
346 match conn.set_ex::<_, _, ()>(&cache_key, &config_json, cache_ttl as usize).await {
347 Ok(_) => {
348 println!("[CONFIG-CACHE] Updated cache for {} with TTL {} seconds", config_id, cache_ttl);
349 let models_count = new_config.as_array().unwrap().len();
350 let response = serde_json::json!({
351 "success": true,
352 "message": format!("Config cache updated for {}", config_id),
353 "models": models_count
354 });
355 return Ok(Response::builder()
356 .status(StatusCode::OK)
357 .header("content-type", "application/json")
358 .body(Body::from(response.to_string()))?);
359 },
360 Err(e) => {
361 eprintln!("[CONFIG-CACHE] Redis error updating cache for {}: {}", config_id, e);
362 return Ok(Response::builder()
363 .status(StatusCode::INTERNAL_SERVER_ERROR)
364 .header("content-type", "application/json")
365 .body(Body::from(format!(r#"{{"error": "Redis error", "details": "{}"}}"#, e)))?);
366 }
367 }
368 } else {
369 return Ok(Response::builder()
370 .status(StatusCode::SERVICE_UNAVAILABLE)
371 .header("content-type", "application/json")
372 .body(Body::from(r#"{"error": "Redis client not available"}"#))?);
373 }
374 }
375
376 async fn redact_content(&self, content: &mut Value, jailbreak_detected: &mut bool) -> Result<Value> {
377 match content {
378 Value::String(text) => {
379 let result = self.redaction_service.screen_user_prompt(text).await?;
380 if result.is_jailbreak {
381 *jailbreak_detected = true;
382 }
383 Ok(Value::String(result.content))
384 }
385 Value::Array(content_blocks) => {
386 let mut redacted_blocks = Vec::new();
387 for block in content_blocks.iter() {
388 if let Some(block_obj) = block.as_object() {
389 if let Some(block_type) = block_obj.get("type") {
390 if block_type == "text" {
391 if let Some(text) = block_obj.get("text").and_then(|t| t.as_str()) {
392 let result = self.redaction_service.screen_user_prompt(text).await?;
393 if result.is_jailbreak {
394 *jailbreak_detected = true;
395 }
396 let mut new_block = block_obj.clone();
397 new_block.insert("text".to_string(), Value::String(result.content));
398 redacted_blocks.push(Value::Object(new_block));
399 } else {
400 redacted_blocks.push(block.clone());
401 }
402 } else {
403 redacted_blocks.push(block.clone());
405 }
406 } else {
407 redacted_blocks.push(block.clone());
408 }
409 } else {
410 redacted_blocks.push(block.clone());
411 }
412 }
413 Ok(Value::Array(redacted_blocks))
414 }
415 _ => Ok(content.clone()),
416 }
417 }
418
419 async fn handle_request_inner(&self, req: Request<Body>) -> Result<Response<Body>> {
420 let start_time = Instant::now();
421 let trace_id = Uuid::new_v4().to_string();
422
423 if req.uri().path() == "/health" && req.method() == Method::GET {
425 info!(
426 event_type = "health_check",
427 trace_id = %trace_id,
428 "Health check requested"
429 );
430
431 let request_count = self.request_count.lock().await;
432 let count = *request_count;
433 drop(request_count);
434
435 let health_data = serde_json::json!({
436 "status": "healthy",
437 "uptime": std::time::SystemTime::now()
438 .duration_since(std::time::UNIX_EPOCH)
439 .unwrap()
440 .as_secs(),
441 "timestamp": chrono::Utc::now(),
442 "requestCount": count
443 });
444
445 return Ok(Response::builder()
446 .status(StatusCode::OK)
447 .header("content-type", "application/json")
448 .body(Body::from(health_data.to_string()))?);
449 }
450
451 if req.uri().path().starts_with("/config/") && req.method() == Method::POST {
453 let config_id = req.uri().path().strip_prefix("/config/").unwrap_or("").to_string();
454 return self.handle_config_cache_update(req, &config_id).await;
455 }
456
457 let mut request_count = self.request_count.lock().await;
458 *request_count += 1;
459 let request_id = *request_count;
460 drop(request_count);
461
462 let (parts, body) = req.into_parts();
464 let body_bytes = hyper::body::to_bytes(body).await?;
465 let request_body = String::from_utf8_lossy(&body_bytes);
466
467 self.process_request_with_config(request_body.to_string(), parts, request_id, start_time, trace_id).await
468 }
469
470 async fn process_request_with_config(
471 &self,
472 request_body: String,
473 parts: http::request::Parts,
474 request_id: u64,
475 start_time: Instant,
476 trace_id: String,
477 ) -> Result<Response<Body>> {
478 let mut model_name = None;
480 let mut input_redacted = false;
481 let mut jailbreak_detected = false; let processed_request_body = if !request_body.is_empty() {
483 if let Ok(mut parsed_body) = serde_json::from_str::<Value>(&request_body) {
484 if let Some(model) = parsed_body.get("model") {
485 if let Some(model_str) = model.as_str() {
486 model_name = Some(model_str.to_string());
487 }
488 }
489
490 if let Some(messages) = parsed_body.get_mut("messages") {
492 info!("Found messages in request body, checking for user messages to redact");
493 if let Some(messages_array) = messages.as_array_mut() {
494 for message in messages_array.iter_mut() {
495 if let Some(message_obj) = message.as_object_mut() {
496 if let Some(role) = message_obj.get("role") {
497 if role == "user" {
498 info!("Found user message, attempting redaction");
499 if let Some(content) = message_obj.get_mut("content") {
500 let original_content = content.clone();
502 match self.redact_content(content, &mut jailbreak_detected).await {
503 Ok(redacted_content) => {
504 if redacted_content != original_content {
505 info!("User message content was modified by redaction");
506 input_redacted = true;
507 } else {
508 info!("User message content unchanged after redaction");
509 }
510 *content = redacted_content;
511
512 if jailbreak_detected {
514 message_obj.insert("role".to_string(), Value::String("user".to_string()));
515 info!("Changed message role from 'user' to 'assistant' due to jailbreak detection");
516 }
517 }
518 Err(e) => {
519 warn!("Failed to redact user message: {}", e);
520 }
522 }
523 } else {
524 info!("User message has no content field");
525 }
526 }
527 }
528 }
529 }
530 }
531 } else {
532 info!("No messages found in request body");
533 }
534
535 if input_redacted {
537 serde_json::to_string(&parsed_body).unwrap_or_else(|_| request_body.clone())
538 } else {
539 request_body.clone()
540 }
541 } else {
542 request_body.clone()
543 }
544 } else {
545 request_body.clone()
546 };
547
548 let target_url = if parts.uri.path().starts_with('/') {
550 let is_multitenant = std::env::var("MULTITENANT").unwrap_or_default() == "true";
552 let mut config_id: Option<String> = None;
553 let actual_path = if is_multitenant {
554 let path_segments: Vec<&str> = parts.uri.path().split('/').filter(|s| !s.is_empty()).collect();
555 if !path_segments.is_empty() {
556 config_id = Some(path_segments[0].to_string());
557 let rewritten_path = format!("/{}", path_segments[1..].join("/"));
558
559 println!("[MULTITENANT] Config ID: {}", config_id.as_ref().unwrap());
560 println!("[MULTITENANT] Rewritten path: {}", rewritten_path);
561
562 rewritten_path
563 } else {
564 parts.uri.path().to_string()
565 }
566 } else {
567 parts.uri.path().to_string()
568 };
569
570 let model_config = if let Some(ref config_id) = config_id {
571 self.get_config_for_multitenant(config_id, model_name.as_deref()).await?
572 } else {
573 let model = model_name.as_ref().ok_or_else(|| {
575 Error::Server("Bad Request: Model must be specified in request body".to_string())
576 })?;
577
578 let config_manager = self.config_manager.read().await;
579 let model_config = config_manager.get_model_config(&model);
580 crate::config::ResolvedModelConfig {
581 provider: model_config.provider,
582 api_base: model_config.api_base,
583 model_name: model_config.model_name,
584 }
585 };
586
587 let mut base_url = if model_config.api_base.ends_with('/') {
589 model_config.api_base
590 } else {
591 format!("{}/", model_config.api_base)
592 };
593
594 let mut path = if actual_path.starts_with('/') {
595 &actual_path[1..]
596 } else {
597 &actual_path
598 };
599
600 let base_ends_with_v1 = base_url.trim_end_matches('/').ends_with("/v1");
602 if base_ends_with_v1 && (path == "v1" || path.starts_with("v1/")) {
603 path = if path.len() > 2 { &path[3..] } else { "" };
605 if !base_url.ends_with('/') { base_url.push('/'); }
607 }
608
609 let final_url = format!(
610 "{}{}{}",
611 base_url,
612 path,
613 parts
614 .uri
615 .query()
616 .map_or(String::new(), |q| format!("?{}", q))
617 );
618
619 final_url
620 } else {
621 parts.uri.to_string()
623 };
624
625 let reqwest_method = match parts.method.as_str() {
627 "GET" => reqwest::Method::GET,
628 "POST" => reqwest::Method::POST,
629 "PUT" => reqwest::Method::PUT,
630 "DELETE" => reqwest::Method::DELETE,
631 "PATCH" => reqwest::Method::PATCH,
632 "HEAD" => reqwest::Method::HEAD,
633 "OPTIONS" => reqwest::Method::OPTIONS,
634 _ => reqwest::Method::GET, };
636
637 let mut req_builder = reqwest::Client::new()
638 .request(reqwest_method, &target_url);
639
640 for (name, value) in parts.headers.iter() {
642 if !matches!(name.as_str(), "host" | "proxy-connection" | "proxy-authorization" | "content-length") {
643 if let Ok(header_value) = value.to_str() {
644 req_builder = req_builder.header(name.as_str(), header_value);
645 }
646 }
647 }
648
649 if !processed_request_body.is_empty() {
651 req_builder = req_builder
653 .header("content-length", processed_request_body.len().to_string())
654 .body(processed_request_body.clone());
655 }
656
657 let proxy_response = req_builder.send().await?;
659 let response_status = proxy_response.status();
660
661 let is_sse = proxy_response
663 .headers()
664 .get("content-type")
665 .and_then(|v| v.to_str().ok())
666 .map_or(false, |ct| ct.contains("text/event-stream"));
667
668 let user_agent = parts.headers.get("user-agent")
670 .and_then(|v| v.to_str().ok())
671 .unwrap_or("")
672 .to_string();
673 let originator = parts.headers.get("originator")
674 .and_then(|v| v.to_str().ok())
675 .unwrap_or("")
676 .to_string();
677
678 let response = if is_sse {
679 self.handle_sse_response(proxy_response, &target_url, request_id).await
680 } else {
681 let response_body = proxy_response.text().await?;
683 let response_size = response_body.len();
684 let redacted_response = self.redaction_engine.redact_sensitive_content(&response_body);
685
686 let result = Ok(Response::builder()
687 .status(response_status)
688 .header("content-type", "application/json")
689 .body(Body::from(redacted_response))?);
690
691 info!(
693 event_type = "request_body",
694 trace_id = %trace_id,
695 method = %parts.method,
696 url = %parts.uri,
697 model = model_name.as_deref().unwrap_or(""),
698 body_size_bytes = request_body.len(),
699 body = %request_body,
700 processed_body = %processed_request_body,
701 redaction_occurred = input_redacted,
702 "Request body received"
703 );
704
705 let duration = start_time.elapsed();
707 let input_redacted = input_redacted;
708
709 info!(
710 event_type = "request_processed",
711 trace_id = %trace_id,
712 method = %parts.method,
713 url = %parts.uri,
714 model = model_name.as_deref().unwrap_or(""),
715 user_agent = %user_agent,
716 originator = %originator,
717 body_size_bytes = request_body.len(),
718 status = %response_status.as_u16(),
719 duration_ms = duration.as_millis() as u64,
720 response_size_bytes = response_size,
721 is_sse = false,
722 input_redacted = input_redacted,
723 output_redacted = true, target_url = %target_url,
725 model_routing = model_name.is_some(),
726 "Request processed"
727 );
728
729 let telemetry_data = serde_json::json!({
731 "requestId": request_id,
732 "timestamp": chrono::Utc::now().to_rfc3339(),
733 "method": parts.method.as_str(),
734 "originalUrl": parts.uri.to_string(),
735 "targetUrl": target_url.to_string(),
736 "headers": serde_json::json!(parts.headers.iter()
737 .filter_map(|(k, v)| {
738 v.to_str().ok().map(|val| (k.as_str(), val))
739 })
740 .collect::<std::collections::HashMap<&str, &str>>()),
741 "body": request_body,
742 "userAgent": user_agent,
743 "originator": originator,
744 "contentType": parts.headers.get("content-type")
745 .and_then(|v| v.to_str().ok())
746 .unwrap_or(""),
747 "responseStatus": response_status.as_u16(),
748 "responseTime": duration.as_millis() as u64,
749 "isSSE": false,
750 "isJailbreak": jailbreak_detected
751 });
752 self.send_to_telemetry_webhook(telemetry_data);
753
754 result
755 };
756
757 if is_sse {
759 info!(
761 event_type = "request_body",
762 trace_id = %trace_id,
763 method = %parts.method,
764 url = %parts.uri,
765 model = model_name.as_deref().unwrap_or(""),
766 body_size_bytes = request_body.len(),
767 body = %request_body,
768 processed_body = %processed_request_body,
769 redaction_occurred = input_redacted,
770 "Request body received (SSE)"
771 );
772
773 let duration = start_time.elapsed();
774 let input_redacted = input_redacted;
775
776 info!(
777 event_type = "request_processed",
778 trace_id = %trace_id,
779 method = %parts.method,
780 url = %parts.uri,
781 model = model_name.as_deref().unwrap_or(""),
782 user_agent = %user_agent,
783 originator = %originator,
784 body_size_bytes = request_body.len(),
785 status = %response_status.as_u16(),
786 duration_ms = duration.as_millis() as u64,
787 response_size_bytes = 0, is_sse = true,
789 input_redacted = input_redacted,
790 output_redacted = true, target_url = %target_url,
792 model_routing = model_name.is_some(),
793 "SSE request processed"
794 );
795
796 let telemetry_data = serde_json::json!({
798 "requestId": request_id,
799 "timestamp": chrono::Utc::now().to_rfc3339(),
800 "method": parts.method.as_str(),
801 "originalUrl": parts.uri.to_string(),
802 "targetUrl": target_url.to_string(),
803 "headers": serde_json::json!(parts.headers.iter()
804 .filter_map(|(k, v)| {
805 v.to_str().ok().map(|val| (k.as_str(), val))
806 })
807 .collect::<std::collections::HashMap<&str, &str>>()),
808 "body": request_body,
809 "userAgent": user_agent,
810 "originator": originator,
811 "contentType": parts.headers.get("content-type")
812 .and_then(|v| v.to_str().ok())
813 .unwrap_or(""),
814 "responseStatus": response_status.as_u16(),
815 "responseTime": duration.as_millis() as u64,
816 "isSSE": true,
817 "isJailbreak": jailbreak_detected
818 });
819 self.send_to_telemetry_webhook(telemetry_data);
820 }
821
822 response
823 }
824
825 async fn handle_sse_response(
826 &self,
827 proxy_response: reqwest::Response,
828 target_url: &str,
829 request_id: u64,
830 ) -> Result<Response<Body>> {
831 use futures::StreamExt;
832
833
834 self.sse_content_accumulators.write().await.insert(
836 request_id,
837 SseAccumulator {
838 buffer: String::new(),
839 },
840 );
841
842 let is_openai_format = target_url.contains("openai.com") || target_url.contains("/responses");
844
845 let (tx, rx) = tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, Box<dyn std::error::Error + Send + Sync>>>(100);
847
848 let redaction_engine = self.redaction_engine.clone();
850 let sse_accumulators = Arc::clone(&self.sse_content_accumulators);
851
852 tokio::spawn(async move {
854 let mut event_buffer = String::new();
855 let mut stream = proxy_response.bytes_stream();
856
857 while let Some(chunk_result) = stream.next().await {
858 match chunk_result {
859 Ok(chunk) => {
860 let chunk_str = String::from_utf8_lossy(&chunk);
861 event_buffer.push_str(&chunk_str);
862
863 let buffer_clone = event_buffer.clone();
865 let events: Vec<&str> = buffer_clone.split("\n\n").collect();
866
867 if events.len() > 1 {
868 event_buffer = events.last().unwrap_or(&"").to_string();
870
871 for event_data in &events[..events.len() - 1] {
873 if !event_data.trim().is_empty() {
874 if let Some(processed_event) = Self::process_sse_event(
875 event_data,
876 request_id,
877 is_openai_format,
878 &redaction_engine,
879 &sse_accumulators,
880 ).await {
881 let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
882 if tx.send(Ok(event_bytes)).await.is_err() {
883 break;
884 }
885 }
886 }
887 }
888 }
889 }
890 Err(e) => {
891 let _ = tx.send(Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)).await;
892 break;
893 }
894 }
895 }
896
897 if !event_buffer.trim().is_empty() {
899 if let Some(processed_event) = Self::process_sse_event(
900 &event_buffer,
901 request_id,
902 is_openai_format,
903 &redaction_engine,
904 &sse_accumulators,
905 ).await {
906 let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
907 let _ = tx.send(Ok(event_bytes)).await;
908 }
909 }
910
911 sse_accumulators.write().await.remove(&request_id);
913 });
914
915 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
917 let body = Body::wrap_stream(stream.map(|item| {
918 item.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
919 }));
920
921 Ok(Response::builder()
922 .status(StatusCode::OK)
923 .header("content-type", "text/event-stream")
924 .header("cache-control", "no-cache")
925 .header("connection", "keep-alive")
926 .body(body)?)
927 }
928
929 async fn process_sse_event(
930 event_data: &str,
931 request_id: u64,
932 is_openai_format: bool,
933 redaction_engine: &RedactionEngine,
934 sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
935 ) -> Option<String> {
936 let lines: Vec<&str> = event_data.split('\n').collect();
937 let mut event_type = String::new();
938 let mut event_data_json: Option<serde_json::Value> = None;
939
940 for line in lines {
942 if let Some(event_value) = line.strip_prefix("event: ") {
943 event_type = event_value.to_string();
944 } else if let Some(data_value) = line.strip_prefix("data: ") {
945 if let Ok(data) = serde_json::from_str::<serde_json::Value>(data_value) {
946 event_data_json = Some(data);
947 } else {
948 if data_value == "[DONE]" {
950 return Some(event_data.to_string());
951 }
952 }
953 }
954 }
955
956 let mut should_redact_and_forward = false;
958 let mut text_content = String::new();
959
960 if let Some(data) = &event_data_json {
961
962 if is_openai_format {
963 if event_type == "response.output_text.delta" {
965 if let Some(delta) = data.get("delta").and_then(|d| d.as_str()) {
966 text_content = delta.to_string();
967 should_redact_and_forward = true;
968 }
969 } else if event_type == "response.output_text.done" {
970 if let Some(text) = data.get("text").and_then(|t| t.as_str()) {
971 text_content = text.to_string();
972 should_redact_and_forward = true;
973 }
974 }
975 } else {
976 if event_type == "content_block_delta" {
978 if let Some(delta) = data.get("delta") {
979 if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
980 text_content = text.to_string();
981 should_redact_and_forward = true;
982 }
983 }
984 }
985 }
986
987 if !should_redact_and_forward {
988 }
989 }
990
991 if should_redact_and_forward && !text_content.is_empty() {
992
993 let redacted_content = if is_openai_format && event_type == "response.output_text.done" {
995 let redacted = redaction_engine.redact_sensitive_content(&text_content);
997 Some(redacted)
998 } else {
999 let result = Self::process_chunk_with_buffer(&text_content, request_id, redaction_engine, sse_accumulators).await;
1001 result
1002 };
1003
1004 if let Some(redacted) = redacted_content {
1005 if let Some(mut data) = event_data_json {
1006 if is_openai_format {
1007 if event_type == "response.output_text.delta" {
1009 data["delta"] = serde_json::Value::String(redacted);
1010 } else if event_type == "response.output_text.done" {
1011 data["text"] = serde_json::Value::String(redacted);
1012 }
1013 } else {
1014 if let Some(delta) = data.get_mut("delta") {
1016 delta["text"] = serde_json::Value::String(redacted);
1017 }
1018 }
1019
1020 return Some(format!("event: {}\ndata: {}", event_type, data));
1021 }
1022 } else {
1023 return None;
1025 }
1026 } else if event_type == "content_block_stop" && !is_openai_format {
1027 if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
1029 let final_event_data = serde_json::json!({
1030 "type": "content_block_delta",
1031 "index": 0,
1032 "delta": {
1033 "type": "text_delta",
1034 "text": final_chunk
1035 }
1036 });
1037 return Some(format!("event: content_block_delta\ndata: {}\n\nevent: content_block_stop\ndata: {}",
1038 final_event_data, event_data_json.unwrap_or(serde_json::Value::Null)));
1039 }
1040 } else if is_openai_format && (event_type == "response.completed" || event_data.contains("[DONE]")) {
1041 if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
1043 let final_data = serde_json::json!({
1044 "type": "response.output_text.delta",
1045 "delta": final_chunk
1046 });
1047 return Some(format!("event: response.output_text.delta\ndata: {}\n\n{}",
1048 final_data, event_data));
1049 }
1050 }
1051
1052 Some(event_data.to_string())
1054 }
1055
1056 async fn process_chunk_with_buffer(
1057 new_chunk: &str,
1058 request_id: u64,
1059 redaction_engine: &RedactionEngine,
1060 sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
1061 ) -> Option<String> {
1062 let mut accumulators = sse_accumulators.write().await;
1063 if let Some(accumulator) = accumulators.get_mut(&request_id) {
1064 accumulator.buffer.push_str(new_chunk);
1066
1067 let buffer_copy = accumulator.buffer.clone();
1069 let mut lines: Vec<&str> = buffer_copy.split('\n').collect();
1070
1071 accumulator.buffer = lines.pop().unwrap_or("").to_string();
1073
1074 if !lines.is_empty() {
1076 let complete_lines = lines.join("\n") + "\n";
1077 let redacted = redaction_engine.redact_sensitive_content(&complete_lines);
1078 return Some(redacted);
1079 }
1080
1081 return None;
1083 }
1084 None
1085 }
1086
1087 async fn flush_buffer(
1088 request_id: u64,
1089 redaction_engine: &RedactionEngine,
1090 sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
1091 ) -> Option<String> {
1092 let mut accumulators = sse_accumulators.write().await;
1093 if let Some(accumulator) = accumulators.get_mut(&request_id) {
1094 if accumulator.buffer.is_empty() {
1095 return None;
1096 }
1097
1098 let remaining = accumulator.buffer.clone();
1100 accumulator.buffer.clear();
1101
1102 let redacted = redaction_engine.redact_sensitive_content(&remaining);
1103 Some(redacted)
1104 } else {
1105 None
1106 }
1107 }
1108}
1109
1110impl Clone for ProxyServer {
1111 fn clone(&self) -> Self {
1112
1113 Self {
1114 port: self.port,
1115 request_count: Arc::clone(&self.request_count),
1116 redaction_engine: self.redaction_engine.clone(),
1117 redaction_service: self.redaction_service.clone(),
1118 config_manager: Arc::clone(&self.config_manager),
1119 sse_content_accumulators: Arc::clone(&self.sse_content_accumulators),
1120 redis_manager: self.redis_manager.clone(),
1121 telemetry_webhook_config: Arc::clone(&self.telemetry_webhook_config),
1122 }
1123 }
1124}
1125
1126impl Clone for RedactionEngine {
1127 fn clone(&self) -> Self {
1128 RedactionEngine::new()
1129 }
1130}