ai_proxy/
server.rs

1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use hyper::service::{make_service_fn, service_fn};
7use hyper::{Body, Client, Method, Request, Response, Server, StatusCode};
8use hyper_tls::HttpsConnector;
9use serde_json::Value;
10use tokio::sync::{Mutex, RwLock};
11use tracing::{info, warn};
12
13use crate::config::ConfigManager;
14use crate::redaction::{RedactionEngine, RedactionService};
15use crate::{Error, Result};
16
17#[derive(Debug, Clone)]
18struct SseAccumulator {
19    buffer: String,
20}
21
22pub struct ProxyServer {
23    port: u16,
24    request_count: Arc<Mutex<u64>>,
25    redaction_engine: RedactionEngine,
26    redaction_service: RedactionService,
27    config_manager: Arc<RwLock<ConfigManager>>,
28    client: Client<HttpsConnector<hyper::client::HttpConnector>>,
29    sse_content_accumulators: Arc<RwLock<HashMap<u64, SseAccumulator>>>,
30}
31
32impl ProxyServer {
33    pub async fn new(port: u16, config_path: Option<String>, redaction_api_url: Option<String>) -> Result<Self> {
34        let https = HttpsConnector::new();
35        let client = Client::builder().build::<_, hyper::Body>(https);
36        
37        let mut config_manager = ConfigManager::new_with_path(config_path);
38        if let Err(e) = config_manager.load_config().await {
39            warn!("Warning: Could not load config file, using defaults: {}", e);
40        }
41
42        Ok(Self {
43            port,
44            request_count: Arc::new(Mutex::new(0)),
45            redaction_engine: RedactionEngine::new(),
46            redaction_service: RedactionService::new(redaction_api_url),
47            config_manager: Arc::new(RwLock::new(config_manager)),
48            client,
49            sse_content_accumulators: Arc::new(RwLock::new(HashMap::new())),
50        })
51    }
52
53    pub async fn start(&self) -> Result<()> {
54        let addr = SocketAddr::from(([0, 0, 0, 0], self.port));
55        
56        let server_clone = Arc::new(self.clone());
57        let make_svc = make_service_fn(move |_conn| {
58            let server = Arc::clone(&server_clone);
59            async move {
60                Ok::<_, Infallible>(service_fn(move |req| {
61                    let server = Arc::clone(&server);
62                    async move { server.handle_request(req).await }
63                }))
64            }
65        });
66
67        let server = Server::bind(&addr).serve(make_svc);
68        info!("Proxy server listening on http://{}", addr);
69
70        if let Err(e) = server.await {
71            return Err(Error::Server(e.to_string()));
72        }
73
74        Ok(())
75    }
76
77    pub async fn stop(&self) {
78        info!("Stopping proxy server...");
79    }
80
81    async fn handle_request(&self, req: Request<Body>) -> std::result::Result<Response<Body>, Infallible> {
82        let result = self.handle_request_inner(req).await;
83        Ok(result.unwrap_or_else(|e| {
84            Response::builder()
85                .status(StatusCode::INTERNAL_SERVER_ERROR)
86                .body(Body::from(format!("Internal Server Error: {}", e)))
87                .unwrap()
88        }))
89    }
90
91    async fn redact_content(&self, content: &mut Value) -> Result<Value> {
92        match content {
93            Value::String(text) => {
94                let redacted = self.redaction_service.redact_user_prompt(text).await?;
95                Ok(Value::String(redacted))
96            }
97            Value::Array(content_blocks) => {
98                let mut redacted_blocks = Vec::new();
99                for block in content_blocks.iter() {
100                    if let Some(block_obj) = block.as_object() {
101                        if let Some(block_type) = block_obj.get("type") {
102                            if block_type == "text" {
103                                if let Some(text) = block_obj.get("text").and_then(|t| t.as_str()) {
104                                    let redacted_text = self.redaction_service.redact_user_prompt(text).await?;
105                                    let mut new_block = block_obj.clone();
106                                    new_block.insert("text".to_string(), Value::String(redacted_text));
107                                    redacted_blocks.push(Value::Object(new_block));
108                                } else {
109                                    redacted_blocks.push(block.clone());
110                                }
111                            } else {
112                                // Non-text blocks (like tool_result) pass through unchanged
113                                redacted_blocks.push(block.clone());
114                            }
115                        } else {
116                            redacted_blocks.push(block.clone());
117                        }
118                    } else {
119                        redacted_blocks.push(block.clone());
120                    }
121                }
122                Ok(Value::Array(redacted_blocks))
123            }
124            _ => Ok(content.clone()),
125        }
126    }
127
128    async fn handle_request_inner(&self, req: Request<Body>) -> Result<Response<Body>> {
129        // Health check endpoint
130        if req.uri().path() == "/health" && req.method() == Method::GET {
131            let request_count = self.request_count.lock().await;
132            let count = *request_count;
133            drop(request_count);
134
135            let health_data = serde_json::json!({
136                "status": "healthy",
137                "uptime": std::time::SystemTime::now()
138                    .duration_since(std::time::UNIX_EPOCH)
139                    .unwrap()
140                    .as_secs(),
141                "timestamp": chrono::Utc::now(),
142                "requestCount": count
143            });
144
145            return Ok(Response::builder()
146                .status(StatusCode::OK)
147                .header("content-type", "application/json")
148                .body(Body::from(health_data.to_string()))?);
149        }
150
151        let mut request_count = self.request_count.lock().await;
152        *request_count += 1;
153        let request_id = *request_count;
154        drop(request_count);
155
156        // Capture request body
157        let (parts, body) = req.into_parts();
158        let body_bytes = hyper::body::to_bytes(body).await?;
159        let request_body = String::from_utf8_lossy(&body_bytes);
160
161
162        self.process_request_with_config(request_body.to_string(), parts, request_id).await
163    }
164
165    async fn process_request_with_config(
166        &self,
167        request_body: String,
168        parts: http::request::Parts,
169        request_id: u64,
170    ) -> Result<Response<Body>> {
171        // Extract model name from request body and process user message redaction
172        let mut model_name = None;
173        let processed_request_body = if !request_body.is_empty() {
174            if let Ok(mut parsed_body) = serde_json::from_str::<Value>(&request_body) {
175                if let Some(model) = parsed_body.get("model") {
176                    if let Some(model_str) = model.as_str() {
177                        model_name = Some(model_str.to_string());
178                    }
179                }
180
181                // Process user messages for redaction
182                if let Some(messages) = parsed_body.get_mut("messages") {
183                    if let Some(messages_array) = messages.as_array_mut() {
184                        for message in messages_array.iter_mut() {
185                            if let Some(message_obj) = message.as_object_mut() {
186                                if let Some(role) = message_obj.get("role") {
187                                    if role == "user" {
188                                        if let Some(content) = message_obj.get_mut("content") {
189                                            match self.redact_content(content).await {
190                                                Ok(redacted_content) => {
191                                                    *content = redacted_content;
192                                                }
193                                                Err(e) => {
194                                                    warn!("Failed to redact user message: {}", e);
195                                                    // Continue with original content on redaction failure
196                                                }
197                                            }
198                                        }
199                                    }
200                                }
201                            }
202                        }
203                    }
204                }
205
206                serde_json::to_string(&parsed_body).unwrap_or(request_body)
207            } else {
208                request_body
209            }
210        } else {
211            request_body
212        };
213
214        // Parse the target URL - handle relative URLs by prepending API base
215        let target_url = if parts.uri.path().starts_with('/') {
216            // Use model from config - always require model to be specified
217            let model = model_name.ok_or_else(|| {
218                Error::Server("Bad Request: Model must be specified in request body".to_string())
219            })?;
220            
221            let config_manager = self.config_manager.read().await;
222            let model_config = config_manager.get_model_config(&model);
223            let base_url = if model_config.api_base.ends_with('/') {
224                model_config.api_base
225            } else {
226                format!("{}/", model_config.api_base)
227            };
228            
229            
230            let path = if parts.uri.path().starts_with('/') {
231                &parts.uri.path()[1..]
232            } else {
233                parts.uri.path()
234            };
235            
236            let final_url = format!("{}{}{}", base_url, path, 
237                parts.uri.query().map_or(String::new(), |q| format!("?{}", q)));
238            
239            final_url
240        } else {
241            // Absolute URL
242            parts.uri.to_string()
243        };
244
245        // Build the proxied request
246        let reqwest_method = match parts.method.as_str() {
247            "GET" => reqwest::Method::GET,
248            "POST" => reqwest::Method::POST,
249            "PUT" => reqwest::Method::PUT,
250            "DELETE" => reqwest::Method::DELETE,
251            "PATCH" => reqwest::Method::PATCH,
252            "HEAD" => reqwest::Method::HEAD,
253            "OPTIONS" => reqwest::Method::OPTIONS,
254            _ => reqwest::Method::GET, // fallback
255        };
256
257        let mut req_builder = reqwest::Client::new()
258            .request(reqwest_method, &target_url);
259
260        // Copy headers (clean up proxy-specific ones)
261        for (name, value) in parts.headers.iter() {
262            if !matches!(name.as_str(), "host" | "proxy-connection" | "proxy-authorization") {
263                if let Ok(header_value) = value.to_str() {
264                    req_builder = req_builder.header(name.as_str(), header_value);
265                }
266            }
267        }
268
269        // Add request body (using processed body with redacted user messages)
270        if !processed_request_body.is_empty() {
271            req_builder = req_builder.body(processed_request_body.clone());
272        }
273
274        // Make the proxied request
275        let proxy_response = req_builder.send().await?;
276        
277
278        // Check if this is an SSE response
279        let is_sse = proxy_response
280            .headers()
281            .get("content-type")
282            .and_then(|v| v.to_str().ok())
283            .map_or(false, |ct| ct.contains("text/event-stream"));
284
285        if is_sse {
286            self.handle_sse_response(proxy_response, &target_url, request_id).await
287        } else {
288            // Handle regular responses
289            let status = proxy_response.status();
290            let response_body = proxy_response.text().await?;
291            let redacted_response = self.redaction_engine.redact_sensitive_content(&response_body);
292            
293
294            Ok(Response::builder()
295                .status(status)
296                .header("content-type", "application/json")
297                .body(Body::from(redacted_response))?)
298        }
299    }
300
301    async fn handle_sse_response(
302        &self,
303        proxy_response: reqwest::Response,
304        target_url: &str,
305        request_id: u64,
306    ) -> Result<Response<Body>> {
307        use futures::StreamExt;
308        
309
310        // Initialize SSE accumulator for this request
311        self.sse_content_accumulators.write().await.insert(
312            request_id,
313            SseAccumulator {
314                buffer: String::new(),
315            },
316        );
317
318        // Detect if this is OpenAI format
319        let is_openai_format = target_url.contains("openai.com") || target_url.contains("/responses");
320
321        // Create streaming body
322        let (tx, rx) = tokio::sync::mpsc::channel::<std::result::Result<bytes::Bytes, Box<dyn std::error::Error + Send + Sync>>>(100);
323
324        // Clone necessary data for the async task
325        let redaction_engine = self.redaction_engine.clone();
326        let sse_accumulators = Arc::clone(&self.sse_content_accumulators);
327
328        // Spawn task to process SSE stream
329        tokio::spawn(async move {
330            let mut event_buffer = String::new();
331            let mut stream = proxy_response.bytes_stream();
332
333            while let Some(chunk_result) = stream.next().await {
334                match chunk_result {
335                    Ok(chunk) => {
336                        let chunk_str = String::from_utf8_lossy(&chunk);
337                        event_buffer.push_str(&chunk_str);
338
339                        // Split by double newlines to get complete events
340                        let buffer_clone = event_buffer.clone();
341                        let events: Vec<&str> = buffer_clone.split("\n\n").collect();
342                        
343                        if events.len() > 1 {
344                            // Keep the last (potentially incomplete) event in buffer
345                            event_buffer = events.last().unwrap_or(&"").to_string();
346
347                            // Process complete events
348                            for event_data in &events[..events.len() - 1] {
349                                if !event_data.trim().is_empty() {
350                                    if let Some(processed_event) = Self::process_sse_event(
351                                        event_data,
352                                        request_id,
353                                        is_openai_format,
354                                        &redaction_engine,
355                                        &sse_accumulators,
356                                    ).await {
357                                        let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
358                                        if tx.send(Ok(event_bytes)).await.is_err() {
359                                            break;
360                                        }
361                                    }
362                                }
363                            }
364                        }
365                    }
366                    Err(e) => {
367                        let _ = tx.send(Err(Box::new(e) as Box<dyn std::error::Error + Send + Sync>)).await;
368                        break;
369                    }
370                }
371            }
372
373            // Handle any remaining buffered event
374            if !event_buffer.trim().is_empty() {
375                if let Some(processed_event) = Self::process_sse_event(
376                    &event_buffer,
377                    request_id,
378                    is_openai_format,
379                    &redaction_engine,
380                    &sse_accumulators,
381                ).await {
382                    let event_bytes = bytes::Bytes::from(format!("{}\n\n", processed_event));
383                    let _ = tx.send(Ok(event_bytes)).await;
384                }
385            }
386
387            // Clean up
388            sse_accumulators.write().await.remove(&request_id);
389        });
390
391        // Create streaming response
392        let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
393        let body = Body::wrap_stream(stream.map(|item| {
394            item.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
395        }));
396
397        Ok(Response::builder()
398            .status(StatusCode::OK)
399            .header("content-type", "text/event-stream")
400            .header("cache-control", "no-cache")
401            .header("connection", "keep-alive")
402            .body(body)?)
403    }
404
405    async fn process_sse_event(
406        event_data: &str,
407        request_id: u64,
408        is_openai_format: bool,
409        redaction_engine: &RedactionEngine,
410        sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
411    ) -> Option<String> {
412        let lines: Vec<&str> = event_data.split('\n').collect();
413        let mut event_type = String::new();
414        let mut event_data_json: Option<serde_json::Value> = None;
415
416        // Parse SSE event
417        for line in lines {
418            if let Some(event_value) = line.strip_prefix("event: ") {
419                event_type = event_value.to_string();
420            } else if let Some(data_value) = line.strip_prefix("data: ") {
421                if let Ok(data) = serde_json::from_str::<serde_json::Value>(data_value) {
422                    event_data_json = Some(data);
423                } else {
424                    // Handle non-JSON data (like [DONE])
425                    if data_value == "[DONE]" {
426                        return Some(event_data.to_string());
427                    }
428                }
429            }
430        }
431
432        // Extract text content for redaction
433        let mut should_redact_and_forward = false;
434        let mut text_content = String::new();
435
436        if let Some(data) = &event_data_json {
437            
438            if is_openai_format {
439                // Handle OpenAI streaming format
440                if event_type == "response.output_text.delta" {
441                    if let Some(delta) = data.get("delta").and_then(|d| d.as_str()) {
442                        text_content = delta.to_string();
443                        should_redact_and_forward = true;
444                    }
445                } else if event_type == "response.output_text.done" {
446                    if let Some(text) = data.get("text").and_then(|t| t.as_str()) {
447                        text_content = text.to_string();
448                        should_redact_and_forward = true;
449                    }
450                }
451            } else {
452                // Handle Claude streaming format  
453                if event_type == "content_block_delta" {
454                    if let Some(delta) = data.get("delta") {
455                        if let Some(text) = delta.get("text").and_then(|t| t.as_str()) {
456                            text_content = text.to_string();
457                            should_redact_and_forward = true;
458                        }
459                    }
460                }
461            }
462            
463            if !should_redact_and_forward {
464            }
465        }
466
467        if should_redact_and_forward && !text_content.is_empty() {
468            
469            // Apply redaction with buffering (same as Node.js)
470            let redacted_content = if is_openai_format && event_type == "response.output_text.done" {
471                // For complete text events, apply redaction directly
472                let redacted = redaction_engine.redact_sensitive_content(&text_content);
473                Some(redacted)
474            } else {
475                // For streaming deltas, use buffering
476                let result = Self::process_chunk_with_buffer(&text_content, request_id, redaction_engine, sse_accumulators).await;
477                result
478            };
479
480            if let Some(redacted) = redacted_content {
481                if let Some(mut data) = event_data_json {
482                    if is_openai_format {
483                        // Update OpenAI format
484                        if event_type == "response.output_text.delta" {
485                            data["delta"] = serde_json::Value::String(redacted);
486                        } else if event_type == "response.output_text.done" {
487                            data["text"] = serde_json::Value::String(redacted);
488                        }
489                    } else {
490                        // Update Claude format
491                        if let Some(delta) = data.get_mut("delta") {
492                            delta["text"] = serde_json::Value::String(redacted);
493                        }
494                    }
495
496                    return Some(format!("event: {}\ndata: {}", event_type, data));
497                }
498            } else {
499                // Don't forward this event since we're buffering
500                return None;
501            }
502        } else if event_type == "content_block_stop" && !is_openai_format {
503            // Claude-specific: Flush any remaining buffer content
504            if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
505                let final_event_data = serde_json::json!({
506                    "type": "content_block_delta",
507                    "index": 0,
508                    "delta": {
509                        "type": "text_delta",
510                        "text": final_chunk
511                    }
512                });
513                return Some(format!("event: content_block_delta\ndata: {}\n\nevent: content_block_stop\ndata: {}", 
514                    final_event_data, event_data_json.unwrap_or(serde_json::Value::Null)));
515            }
516        } else if is_openai_format && (event_type == "response.completed" || event_data.contains("[DONE]")) {
517            // OpenAI-specific: Handle end of stream and flush buffer
518            if let Some(final_chunk) = Self::flush_buffer(request_id, redaction_engine, sse_accumulators).await {
519                let final_data = serde_json::json!({
520                    "type": "response.output_text.delta",
521                    "delta": final_chunk
522                });
523                return Some(format!("event: response.output_text.delta\ndata: {}\n\n{}", 
524                    final_data, event_data));
525            }
526        }
527
528        // Forward all other events as-is
529        Some(event_data.to_string())
530    }
531
532    async fn process_chunk_with_buffer(
533        new_chunk: &str,
534        request_id: u64,
535        redaction_engine: &RedactionEngine,
536        sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
537    ) -> Option<String> {
538        let mut accumulators = sse_accumulators.write().await;
539        if let Some(accumulator) = accumulators.get_mut(&request_id) {
540            // Add new chunk to buffer
541            accumulator.buffer.push_str(new_chunk);
542
543            // Split by lines and process complete lines
544            let buffer_copy = accumulator.buffer.clone();
545            let mut lines: Vec<&str> = buffer_copy.split('\n').collect();
546            
547            // Keep the last (potentially incomplete) line in buffer (like Node.js pop())
548            accumulator.buffer = lines.pop().unwrap_or("").to_string();
549            
550            // Process complete lines
551            if !lines.is_empty() {
552                let complete_lines = lines.join("\n") + "\n";
553                let redacted = redaction_engine.redact_sensitive_content(&complete_lines);
554                return Some(redacted);
555            }
556            
557            // No complete lines yet, don't send anything
558            return None;
559        }
560        None
561    }
562
563    async fn flush_buffer(
564        request_id: u64,
565        redaction_engine: &RedactionEngine,
566        sse_accumulators: &Arc<RwLock<HashMap<u64, SseAccumulator>>>,
567    ) -> Option<String> {
568        let mut accumulators = sse_accumulators.write().await;
569        if let Some(accumulator) = accumulators.get_mut(&request_id) {
570            if accumulator.buffer.is_empty() {
571                return None;
572            }
573            
574            // Process remaining buffer content
575            let remaining = accumulator.buffer.clone();
576            accumulator.buffer.clear();
577            
578            let redacted = redaction_engine.redact_sensitive_content(&remaining);
579            Some(redacted)
580        } else {
581            None
582        }
583    }
584}
585
586impl Clone for ProxyServer {
587    fn clone(&self) -> Self {
588        let https = HttpsConnector::new();
589        let client = Client::builder().build::<_, hyper::Body>(https);
590        
591        Self {
592            port: self.port,
593            request_count: Arc::clone(&self.request_count),
594            redaction_engine: self.redaction_engine.clone(),
595            redaction_service: self.redaction_service.clone(),
596            config_manager: Arc::clone(&self.config_manager),
597            client,
598            sse_content_accumulators: Arc::clone(&self.sse_content_accumulators),
599        }
600    }
601}
602
603impl Clone for RedactionEngine {
604    fn clone(&self) -> Self {
605        RedactionEngine::new()
606    }
607}