Skip to main content

roy_cli/
responses.rs

1// Copyright 2025 Massimiliano Pippi
2// SPDX-License-Identifier: MIT
3
4use crate::server_state::ServerState;
5use axum::{
6    extract::State,
7    http::StatusCode,
8    response::{sse::Event, IntoResponse, Json, Sse},
9};
10use rand::distributions::Alphanumeric;
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::time::{Duration, SystemTime, UNIX_EPOCH};
16use tokio::time::sleep;
17
18#[derive(Deserialize)]
19pub struct ResponsesRequest {
20    pub model: Option<String>,
21    pub input: Option<String>,
22    pub instructions: Option<String>,
23    pub stream: Option<bool>,
24    #[serde(flatten)]
25    pub _other: Value,
26}
27
28// Helper to generate random IDs
29fn generate_id(prefix: &str) -> String {
30    format!("{}_{:x}", prefix, rand::thread_rng().gen::<u128>())
31}
32
33// Data Models
34
35#[derive(Serialize, Clone, Default)]
36struct ResponseFormatText {
37    #[serde(rename = "type")]
38    _type: String,
39}
40
41#[derive(Serialize, Clone, Default)]
42struct ResponseTextConfig {
43    format: ResponseFormatText,
44    verbosity: String,
45}
46
47#[derive(Serialize, Clone, Default)]
48struct Reasoning {
49    effort: String,
50    generate_summary: Option<bool>,
51    summary: Option<String>,
52}
53
54#[derive(Serialize, Clone, Debug)]
55struct ResponseReasoningItem {
56    id: String,
57    #[serde(rename = "type")]
58    _type: String,
59    summary: Vec<Value>,
60    content: Option<Value>,
61    encrypted_content: Option<Value>,
62    status: Option<String>,
63}
64
65#[derive(Serialize, Clone, Debug, Default)]
66struct ResponseOutputText {
67    #[serde(rename = "type")]
68    _type: String,
69    text: String,
70    annotations: Vec<Value>,
71    logprobs: Vec<Value>,
72}
73
74#[derive(Serialize, Clone, Debug)]
75struct ResponseOutputMessage {
76    id: String,
77    #[serde(rename = "type")]
78    _type: String,
79    content: Vec<ResponseOutputText>,
80    role: String,
81    status: String,
82}
83
84#[derive(Serialize, Clone, Debug)]
85#[serde(untagged)]
86enum ResponseOutputItem {
87    Reasoning(ResponseReasoningItem),
88    Message(ResponseOutputMessage),
89}
90
91#[derive(Serialize, Clone)]
92struct InputTokensDetails {
93    cached_tokens: u32,
94}
95
96#[derive(Serialize, Clone)]
97struct OutputTokensDetails {
98    reasoning_tokens: u32,
99}
100
101#[derive(Serialize, Clone)]
102struct ResponseUsage {
103    input_tokens: u32,
104    input_tokens_details: InputTokensDetails,
105    output_tokens: u32,
106    output_tokens_details: OutputTokensDetails,
107    total_tokens: u32,
108}
109
110#[derive(Serialize, Clone, Default)]
111struct Response {
112    id: String,
113    created_at: f64,
114    error: Option<Value>,
115    incomplete_details: Option<Value>,
116    instructions: Option<String>,
117    metadata: Value,
118    model: String,
119    object: String,
120    output: Vec<ResponseOutputItem>,
121    parallel_tool_calls: bool,
122    temperature: f32,
123    tool_choice: String,
124    tools: Vec<Value>,
125    top_p: f32,
126    background: bool,
127    max_output_tokens: Option<u32>,
128    max_tool_calls: Option<u32>,
129    previous_response_id: Option<String>,
130    prompt: Option<String>,
131    prompt_cache_key: Option<String>,
132    reasoning: Reasoning,
133    safety_identifier: Option<String>,
134    service_tier: String,
135    status: String,
136    text: ResponseTextConfig,
137    top_logprobs: u32,
138    truncation: String,
139    #[serde(skip_serializing_if = "Option::is_none")]
140    usage: Option<ResponseUsage>,
141    user: Option<String>,
142    store: bool,
143}
144
145// SSE
146
147#[derive(Serialize)]
148struct ResponseEvent {
149    #[serde(rename = "type")]
150    _type: String,
151    sequence_number: u32,
152    response: Response,
153}
154
155#[derive(Serialize)]
156#[serde(untagged)]
157enum OutputItem {
158    Reasoning(ResponseReasoningItem),
159    Message(ResponseOutputMessage),
160}
161
162#[derive(Serialize)]
163struct ResponseOutputItemAddedEvent {
164    #[serde(rename = "type")]
165    _type: String,
166    sequence_number: u32,
167    output_index: u32,
168    item: OutputItem,
169}
170
171#[derive(Serialize)]
172struct ResponseOutputItemDoneEvent {
173    #[serde(rename = "type")]
174    _type: String,
175    sequence_number: u32,
176    output_index: u32,
177    item: OutputItem,
178}
179
180#[derive(Serialize)]
181struct ResponseContentPartAddedEvent {
182    #[serde(rename = "type")]
183    _type: String,
184    sequence_number: u32,
185    output_index: u32,
186    item_id: String,
187    content_index: u32,
188    part: ResponseOutputText,
189}
190
191#[derive(Serialize)]
192struct ResponseTextDeltaEvent {
193    #[serde(rename = "type")]
194    _type: String,
195    sequence_number: u32,
196    output_index: u32,
197    item_id: String,
198    content_index: u32,
199    delta: String,
200    logprobs: Vec<Value>,
201    obfuscation: String,
202}
203
204#[derive(Serialize)]
205struct ResponseTextDoneEvent {
206    #[serde(rename = "type")]
207    _type: String,
208    sequence_number: u32,
209    output_index: u32,
210    item_id: String,
211    content_index: u32,
212    text: String,
213    logprobs: Vec<Value>,
214}
215
216#[derive(Serialize)]
217struct ResponseContentPartDoneEvent {
218    #[serde(rename = "type")]
219    _type: String,
220    sequence_number: u32,
221    output_index: u32,
222    item_id: String,
223    content_index: u32,
224    part: ResponseOutputText,
225}
226
227pub async fn responses(
228    state: State<ServerState>,
229    Json(payload): Json<ResponsesRequest>,
230) -> impl IntoResponse {
231    if state.check_request_limit_exceeded() {
232        let headers = state.get_rate_limit_headers();
233        let error_body = json!({
234            "error": {
235                "message": "Too many requests",
236                "type": "rate_limit_error",
237                "code": "rate_limit_exceeded"
238            }
239        });
240        return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
241    }
242    state.increment_request_count();
243
244    if let Some(error_code) = state.should_return_error() {
245        let headers = state.get_rate_limit_headers();
246        let status_code =
247            StatusCode::from_u16(error_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
248
249        let error_body = json!({
250            "error": {
251                "message": format!("Simulated error with code {}", error_code),
252                "type": "api_error",
253                "code": error_code.to_string()
254            }
255        });
256
257        return (status_code, headers, Json(error_body)).into_response();
258    }
259
260    let response_length = state.get_response_length();
261
262    if response_length == 0 {
263        let headers = state.get_rate_limit_headers();
264        return (StatusCode::NO_CONTENT, headers, Json(json!({}))).into_response();
265    }
266
267    let content = state.generate_lorem_content(response_length);
268
269    let prompt_text = payload.input.clone().unwrap_or_else(|| "".to_string());
270    let prompt_tokens = state.count_tokens(&prompt_text).unwrap_or(0) as u32;
271    let completion_tokens = state.count_tokens(&content).unwrap_or(0) as u32;
272    let total_tokens = prompt_tokens + completion_tokens;
273
274    if state.check_token_limit_exceeded(total_tokens) {
275        let headers = state.get_rate_limit_headers();
276        let error_body = json!({
277            "error": {
278                "message": "You have exceeded your token quota.",
279                "type": "rate_limit_error",
280                "code": "rate_limit_exceeded"
281            }
282        });
283        return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
284    }
285    state.add_token_usage(total_tokens);
286
287    let headers = state.get_rate_limit_headers();
288    let model = payload
289        .model
290        .clone()
291        .unwrap_or_else(|| "gpt-5-2025-08-07".to_string());
292    let response_id = generate_id("resp");
293    let message_id = generate_id("msg");
294    let created_at = SystemTime::now()
295        .duration_since(UNIX_EPOCH)
296        .expect("should be able to get duration")
297        .as_secs_f64();
298
299    let stream_response = payload.stream.unwrap_or(false);
300    if stream_response {
301        let reasoning_item_id = generate_id("rs");
302        let stream = async_stream::stream! {
303            let mut sequence_number = 0;
304            let mut response = Response {
305                id: response_id.clone(),
306                object: "response".to_string(),
307                created_at,
308                model: model.clone(),
309                status: "in_progress".to_string(),
310                instructions: payload.instructions.clone(),
311                parallel_tool_calls: true,
312                temperature: 1.0,
313                tool_choice: "auto".to_string(),
314                top_p: 1.0,
315                background: false,
316                reasoning: Reasoning {
317                    effort: "medium".to_string(),
318                    ..Default::default()
319                },
320                service_tier: "auto".to_string(),
321                text: ResponseTextConfig {
322                    format: ResponseFormatText { _type: "text".to_string() },
323                    verbosity: "medium".to_string(),
324                },
325                top_logprobs: 0,
326                truncation: "disabled".to_string(),
327                store: false,
328                ..Default::default()
329            };
330
331            // 1. response.created
332            let created_event = ResponseEvent {
333                _type: "response.created".to_string(),
334                sequence_number,
335                response: response.clone(),
336            };
337            yield Ok::<_, Infallible>(Event::default().event("response.created").data(serde_json::to_string(&created_event).unwrap()));
338            sequence_number += 1;
339
340            // 2. response.in_progress
341            let in_progress_event = ResponseEvent {
342                _type: "response.in_progress".to_string(),
343                sequence_number,
344                response: response.clone(),
345            };
346            yield Ok::<_, Infallible>(Event::default().event("response.in_progress").data(serde_json::to_string(&in_progress_event).unwrap()));
347            sequence_number += 1;
348
349            // 3. response.output_item.added (simulate a reasoning event)
350            let reasoning_item = ResponseReasoningItem {
351                id: reasoning_item_id.clone(),
352                _type: "reasoning".to_string(),
353                summary: vec![],
354                content: None,
355                encrypted_content: None,
356                status: None,
357            };
358            let output_item_added_event = ResponseOutputItemAddedEvent {
359                _type: "response.output_item.added".to_string(),
360                sequence_number,
361                output_index: 0,
362                item: OutputItem::Reasoning(reasoning_item.clone()),
363            };
364            response.output.push(ResponseOutputItem::Reasoning(reasoning_item.clone()));
365            yield Ok::<_, Infallible>(Event::default().event("response.output_item.added").data(serde_json::to_string(&output_item_added_event).unwrap()));
366            sequence_number += 1;
367
368            // 4. response.output_item.done (reasoning)
369            let output_item_done_event = ResponseOutputItemDoneEvent {
370                _type: "response.output_item.done".to_string(),
371                sequence_number,
372                output_index: 0,
373                item: OutputItem::Reasoning(reasoning_item.clone()),
374            };
375            yield Ok::<_, Infallible>(Event::default().event("response.output_item.done").data(serde_json::to_string(&output_item_done_event).unwrap()));
376            sequence_number += 1;
377
378            // 5. response.output_item.added (message)
379            let message_item = ResponseOutputMessage {
380                id: message_id.clone(),
381                _type: "message".to_string(),
382                content: vec![],
383                role: "assistant".to_string(),
384                status: "in_progress".to_string(),
385            };
386            let output_item_added_event = ResponseOutputItemAddedEvent {
387                _type: "response.output_item.added".to_string(),
388                sequence_number,
389                output_index: 1,
390                item: OutputItem::Message(message_item.clone()),
391            };
392            response.output.push(ResponseOutputItem::Message(message_item.clone()));
393            yield Ok::<_, Infallible>(Event::default().event("response.output_item.added").data(serde_json::to_string(&output_item_added_event).unwrap()));
394            sequence_number += 1;
395
396            // 6. response.content_part.added
397            let part = ResponseOutputText {
398                _type: "output_text".to_string(),
399                text: "".to_string(),
400                annotations: vec![],
401                logprobs: vec![],
402            };
403            let content_part_added_event = ResponseContentPartAddedEvent {
404                _type: "response.content_part.added".to_string(),
405                sequence_number,
406                output_index: 1,
407                item_id: message_id.clone(),
408                content_index: 0,
409                part: part.clone(),
410            };
411            if let Some(ResponseOutputItem::Message(msg)) = response.output.get_mut(1) {
412                msg.content.push(part);
413            }
414            yield Ok::<_, Infallible>(Event::default().event("response.content_part.added").data(serde_json::to_string(&content_part_added_event).unwrap()));
415            sequence_number += 1;
416
417            // 7. response.output_text.delta
418            let chunks = content.as_bytes().chunks(5);
419            for chunk in chunks {
420                let delta = String::from_utf8_lossy(chunk).to_string();
421                let obfuscation: String = rand::thread_rng()
422                    .sample_iter(&Alphanumeric)
423                    .take(10)
424                    .map(char::from)
425                    .collect();
426                let delta_event = ResponseTextDeltaEvent {
427                    _type: "response.output_text.delta".to_string(),
428                    sequence_number,
429                    output_index: 1,
430                    item_id: message_id.clone(),
431                    content_index: 0,
432                    delta,
433                    logprobs: vec![],
434                    obfuscation,
435                };
436                yield Ok::<_, Infallible>(Event::default().event("response.output_text.delta").data(serde_json::to_string(&delta_event).unwrap()));
437                sequence_number += 1;
438                sleep(Duration::from_millis(10)).await;
439            }
440
441            // 8. response.output_text.done
442            let text_done_event = ResponseTextDoneEvent {
443                _type: "response.output_text.done".to_string(),
444                sequence_number,
445                output_index: 1,
446                item_id: message_id.clone(),
447                content_index: 0,
448                text: content.clone(),
449                logprobs: vec![],
450            };
451            if let Some(ResponseOutputItem::Message(msg)) = response.output.get_mut(1) {
452                if let Some(p) = msg.content.get_mut(0) {
453                    p.text = content.clone();
454                }
455            }
456            yield Ok::<_, Infallible>(Event::default().event("response.output_text.done").data(serde_json::to_string(&text_done_event).unwrap()));
457            sequence_number += 1;
458
459            // 9. response.content_part.done
460            let part = ResponseOutputText {
461                _type: "output_text".to_string(),
462                text: content.clone(),
463                annotations: vec![],
464                logprobs: vec![],
465            };
466            let content_part_done_event = ResponseContentPartDoneEvent {
467                _type: "response.content_part.done".to_string(),
468                sequence_number,
469                output_index: 1,
470                item_id: message_id.clone(),
471                content_index: 0,
472                part,
473            };
474            yield Ok::<_, Infallible>(Event::default().event("response.content_part.done").data(serde_json::to_string(&content_part_done_event).unwrap()));
475            sequence_number += 1;
476
477            // 10. response.output_item.done (message)
478            let final_message_item = ResponseOutputMessage {
479                id: message_id.clone(),
480                _type: "message".to_string(),
481                content: vec![ResponseOutputText {
482                    _type: "output_text".to_string(),
483                    text: content.clone(),
484                    annotations: vec![],
485                    logprobs: vec![],
486                }],
487                role: "assistant".to_string(),
488                status: "completed".to_string(),
489            };
490            let output_item_done_event = ResponseOutputItemDoneEvent {
491                _type: "response.output_item.done".to_string(),
492                sequence_number,
493                output_index: 1,
494                item: OutputItem::Message(final_message_item.clone()),
495            };
496            if let Some(ResponseOutputItem::Message(msg)) = response.output.get_mut(1) {
497                *msg = final_message_item;
498            }
499            yield Ok::<_, Infallible>(Event::default().event("response.output_item.done").data(serde_json::to_string(&output_item_done_event).unwrap()));
500            sequence_number += 1;
501
502            // 11. response.completed
503            response.status = "completed".to_string();
504            response.usage = Some(ResponseUsage {
505                input_tokens: prompt_tokens,
506                input_tokens_details: InputTokensDetails { cached_tokens: 0 },
507                output_tokens: completion_tokens + 128, // mock reasoning tokens
508                output_tokens_details: OutputTokensDetails { reasoning_tokens: 128 },
509                total_tokens: total_tokens + 128,
510            });
511            let completed_event = ResponseEvent {
512                _type: "response.completed".to_string(),
513                sequence_number,
514                response: response.clone(),
515            };
516            yield Ok::<_, Infallible>(Event::default().event("response.completed").data(serde_json::to_string(&completed_event).unwrap()));
517
518            // End of stream
519            yield Ok::<_, Infallible>(Event::default().data("[DONE]"));
520        };
521
522        return Sse::new(stream).into_response();
523    } else {
524        let output_text = ResponseOutputText {
525            _type: "output_text".to_string(),
526            text: content.clone(),
527            ..Default::default()
528        };
529
530        let message_item = ResponseOutputMessage {
531            id: message_id,
532            _type: "message".to_string(),
533            content: vec![output_text],
534            role: "assistant".to_string(),
535            status: "completed".to_string(),
536        };
537
538        let response = Response {
539            id: response_id,
540            object: "response".to_string(),
541            created_at,
542            model,
543            status: "completed".to_string(),
544            output: vec![ResponseOutputItem::Message(message_item)],
545            usage: Some(ResponseUsage {
546                input_tokens: prompt_tokens,
547                input_tokens_details: InputTokensDetails { cached_tokens: 0 },
548                output_tokens: completion_tokens,
549                output_tokens_details: OutputTokensDetails {
550                    reasoning_tokens: 0,
551                },
552                total_tokens,
553            }),
554            instructions: payload.instructions,
555            parallel_tool_calls: true,
556            temperature: 1.0,
557            tool_choice: "auto".to_string(),
558            top_p: 1.0,
559            background: false,
560            reasoning: Reasoning {
561                effort: "medium".to_string(),
562                ..Default::default()
563            },
564            service_tier: "auto".to_string(),
565            text: ResponseTextConfig {
566                format: ResponseFormatText {
567                    _type: "text".to_string(),
568                },
569                verbosity: "medium".to_string(),
570            },
571            top_logprobs: 0,
572            truncation: "disabled".to_string(),
573            store: false,
574            ..Default::default()
575        };
576
577        return (headers, Json(json!(response))).into_response();
578    }
579}