Skip to main content

roy_cli/
chat_completions.rs

1// Copyright 2025 Massimiliano Pippi
2// SPDX-License-Identifier: MIT
3
4use axum::{
5    extract::State,
6    http::StatusCode,
7    response::{sse::Event, IntoResponse, Sse},
8    Json,
9};
10use futures_util::stream;
11use rand::Rng;
12use serde::{Deserialize, Serialize};
13use serde_json::{json, Value};
14use std::convert::Infallible;
15use std::time::{SystemTime, UNIX_EPOCH};
16
17use crate::server_state::ServerState;
18
19#[derive(Serialize, Debug)]
20pub struct Usage {
21    pub prompt_tokens: u32,
22    pub completion_tokens: u32,
23    pub total_tokens: u32,
24}
25
26#[derive(Deserialize)]
27pub struct ChatCompletionRequest {
28    pub messages: Option<Vec<Value>>,
29    pub model: Option<String>,
30    #[serde(default)]
31    pub stream: Option<bool>,
32    #[serde(flatten)]
33    pub _other: Value,
34}
35
36#[derive(Serialize)]
37pub struct ChatCompletionResponse {
38    pub id: String,
39    pub object: String,
40    pub created: u64,
41    pub model: String,
42    pub choices: Vec<Choice>,
43    pub usage: Usage,
44}
45
46#[derive(Serialize, Debug)]
47pub struct ChatCompletionChunk {
48    pub id: String,
49    pub object: String,
50    pub created: u64,
51    pub model: String,
52    pub choices: Vec<ChunkChoice>,
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub usage: Option<Usage>,
55}
56
57#[derive(Serialize, Debug)]
58pub struct ChunkChoice {
59    pub index: u32,
60    pub delta: ChoiceDelta,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub finish_reason: Option<String>,
63}
64
65#[derive(Serialize, Debug, Default)]
66pub struct ChoiceDelta {
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub role: Option<String>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub content: Option<String>,
71}
72
73#[derive(Serialize)]
74pub struct Choice {
75    pub index: u32,
76    pub message: Message,
77    pub finish_reason: String,
78}
79
80#[derive(Serialize)]
81pub struct Message {
82    pub role: String,
83    pub content: String,
84}
85
86pub async fn chat_completions(
87    state: State<ServerState>,
88    Json(payload): Json<ChatCompletionRequest>,
89) -> impl IntoResponse {
90    if state.check_request_limit_exceeded() {
91        let headers = state.get_rate_limit_headers();
92        let error_body = json!({
93            "error": {
94                "message": "Too many requests",
95                "type": "rate_limit_error",
96                "code": "rate_limit_exceeded"
97            }
98        });
99        return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
100    }
101    state.increment_request_count();
102
103    if let Some(error_code) = state.should_return_error() {
104        let headers = state.get_rate_limit_headers();
105        let status_code =
106            StatusCode::from_u16(error_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
107
108        let error_body = json!({
109            "error": {
110                "message": format!("Simulated error with code {}", error_code),
111                "type": "api_error",
112                "code": error_code.to_string()
113            }
114        });
115
116        return (status_code, headers, Json(error_body)).into_response();
117    }
118
119    let response_length = state.get_response_length();
120
121    if response_length == 0 {
122        let headers = state.get_rate_limit_headers();
123        return (StatusCode::NO_CONTENT, headers, Json(json!({}))).into_response();
124    }
125
126    let content = state.generate_lorem_content(response_length);
127
128    let prompt_text = payload
129        .messages
130        .as_ref()
131        .map(|msgs| serde_json::to_string(msgs).unwrap_or_default())
132        .unwrap_or_default();
133
134    let prompt_tokens = state.count_tokens(&prompt_text).unwrap_or(0);
135    let completion_tokens = state.count_tokens(&content).unwrap_or(0);
136    let total_tokens = prompt_tokens + completion_tokens;
137
138    if state.check_token_limit_exceeded(total_tokens) {
139        let headers = state.get_rate_limit_headers();
140        let error_body = json!({
141            "error": {
142                "message": "You have exceeded your token quota.",
143                "type": "rate_limit_error",
144                "code": "rate_limit_exceeded"
145            }
146        });
147        return (StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)).into_response();
148    }
149    state.add_token_usage(total_tokens);
150
151    let stream_response = payload.stream.unwrap_or(false);
152    if stream_response {
153        let id = format!("chatcmpl-{}", rand::thread_rng().gen::<u32>());
154        let created = SystemTime::now()
155            .duration_since(UNIX_EPOCH)
156            .expect("should be able to get duration")
157            .as_secs();
158        let model = payload
159            .model
160            .clone()
161            .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
162        let words = content
163            .split_whitespace()
164            .map(|s| s.to_string())
165            .collect::<Vec<_>>();
166
167        let mut events = vec![];
168
169        // 1. First chunk with role
170        let first_chunk = ChatCompletionChunk {
171            id: id.clone(),
172            object: "chat.completion.chunk".to_string(),
173            created,
174            model: model.clone(),
175            choices: vec![ChunkChoice {
176                index: 0,
177                delta: ChoiceDelta {
178                    role: Some("assistant".to_string()),
179                    content: None,
180                },
181                finish_reason: None,
182            }],
183            usage: None,
184        };
185        events.push(Ok::<_, Infallible>(
186            Event::default().data(serde_json::to_string(&first_chunk).unwrap()),
187        ));
188
189        // 2. Content chunks
190        for word in words {
191            let chunk = ChatCompletionChunk {
192                id: id.clone(),
193                object: "chat.completion.chunk".to_string(),
194                created,
195                model: model.clone(),
196                choices: vec![ChunkChoice {
197                    index: 0,
198                    delta: ChoiceDelta {
199                        role: None,
200                        content: Some(format!("{} ", word)),
201                    },
202                    finish_reason: None,
203                }],
204                usage: None,
205            };
206            events.push(Ok(
207                Event::default().data(serde_json::to_string(&chunk).unwrap())
208            ));
209        }
210
211        // 3. Final chunk with finish_reason
212        let final_chunk = ChatCompletionChunk {
213            id: id.clone(),
214            object: "chat.completion.chunk".to_string(),
215            created,
216            model: model.clone(),
217            choices: vec![ChunkChoice {
218                index: 0,
219                delta: Default::default(),
220                finish_reason: Some("stop".to_string()),
221            }],
222            usage: Some(Usage {
223                prompt_tokens,
224                completion_tokens,
225                total_tokens,
226            }),
227        };
228        events.push(Ok(
229            Event::default().data(serde_json::to_string(&final_chunk).unwrap())
230        ));
231
232        // 4. Done message
233        events.push(Ok(Event::default().data("[DONE]")));
234
235        let stream = stream::iter(events);
236
237        return Sse::new(stream).into_response();
238    }
239
240    let response = ChatCompletionResponse {
241        id: format!("chatcmpl-{}", rand::thread_rng().gen::<u32>()),
242        object: "chat.completion".to_string(),
243        created: SystemTime::now()
244            .duration_since(UNIX_EPOCH)
245            .expect("should be able to get duration")
246            .as_secs(),
247        model: payload.model.unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
248        choices: vec![Choice {
249            index: 0,
250            message: Message {
251                role: "assistant".to_string(),
252                content,
253            },
254            finish_reason: "stop".to_string(),
255        }],
256        usage: Usage {
257            prompt_tokens,
258            completion_tokens,
259            total_tokens,
260        },
261    };
262
263    let headers = state.get_rate_limit_headers();
264    (headers, Json(json!(response))).into_response()
265}