1use axum::{
5 extract::State,
6 http::StatusCode,
7 response::{sse::Event, IntoResponse, Sse},
8 Json,
9};
10use futures_util::stream;
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::time::{SystemTime, UNIX_EPOCH};
16
17use crate::server_state::ServerState;
18
19#[derive(Serialize, Debug)]
20pub struct Usage {
21 pub prompt_tokens: u32,
22 pub completion_tokens: u32,
23 pub total_tokens: u32,
24}
25
26#[derive(Deserialize)]
27pub struct ChatCompletionRequest {
28 pub messages: Option<Vec<Value>>,
29 pub model: Option<String>,
30 #[serde(default)]
31 pub stream: Option<bool>,
32 #[serde(flatten)]
33 pub _other: Value,
34}
35
36#[derive(Serialize)]
37pub struct ChatCompletionResponse {
38 pub id: String,
39 pub object: String,
40 pub created: u64,
41 pub model: String,
42 pub choices: Vec<Choice>,
43 pub usage: Usage,
44}
45
46#[derive(Serialize, Debug)]
47pub struct ChatCompletionChunk {
48 pub id: String,
49 pub object: String,
50 pub created: u64,
51 pub model: String,
52 pub choices: Vec<ChunkChoice>,
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub usage: Option<Usage>,
55}
56
57#[derive(Serialize, Debug)]
58pub struct ChunkChoice {
59 pub index: u32,
60 pub delta: ChoiceDelta,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub finish_reason: Option<String>,
63}
64
65#[derive(Serialize, Debug, Default)]
66pub struct ChoiceDelta {
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub role: Option<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub content: Option<String>,
71}
72
73#[derive(Serialize)]
74pub struct Choice {
75 pub index: u32,
76 pub message: Message,
77 pub finish_reason: String,
78}
79
80#[derive(Serialize)]
81pub struct Message {
82 pub role: String,
83 pub content: String,
84}
85
86pub async fn chat_completions(
87 state: State<ServerState>,
88 Json(payload): Json<ChatCompletionRequest>,
89) -> impl IntoResponse {
90 if state.check_request_limit_exceeded() {
91 let headers = state.get_rate_limit_headers();
92 let error_body = json!({
93 "error": {
94 "message": "Too many requests",
95 "type": "rate_limit_error",
96 "code": "rate_limit_exceeded"
97 }
98 });
99 return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
100 }
101 state.increment_request_count();
102
103 if let Some(error_code) = state.should_return_error() {
104 let headers = state.get_rate_limit_headers();
105 let status_code =
106 StatusCode::from_u16(error_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
107
108 let error_body = json!({
109 "error": {
110 "message": format!("Simulated error with code {}", error_code),
111 "type": "api_error",
112 "code": error_code.to_string()
113 }
114 });
115
116 return (status_code, headers, Json(error_body)).into_response();
117 }
118
119 let response_length = state.get_response_length();
120
121 if response_length == 0 {
122 let headers = state.get_rate_limit_headers();
123 return (StatusCode::NO_CONTENT, headers, Json(json!({}))).into_response();
124 }
125
126 let content = state.generate_lorem_content(response_length);
127
128 let prompt_text = payload
129 .messages
130 .as_ref()
131 .map(|msgs| serde_json::to_string(msgs).unwrap_or_default())
132 .unwrap_or_default();
133
134 let prompt_tokens = state.count_tokens(&prompt_text).unwrap_or(0);
135 let completion_tokens = state.count_tokens(&content).unwrap_or(0);
136 let total_tokens = prompt_tokens + completion_tokens;
137
138 if state.check_token_limit_exceeded(total_tokens) {
139 let headers = state.get_rate_limit_headers();
140 let error_body = json!({
141 "error": {
142 "message": "You have exceeded your token quota.",
143 "type": "rate_limit_error",
144 "code": "rate_limit_exceeded"
145 }
146 });
147 return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
148 }
149 state.add_token_usage(total_tokens);
150
151 let stream_response = payload.stream.unwrap_or(false);
152 if stream_response {
153 let id = format!("chatcmpl-{}", rand::thread_rng().gen::<u32>());
154 let created = SystemTime::now()
155 .duration_since(UNIX_EPOCH)
156 .expect("should be able to get duration")
157 .as_secs();
158 let model = payload
159 .model
160 .clone()
161 .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
162 let words = content
163 .split_whitespace()
164 .map(|s| s.to_string())
165 .collect::<Vec<_>>();
166
167 let mut events = vec![];
168
169 let first_chunk = ChatCompletionChunk {
171 id: id.clone(),
172 object: "chat.completion.chunk".to_string(),
173 created,
174 model: model.clone(),
175 choices: vec![ChunkChoice {
176 index: 0,
177 delta: ChoiceDelta {
178 role: Some("assistant".to_string()),
179 content: None,
180 },
181 finish_reason: None,
182 }],
183 usage: None,
184 };
185 events.push(Ok::<_, Infallible>(
186 Event::default().data(serde_json::to_string(&first_chunk).unwrap()),
187 ));
188
189 for word in words {
191 let chunk = ChatCompletionChunk {
192 id: id.clone(),
193 object: "chat.completion.chunk".to_string(),
194 created,
195 model: model.clone(),
196 choices: vec![ChunkChoice {
197 index: 0,
198 delta: ChoiceDelta {
199 role: None,
200 content: Some(format!("{} ", word)),
201 },
202 finish_reason: None,
203 }],
204 usage: None,
205 };
206 events.push(Ok(
207 Event::default().data(serde_json::to_string(&chunk).unwrap())
208 ));
209 }
210
211 let final_chunk = ChatCompletionChunk {
213 id: id.clone(),
214 object: "chat.completion.chunk".to_string(),
215 created,
216 model: model.clone(),
217 choices: vec![ChunkChoice {
218 index: 0,
219 delta: Default::default(),
220 finish_reason: Some("stop".to_string()),
221 }],
222 usage: Some(Usage {
223 prompt_tokens,
224 completion_tokens,
225 total_tokens,
226 }),
227 };
228 events.push(Ok(
229 Event::default().data(serde_json::to_string(&final_chunk).unwrap())
230 ));
231
232 events.push(Ok(Event::default().data("[DONE]")));
234
235 let stream = stream::iter(events);
236
237 return Sse::new(stream).into_response();
238 }
239
240 let response = ChatCompletionResponse {
241 id: format!("chatcmpl-{}", rand::thread_rng().gen::<u32>()),
242 object: "chat.completion".to_string(),
243 created: SystemTime::now()
244 .duration_since(UNIX_EPOCH)
245 .expect("should be able to get duration")
246 .as_secs(),
247 model: payload.model.unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
248 choices: vec![Choice {
249 index: 0,
250 message: Message {
251 role: "assistant".to_string(),
252 content,
253 },
254 finish_reason: "stop".to_string(),
255 }],
256 usage: Usage {
257 prompt_tokens,
258 completion_tokens,
259 total_tokens,
260 },
261 };
262
263 let headers = state.get_rate_limit_headers();
264 (headers, Json(json!(response))).into_response()
265}