llm_edge_agent/
proxy.rs

1//! Proxy request handler
2//!
3//! This module implements the complete request/response flow through all layers:
4//! 1. Request validation and transformation
5//! 2. Cache lookup (L1/L2)
6//! 3. Shield/PII detection (if cache miss)
7//! 4. Provider routing decision
8//! 5. Provider request execution
9//! 6. Response validation
10//! 7. Cache write (async)
11//! 8. Response transformation and return
12
13use axum::{
14    extract::State,
15    http::StatusCode,
16    response::{IntoResponse, Response},
17    Json,
18};
19use llm_edge_cache::CacheLookupResult;
20use llm_edge_monitoring::metrics;
21use llm_edge_providers::{LLMProvider, UnifiedRequest, UnifiedResponse};
22use serde::{Deserialize, Serialize};
23use std::sync::Arc;
24use std::time::Instant;
25use tracing::{debug, error, info, instrument, warn};
26use uuid::Uuid;
27
28use crate::integration::AppState;
29
30/// OpenAI-compatible chat completion request
31#[derive(Debug, Clone, Deserialize)]
32pub struct ChatCompletionRequest {
33    pub model: String,
34    pub messages: Vec<ChatMessage>,
35    #[serde(default)]
36    pub temperature: Option<f32>,
37    #[serde(default)]
38    pub max_tokens: Option<u32>,
39    #[serde(default)]
40    pub stream: bool,
41}
42
43#[derive(Debug, Clone, Deserialize, Serialize)]
44pub struct ChatMessage {
45    pub role: String,
46    pub content: String,
47}
48
49/// OpenAI-compatible chat completion response
50#[derive(Debug, Serialize)]
51pub struct ChatCompletionResponse {
52    pub id: String,
53    pub object: String,
54    pub created: i64,
55    pub model: String,
56    pub choices: Vec<ChatChoice>,
57    pub usage: Usage,
58    #[serde(skip_serializing_if = "Option::is_none")]
59    pub metadata: Option<ResponseMetadata>,
60}
61
62#[derive(Debug, Serialize)]
63pub struct ChatChoice {
64    pub index: u32,
65    pub message: ChatMessage,
66    pub finish_reason: String,
67}
68
69#[derive(Debug, Serialize)]
70pub struct Usage {
71    pub prompt_tokens: u32,
72    pub completion_tokens: u32,
73    pub total_tokens: u32,
74}
75
76#[derive(Debug, Serialize)]
77pub struct ResponseMetadata {
78    pub provider: String,
79    pub cached: bool,
80    pub cache_tier: Option<String>,
81    pub latency_ms: u64,
82    pub cost_usd: Option<f64>,
83}
84
85/// Error type for proxy operations
86#[derive(Debug)]
87pub enum ProxyError {
88    CacheError(String),
89    ProviderError(String),
90    ValidationError(String),
91    InternalError(String),
92}
93
94impl IntoResponse for ProxyError {
95    fn into_response(self) -> Response {
96        let (status, message) = match self {
97            ProxyError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg),
98            ProxyError::ProviderError(msg) => (StatusCode::BAD_GATEWAY, msg),
99            ProxyError::CacheError(msg) => (
100                StatusCode::INTERNAL_SERVER_ERROR,
101                format!("Cache error: {}", msg),
102            ),
103            ProxyError::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
104        };
105
106        let body = serde_json::json!({
107            "error": {
108                "message": message,
109                "type": "proxy_error",
110            }
111        });
112
113        (status, Json(body)).into_response()
114    }
115}
116
117/// Main chat completions proxy handler
118///
119/// This is the core handler that processes all chat completion requests.
120/// It orchestrates the entire request flow through caching, routing, and provider layers.
121#[instrument(name = "proxy_chat_completions", skip(state, request), fields(
122    request_id = %Uuid::new_v4(),
123    model = %request.model,
124    message_count = request.messages.len(),
125))]
126pub async fn handle_chat_completions(
127    State(state): State<Arc<AppState>>,
128    Json(request): Json<ChatCompletionRequest>,
129) -> Result<Json<ChatCompletionResponse>, ProxyError> {
130    let start_time = Instant::now();
131    let request_id = Uuid::new_v4().to_string();
132
133    info!(
134        request_id = %request_id,
135        model = %request.model,
136        "Processing chat completion request"
137    );
138
139    // Step 1: Validate request
140    validate_request(&request)?;
141
142    // Step 2: Convert to cacheable format
143    let cacheable_req = convert_to_cacheable(&request);
144
145    // Step 3: Check cache (L1 -> L2)
146    let cache_lookup = state.cache_manager.lookup(&cacheable_req).await;
147
148    match cache_lookup {
149        CacheLookupResult::L1Hit(cached_response) => {
150            info!(request_id = %request_id, "Cache HIT: L1");
151            metrics::record_cache_hit("l1");
152
153            let response = build_response_from_cache(
154                &request,
155                &cached_response,
156                "l1",
157                start_time.elapsed().as_millis() as u64,
158            );
159
160            return Ok(Json(response));
161        }
162        CacheLookupResult::L2Hit(cached_response) => {
163            info!(request_id = %request_id, "Cache HIT: L2");
164            metrics::record_cache_hit("l2");
165
166            let response = build_response_from_cache(
167                &request,
168                &cached_response,
169                "l2",
170                start_time.elapsed().as_millis() as u64,
171            );
172
173            return Ok(Json(response));
174        }
175        CacheLookupResult::Miss => {
176            debug!(request_id = %request_id, "Cache MISS - routing to provider");
177            metrics::record_cache_miss("all");
178        }
179    }
180
181    // Step 4: Route to provider
182    let (provider, provider_name) = select_provider(&state, &request)?;
183
184    // Step 5: Convert to unified request format
185    let unified_request = convert_to_unified(&request);
186
187    // Step 6: Send to provider
188    info!(
189        request_id = %request_id,
190        provider = %provider_name,
191        "Sending request to provider"
192    );
193
194    let provider_start = Instant::now();
195    let provider_response = provider.send(unified_request).await.map_err(|e| {
196        error!(
197            request_id = %request_id,
198            provider = %provider_name,
199            error = %e,
200            "Provider request failed"
201        );
202        metrics::record_request_failure(&provider_name, &request.model, "provider_error");
203        ProxyError::ProviderError(format!("Provider error: {}", e))
204    })?;
205
206    let provider_latency = provider_start.elapsed().as_millis() as u64;
207
208    // Step 7: Calculate cost
209    let cost_usd = calculate_cost(&provider, &request.model, &provider_response);
210
211    // Step 8: Record metrics
212    metrics::record_request_success(&provider_name, &request.model, provider_latency);
213    metrics::record_token_usage(
214        &provider_name,
215        &request.model,
216        provider_response.usage.prompt_tokens,
217        provider_response.usage.completion_tokens,
218    );
219    if let Some(cost) = cost_usd {
220        metrics::record_cost(&provider_name, &request.model, cost);
221    }
222
223    // Step 9: Store in cache (async, non-blocking)
224    let cache_response = convert_provider_to_cache(&provider_response);
225    tokio::spawn({
226        let cache_manager = state.cache_manager.clone();
227        let cacheable_req = cacheable_req.clone();
228        async move {
229            cache_manager.store(&cacheable_req, cache_response).await;
230        }
231    });
232
233    // Step 10: Build and return response
234    let total_latency = start_time.elapsed().as_millis() as u64;
235    let response = build_response_from_provider(
236        &request,
237        provider_response,
238        &provider_name,
239        total_latency,
240        cost_usd,
241    );
242
243    info!(
244        request_id = %request_id,
245        provider = %provider_name,
246        total_latency_ms = total_latency,
247        provider_latency_ms = provider_latency,
248        "Request completed successfully"
249    );
250
251    Ok(Json(response))
252}
253
254/// Validate the incoming request
255fn validate_request(request: &ChatCompletionRequest) -> Result<(), ProxyError> {
256    if request.model.is_empty() {
257        return Err(ProxyError::ValidationError("Model is required".to_string()));
258    }
259
260    if request.messages.is_empty() {
261        return Err(ProxyError::ValidationError(
262            "Messages cannot be empty".to_string(),
263        ));
264    }
265
266    if request.stream {
267        return Err(ProxyError::ValidationError(
268            "Streaming is not yet supported".to_string(),
269        ));
270    }
271
272    Ok(())
273}
274
275/// Convert chat completion request to cacheable format
276fn convert_to_cacheable(request: &ChatCompletionRequest) -> llm_edge_cache::key::CacheableRequest {
277    // Concatenate all messages into a single prompt for caching
278    let prompt = request
279        .messages
280        .iter()
281        .map(|m| format!("{}: {}", m.role, m.content))
282        .collect::<Vec<_>>()
283        .join("\n");
284
285    let mut cacheable = llm_edge_cache::key::CacheableRequest::new(&request.model, prompt);
286
287    if let Some(temp) = request.temperature {
288        cacheable = cacheable.with_temperature(temp);
289    }
290
291    if let Some(max_tokens) = request.max_tokens {
292        cacheable = cacheable.with_max_tokens(max_tokens);
293    }
294
295    cacheable
296}
297
298/// Convert chat completion request to unified format
299fn convert_to_unified(request: &ChatCompletionRequest) -> UnifiedRequest {
300    use std::collections::HashMap;
301
302    UnifiedRequest {
303        model: request.model.clone(),
304        messages: request
305            .messages
306            .iter()
307            .map(|m| llm_edge_providers::Message {
308                role: m.role.clone(),
309                content: m.content.clone(),
310            })
311            .collect(),
312        temperature: request.temperature,
313        max_tokens: request.max_tokens.map(|t| t as usize),
314        stream: request.stream,
315        metadata: HashMap::new(),
316    }
317}
318
319/// Select the appropriate provider for the request
320fn select_provider(
321    state: &AppState,
322    request: &ChatCompletionRequest,
323) -> Result<(Arc<dyn LLMProvider>, String), ProxyError> {
324    // For MVP, use simple model-based routing
325    // In production, this would use the routing engine
326
327    let model_lower = request.model.to_lowercase();
328
329    if model_lower.contains("gpt") || model_lower.contains("openai") {
330        if let Some(provider) = &state.openai_provider {
331            return Ok((provider.clone(), "openai".to_string()));
332        }
333    }
334
335    if model_lower.contains("claude") || model_lower.contains("anthropic") {
336        if let Some(provider) = &state.anthropic_provider {
337            return Ok((provider.clone(), "anthropic".to_string()));
338        }
339    }
340
341    // Fallback to first available provider
342    if let Some(provider) = &state.openai_provider {
343        warn!("Using fallback provider: openai");
344        return Ok((provider.clone(), "openai".to_string()));
345    }
346
347    if let Some(provider) = &state.anthropic_provider {
348        warn!("Using fallback provider: anthropic");
349        return Ok((provider.clone(), "anthropic".to_string()));
350    }
351
352    Err(ProxyError::InternalError(
353        "No providers configured".to_string(),
354    ))
355}
356
357/// Calculate the cost of a request
358fn calculate_cost(
359    provider: &Arc<dyn LLMProvider>,
360    model: &str,
361    response: &UnifiedResponse,
362) -> Option<f64> {
363    provider.get_pricing(model).map(|pricing| {
364        let input_cost = (response.usage.prompt_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
365        let output_cost =
366            (response.usage.completion_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
367        input_cost + output_cost
368    })
369}
370
371/// Build response from cached data
372fn build_response_from_cache(
373    request: &ChatCompletionRequest,
374    cached: &llm_edge_cache::l1::CachedResponse,
375    cache_tier: &str,
376    latency_ms: u64,
377) -> ChatCompletionResponse {
378    ChatCompletionResponse {
379        id: format!("chatcmpl-{}", Uuid::new_v4()),
380        object: "chat.completion".to_string(),
381        created: chrono::Utc::now().timestamp(),
382        model: request.model.clone(),
383        choices: vec![ChatChoice {
384            index: 0,
385            message: ChatMessage {
386                role: "assistant".to_string(),
387                content: cached.content.clone(),
388            },
389            finish_reason: "stop".to_string(),
390        }],
391        usage: Usage {
392            prompt_tokens: cached.tokens.as_ref().map(|t| t.prompt_tokens).unwrap_or(0),
393            completion_tokens: cached
394                .tokens
395                .as_ref()
396                .map(|t| t.completion_tokens)
397                .unwrap_or(0),
398            total_tokens: cached.tokens.as_ref().map(|t| t.total_tokens).unwrap_or(0),
399        },
400        metadata: Some(ResponseMetadata {
401            provider: "cache".to_string(),
402            cached: true,
403            cache_tier: Some(cache_tier.to_string()),
404            latency_ms,
405            cost_usd: Some(0.0), // Cached responses have zero cost
406        }),
407    }
408}
409
410/// Convert provider response to cache format
411fn convert_provider_to_cache(response: &UnifiedResponse) -> llm_edge_cache::l1::CachedResponse {
412    let content = response
413        .choices
414        .first()
415        .map(|c| c.message.content.clone())
416        .unwrap_or_default();
417
418    llm_edge_cache::l1::CachedResponse {
419        content,
420        tokens: Some(llm_edge_cache::l1::TokenUsage {
421            prompt_tokens: response.usage.prompt_tokens as u32,
422            completion_tokens: response.usage.completion_tokens as u32,
423            total_tokens: response.usage.total_tokens as u32,
424        }),
425        model: response.model.clone(),
426        cached_at: chrono::Utc::now().timestamp(),
427    }
428}
429
430/// Build response from provider data
431fn build_response_from_provider(
432    request: &ChatCompletionRequest,
433    provider_response: UnifiedResponse,
434    provider_name: &str,
435    latency_ms: u64,
436    cost_usd: Option<f64>,
437) -> ChatCompletionResponse {
438    ChatCompletionResponse {
439        id: provider_response.id,
440        object: "chat.completion".to_string(),
441        created: chrono::Utc::now().timestamp(),
442        model: request.model.clone(),
443        choices: provider_response
444            .choices
445            .into_iter()
446            .map(|c| ChatChoice {
447                index: c.index as u32,
448                message: ChatMessage {
449                    role: c.message.role,
450                    content: c.message.content,
451                },
452                finish_reason: c.finish_reason.unwrap_or_else(|| "stop".to_string()),
453            })
454            .collect(),
455        usage: Usage {
456            prompt_tokens: provider_response.usage.prompt_tokens as u32,
457            completion_tokens: provider_response.usage.completion_tokens as u32,
458            total_tokens: provider_response.usage.total_tokens as u32,
459        },
460        metadata: Some(ResponseMetadata {
461            provider: provider_name.to_string(),
462            cached: false,
463            cache_tier: None,
464            latency_ms,
465            cost_usd,
466        }),
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_validate_request_valid() {
476        let request = ChatCompletionRequest {
477            model: "gpt-4".to_string(),
478            messages: vec![ChatMessage {
479                role: "user".to_string(),
480                content: "Hello".to_string(),
481            }],
482            temperature: Some(0.7),
483            max_tokens: Some(100),
484            stream: false,
485        };
486
487        assert!(validate_request(&request).is_ok());
488    }
489
490    #[test]
491    fn test_validate_request_empty_model() {
492        let request = ChatCompletionRequest {
493            model: "".to_string(),
494            messages: vec![ChatMessage {
495                role: "user".to_string(),
496                content: "Hello".to_string(),
497            }],
498            temperature: None,
499            max_tokens: None,
500            stream: false,
501        };
502
503        assert!(validate_request(&request).is_err());
504    }
505
506    #[test]
507    fn test_validate_request_empty_messages() {
508        let request = ChatCompletionRequest {
509            model: "gpt-4".to_string(),
510            messages: vec![],
511            temperature: None,
512            max_tokens: None,
513            stream: false,
514        };
515
516        assert!(validate_request(&request).is_err());
517    }
518
519    #[test]
520    fn test_convert_to_cacheable() {
521        let request = ChatCompletionRequest {
522            model: "gpt-4".to_string(),
523            messages: vec![
524                ChatMessage {
525                    role: "user".to_string(),
526                    content: "Hello".to_string(),
527                },
528                ChatMessage {
529                    role: "assistant".to_string(),
530                    content: "Hi".to_string(),
531                },
532            ],
533            temperature: Some(0.7),
534            max_tokens: Some(100),
535            stream: false,
536        };
537
538        let cacheable = convert_to_cacheable(&request);
539        assert_eq!(cacheable.model, "gpt-4");
540        assert_eq!(cacheable.temperature, Some(0.7));
541        assert_eq!(cacheable.max_tokens, Some(100));
542    }
543}