1use crate::server_state::ServerState;
5use axum::{
6 extract::State,
7 http::StatusCode,
8 response::{sse::Event, IntoResponse, Json, Sse},
9};
10use rand::distributions::Alphanumeric;
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16use tokio::time::sleep;
17
18#[derive(Deserialize)]
19pub struct ResponsesRequest {
20 pub model: Option<String>,
21 pub input: Option<String>,
22 pub instructions: Option<String>,
23 pub stream: Option<bool>,
24 #[serde(flatten)]
25 pub _other: Value,
26}
27
28fn generate_id(prefix: &str) -> String {
30 format!("{}_{:x}", prefix, rand::thread_rng().gen::<u128>())
31}
32
33#[derive(Serialize, Clone, Default)]
36struct ResponseFormatText {
37 #[serde(rename = "type")]
38 _type: String,
39}
40
41#[derive(Serialize, Clone, Default)]
42struct ResponseTextConfig {
43 format: ResponseFormatText,
44 verbosity: String,
45}
46
47#[derive(Serialize, Clone, Default)]
48struct Reasoning {
49 effort: String,
50 generate_summary: Option<bool>,
51 summary: Option<String>,
52}
53
54#[derive(Serialize, Clone, Debug)]
55struct ResponseReasoningItem {
56 id: String,
57 #[serde(rename = "type")]
58 _type: String,
59 summary: Vec<Value>,
60 content: Option<Value>,
61 encrypted_content: Option<Value>,
62 status: Option<String>,
63}
64
65#[derive(Serialize, Clone, Debug, Default)]
66struct ResponseOutputText {
67 #[serde(rename = "type")]
68 _type: String,
69 text: String,
70 annotations: Vec<Value>,
71 logprobs: Vec<Value>,
72}
73
74#[derive(Serialize, Clone, Debug)]
75struct ResponseOutputMessage {
76 id: String,
77 #[serde(rename = "type")]
78 _type: String,
79 content: Vec<ResponseOutputText>,
80 role: String,
81 status: String,
82}
83
84#[derive(Serialize, Clone, Debug)]
85#[serde(untagged)]
86enum ResponseOutputItem {
87 Reasoning(ResponseReasoningItem),
88 Message(ResponseOutputMessage),
89}
90
91#[derive(Serialize, Clone)]
92struct InputTokensDetails {
93 cached_tokens: u32,
94}
95
96#[derive(Serialize, Clone)]
97struct OutputTokensDetails {
98 reasoning_tokens: u32,
99}
100
101#[derive(Serialize, Clone)]
102struct ResponseUsage {
103 input_tokens: u32,
104 input_tokens_details: InputTokensDetails,
105 output_tokens: u32,
106 output_tokens_details: OutputTokensDetails,
107 total_tokens: u32,
108}
109
110#[derive(Serialize, Clone, Default)]
111struct Response {
112 id: String,
113 created_at: f64,
114 error: Option<Value>,
115 incomplete_details: Option<Value>,
116 instructions: Option<String>,
117 metadata: Value,
118 model: String,
119 object: String,
120 output: Vec<ResponseOutputItem>,
121 parallel_tool_calls: bool,
122 temperature: f32,
123 tool_choice: String,
124 tools: Vec<Value>,
125 top_p: f32,
126 background: bool,
127 max_output_tokens: Option<u32>,
128 max_tool_calls: Option<u32>,
129 previous_response_id: Option<String>,
130 prompt: Option<String>,
131 prompt_cache_key: Option<String>,
132 reasoning: Reasoning,
133 safety_identifier: Option<String>,
134 service_tier: String,
135 status: String,
136 text: ResponseTextConfig,
137 top_logprobs: u32,
138 truncation: String,
139 #[serde(skip_serializing_if = "Option::is_none")]
140 usage: Option<ResponseUsage>,
141 user: Option<String>,
142 store: bool,
143}
144
145#[derive(Serialize)]
148struct ResponseEvent {
149 #[serde(rename = "type")]
150 _type: String,
151 sequence_number: u32,
152 response: Response,
153}
154
155#[derive(Serialize)]
156#[serde(untagged)]
157enum OutputItem {
158 Reasoning(ResponseReasoningItem),
159 Message(ResponseOutputMessage),
160}
161
162#[derive(Serialize)]
163struct ResponseOutputItemAddedEvent {
164 #[serde(rename = "type")]
165 _type: String,
166 sequence_number: u32,
167 output_index: u32,
168 item: OutputItem,
169}
170
171#[derive(Serialize)]
172struct ResponseOutputItemDoneEvent {
173 #[serde(rename = "type")]
174 _type: String,
175 sequence_number: u32,
176 output_index: u32,
177 item: OutputItem,
178}
179
180#[derive(Serialize)]
181struct ResponseContentPartAddedEvent {
182 #[serde(rename = "type")]
183 _type: String,
184 sequence_number: u32,
185 output_index: u32,
186 item_id: String,
187 content_index: u32,
188 part: ResponseOutputText,
189}
190
191#[derive(Serialize)]
192struct ResponseTextDeltaEvent {
193 #[serde(rename = "type")]
194 _type: String,
195 sequence_number: u32,
196 output_index: u32,
197 item_id: String,
198 content_index: u32,
199 delta: String,
200 logprobs: Vec<Value>,
201 obfuscation: String,
202}
203
204#[derive(Serialize)]
205struct ResponseTextDoneEvent {
206 #[serde(rename = "type")]
207 _type: String,
208 sequence_number: u32,
209 output_index: u32,
210 item_id: String,
211 content_index: u32,
212 text: String,
213 logprobs: Vec<Value>,
214}
215
216#[derive(Serialize)]
217struct ResponseContentPartDoneEvent {
218 #[serde(rename = "type")]
219 _type: String,
220 sequence_number: u32,
221 output_index: u32,
222 item_id: String,
223 content_index: u32,
224 part: ResponseOutputText,
225}
226
227pub async fn responses(
228 state: State<ServerState>,
229 Json(payload): Json<ResponsesRequest>,
230) -> impl IntoResponse {
231 if state.check_request_limit_exceeded() {
232 let headers = state.get_rate_limit_headers();
233 let error_body = json!({
234 "error": {
235 "message": "Too many requests",
236 "type": "rate_limit_error",
237 "code": "rate_limit_exceeded"
238 }
239 });
240 return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
241 }
242 state.increment_request_count();
243
244 if let Some(error_code) = state.should_return_error() {
245 let headers = state.get_rate_limit_headers();
246 let status_code =
247 StatusCode::from_u16(error_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
248
249 let error_body = json!({
250 "error": {
251 "message": format!("Simulated error with code {}", error_code),
252 "type": "api_error",
253 "code": error_code.to_string()
254 }
255 });
256
257 return (status_code, headers, Json(error_body)).into_response();
258 }
259
260 let response_length = state.get_response_length();
261
262 if response_length == 0 {
263 let headers = state.get_rate_limit_headers();
264 return (StatusCode::NO_CONTENT, headers, Json(json!({}))).into_response();
265 }
266
267 let content = state.generate_lorem_content(response_length);
268
269 let prompt_text = payload.input.clone().unwrap_or_else(|| "".to_string());
270 let prompt_tokens = state.count_tokens(&prompt_text).unwrap_or(0) as u32;
271 let completion_tokens = state.count_tokens(&content).unwrap_or(0) as u32;
272 let total_tokens = prompt_tokens + completion_tokens;
273
274 if state.check_token_limit_exceeded(total_tokens) {
275 let headers = state.get_rate_limit_headers();
276 let error_body = json!({
277 "error": {
278 "message": "You have exceeded your token quota.",
279 "type": "rate_limit_error",
280 "code": "rate_limit_exceeded"
281 }
282 });
283 return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
284 }
285 state.add_token_usage(total_tokens);
286
287 let headers = state.get_rate_limit_headers();
288 let model = payload
289 .model
290 .clone()
291 .unwrap_or_else(|| "gpt-5-2025-08-07".to_string());
292 let response_id = generate_id("resp");
293 let message_id = generate_id("msg");
294 let created_at = SystemTime::now()
295 .duration_since(UNIX_EPOCH)
296 .expect("should be able to get duration")
297 .as_secs_f64();
298
299 let stream_response = payload.stream.unwrap_or(false);
300 if stream_response {
301 let reasoning_item_id = generate_id("rs");
302 let stream = async_stream::stream! {
303 let mut sequence_number = 0;
304 let mut response = Response {
305 id: response_id.clone(),
306 object: "response".to_string(),
307 created_at,
308 model: model.clone(),
309 status: "in_progress".to_string(),
310 instructions: payload.instructions.clone(),
311 parallel_tool_calls: true,
312 temperature: 1.0,
313 tool_choice: "auto".to_string(),
314 top_p: 1.0,
315 background: false,
316 reasoning: Reasoning {
317 effort: "medium".to_string(),
318 ..Default::default()
319 },
320 service_tier: "auto".to_string(),
321 text: ResponseTextConfig {
322 format: ResponseFormatText { _type: "text".to_string() },
323 verbosity: "medium".to_string(),
324 },
325 top_logprobs: 0,
326 truncation: "disabled".to_string(),
327 store: false,
328 ..Default::default()
329 };
330
331 let created_event = ResponseEvent {
333 _type: "response.created".to_string(),
334 sequence_number,
335 response: response.clone(),
336 };
337 yield Ok::<_, Infallible>(Event::default().event("response.created").data(serde_json::to_string(&created_event).unwrap()));
338 sequence_number += 1;
339
340 let in_progress_event = ResponseEvent {
342 _type: "response.in_progress".to_string(),
343 sequence_number,
344 response: response.clone(),
345 };
346 yield Ok::<_, Infallible>(Event::default().event("response.in_progress").data(serde_json::to_string(&in_progress_event).unwrap()));
347 sequence_number += 1;
348
349 let reasoning_item = ResponseReasoningItem {
351 id: reasoning_item_id.clone(),
352 _type: "reasoning".to_string(),
353 summary: vec![],
354 content: None,
355 encrypted_content: None,
356 status: None,
357 };
358 let output_item_added_event = ResponseOutputItemAddedEvent {
359 _type: "response.output_item.added".to_string(),
360 sequence_number,
361 output_index: 0,
362 item: OutputItem::Reasoning(reasoning_item.clone()),
363 };
364 response.output.push(ResponseOutputItem::Reasoning(reasoning_item.clone()));
365 yield Ok::<_, Infallible>(Event::default().event("response.output_item.added").data(serde_json::to_string(&output_item_added_event).unwrap()));
366 sequence_number += 1;
367
368 let output_item_done_event = ResponseOutputItemDoneEvent {
370 _type: "response.output_item.done".to_string(),
371 sequence_number,
372 output_index: 0,
373 item: OutputItem::Reasoning(reasoning_item.clone()),
374 };
375 yield Ok::<_, Infallible>(Event::default().event("response.output_item.done").data(serde_json::to_string(&output_item_done_event).unwrap()));
376 sequence_number += 1;
377
378 let message_item = ResponseOutputMessage {
380 id: message_id.clone(),
381 _type: "message".to_string(),
382 content: vec![],
383 role: "assistant".to_string(),
384 status: "in_progress".to_string(),
385 };
386 let output_item_added_event = ResponseOutputItemAddedEvent {
387 _type: "response.output_item.added".to_string(),
388 sequence_number,
389 output_index: 1,
390 item: OutputItem::Message(message_item.clone()),
391 };
392 response.output.push(ResponseOutputItem::Message(message_item.clone()));
393 yield Ok::<_, Infallible>(Event::default().event("response.output_item.added").data(serde_json::to_string(&output_item_added_event).unwrap()));
394 sequence_number += 1;
395
396 let part = ResponseOutputText {
398 _type: "output_text".to_string(),
399 text: "".to_string(),
400 annotations: vec![],
401 logprobs: vec![],
402 };
403 let content_part_added_event = ResponseContentPartAddedEvent {
404 _type: "response.content_part.added".to_string(),
405 sequence_number,
406 output_index: 1,
407 item_id: message_id.clone(),
408 content_index: 0,
409 part: part.clone(),
410 };
411 if let Some(ResponseOutputItem::Message(msg)) = response.output.get_mut(1) {
412 msg.content.push(part);
413 }
414 yield Ok::<_, Infallible>(Event::default().event("response.content_part.added").data(serde_json::to_string(&content_part_added_event).unwrap()));
415 sequence_number += 1;
416
417 let chunks = content.as_bytes().chunks(5);
419 for chunk in chunks {
420 let delta = String::from_utf8_lossy(chunk).to_string();
421 let obfuscation: String = rand::thread_rng()
422 .sample_iter(&Alphanumeric)
423 .take(10)
424 .map(char::from)
425 .collect();
426 let delta_event = ResponseTextDeltaEvent {
427 _type: "response.output_text.delta".to_string(),
428 sequence_number,
429 output_index: 1,
430 item_id: message_id.clone(),
431 content_index: 0,
432 delta,
433 logprobs: vec![],
434 obfuscation,
435 };
436 yield Ok::<_, Infallible>(Event::default().event("response.output_text.delta").data(serde_json::to_string(&delta_event).unwrap()));
437 sequence_number += 1;
438 sleep(Duration::from_millis(10)).await;
439 }
440
441 let text_done_event = ResponseTextDoneEvent {
443 _type: "response.output_text.done".to_string(),
444 sequence_number,
445 output_index: 1,
446 item_id: message_id.clone(),
447 content_index: 0,
448 text: content.clone(),
449 logprobs: vec![],
450 };
451 if let Some(ResponseOutputItem::Message(msg)) = response.output.get_mut(1) {
452 if let Some(p) = msg.content.get_mut(0) {
453 p.text = content.clone();
454 }
455 }
456 yield Ok::<_, Infallible>(Event::default().event("response.output_text.done").data(serde_json::to_string(&text_done_event).unwrap()));
457 sequence_number += 1;
458
459 let part = ResponseOutputText {
461 _type: "output_text".to_string(),
462 text: content.clone(),
463 annotations: vec![],
464 logprobs: vec![],
465 };
466 let content_part_done_event = ResponseContentPartDoneEvent {
467 _type: "response.content_part.done".to_string(),
468 sequence_number,
469 output_index: 1,
470 item_id: message_id.clone(),
471 content_index: 0,
472 part,
473 };
474 yield Ok::<_, Infallible>(Event::default().event("response.content_part.done").data(serde_json::to_string(&content_part_done_event).unwrap()));
475 sequence_number += 1;
476
477 let final_message_item = ResponseOutputMessage {
479 id: message_id.clone(),
480 _type: "message".to_string(),
481 content: vec![ResponseOutputText {
482 _type: "output_text".to_string(),
483 text: content.clone(),
484 annotations: vec![],
485 logprobs: vec![],
486 }],
487 role: "assistant".to_string(),
488 status: "completed".to_string(),
489 };
490 let output_item_done_event = ResponseOutputItemDoneEvent {
491 _type: "response.output_item.done".to_string(),
492 sequence_number,
493 output_index: 1,
494 item: OutputItem::Message(final_message_item.clone()),
495 };
496 if let Some(ResponseOutputItem::Message(msg)) = response.output.get_mut(1) {
497 *msg = final_message_item;
498 }
499 yield Ok::<_, Infallible>(Event::default().event("response.output_item.done").data(serde_json::to_string(&output_item_done_event).unwrap()));
500 sequence_number += 1;
501
502 response.status = "completed".to_string();
504 response.usage = Some(ResponseUsage {
505 input_tokens: prompt_tokens,
506 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
507 output_tokens: completion_tokens + 128, output_tokens_details: OutputTokensDetails { reasoning_tokens: 128 },
509 total_tokens: total_tokens + 128,
510 });
511 let completed_event = ResponseEvent {
512 _type: "response.completed".to_string(),
513 sequence_number,
514 response: response.clone(),
515 };
516 yield Ok::<_, Infallible>(Event::default().event("response.completed").data(serde_json::to_string(&completed_event).unwrap()));
517
518 yield Ok::<_, Infallible>(Event::default().data("[DONE]"));
520 };
521
522 return Sse::new(stream).into_response();
523 } else {
524 let output_text = ResponseOutputText {
525 _type: "output_text".to_string(),
526 text: content.clone(),
527 ..Default::default()
528 };
529
530 let message_item = ResponseOutputMessage {
531 id: message_id,
532 _type: "message".to_string(),
533 content: vec![output_text],
534 role: "assistant".to_string(),
535 status: "completed".to_string(),
536 };
537
538 let response = Response {
539 id: response_id,
540 object: "response".to_string(),
541 created_at,
542 model,
543 status: "completed".to_string(),
544 output: vec![ResponseOutputItem::Message(message_item)],
545 usage: Some(ResponseUsage {
546 input_tokens: prompt_tokens,
547 input_tokens_details: InputTokensDetails { cached_tokens: 0 },
548 output_tokens: completion_tokens,
549 output_tokens_details: OutputTokensDetails {
550 reasoning_tokens: 0,
551 },
552 total_tokens,
553 }),
554 instructions: payload.instructions,
555 parallel_tool_calls: true,
556 temperature: 1.0,
557 tool_choice: "auto".to_string(),
558 top_p: 1.0,
559 background: false,
560 reasoning: Reasoning {
561 effort: "medium".to_string(),
562 ..Default::default()
563 },
564 service_tier: "auto".to_string(),
565 text: ResponseTextConfig {
566 format: ResponseFormatText {
567 _type: "text".to_string(),
568 },
569 verbosity: "medium".to_string(),
570 },
571 top_logprobs: 0,
572 truncation: "disabled".to_string(),
573 store: false,
574 ..Default::default()
575 };
576
577 return (headers, Json(json!(response))).into_response();
578 }
579}