1use axum::{
14 extract::State,
15 http::StatusCode,
16 response::{IntoResponse, Response},
17 Json,
18};
19use llm_edge_cache::CacheLookupResult;
20use llm_edge_monitoring::metrics;
21use llm_edge_providers::{LLMProvider, UnifiedRequest, UnifiedResponse};
22use serde::{Deserialize, Serialize};
23use std::sync::Arc;
24use std::time::Instant;
25use tracing::{debug, error, info, instrument, warn};
26use uuid::Uuid;
27
28use crate::integration::AppState;
29
30#[derive(Debug, Clone, Deserialize)]
32pub struct ChatCompletionRequest {
33 pub model: String,
34 pub messages: Vec<ChatMessage>,
35 #[serde(default)]
36 pub temperature: Option<f32>,
37 #[serde(default)]
38 pub max_tokens: Option<u32>,
39 #[serde(default)]
40 pub stream: bool,
41}
42
43#[derive(Debug, Clone, Deserialize, Serialize)]
44pub struct ChatMessage {
45 pub role: String,
46 pub content: String,
47}
48
49#[derive(Debug, Serialize)]
51pub struct ChatCompletionResponse {
52 pub id: String,
53 pub object: String,
54 pub created: i64,
55 pub model: String,
56 pub choices: Vec<ChatChoice>,
57 pub usage: Usage,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 pub metadata: Option<ResponseMetadata>,
60}
61
62#[derive(Debug, Serialize)]
63pub struct ChatChoice {
64 pub index: u32,
65 pub message: ChatMessage,
66 pub finish_reason: String,
67}
68
69#[derive(Debug, Serialize)]
70pub struct Usage {
71 pub prompt_tokens: u32,
72 pub completion_tokens: u32,
73 pub total_tokens: u32,
74}
75
76#[derive(Debug, Serialize)]
77pub struct ResponseMetadata {
78 pub provider: String,
79 pub cached: bool,
80 pub cache_tier: Option<String>,
81 pub latency_ms: u64,
82 pub cost_usd: Option<f64>,
83}
84
85#[derive(Debug)]
87pub enum ProxyError {
88 CacheError(String),
89 ProviderError(String),
90 ValidationError(String),
91 InternalError(String),
92}
93
94impl IntoResponse for ProxyError {
95 fn into_response(self) -> Response {
96 let (status, message) = match self {
97 ProxyError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg),
98 ProxyError::ProviderError(msg) => (StatusCode::BAD_GATEWAY, msg),
99 ProxyError::CacheError(msg) => (
100 StatusCode::INTERNAL_SERVER_ERROR,
101 format!("Cache error: {}", msg),
102 ),
103 ProxyError::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg),
104 };
105
106 let body = serde_json::json!({
107 "error": {
108 "message": message,
109 "type": "proxy_error",
110 }
111 });
112
113 (status, Json(body)).into_response()
114 }
115}
116
117#[instrument(name = "proxy_chat_completions", skip(state, request), fields(
122 request_id = %Uuid::new_v4(),
123 model = %request.model,
124 message_count = request.messages.len(),
125))]
126pub async fn handle_chat_completions(
127 State(state): State<Arc<AppState>>,
128 Json(request): Json<ChatCompletionRequest>,
129) -> Result<Json<ChatCompletionResponse>, ProxyError> {
130 let start_time = Instant::now();
131 let request_id = Uuid::new_v4().to_string();
132
133 info!(
134 request_id = %request_id,
135 model = %request.model,
136 "Processing chat completion request"
137 );
138
139 validate_request(&request)?;
141
142 let cacheable_req = convert_to_cacheable(&request);
144
145 let cache_lookup = state.cache_manager.lookup(&cacheable_req).await;
147
148 match cache_lookup {
149 CacheLookupResult::L1Hit(cached_response) => {
150 info!(request_id = %request_id, "Cache HIT: L1");
151 metrics::record_cache_hit("l1");
152
153 let response = build_response_from_cache(
154 &request,
155 &cached_response,
156 "l1",
157 start_time.elapsed().as_millis() as u64,
158 );
159
160 return Ok(Json(response));
161 }
162 CacheLookupResult::L2Hit(cached_response) => {
163 info!(request_id = %request_id, "Cache HIT: L2");
164 metrics::record_cache_hit("l2");
165
166 let response = build_response_from_cache(
167 &request,
168 &cached_response,
169 "l2",
170 start_time.elapsed().as_millis() as u64,
171 );
172
173 return Ok(Json(response));
174 }
175 CacheLookupResult::Miss => {
176 debug!(request_id = %request_id, "Cache MISS - routing to provider");
177 metrics::record_cache_miss("all");
178 }
179 }
180
181 let (provider, provider_name) = select_provider(&state, &request)?;
183
184 let unified_request = convert_to_unified(&request);
186
187 info!(
189 request_id = %request_id,
190 provider = %provider_name,
191 "Sending request to provider"
192 );
193
194 let provider_start = Instant::now();
195 let provider_response = provider.send(unified_request).await.map_err(|e| {
196 error!(
197 request_id = %request_id,
198 provider = %provider_name,
199 error = %e,
200 "Provider request failed"
201 );
202 metrics::record_request_failure(&provider_name, &request.model, "provider_error");
203 ProxyError::ProviderError(format!("Provider error: {}", e))
204 })?;
205
206 let provider_latency = provider_start.elapsed().as_millis() as u64;
207
208 let cost_usd = calculate_cost(&provider, &request.model, &provider_response);
210
211 metrics::record_request_success(&provider_name, &request.model, provider_latency);
213 metrics::record_token_usage(
214 &provider_name,
215 &request.model,
216 provider_response.usage.prompt_tokens,
217 provider_response.usage.completion_tokens,
218 );
219 if let Some(cost) = cost_usd {
220 metrics::record_cost(&provider_name, &request.model, cost);
221 }
222
223 let cache_response = convert_provider_to_cache(&provider_response);
225 tokio::spawn({
226 let cache_manager = state.cache_manager.clone();
227 let cacheable_req = cacheable_req.clone();
228 async move {
229 cache_manager.store(&cacheable_req, cache_response).await;
230 }
231 });
232
233 let total_latency = start_time.elapsed().as_millis() as u64;
235 let response = build_response_from_provider(
236 &request,
237 provider_response,
238 &provider_name,
239 total_latency,
240 cost_usd,
241 );
242
243 info!(
244 request_id = %request_id,
245 provider = %provider_name,
246 total_latency_ms = total_latency,
247 provider_latency_ms = provider_latency,
248 "Request completed successfully"
249 );
250
251 Ok(Json(response))
252}
253
254fn validate_request(request: &ChatCompletionRequest) -> Result<(), ProxyError> {
256 if request.model.is_empty() {
257 return Err(ProxyError::ValidationError("Model is required".to_string()));
258 }
259
260 if request.messages.is_empty() {
261 return Err(ProxyError::ValidationError(
262 "Messages cannot be empty".to_string(),
263 ));
264 }
265
266 if request.stream {
267 return Err(ProxyError::ValidationError(
268 "Streaming is not yet supported".to_string(),
269 ));
270 }
271
272 Ok(())
273}
274
275fn convert_to_cacheable(request: &ChatCompletionRequest) -> llm_edge_cache::key::CacheableRequest {
277 let prompt = request
279 .messages
280 .iter()
281 .map(|m| format!("{}: {}", m.role, m.content))
282 .collect::<Vec<_>>()
283 .join("\n");
284
285 let mut cacheable = llm_edge_cache::key::CacheableRequest::new(&request.model, prompt);
286
287 if let Some(temp) = request.temperature {
288 cacheable = cacheable.with_temperature(temp);
289 }
290
291 if let Some(max_tokens) = request.max_tokens {
292 cacheable = cacheable.with_max_tokens(max_tokens);
293 }
294
295 cacheable
296}
297
298fn convert_to_unified(request: &ChatCompletionRequest) -> UnifiedRequest {
300 use std::collections::HashMap;
301
302 UnifiedRequest {
303 model: request.model.clone(),
304 messages: request
305 .messages
306 .iter()
307 .map(|m| llm_edge_providers::Message {
308 role: m.role.clone(),
309 content: m.content.clone(),
310 })
311 .collect(),
312 temperature: request.temperature,
313 max_tokens: request.max_tokens.map(|t| t as usize),
314 stream: request.stream,
315 metadata: HashMap::new(),
316 }
317}
318
319fn select_provider(
321 state: &AppState,
322 request: &ChatCompletionRequest,
323) -> Result<(Arc<dyn LLMProvider>, String), ProxyError> {
324 let model_lower = request.model.to_lowercase();
328
329 if model_lower.contains("gpt") || model_lower.contains("openai") {
330 if let Some(provider) = &state.openai_provider {
331 return Ok((provider.clone(), "openai".to_string()));
332 }
333 }
334
335 if model_lower.contains("claude") || model_lower.contains("anthropic") {
336 if let Some(provider) = &state.anthropic_provider {
337 return Ok((provider.clone(), "anthropic".to_string()));
338 }
339 }
340
341 if let Some(provider) = &state.openai_provider {
343 warn!("Using fallback provider: openai");
344 return Ok((provider.clone(), "openai".to_string()));
345 }
346
347 if let Some(provider) = &state.anthropic_provider {
348 warn!("Using fallback provider: anthropic");
349 return Ok((provider.clone(), "anthropic".to_string()));
350 }
351
352 Err(ProxyError::InternalError(
353 "No providers configured".to_string(),
354 ))
355}
356
357fn calculate_cost(
359 provider: &Arc<dyn LLMProvider>,
360 model: &str,
361 response: &UnifiedResponse,
362) -> Option<f64> {
363 provider.get_pricing(model).map(|pricing| {
364 let input_cost = (response.usage.prompt_tokens as f64 / 1000.0) * pricing.input_cost_per_1k;
365 let output_cost =
366 (response.usage.completion_tokens as f64 / 1000.0) * pricing.output_cost_per_1k;
367 input_cost + output_cost
368 })
369}
370
371fn build_response_from_cache(
373 request: &ChatCompletionRequest,
374 cached: &llm_edge_cache::l1::CachedResponse,
375 cache_tier: &str,
376 latency_ms: u64,
377) -> ChatCompletionResponse {
378 ChatCompletionResponse {
379 id: format!("chatcmpl-{}", Uuid::new_v4()),
380 object: "chat.completion".to_string(),
381 created: chrono::Utc::now().timestamp(),
382 model: request.model.clone(),
383 choices: vec![ChatChoice {
384 index: 0,
385 message: ChatMessage {
386 role: "assistant".to_string(),
387 content: cached.content.clone(),
388 },
389 finish_reason: "stop".to_string(),
390 }],
391 usage: Usage {
392 prompt_tokens: cached.tokens.as_ref().map(|t| t.prompt_tokens).unwrap_or(0),
393 completion_tokens: cached
394 .tokens
395 .as_ref()
396 .map(|t| t.completion_tokens)
397 .unwrap_or(0),
398 total_tokens: cached.tokens.as_ref().map(|t| t.total_tokens).unwrap_or(0),
399 },
400 metadata: Some(ResponseMetadata {
401 provider: "cache".to_string(),
402 cached: true,
403 cache_tier: Some(cache_tier.to_string()),
404 latency_ms,
405 cost_usd: Some(0.0), }),
407 }
408}
409
410fn convert_provider_to_cache(response: &UnifiedResponse) -> llm_edge_cache::l1::CachedResponse {
412 let content = response
413 .choices
414 .first()
415 .map(|c| c.message.content.clone())
416 .unwrap_or_default();
417
418 llm_edge_cache::l1::CachedResponse {
419 content,
420 tokens: Some(llm_edge_cache::l1::TokenUsage {
421 prompt_tokens: response.usage.prompt_tokens as u32,
422 completion_tokens: response.usage.completion_tokens as u32,
423 total_tokens: response.usage.total_tokens as u32,
424 }),
425 model: response.model.clone(),
426 cached_at: chrono::Utc::now().timestamp(),
427 }
428}
429
430fn build_response_from_provider(
432 request: &ChatCompletionRequest,
433 provider_response: UnifiedResponse,
434 provider_name: &str,
435 latency_ms: u64,
436 cost_usd: Option<f64>,
437) -> ChatCompletionResponse {
438 ChatCompletionResponse {
439 id: provider_response.id,
440 object: "chat.completion".to_string(),
441 created: chrono::Utc::now().timestamp(),
442 model: request.model.clone(),
443 choices: provider_response
444 .choices
445 .into_iter()
446 .map(|c| ChatChoice {
447 index: c.index as u32,
448 message: ChatMessage {
449 role: c.message.role,
450 content: c.message.content,
451 },
452 finish_reason: c.finish_reason.unwrap_or_else(|| "stop".to_string()),
453 })
454 .collect(),
455 usage: Usage {
456 prompt_tokens: provider_response.usage.prompt_tokens as u32,
457 completion_tokens: provider_response.usage.completion_tokens as u32,
458 total_tokens: provider_response.usage.total_tokens as u32,
459 },
460 metadata: Some(ResponseMetadata {
461 provider: provider_name.to_string(),
462 cached: false,
463 cache_tier: None,
464 latency_ms,
465 cost_usd,
466 }),
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_validate_request_valid() {
476 let request = ChatCompletionRequest {
477 model: "gpt-4".to_string(),
478 messages: vec![ChatMessage {
479 role: "user".to_string(),
480 content: "Hello".to_string(),
481 }],
482 temperature: Some(0.7),
483 max_tokens: Some(100),
484 stream: false,
485 };
486
487 assert!(validate_request(&request).is_ok());
488 }
489
490 #[test]
491 fn test_validate_request_empty_model() {
492 let request = ChatCompletionRequest {
493 model: "".to_string(),
494 messages: vec![ChatMessage {
495 role: "user".to_string(),
496 content: "Hello".to_string(),
497 }],
498 temperature: None,
499 max_tokens: None,
500 stream: false,
501 };
502
503 assert!(validate_request(&request).is_err());
504 }
505
506 #[test]
507 fn test_validate_request_empty_messages() {
508 let request = ChatCompletionRequest {
509 model: "gpt-4".to_string(),
510 messages: vec![],
511 temperature: None,
512 max_tokens: None,
513 stream: false,
514 };
515
516 assert!(validate_request(&request).is_err());
517 }
518
519 #[test]
520 fn test_convert_to_cacheable() {
521 let request = ChatCompletionRequest {
522 model: "gpt-4".to_string(),
523 messages: vec![
524 ChatMessage {
525 role: "user".to_string(),
526 content: "Hello".to_string(),
527 },
528 ChatMessage {
529 role: "assistant".to_string(),
530 content: "Hi".to_string(),
531 },
532 ],
533 temperature: Some(0.7),
534 max_tokens: Some(100),
535 stream: false,
536 };
537
538 let cacheable = convert_to_cacheable(&request);
539 assert_eq!(cacheable.model, "gpt-4");
540 assert_eq!(cacheable.temperature, Some(0.7));
541 assert_eq!(cacheable.max_tokens, Some(100));
542 }
543}