1use axum::{
4 http::StatusCode,
5 response::{IntoResponse, Json, Response},
6};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
12pub struct ChatCompletionRequest {
13 pub model: String,
14 pub messages: Vec<ChatMessage>,
15 #[serde(default)]
16 pub stream: bool,
17 #[serde(skip_serializing_if = "Option::is_none")]
18 pub temperature: Option<f32>,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub max_tokens: Option<u32>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub top_p: Option<f32>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub stop: Option<Vec<String>>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub presence_penalty: Option<f32>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub frequency_penalty: Option<f32>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub user: Option<String>,
31 #[serde(flatten)]
33 pub extra: HashMap<String, serde_json::Value>,
34}
35
36#[derive(Debug, Clone, Deserialize, Serialize)]
38pub struct ChatMessage {
39 pub role: String,
40 #[serde(flatten)]
41 pub content: MessageContent,
42 #[serde(skip_serializing_if = "Option::is_none", default)]
43 pub name: Option<String>,
44 #[serde(skip_serializing_if = "Option::is_none", default)]
46 pub function_call: Option<FunctionCall>,
47}
48
49#[derive(Debug, Clone, Deserialize, Serialize)]
51pub struct FunctionCall {
52 pub name: String,
53 pub arguments: String,
54}
55
56#[derive(Debug, Clone, Deserialize, Serialize)]
58#[serde(untagged)]
59pub enum MessageContent {
60 Text {
61 #[serde(
62 deserialize_with = "deserialize_nullable_string",
63 serialize_with = "serialize_empty_as_null"
64 )]
65 content: String,
66 },
67 Parts {
68 content: Vec<ContentPart>,
69 },
70}
71
72fn deserialize_nullable_string<'de, D>(deserializer: D) -> Result<String, D::Error>
74where
75 D: serde::Deserializer<'de>,
76{
77 let opt: Option<String> = Option::deserialize(deserializer)?;
78 Ok(opt.unwrap_or_default())
79}
80
81fn serialize_empty_as_null<S>(value: &str, serializer: S) -> Result<S::Ok, S::Error>
83where
84 S: serde::Serializer,
85{
86 if value.is_empty() {
87 serializer.serialize_none()
88 } else {
89 serializer.serialize_str(value)
90 }
91}
92
93#[derive(Debug, Clone, Deserialize, Serialize)]
95pub struct ContentPart {
96 #[serde(rename = "type")]
97 pub part_type: String,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub text: Option<String>,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub image_url: Option<ImageUrl>,
102}
103
104#[derive(Debug, Clone, Deserialize, Serialize)]
106pub struct ImageUrl {
107 pub url: String,
108}
109
110#[derive(Debug, Clone, Deserialize, Serialize)]
112pub struct Choice {
113 pub index: u32,
114 pub message: ChatMessage,
115 #[serde(skip_serializing_if = "Option::is_none")]
116 pub finish_reason: Option<String>,
117}
118
119#[derive(Debug, Clone, Deserialize, Serialize)]
121pub struct ChatCompletionResponse {
122 pub id: String,
123 pub object: String,
124 pub created: i64,
125 pub model: String,
126 pub choices: Vec<Choice>,
127 #[serde(skip_serializing_if = "Option::is_none")]
128 pub usage: Option<Usage>,
129 #[serde(flatten, default)]
131 pub extra: HashMap<String, serde_json::Value>,
132}
133
134#[derive(Debug, Clone, Deserialize, Serialize)]
136pub struct Usage {
137 pub prompt_tokens: u32,
138 pub completion_tokens: u32,
139 pub total_tokens: u32,
140}
141
142#[derive(Debug, Clone, Deserialize, Serialize)]
144pub struct ChatCompletionChunk {
145 pub id: String,
146 pub object: String,
147 pub created: i64,
148 pub model: String,
149 pub choices: Vec<ChunkChoice>,
150}
151
152#[derive(Debug, Clone, Deserialize, Serialize)]
154pub struct ChunkChoice {
155 pub index: u32,
156 pub delta: ChunkDelta,
157 #[serde(skip_serializing_if = "Option::is_none")]
158 pub finish_reason: Option<String>,
159}
160
161#[derive(Debug, Clone, Deserialize, Serialize)]
163pub struct ChunkDelta {
164 #[serde(skip_serializing_if = "Option::is_none")]
165 pub role: Option<String>,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 pub content: Option<String>,
168}
169
170#[derive(Debug, Clone, Deserialize, Serialize)]
172pub struct ApiError {
173 pub error: ApiErrorBody,
174}
175
176#[derive(Debug, Clone, Deserialize, Serialize)]
178pub struct ApiErrorBody {
179 pub message: String,
180 pub r#type: String,
181 pub param: Option<String>,
182 pub code: Option<String>,
183}
184
185impl ApiError {
186 pub fn bad_request(message: &str) -> Self {
188 Self {
189 error: ApiErrorBody {
190 message: message.to_string(),
191 r#type: "invalid_request_error".to_string(),
192 param: None,
193 code: Some("invalid_request_error".to_string()),
194 },
195 }
196 }
197
198 pub fn model_not_found(model: &str, available: &[String]) -> Self {
200 let hint = if available.is_empty() {
201 "No models available".to_string()
202 } else {
203 format!("Available: {}", available.join(", "))
204 };
205 Self {
206 error: ApiErrorBody {
207 message: format!("Model '{}' not found. {}", model, hint),
208 r#type: "invalid_request_error".to_string(),
209 param: Some("model".to_string()),
210 code: Some("model_not_found".to_string()),
211 },
212 }
213 }
214
215 pub fn bad_gateway(message: &str) -> Self {
217 Self {
218 error: ApiErrorBody {
219 message: message.to_string(),
220 r#type: "server_error".to_string(),
221 param: None,
222 code: Some("bad_gateway".to_string()),
223 },
224 }
225 }
226
227 pub fn gateway_timeout() -> Self {
229 Self {
230 error: ApiErrorBody {
231 message: "Backend request timed out".to_string(),
232 r#type: "server_error".to_string(),
233 param: None,
234 code: Some("gateway_timeout".to_string()),
235 },
236 }
237 }
238
239 pub fn service_unavailable(message: &str) -> Self {
241 Self {
242 error: ApiErrorBody {
243 message: message.to_string(),
244 r#type: "server_error".to_string(),
245 param: None,
246 code: Some("service_unavailable".to_string()),
247 },
248 }
249 }
250
251 pub fn from_backend_json(status_code: u16, json_body: String) -> Result<Self, Self> {
258 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&json_body) {
260 if parsed.get("error").is_some() {
261 return Ok(ApiError::from_raw_json(json_body));
264 }
265 }
266
267 Err(Self::bad_gateway(&format!(
269 "Backend returned {}: {}",
270 status_code, json_body
271 )))
272 }
273
274 fn from_raw_json(json: String) -> Self {
276 let value: serde_json::Value = serde_json::from_str(&json).unwrap();
278 let error_obj = value.get("error").unwrap();
279
280 Self {
281 error: serde_json::from_value(error_obj.clone()).unwrap(),
282 }
283 }
284
285 pub fn from_agent_error(error: crate::agent::AgentError) -> Self {
287 match error {
288 crate::agent::AgentError::Network(msg) => {
289 Self::bad_gateway(&format!("Network error: {}", msg))
290 }
291 crate::agent::AgentError::Timeout(_) => Self::gateway_timeout(),
292 crate::agent::AgentError::Upstream { status, message } => {
293 if status >= 500 {
294 Self::bad_gateway(&format!("Backend returned {}: {}", status, message))
295 } else if status == 404 {
296 Self {
297 error: ApiErrorBody {
298 message: format!("Backend returned 404: {}", message),
299 r#type: "invalid_request_error".to_string(),
300 param: None,
301 code: Some("not_found".to_string()),
302 },
303 }
304 } else {
305 Self::bad_request(&format!("Backend returned {}: {}", status, message))
306 }
307 }
308 crate::agent::AgentError::InvalidResponse(msg) => {
309 Self::bad_gateway(&format!("Invalid backend response: {}", msg))
310 }
311 crate::agent::AgentError::Unsupported(msg) => {
312 Self::service_unavailable(&format!("Feature not supported: {}", msg))
313 }
314 crate::agent::AgentError::Configuration(msg) => {
315 Self::bad_gateway(&format!("Backend configuration error: {}", msg))
316 }
317 }
318 }
319
320 fn status_code(&self) -> StatusCode {
322 match self.error.code.as_deref() {
323 Some("invalid_request_error") => StatusCode::BAD_REQUEST,
324 Some("model_not_found") => StatusCode::NOT_FOUND,
325 Some("bad_gateway") => StatusCode::BAD_GATEWAY,
326 Some("gateway_timeout") => StatusCode::GATEWAY_TIMEOUT,
327 Some("service_unavailable") => StatusCode::SERVICE_UNAVAILABLE,
328 _ => StatusCode::INTERNAL_SERVER_ERROR,
329 }
330 }
331}
332
333impl IntoResponse for ApiError {
334 fn into_response(self) -> Response {
335 (self.status_code(), Json(self)).into_response()
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use serde_json::json;
343
344 #[test]
345 fn test_chat_message_deserialize_text() {
346 let json = json!({"role": "user", "content": "Hello"});
347 let msg: ChatMessage = serde_json::from_value(json).unwrap();
348 assert_eq!(msg.role, "user");
349 if let MessageContent::Text { content } = msg.content {
350 assert_eq!(content, "Hello");
351 } else {
352 panic!("Expected text content");
353 }
354 }
355
356 #[test]
357 fn test_chat_message_deserialize_multimodal() {
358 let json = json!({
359 "role": "user",
360 "content": [
361 {"type": "text", "text": "What's in this image?"},
362 {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
363 ]
364 });
365 let msg: ChatMessage = serde_json::from_value(json).unwrap();
366 assert_eq!(msg.role, "user");
367 if let MessageContent::Parts { content } = msg.content {
368 assert_eq!(content.len(), 2);
369 assert_eq!(content[0].part_type, "text");
370 } else {
371 panic!("Expected parts content");
372 }
373 }
374
375 #[test]
376 fn test_chat_request_deserialize_minimal() {
377 let json = json!({
378 "model": "llama3:70b",
379 "messages": [{"role": "user", "content": "Hi"}]
380 });
381 let req: ChatCompletionRequest = serde_json::from_value(json).unwrap();
382 assert_eq!(req.model, "llama3:70b");
383 assert!(!req.stream); }
385
386 #[test]
387 fn test_chat_request_deserialize_full() {
388 let json = json!({
389 "model": "llama3:70b",
390 "messages": [{"role": "user", "content": "Hi"}],
391 "stream": true,
392 "temperature": 0.7,
393 "max_tokens": 1000,
394 "top_p": 0.9
395 });
396 let req: ChatCompletionRequest = serde_json::from_value(json).unwrap();
397 assert!(req.stream);
398 assert_eq!(req.temperature, Some(0.7));
399 assert_eq!(req.max_tokens, Some(1000));
400 assert_eq!(req.top_p, Some(0.9));
401 }
402
403 #[test]
404 fn test_chat_request_stream_default_false() {
405 let json = json!({
406 "model": "test",
407 "messages": []
408 });
409 let req: ChatCompletionRequest = serde_json::from_value(json).unwrap();
410 assert!(!req.stream);
411 }
412
413 #[test]
414 fn test_chat_response_serialize() {
415 let response = ChatCompletionResponse {
416 id: "chatcmpl-123".to_string(),
417 object: "chat.completion".to_string(),
418 created: 1699999999,
419 model: "llama3:70b".to_string(),
420 choices: vec![],
421 usage: None,
422 extra: HashMap::new(),
423 };
424 let json = serde_json::to_value(&response).unwrap();
425 assert_eq!(json["object"], "chat.completion");
426 assert_eq!(json["id"], "chatcmpl-123");
427 assert_eq!(json["model"], "llama3:70b");
428 }
429
430 #[test]
431 fn test_chat_chunk_serialize() {
432 let chunk = ChatCompletionChunk {
433 id: "chatcmpl-123".to_string(),
434 object: "chat.completion.chunk".to_string(),
435 created: 1699999999,
436 model: "llama3:70b".to_string(),
437 choices: vec![],
438 };
439 let json = serde_json::to_value(&chunk).unwrap();
440 assert_eq!(json["object"], "chat.completion.chunk");
441 assert_eq!(json["id"], "chatcmpl-123");
442 }
443
444 #[test]
445 fn test_usage_serialize() {
446 let usage = Usage {
447 prompt_tokens: 10,
448 completion_tokens: 20,
449 total_tokens: 30,
450 };
451 let json = serde_json::to_value(&usage).unwrap();
452 assert_eq!(json["prompt_tokens"], 10);
453 assert_eq!(json["completion_tokens"], 20);
454 assert_eq!(json["total_tokens"], 30);
455 }
456
457 #[test]
458 fn test_api_error_serialize() {
459 let error = ApiError {
460 error: ApiErrorBody {
461 message: "Test error".to_string(),
462 r#type: "invalid_request_error".to_string(),
463 param: Some("model".to_string()),
464 code: Some("model_not_found".to_string()),
465 },
466 };
467 let json = serde_json::to_value(&error).unwrap();
468 assert_eq!(json["error"]["message"], "Test error");
469 assert_eq!(json["error"]["type"], "invalid_request_error");
470 assert_eq!(json["error"]["code"], "model_not_found");
471 }
472
473 #[test]
474 fn test_choice_serialize() {
475 let choice = Choice {
476 index: 0,
477 message: ChatMessage {
478 role: "assistant".to_string(),
479 content: MessageContent::Text {
480 content: "Hello!".to_string(),
481 },
482 name: None,
483 function_call: None,
484 },
485 finish_reason: Some("stop".to_string()),
486 };
487 let json = serde_json::to_value(&choice).unwrap();
488 assert_eq!(json["index"], 0);
489 assert_eq!(json["finish_reason"], "stop");
490 }
491
492 #[test]
493 fn test_chunk_delta_serialize() {
494 let delta = ChunkDelta {
495 role: Some("assistant".to_string()),
496 content: Some("Hello".to_string()),
497 };
498 let json = serde_json::to_value(&delta).unwrap();
499 assert_eq!(json["role"], "assistant");
500 assert_eq!(json["content"], "Hello");
501 }
502
503 #[test]
504 fn test_api_error_serialize_400() {
505 let error = ApiError::bad_request("Invalid JSON");
506 let json = serde_json::to_value(&error).unwrap();
507 assert_eq!(json["error"]["code"], "invalid_request_error");
508 assert_eq!(json["error"]["message"], "Invalid JSON");
509 }
510
511 #[test]
512 fn test_api_error_serialize_404() {
513 let error = ApiError::model_not_found(
514 "gpt-4",
515 &["llama3:70b".to_string(), "mistral:7b".to_string()],
516 );
517 let json = serde_json::to_value(&error).unwrap();
518 assert_eq!(json["error"]["code"], "model_not_found");
519 assert!(json["error"]["message"].as_str().unwrap().contains("gpt-4"));
520 assert!(json["error"]["message"]
521 .as_str()
522 .unwrap()
523 .contains("llama3:70b"));
524 }
525
526 #[test]
527 fn test_api_error_serialize_502() {
528 let error = ApiError::bad_gateway("Connection refused");
529 let json = serde_json::to_value(&error).unwrap();
530 assert_eq!(json["error"]["code"], "bad_gateway");
531 assert_eq!(json["error"]["message"], "Connection refused");
532 }
533
534 #[test]
535 fn test_api_error_into_response() {
536 let error = ApiError::service_unavailable("No backends");
538 let response = error.into_response();
539 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
540 }
541
542 #[test]
543 fn test_api_error_model_not_found_empty_available() {
544 let error = ApiError::model_not_found("gpt-4", &[]);
545 let json = serde_json::to_value(&error).unwrap();
546 assert!(json["error"]["message"]
547 .as_str()
548 .unwrap()
549 .contains("No models available"));
550 }
551
552 #[test]
553 fn test_api_error_gateway_timeout() {
554 let error = ApiError::gateway_timeout();
555 let json = serde_json::to_value(&error).unwrap();
556 assert_eq!(json["error"]["code"], "gateway_timeout");
557 assert!(json["error"]["message"]
558 .as_str()
559 .unwrap()
560 .contains("timed out"));
561 }
562
563 #[test]
564 fn test_api_error_status_codes() {
565 assert_eq!(
566 ApiError::bad_request("x").into_response().status(),
567 StatusCode::BAD_REQUEST
568 );
569 assert_eq!(
570 ApiError::model_not_found("x", &[]).into_response().status(),
571 StatusCode::NOT_FOUND
572 );
573 assert_eq!(
574 ApiError::bad_gateway("x").into_response().status(),
575 StatusCode::BAD_GATEWAY
576 );
577 assert_eq!(
578 ApiError::gateway_timeout().into_response().status(),
579 StatusCode::GATEWAY_TIMEOUT
580 );
581 assert_eq!(
582 ApiError::service_unavailable("x").into_response().status(),
583 StatusCode::SERVICE_UNAVAILABLE
584 );
585 }
586
587 #[test]
588 fn test_api_error_unknown_code_returns_500() {
589 let error = ApiError {
590 error: ApiErrorBody {
591 message: "Unknown".to_string(),
592 r#type: "server_error".to_string(),
593 param: None,
594 code: Some("unknown_code".to_string()),
595 },
596 };
597 assert_eq!(
598 error.into_response().status(),
599 StatusCode::INTERNAL_SERVER_ERROR
600 );
601 }
602
603 #[test]
604 fn test_api_error_no_code_returns_500() {
605 let error = ApiError {
606 error: ApiErrorBody {
607 message: "Unknown".to_string(),
608 r#type: "server_error".to_string(),
609 param: None,
610 code: None,
611 },
612 };
613 assert_eq!(
614 error.into_response().status(),
615 StatusCode::INTERNAL_SERVER_ERROR
616 );
617 }
618
619 #[test]
620 fn test_from_backend_json_valid_error() {
621 let json_body =
622 r#"{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","param":null,"code":"rate_limit"}}"#
623 .to_string();
624 let result = ApiError::from_backend_json(429, json_body);
625 assert!(result.is_ok());
626 let error = result.unwrap();
627 assert_eq!(error.error.message, "Rate limit exceeded");
628 }
629
630 #[test]
631 fn test_from_backend_json_invalid_json() {
632 let json_body = "this is not json at all".to_string();
633 let result = ApiError::from_backend_json(500, json_body);
634 assert!(result.is_err());
635 let error = result.unwrap_err();
636 assert_eq!(error.error.code.as_deref(), Some("bad_gateway"));
637 assert!(error.error.message.contains("this is not json at all"));
638 }
639
640 #[test]
641 fn test_from_backend_json_empty_string() {
642 let result = ApiError::from_backend_json(500, String::new());
643 assert!(result.is_err());
644 let error = result.unwrap_err();
645 assert_eq!(error.error.code.as_deref(), Some("bad_gateway"));
646 }
647
648 #[test]
649 fn test_from_agent_error_network() {
650 let agent_err = crate::agent::AgentError::Network("connection refused".to_string());
651 let api_err = ApiError::from_agent_error(agent_err);
652 assert_eq!(api_err.error.code.as_deref(), Some("bad_gateway"));
653 assert!(api_err.error.message.contains("Network error"));
654 }
655
656 #[test]
657 fn test_from_agent_error_timeout() {
658 let agent_err = crate::agent::AgentError::Timeout(5000);
659 let api_err = ApiError::from_agent_error(agent_err);
660 assert_eq!(api_err.error.code.as_deref(), Some("gateway_timeout"));
661 assert_eq!(
662 api_err.into_response().status(),
663 StatusCode::GATEWAY_TIMEOUT
664 );
665 }
666
667 #[test]
668 fn test_from_agent_error_upstream_5xx() {
669 let agent_err = crate::agent::AgentError::Upstream {
670 status: 503,
671 message: "Service Unavailable".to_string(),
672 };
673 let api_err = ApiError::from_agent_error(agent_err);
674 assert_eq!(api_err.error.code.as_deref(), Some("bad_gateway"));
675 assert!(api_err.error.message.contains("503"));
676 }
677
678 #[test]
679 fn test_from_agent_error_upstream_404() {
680 let agent_err = crate::agent::AgentError::Upstream {
681 status: 404,
682 message: "Model not found".to_string(),
683 };
684 let api_err = ApiError::from_agent_error(agent_err);
685 assert_eq!(api_err.error.code.as_deref(), Some("not_found"));
686 assert_eq!(
687 api_err.into_response().status(),
688 StatusCode::INTERNAL_SERVER_ERROR
689 );
690 }
691
692 #[test]
693 fn test_from_agent_error_invalid_response() {
694 let agent_err = crate::agent::AgentError::InvalidResponse("malformed JSON".to_string());
695 let api_err = ApiError::from_agent_error(agent_err);
696 assert_eq!(api_err.error.code.as_deref(), Some("bad_gateway"));
697 assert!(api_err.error.message.contains("Invalid backend response"));
698 }
699
700 #[test]
701 fn test_from_agent_error_upstream_4xx_not_404() {
702 let agent_err = crate::agent::AgentError::Upstream {
703 status: 429,
704 message: "Rate limit exceeded".to_string(),
705 };
706 let api_err = ApiError::from_agent_error(agent_err);
707 assert_eq!(api_err.error.code.as_deref(), Some("invalid_request_error"));
708 assert!(api_err.error.message.contains("429"));
709 assert_eq!(api_err.into_response().status(), StatusCode::BAD_REQUEST);
710 }
711
712 #[test]
713 fn test_from_agent_error_unsupported() {
714 let agent_err = crate::agent::AgentError::Unsupported("embeddings");
715 let api_err = ApiError::from_agent_error(agent_err);
716 assert_eq!(api_err.error.code.as_deref(), Some("service_unavailable"));
717 assert!(api_err.error.message.contains("not supported"));
718 assert_eq!(
719 api_err.into_response().status(),
720 StatusCode::SERVICE_UNAVAILABLE
721 );
722 }
723
724 #[test]
725 fn test_from_agent_error_configuration() {
726 let agent_err = crate::agent::AgentError::Configuration("missing API key".to_string());
727 let api_err = ApiError::from_agent_error(agent_err);
728 assert_eq!(api_err.error.code.as_deref(), Some("bad_gateway"));
729 assert!(api_err.error.message.contains("configuration error"));
730 assert_eq!(api_err.into_response().status(), StatusCode::BAD_GATEWAY);
731 }
732}