roy_cli/
chat_completion.rs1use axum::{
5 extract::State,
6 http::{HeaderMap, StatusCode},
7 response::Json,
8};
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use crate::server_state::ServerState;
15
16#[derive(Deserialize)]
17pub struct ChatCompletionRequest {
18 pub messages: Option<Vec<Value>>,
19 pub model: Option<String>,
20 #[serde(flatten)]
21 pub _other: Value,
22}
23
24#[derive(Serialize)]
25pub struct ChatCompletionResponse {
26 pub id: String,
27 pub object: String,
28 pub created: u64,
29 pub model: String,
30 pub choices: Vec<Choice>,
31 pub usage: Usage,
32}
33
34#[derive(Serialize)]
35pub struct Choice {
36 pub index: u32,
37 pub message: Message,
38 pub finish_reason: String,
39}
40
41#[derive(Serialize)]
42pub struct Message {
43 pub role: String,
44 pub content: String,
45}
46
47#[derive(Serialize)]
48pub struct Usage {
49 pub prompt_tokens: u32,
50 pub completion_tokens: u32,
51 pub total_tokens: u32,
52}
53
54pub async fn chat_completions(
55 state: State<ServerState>,
56 Json(payload): Json<ChatCompletionRequest>,
57) -> Result<(HeaderMap, Json<Value>), (StatusCode, HeaderMap, Json<Value>)> {
58 if state.check_request_limit_exceeded() {
59 let headers = state.get_rate_limit_headers();
60 let error_body = json!({
61 "error": {
62 "message": "Too many requests",
63 "type": "rate_limit_error",
64 "code": "rate_limit_exceeded"
65 }
66 });
67 return Err((StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)));
68 }
69 state.increment_request_count();
70
71 if let Some(error_code) = state.should_return_error() {
72 let headers = state.get_rate_limit_headers();
73 let status_code =
74 StatusCode::from_u16(error_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
75
76 let error_body = json!({
77 "error": {
78 "message": format!("Simulated error with code {}", error_code),
79 "type": "api_error",
80 "code": error_code.to_string()
81 }
82 });
83
84 return Err((status_code, headers, Json(error_body)));
85 }
86
87 let response_length = state.get_response_length();
88
89 if response_length == 0 {
90 let headers = state.get_rate_limit_headers();
91 return Err((StatusCode::NO_CONTENT, headers, Json(json!({}))));
92 }
93
94 let content = state.generate_lorem_content(response_length);
95
96 let prompt_text = payload
97 .messages
98 .as_ref()
99 .map(|msgs| serde_json::to_string(msgs).unwrap_or_default())
100 .unwrap_or_default();
101
102 let prompt_tokens = state.count_tokens(&prompt_text).unwrap_or(0);
103 let completion_tokens = state.count_tokens(&content).unwrap_or(0);
104 let total_tokens = prompt_tokens + completion_tokens;
105
106 if state.check_token_limit_exceeded(total_tokens) {
107 let headers = state.get_rate_limit_headers();
108 let error_body = json!({
109 "error": {
110 "message": "You have exceeded your token quota.",
111 "type": "rate_limit_error",
112 "code": "rate_limit_exceeded"
113 }
114 });
115 return Err((StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)));
116 }
117 state.add_token_usage(total_tokens);
118
119 let response = ChatCompletionResponse {
120 id: format!("chatcmpl-{}", rand::thread_rng().gen::<u32>()),
121 object: "chat.completion".to_string(),
122 created: SystemTime::now()
123 .duration_since(UNIX_EPOCH)
124 .expect("should be able to get duration")
125 .as_secs(),
126 model: payload.model.unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
127 choices: vec![Choice {
128 index: 0,
129 message: Message {
130 role: "assistant".to_string(),
131 content,
132 },
133 finish_reason: "stop".to_string(),
134 }],
135 usage: Usage {
136 prompt_tokens,
137 completion_tokens,
138 total_tokens,
139 },
140 };
141
142 let headers = state.get_rate_limit_headers();
143 Ok((headers, Json(json!(response))))
144}