roy_cli/
chat_completion.rs

1// Copyright 2025 Massimiliano Pippi
2// SPDX-License-Identifier: MIT
3
4use axum::{
5    extract::State,
6    http::{HeaderMap, StatusCode},
7    response::Json,
8};
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::time::{SystemTime, UNIX_EPOCH};
13
14use crate::server_state::ServerState;
15
16#[derive(Deserialize)]
17pub struct ChatCompletionRequest {
18    pub messages: Option<Vec<Value>>,
19    pub model: Option<String>,
20    #[serde(flatten)]
21    pub _other: Value,
22}
23
24#[derive(Serialize)]
25pub struct ChatCompletionResponse {
26    pub id: String,
27    pub object: String,
28    pub created: u64,
29    pub model: String,
30    pub choices: Vec<Choice>,
31    pub usage: Usage,
32}
33
34#[derive(Serialize)]
35pub struct Choice {
36    pub index: u32,
37    pub message: Message,
38    pub finish_reason: String,
39}
40
41#[derive(Serialize)]
42pub struct Message {
43    pub role: String,
44    pub content: String,
45}
46
47#[derive(Serialize)]
48pub struct Usage {
49    pub prompt_tokens: u32,
50    pub completion_tokens: u32,
51    pub total_tokens: u32,
52}
53
54pub async fn chat_completions(
55    state: State<ServerState>,
56    Json(payload): Json<ChatCompletionRequest>,
57) -> Result<(HeaderMap, Json<Value>), (StatusCode, HeaderMap, Json<Value>)> {
58    if state.check_request_limit_exceeded() {
59        let headers = state.get_rate_limit_headers();
60        let error_body = json!({
61            "error": {
62                "message": "Too many requests",
63                "type": "rate_limit_error",
64                "code": "rate_limit_exceeded"
65            }
66        });
67        return Err((StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)));
68    }
69    state.increment_request_count();
70
71    if let Some(error_code) = state.should_return_error() {
72        let headers = state.get_rate_limit_headers();
73        let status_code =
74            StatusCode::from_u16(error_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
75
76        let error_body = json!({
77            "error": {
78                "message": format!("Simulated error with code {}", error_code),
79                "type": "api_error",
80                "code": error_code.to_string()
81            }
82        });
83
84        return Err((status_code, headers, Json(error_body)));
85    }
86
87    let response_length = state.get_response_length();
88
89    if response_length == 0 {
90        let headers = state.get_rate_limit_headers();
91        return Err((StatusCode::NO_CONTENT, headers, Json(json!({}))));
92    }
93
94    let content = state.generate_lorem_content(response_length);
95
96    let prompt_text = payload
97        .messages
98        .as_ref()
99        .map(|msgs| serde_json::to_string(msgs).unwrap_or_default())
100        .unwrap_or_default();
101
102    let prompt_tokens = state.count_tokens(&prompt_text).unwrap_or(0);
103    let completion_tokens = state.count_tokens(&content).unwrap_or(0);
104    let total_tokens = prompt_tokens + completion_tokens;
105
106    if state.check_token_limit_exceeded(total_tokens) {
107        let headers = state.get_rate_limit_headers();
108        let error_body = json!({
109            "error": {
110                "message": "You have exceeded your token quota.",
111                "type": "rate_limit_error",
112                "code": "rate_limit_exceeded"
113            }
114        });
115        return Err((StatusCode::TOO_MANY_REQUESTS, headers, Json(error_body)));
116    }
117    state.add_token_usage(total_tokens);
118
119    let response = ChatCompletionResponse {
120        id: format!("chatcmpl-{}", rand::thread_rng().gen::<u32>()),
121        object: "chat.completion".to_string(),
122        created: SystemTime::now()
123            .duration_since(UNIX_EPOCH)
124            .expect("should be able to get duration")
125            .as_secs(),
126        model: payload.model.unwrap_or_else(|| "gpt-3.5-turbo".to_string()),
127        choices: vec![Choice {
128            index: 0,
129            message: Message {
130                role: "assistant".to_string(),
131                content,
132            },
133            finish_reason: "stop".to_string(),
134        }],
135        usage: Usage {
136            prompt_tokens,
137            completion_tokens,
138            total_tokens,
139        },
140    };
141
142    let headers = state.get_rate_limit_headers();
143    Ok((headers, Json(json!(response))))
144}