1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::time::Instant;
6
7use hyper::service::{make_service_fn, service_fn};
8use hyper::{Body, Method, Request, Response, Server, StatusCode};
9use serde_json::Value;
10use tokio::sync::{Mutex, RwLock};
11use tracing::{info, warn};
12use uuid::Uuid;
13use redis::{AsyncCommands, aio::ConnectionManager};
14
15use crate::config::ConfigManager;
16use crate::redaction::{RedactionEngine, RedactionService};
17use crate::{Error, Result};
18
19#[derive(Debug, Clone)]
20struct SseAccumulator {
21    buffer: String,
22}
23
24pub struct ProxyServer {
25    port: u16,
26    request_count: Arc<Mutex<u64>>,
27    redaction_engine: RedactionEngine,
28    redaction_service: RedactionService,
29    config_manager: Arc<RwLock<ConfigManager>>,
30    sse_content_accumulators: Arc<RwLock<HashMap<u64, SseAccumulator>>>,
31    redis_manager: Option<ConnectionManager>,
32    telemetry_webhook_config: Arc<RwLock<Option<crate::config::TelemetryWebhookConfig>>>,
33}
34
35impl ProxyServer {
36    pub async fn new(port: u16, config_path: Option<String>, redaction_api_url: Option<String>) -> Result<Self> {
37        
38        let mut config_manager = ConfigManager::new_with_path(config_path.clone());
39        if let Err(e) = config_manager.load_config().await {
40            warn!(
41                event_type = "config_load_failed",
42                config_path = config_path.as_deref().unwrap_or("default"),
43                error = %e,
44                "Could not load config file, using defaults"
45            );
46        }
47
48        let redis_manager = if std::env::var("MULTITENANT").unwrap_or_default() == "true" {
50            let redis_url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://localhost:6379".to_string());
51            match redis::Client::open(redis_url.as_str()) {
52                Ok(client) => {
53                    match ConnectionManager::new(client).await {
54                        Ok(manager) => {
55                            println!("[REDIS] Connection manager initialized successfully");
56                            Some(manager)
57                        },
58                        Err(e) => {
59                            eprintln!("[REDIS] Failed to create connection manager: {}", e);
60                            None
61                        }
62                    }
63                },
64                Err(e) => {
65                    eprintln!("[REDIS] Failed to initialize Redis client: {}", e);
66                    None
67                }
68            }
69        } else {
70            None
71        };
72
73        Ok(Self {
74            port,
75            request_count: Arc::new(Mutex::new(0)),
76            redaction_engine: RedactionEngine::new(),
77            redaction_service: RedactionService::new(redaction_api_url),
78            config_manager: Arc::new(RwLock::new(config_manager)),
79            sse_content_accumulators: Arc::new(RwLock::new(HashMap::new())),
80            redis_manager,
81            telemetry_webhook_config: Arc::new(RwLock::new(None)),
82        })
83    }
84
85    pub async fn start(&self) -> Result<()> {
86        let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
87        
88        info!(
89            event_type = "server_start",
90            service = "ai-firewall-rust",
91            version = "0.0.1",
92            port = self.port,
93            address = %addr,
94            "Proxy server starting"
95        );
96        
97        let server_clone = Arc::new(self.clone());
98        let make_svc = make_service_fn(move |_conn| {
99            let server = Arc::clone(&server_clone);
100            async move {
101                Ok::<_, Infallible>(service_fn(move |req| {
102                    let server = Arc::clone(&server);
103                    async move { server.handle_request(req).await }
104                }))
105            }
106        });
107
108        let server = Server::bind(&addr).serve(make_svc);
109        info!(
110            event_type = "server_listening",
111            address = %addr,
112            "Proxy server listening"
113        );
114
115        if let Err(e) = server.await {
116            return Err(Error::Server(e.to_string()));
117        }
118
119        Ok(())
120    }
121
122    pub async fn stop(&self) {
123        info!(
124            event_type = "server_stop",
125            "Stopping proxy server"
126        );
127    }
128
129    async fn handle_request(&self, req: Request<Body>) -> std::result::Result<Response<Body>, Infallible> {
130        let result = self.handle_request_inner(req).await;
131        Ok(result.unwrap_or_else(|e| {
132            Response::builder()
133                .status(StatusCode::INTERNAL_SERVER_ERROR)
134                .body(Body::from(format!("Internal Server Error: {}", e)))
135                .unwrap()
136        }))
137    }
138
139    async fn get_config_for_multitenant(&self, config_id: &str, model_name: Option<&str>) -> Result<crate::config::ResolvedModelConfig> {
140        println!("[MULTITENANT] Config ID: {}, Model: {:?}", config_id, model_name);
141        
142        let cache_key = format!("config:{}", config_id);
143        let cache_ttl: u64 = std::env::var("CONFIG_CACHE_TTL")
144            .unwrap_or_else(|_| "3600".to_string())
145            .parse()
146            .unwrap_or(3600);
147        
148        let mut all_configs: Option<Vec<serde_json::Value>> = None;
149        
150        if let Some(redis_manager) = &self.redis_manager {
152            let mut conn = redis_manager.clone();
153            match conn.get::<_, String>(&cache_key).await {
154                    Ok(cached_data) => {
155                        println!("[MULTITENANT] Cache HIT for config ID: {}", config_id);
156                        if let Ok(cached_config) = serde_json::from_str::<serde_json::Value>(&cached_data) {
157                            if let Some(models) = cached_config.get("models").and_then(|m| m.as_array()) {
159                                all_configs = Some(models.clone());
160                            }
161                            
162                            if let Some(telemetry_webhook) = cached_config.get("telemetry_webhook") {
164                                if let Ok(webhook_config) = serde_json::from_value::<crate::config::TelemetryWebhookConfig>(telemetry_webhook.clone()) {
165                                    *self.telemetry_webhook_config.write().await = Some(webhook_config);
166                                }
167                            }
168                        }
169                    },
170                Err(_) => {
171                    println!("[MULTITENANT] Cache MISS for config ID: {}", config_id);
172                }
173            }
174        }
175        
176        if all_configs.is_none() {
178            let config_api_url = std::env::var("MULTITENANT_CONFIG_API_URL")
179                .map_err(|_| Error::Server("MULTITENANT_CONFIG_API_URL environment variable is required when MULTITENANT=true".to_string()))?;
180            
181            let full_url = format!("{}{}", config_api_url, config_id);
182            println!("[MULTITENANT] Fetching config from: {}", full_url);
183            
184            let response = reqwest::Client::new().get(&full_url).send().await
185                .map_err(|e| Error::Server(format!("Failed to fetch config for {}: {}", config_id, e)))?;
186            
187            if !response.status().is_success() {
188                return Err(Error::Server(format!("Config API returned {}: {}", response.status(), response.status().canonical_reason().unwrap_or("Unknown"))));
189            }
190            
191            let config_response = response.json::<serde_json::Value>().await
192                .map_err(|e| Error::Server(format!("Failed to parse config response for {}: {}", config_id, e)))?;
193            
194            println!("[MULTITENANT] Received config response: {}", serde_json::to_string_pretty(&config_response).unwrap_or_default());
195            
196            let configs_array = config_response.get("models")
198                .and_then(|m| m.as_array())
199                .ok_or_else(|| Error::Server("Config API returned invalid models array".to_string()))?
200                .clone();
201            
202            if let Some(telemetry_webhook) = config_response.get("telemetry_webhook") {
204                if let Ok(webhook_config) = serde_json::from_value::<crate::config::TelemetryWebhookConfig>(telemetry_webhook.clone()) {
205                    *self.telemetry_webhook_config.write().await = Some(webhook_config);
206                    println!("[MULTITENANT] Stored telemetry webhook config");
207                } else {
208                    eprintln!("[MULTITENANT] Failed to parse telemetry webhook config");
209                }
210            }
211            
212            if configs_array.is_empty() {
213                return Err(Error::Server("Config API returned empty models array".to_string()));
214            }
215            
216            all_configs = Some(configs_array);
217            
218            if let Some(redis_manager) = &self.redis_manager {
220                let mut conn = redis_manager.clone();
221                if let Ok(config_json) = serde_json::to_string(&config_response) {
222                    match conn.set_ex::<_, _, ()>(&cache_key, &config_json, cache_ttl as usize).await {
223                        Ok(_) => println!("[MULTITENANT] Cached complete config for {} with TTL {} seconds", config_id, cache_ttl),
224                        Err(e) => eprintln!("[REDIS] Failed to cache config: {}", e),
225                    }
226                }
227            }
228        }
229        
230        if let (Some(model_name), Some(all_configs)) = (model_name, &all_configs) {
232            for config in all_configs {
233                if let Some(config_model_name) = config.get("model_name").and_then(|v| v.as_str()) {
234                    if config_model_name == model_name {
235                        println!("[MULTITENANT] Found specific config for model: {}", model_name);
236                        return Ok(crate::config::ResolvedModelConfig {
237                            provider: config.get("provider").and_then(|v| v.as_str()).unwrap_or("openai").to_string(),
238                            api_base: config.get("api_base").and_then(|v| v.as_str()).unwrap_or("https://api.openai.com/").to_string(),
239                            model_name: config.get("model_name").and_then(|v| v.as_str()).unwrap_or(model_name).to_string(),
240                        });
241                    }
242                }
243            }
244            
245            return Err(Error::Server(format!("Model '{}' not found in configuration for config_id: {}", model_name, config_id)));
247        }
248        
249        Err(Error::Server("Model must be specified in request body for multitenant configuration".to_string()))
251    }
252
253    fn send_to_telemetry_webhook(&self, request_data: serde_json::Value) {
255        let webhook_config = Arc::clone(&self.telemetry_webhook_config);
256        let config_manager = Arc::clone(&self.config_manager);
257        
258        tokio::spawn(async move {
259            let telemetry_config = {
261                let is_multitenant = std::env::var("MULTITENANT").unwrap_or_default() == "true";
262                
263                if is_multitenant {
264                    let config = webhook_config.read().await;
266                    config.clone()
267                } else {
268                    let config_manager = config_manager.read().await;
270                    config_manager.get_telemetry_webhook_config().cloned()
271                }
272            };
273            
274            if let Some(config) = telemetry_config {
275                let payload = serde_json::json!({
276                    "source": "superagent-proxy",
277                    "event": request_data
278                });
279                
280                let mut req_builder = reqwest::Client::new()
282                    .post(&config.url)
283                    .header("content-type", "application/json");
284                
285                for (key, value) in &config.headers {
286                    req_builder = req_builder.header(key, value);
287                }
288                
289                match req_builder.json(&payload).send().await {
290                    Ok(_) => {}, Err(e) => eprintln!("[TELEMETRY] Failed to send webhook: {}", e),
292                }
293            }
294        });
295    }
296
297    async fn handle_config_cache_update(&self, req: Request<Body>, config_id: &str) -> Result<Response<Body>> {
298        if config_id.is_empty() {
299            return Ok(Response::builder()
300                .status(StatusCode::BAD_REQUEST)
301                .header("content-type", "application/json")
302                .body(Body::from(r#"{"error": "Config ID is required"}"#))?);
303        }
304
305        let body_bytes = hyper::body::to_bytes(req.into_body()).await?;
307        let request_body = String::from_utf8_lossy(&body_bytes);
308
309        if request_body.is_empty() {
310            return Ok(Response::builder()
311                .status(StatusCode::BAD_REQUEST)
312                .header("content-type", "application/json")
313                .body(Body::from(r#"{"error": "Request body with config array is required"}"#))?);
314        }
315
316        let new_config: serde_json::Value = match serde_json::from_str(&request_body) {
318            Ok(config) => config,
319            Err(e) => {
320                eprintln!("[CONFIG-CACHE] Error parsing config for {}: {}", config_id, e);
321                return Ok(Response::builder()
322                    .status(StatusCode::BAD_REQUEST)
323                    .header("content-type", "application/json")
324                    .body(Body::from(format!(r#"{{"error": "Invalid JSON", "details": "{}"}}"#, e)))?);
325            }
326        };
327
328        if !new_config.is_array() {
330            return Ok(Response::builder()
331                .status(StatusCode::BAD_REQUEST)
332                .header("content-type", "application/json")
333                .body(Body::from(r#"{"error": "Config must be an array"}"#))?);
334        }
335
336        let cache_key = format!("config:{}", config_id);
337        let cache_ttl: u64 = std::env::var("CONFIG_CACHE_TTL")
338            .unwrap_or_else(|_| "3600".to_string())
339            .parse()
340            .unwrap_or(3600);
341
342        if let Some(redis_manager) = &self.redis_manager {
344            let mut conn = redis_manager.clone();
345            let config_json = serde_json::to_string(&new_config).unwrap();
346            match conn.set_ex::<_, _, ()>(&cache_key, &config_json, cache_ttl as usize).await {
347                    Ok(_) => {
348                        println!("[CONFIG-CACHE] Updated cache for {} with TTL {} seconds", config_id, cache_ttl);
349                        let models_count = new_config.as_array().unwrap().len();
350                        let response = serde_json::json!({
351                            "success": true,
352                            "message": format!("Config cache updated for {}", config_id),
353                            "models": models_count
354                        });
355                        return Ok(Response::builder()
356                            .status(StatusCode::OK)
357                            .header("content-type", "application/json")
358                            .body(Body::from(response.to_string()))?);
359                    },
360                Err(e) => {
361                    eprintln!("[CONFIG-CACHE] Redis error updating cache for {}: {}", config_id, e);
362                    return Ok(Response::builder()
363                        .status(StatusCode::INTERNAL_SERVER_ERROR)
364                        .header("content-type", "application/json")
365                        .body(Body::from(format!(r#"{{"error": "Redis error", "details": "{}"}}"#, e)))?);
366                }
367            }
368        } else {
369            return Ok(Response::builder()
370                .status(StatusCode::SERVICE_UNAVAILABLE)
371                .header("content-type", "application/json")
372                .body(Body::from(r#"{"error": "Redis client not available"}"#))?);
373        }
374    }
375
376    async fn redact_content(&self, content: &mut Value, jailbreak_detected: &mut bool) -> Result<Value> {
377        match content {
378            Value::String(text) => {
379                let result = self.redaction_service.screen_user_prompt(text).await?;
380                if result.is_jailbreak {
381                    *jailbreak_detected = true;
382                }
383                Ok(Value::String(result.content))
384            }
385            Value::Array(content_blocks) => {
386                let mut redacted_blocks = Vec::new();
387                for block in content_blocks.iter() {
388                    if let Some(block_obj) = block.as_object() {
389                        if let Some(block_type) = block_obj.get("type") {
390                            if block_type == "text" {
391                                if let Some(text) = block_obj.get("text").and_then(|t| t.as_str()) {
392                                    let result = self.redaction_service.screen_user_prompt(text).await?;
393                                    if result.is_jailbreak {
394                                        *jailbreak_detected = true;
395                                    }
396                                    let mut new_block = block_obj.clone();
397                                    new_block.insert("text".to_string(), Value::String(result.content));
398                                    redacted_blocks.push(Value::Object(new_block));
399                                } else {
400                                    redacted_blocks.push(block.clone());
401                                }
402                            } else {
403                                redacted_blocks.push(block.clone());
405                            }
406                        } else {
407                            redacted_blocks.push(block.clone());
408                        }
409                    } else {
410                        redacted_blocks.push(block.clone());
411                    }
412                }
413                Ok(Value::Array(redacted_blocks))
414            }
415            _ => Ok(content.clone()),
416        }
417    }
418
419    async fn handle_request_inner(&self, req: Request<Body>) -> Result<Response<Body>> {
420        let start_time = Instant::now();
421        let trace_id = Uuid::new_v4().to_string();
422
423        if req.uri().path() == "/health" && req.method() == Method::GET {
425            info!(
426                event_type = "health_check",
427                trace_id = %trace_id,
428                "Health check requested"
429            );
430
431            let request_count = self.request_count.lock().await;
432            let count = *request_count;
433            drop(request_count);
434
435            let health_data = serde_json::json!({
436                "status": "healthy",
437                "uptime": std::time::SystemTime::now()
438                    .duration_since(std::time::UNIX_EPOCH)
439                    .unwrap()
440                    .as_secs(),
441                "timestamp": chrono::Utc::now(),
442                "requestCount": count
443            });
444
445            return Ok(Response::builder()
446                .status(StatusCode::OK)
447                .header("content-type", "application/json")
448                .body(Body::from(health_data.to_string()))?);
449        }
450        
451        if req.uri().path().starts_with("/config/") && req.method() == Method::POST {
453            let config_id = req.uri().path().strip_prefix("/config/").unwrap_or("").to_string();
454            return self.handle_config_cache_update(req, &config_id).await;
455        }
456
457        let mut request_count = self.request_count.lock().await;
458        *request_count += 1;
459        let request_id = *request_count;
460        drop(request_count);
461
462        let (parts, body) = req.into_parts();
464        let body_bytes = hyper::body::to_bytes(body).await?;
465        let request_body = String::from_utf8_lossy(&body_bytes);
466
467        self.process_request_with_config(request_body.to_string(), parts, request_id, start_time, trace_id).await
468    }
469
470    async fn process_request_with_config(
471        &self,
472        request_body: String,
473        parts: http::request::Parts,
474        request_id: u64,
475        start_time: Instant,
476        trace_id: String,
477    ) -> Result<Response<Body>> {
478        let mut model_name = None;
480        let mut input_redacted = false;
481        let mut jailbreak_detected = false; let processed_request_body = if !request_body.is_empty() {
483            if let Ok(mut parsed_body) = serde_json::from_str::<Value>(&request_body) {
484                if let Some(model) = parsed_body.get("model") {
485                    if let Some(model_str) = model.as_str() {
486                        model_name = Some(model_str.to_string());
487                    }
488                }
489
490                if let Some(messages) = parsed_body.get_mut("messages") {
492                    info!("Found messages in request body, checking for user messages to redact");
493                    if let Some(messages_array) = messages.as_array_mut() {
494                        for message in messages_array.iter_mut() {
495                            if let Some(message_obj) = message.as_object_mut() {
496                                if let Some(role) = message_obj.get("role") {
497                                    if role == "user" {
498                                        info!("Found user message, attempting redaction");
499                                        if let Some(content) = message_obj.get_mut("content") {
500                                            let original_content = content.clone();
502                                            match self.redact_content(content, &mut jailbreak_detected).await {
503                                                Ok(redacted_content) => {
504                                                    if redacted_content != original_content {
505                                                        info!("User message content was modified by redaction");
506                                                        input_redacted = true;
507                                                    } else {
508                                                        info!("User message content unchanged after redaction");
509                                                    }
510                                                    *content = redacted_content;
511
512                                                    if jailbreak_detected {
514                                                        message_obj.insert("role".to_string(), Value::String("user".to_string()));
515                                                        info!("Changed message role from 'user' to 'assistant' due to jailbreak detection");
516                                                    }
517                                                }
518                                                Err(e) => {
519                                                    warn!("Failed to redact user message: {}", e);
520                                                    }
522                                            }
523                                        } else {
524                                            info!("User message has no content field");
525                                        }
526                                    }
527                                }
528                            }
529                        }
530                    }
531                } else {
532                    info!("No messages found in request body");
533                }
534
535                if input_redacted {
537                    serde_json::to_string(&parsed_body).unwrap_or_else(|_| request_body.clone())
538                } else {
539                    request_body.clone()
540                }
541            } else {
542                request_body.clone()
543            }
544        } else {
545            request_body.clone()
546        };
547
548        let target_url = if parts.uri.path().starts_with('/') {
550            let is_multitenant = std::env::var("MULTITENANT").unwrap_or_default() == "true";
552            let mut config_id: Option<String> = None;
553            let actual_path = if is_multitenant {
554                let path_segments: Vec<&str> = parts.uri.path().split('/').filter(|s| !s.is_empty()).collect();
555                if !path_segments.is_empty() {
556                    config_id = Some(path_segments[0].to_string());
557                    let rewritten_path = format!("/{}", path_segments[1..].join("/"));
558                    
559                    println!("[MULTITENANT] Config ID: {}", config_id.as_ref().unwrap());
560                    println!("[MULTITENANT] Rewritten path: {}", rewritten_path);
561                    
562                    rewritten_path
563                } else {
564                    parts.uri.path().to_string()
565                }
566            } else {
567                parts.uri.path().to_string()
568            };
569            
570            let model_config = if let Some(ref config_id) = config_id {
571                self.get_config_for_multitenant(config_id, model_name.as_deref()).await?
572            } else {
573                let model = model_name.as_ref().ok_or_else(|| {
575                    Error::Server("Bad Request: Model must be specified in request body".to_string())
576                })?;
577                
578                let config_manager = self.config_manager.read().await;
579                let model_config = config_manager.get_model_config(&model);
580                crate::config::ResolvedModelConfig {
581                    provider: model_config.provider,
582                    api_base: model_config.api_base,
583                    model_name: model_config.model_name,
584                }
585            };
586            
587            let mut base_url = if model_config.api_base.ends_with('/') {
589                model_config.api_base
590            } else {
591                format!("{}/", model_config.api_base)
592            };
593
594            let mut path = if actual_path.starts_with('/') {
595                &actual_path[1..]
596            } else {
597                &actual_path
598            };
599
600            let base_ends_with_v1 = base_url.trim_end_matches('/').ends_with("/v1");
602            if base_ends_with_v1 && (path == "v1" || path.starts_with("v1/")) {
603                path = if path.len() > 2 { &path[3..] } else { "" };
605                if !base_url.ends_with('/') { base_url.push('/'); }
607            }
608
609            let final_url = format!(
610                "{}{}{}",
611                base_url,
612                path,
613                parts
614                    .uri
615                    .query()
616                    .map_or(String::new(), |q| format!("?{}", q))
617            );
618            
619            final_url
620        } else {
621            parts.uri.to_string()
623        };
624
625        let reqwest_method = match parts.method.as_str() {
627            "GET" => reqwest::Method::GET,
628            "POST" => reqwest::Method::POST,
629            "PUT" => reqwest::Method::PUT,
630            "DELETE" => reqwest::Method::DELETE,
631            "PATCH" => reqwest::Method::PATCH,
632            "HEAD" => reqwest::Method::HEAD,
633            "OPTIONS" => reqwest::Method::OPTIONS,
634            _ => reqwest::Method::GET, };
636
637        let mut req_builder = reqwest::Client::new()
638            .request(reqwest_method, &target_url);
639
640        for (name, value) in parts.headers.iter() {
642            if !matches!(name.as_str(), "host" | "proxy-connection" | "proxy-authorization" | "content-length") {
643                if let Ok(header_value) = value.to_str() {
644                    req_builder = req_builder.header(name.as_str(), header_value);
645                }
646            }
647        }
648
649        if !processed_request_body.is_empty() {
651            req_builder = req_builder
653                .header("content-length", processed_request_body.len().to_string())
654                .body(processed_request_body.clone());
655        }
656
657        let proxy_response = req_builder.send().await?;
659        let response_status = proxy_response.status();
660
661        let is_sse = proxy_response
663            .headers()
664            .get("content-type")
665            .and_then(|v| v.to_str().ok())
666            .map_or(false, |ct| ct.contains("text/event-stream"));
667
668        let user_agent = parts.headers.get("user-agent")
670            .and_then(|v| v.to_str().ok())
671            .unwrap_or("")
672            .to_string();
673        let originator = parts.headers.get("originator")
674            .and_then(|v| v.to_str().ok())
675            .unwrap_or("")
676            .to_string();
677
678        let response = if is_sse {
679            self.handle_sse_response(proxy_response, &target_url, request_id).await
680        } else {
681            let response_body = proxy_response.text().await?;
683            let response_size = response_body.len();
684            let redacted_response = self.redaction_engine.redact_sensitive_content(&response_body);
685            
686            let result = Ok(Response::builder()
687                .status(response_status)
688                .header("content-type", "application/json")
689                .body(Body::from(redacted_response))?);
690
691            info!(
693                event_type = "request_body",
694                trace_id = %trace_id,
695                method = %parts.method,
696                url = %parts.uri,
697                model = model_name.as_deref().unwrap_or(""),
698                body_size_bytes = request_body.len(),
699                body = %request_body,
700                processed_body = %processed_request_body,
701                redaction_occurred = input_redacted,
702                "Request body received"
703            );
704
705            let duration = start_time.elapsed();
707            let input_redacted = input_redacted;
708            
709            info!(
710                event_type = "request_processed",
711                trace_id = %trace_id,
712                method = %parts.method,
713                url = %parts.uri,
714                model = model_name.as_deref().unwrap_or(""),
715                user_agent = %user_agent,
716                originator = %originator,
717                body_size_bytes = request_body.len(),
718                status = %response_status.as_u16(),
719                duration_ms = duration.as_millis() as u64,
720                response_size_bytes = response_size,
721                is_sse = false,
722                input_redacted = input_redacted,
723                output_redacted = true, target_url = %target_url,
725                model_routing = model_name.is_some(),
726                "Request processed"
727            );
728
729            let telemetry_data = serde_json::json!({
731                "requestId": request_id,
732                "timestamp": chrono::Utc::now().to_rfc3339(),
733                "method": parts.method.as_str(),
734                "originalUrl": parts.uri.to_string(),
735                "targetUrl": target_url.to_string(),
736                "headers": serde_json::json!(parts.headers.iter()
737                    .filter_map(|(k, v)| {
738                        v.to_str().ok().map(|val| (k.as_str(), val))
739                    })
740                    .collect::<std::collections::HashMap<&str, &str>>()),
741                "body": request_body,
742                "userAgent": user_agent,
743                "originator": originator,
744                "contentType": parts.headers.get("content-type")
745                    .and_then(|v| v.to_str().ok())
746                    .unwrap_or(""),
747                "responseStatus": response_status.as_u16(),
748                "responseTime": duration.as_millis() as u64,
749                "isSSE": false,
750                "isJailbreak": jailbreak_detected
751            });
752            self.send_to_telemetry_webhook(telemetry_data);
753
754            result
755        };
756
757        if is_sse {
759            info!(
761                event_type = "request_body",
762                trace_id = %trace_id,
763                method = %parts.method,
764                url = %parts.uri,
765                model = model_name.as_deref().unwrap_or(""),
766                body_size_bytes = request_body.len(),
767                body = %request_body,
768                processed_body = %processed_request_body,
769                redaction_occurred = input_redacted,
770                "Request body received (SSE)"
771            );
772
773            let duration = start_time.elapsed();
774            let input_redacted = input_redacted;
775            
776            info!(
777                event_type = "request_processed",
778                trace_id = %trace_id,
779                method = %parts.method,
780                url = %parts.uri,
781                model = model_name.as_deref().unwrap_or(""),
782                user_agent = %user_agent,
783                originator = %originator,
784                body_size_bytes = request_body.len(),
785                status = %response_status.as_u16(),
786                duration_ms = duration.as_millis() as u64,
787                response_size_bytes = 0, is_sse = true,
789                input_redacted = input_redacted,
790                output_redacted = true, target_url = %target_url,
792                model_routing = model_name.is_some(),
793                "SSE request processed"
794            );
795
796            let telemetry_data = serde_json::json!({
798                "requestId": request_id,
799                "timestamp": chrono::Utc::now().to_rfc3339(),
800                "method": parts.method.as_str(),
801                "originalUrl": parts.uri.to_string(),
802                "targetUrl": target_url.to_string(),
803                "headers": serde_json::json!(parts.headers.iter()
804                    .filter_map(|(k, v)| {
805                        v.to_str().ok().map(|val| (k.as_str(), val))
806                    })
807                    .collect::<std::collections::HashMap<&str, &str>>()),
808                "body": request_body,
809                "userAgent": user_agent,
810                "originator": originator,
811                "contentType": parts.headers.get("content-type")
812                    .and_then(|v| v.to_str().ok())
813                    .unwrap_or(""),
814                "responseStatus": response_status.as_u16(),
815                "responseTime": duration.as_millis() as u64,
816                "isSSE": true,
817                "isJailbreak": jailbreak_detected
818            });
819            self.send_to_telemetry_webhook(telemetry_data);
820        }
821
822        response
823    }
824
825    async fn handle_sse_response(
826        &self,
827        proxy_response: reqwest::Response,
828        target_url: &str,
829        request_id: u64,
830    ) -> Result<Response<Body>> {
831        use futures::StreamExt;
832        
833
834        self.sse_content_accumulators.write().await.insert(
836            request_id,
837            SseAccumulator {
838                buffer: String::new(),
839            },
840        );
841
842        let is_openai_format = target_url.contains("openai.com") || target_url.contains("/responses");
844
845        let (tx, rx) = tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, Box<dyn std::error::Error + Send + Sync>>>(100);
847
848        let redaction_engine = self.redaction_engine.clone();
850        let sse_accumulators = Arc::clone(&self.sse_content_accumulators);
851
852        tokio::spawn(async move {
854            let mut event_buffer = String::new();
855            let mut stream = proxy_response.bytes_stream();
856
857            while let Some(chunk_result) = stream.next().await {
858                match chunk_result {
859                    Ok(chunk) => {
860                        let chunk_str = String::from_utf8_lossy(&chunk);
861                        event_buffer.push_str(&chunk_str);
862
863                        let buffer_clone = event_buffer.clone();
865                        let events: Vec<&str> = buffer_clone.split("\n\n").collect();
866                        
867                        if events.len() > 1 {
868                            event_buffer = events.last().unwrap_or(&"").to_string();
870
871                            for event_data in &events[..events.len() - 1] {
873                                if !event_data.trim().is_empty() {
874                                    if let Some(processed_event) = Self::process_sse_event(
875                                        event_data,
876                                        request_id,
877                                        is_openai_format,
878                                        &redaction_engine,
879                                        &sse_accumulators,
880                                    ).await {
881                                        let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
882                                        if tx.send(Ok(event_bytes)).await.is_err() {
883                                            break;
884                                        }
885                                    }
886                                }
887                            }
888                        }
889                    }
890                    Err(e) => {
891                        let _ = tx.send(Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)).await;
892                        break;
893                    }
894                }
895            }
896
897            if !event_buffer.trim().is_empty() {
899                if let Some(processed_event) = Self::process_sse_event(
900                    &event_buffer,
901                    request_id,
902                    is_openai_format,
903                    &redaction_engine,
904                    &sse_accumulators,
905                ).await {
906                    let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
907                    let _ = tx.send(Ok(event_bytes)).await;
908                }
909            }
910
911            sse_accumulators.write().await.remove(&request_id);
913        });
914
915        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
917        let body = Body::wrap_stream(stream.map(|item| {
918            item.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
919        }));
920
921        Ok(Response::builder()
922            .status(StatusCode::OK)
923            .header("content-type", "text/event-stream")
924            .header("cache-control", "no-cache")
925            .header("connection", "keep-alive")
926            .body(body)?)
927    }
928
929    async fn process_sse_event(
930        event_data: &str,
931        request_id: u64,
932        is_openai_format: bool,
933        redaction_engine: &RedactionEngine,
934        sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
935    ) -> Option<String> {
936        let lines: Vec<&str> = event_data.split('\n').collect();
937        let mut event_type = String::new();
938        let mut event_data_json: Option<serde_json::Value> = None;
939
940        for line in lines {
942            if let Some(event_value) = line.strip_prefix("event: ") {
943                event_type = event_value.to_string();
944            } else if let Some(data_value) = line.strip_prefix("data: ") {
945                if let Ok(data) = serde_json::from_str::<serde_json::Value>(data_value) {
946                    event_data_json = Some(data);
947                } else {
948                    if data_value == "[DONE]" {
950                        return Some(event_data.to_string());
951                    }
952                }
953            }
954        }
955
956        let mut should_redact_and_forward = false;
958        let mut text_content = String::new();
959
960        if let Some(data) = &event_data_json {
961            
962            if is_openai_format {
963                if event_type == "response.output_text.delta" {
965                    if let Some(delta) = data.get("delta").and_then(|d| d.as_str()) {
966                        text_content = delta.to_string();
967                        should_redact_and_forward = true;
968                    }
969                } else if event_type == "response.output_text.done" {
970                    if let Some(text) = data.get("text").and_then(|t| t.as_str()) {
971                        text_content = text.to_string();
972                        should_redact_and_forward = true;
973                    }
974                }
975            } else {
976                if event_type == "content_block_delta" {
978                    if let Some(delta) = data.get("delta") {
979                        if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
980                            text_content = text.to_string();
981                            should_redact_and_forward = true;
982                        }
983                    }
984                }
985            }
986            
987            if !should_redact_and_forward {
988            }
989        }
990
991        if should_redact_and_forward && !text_content.is_empty() {
992            
993            let redacted_content = if is_openai_format && event_type == "response.output_text.done" {
995                let redacted = redaction_engine.redact_sensitive_content(&text_content);
997                Some(redacted)
998            } else {
999                let result = Self::process_chunk_with_buffer(&text_content, request_id, redaction_engine, sse_accumulators).await;
1001                result
1002            };
1003
1004            if let Some(redacted) = redacted_content {
1005                if let Some(mut data) = event_data_json {
1006                    if is_openai_format {
1007                        if event_type == "response.output_text.delta" {
1009                            data["delta"] = serde_json::Value::String(redacted);
1010                        } else if event_type == "response.output_text.done" {
1011                            data["text"] = serde_json::Value::String(redacted);
1012                        }
1013                    } else {
1014                        if let Some(delta) = data.get_mut("delta") {
1016                            delta["text"] = serde_json::Value::String(redacted);
1017                        }
1018                    }
1019
1020                    return Some(format!("event: {}\ndata: {}", event_type, data));
1021                }
1022            } else {
1023                return None;
1025            }
1026        } else if event_type == "content_block_stop" && !is_openai_format {
1027            if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
1029                let final_event_data = serde_json::json!({
1030                    "type": "content_block_delta",
1031                    "index": 0,
1032                    "delta": {
1033                        "type": "text_delta",
1034                        "text": final_chunk
1035                    }
1036                });
1037                return Some(format!("event: content_block_delta\ndata: {}\n\nevent: content_block_stop\ndata: {}", 
1038                    final_event_data, event_data_json.unwrap_or(serde_json::Value::Null)));
1039            }
1040        } else if is_openai_format && (event_type == "response.completed" || event_data.contains("[DONE]")) {
1041            if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
1043                let final_data = serde_json::json!({
1044                    "type": "response.output_text.delta",
1045                    "delta": final_chunk
1046                });
1047                return Some(format!("event: response.output_text.delta\ndata: {}\n\n{}", 
1048                    final_data, event_data));
1049            }
1050        }
1051
1052        Some(event_data.to_string())
1054    }
1055
1056    async fn process_chunk_with_buffer(
1057        new_chunk: &str,
1058        request_id: u64,
1059        redaction_engine: &RedactionEngine,
1060        sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
1061    ) -> Option<String> {
1062        let mut accumulators = sse_accumulators.write().await;
1063        if let Some(accumulator) = accumulators.get_mut(&request_id) {
1064            accumulator.buffer.push_str(new_chunk);
1066
1067            let buffer_copy = accumulator.buffer.clone();
1069            let mut lines: Vec<&str> = buffer_copy.split('\n').collect();
1070            
1071            accumulator.buffer = lines.pop().unwrap_or("").to_string();
1073            
1074            if !lines.is_empty() {
1076                let complete_lines = lines.join("\n") + "\n";
1077                let redacted = redaction_engine.redact_sensitive_content(&complete_lines);
1078                return Some(redacted);
1079            }
1080            
1081            return None;
1083        }
1084        None
1085    }
1086
1087    async fn flush_buffer(
1088        request_id: u64,
1089        redaction_engine: &RedactionEngine,
1090        sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
1091    ) -> Option<String> {
1092        let mut accumulators = sse_accumulators.write().await;
1093        if let Some(accumulator) = accumulators.get_mut(&request_id) {
1094            if accumulator.buffer.is_empty() {
1095                return None;
1096            }
1097            
1098            let remaining = accumulator.buffer.clone();
1100            accumulator.buffer.clear();
1101            
1102            let redacted = redaction_engine.redact_sensitive_content(&remaining);
1103            Some(redacted)
1104        } else {
1105            None
1106        }
1107    }
1108}
1109
1110impl Clone for ProxyServer {
1111    fn clone(&self) -> Self {
1112        
1113        Self {
1114            port: self.port,
1115            request_count: Arc::clone(&self.request_count),
1116            redaction_engine: self.redaction_engine.clone(),
1117            redaction_service: self.redaction_service.clone(),
1118            config_manager: Arc::clone(&self.config_manager),
1119            sse_content_accumulators: Arc::clone(&self.sse_content_accumulators),
1120            redis_manager: self.redis_manager.clone(),
1121            telemetry_webhook_config: Arc::clone(&self.telemetry_webhook_config),
1122        }
1123    }
1124}
1125
1126impl Clone for RedactionEngine {
1127    fn clone(&self) -> Self {
1128        RedactionEngine::new()
1129    }
1130}