Skip to main content

roy_cli/
server_state.rs

1// Copyright 2025 Massimiliano Pippi
2// SPDX-License-Identifier: MIT
3
4use axum::http::HeaderMap;
5use humantime;
6use rand::Rng;
7use std::{
8    collections::VecDeque,
9    sync::{Arc, Mutex},
10    time::{Duration, SystemTime},
11};
12use tiktoken_rs::cl100k_base;
13
14use crate::Args;
15
16#[derive(Clone)]
17pub struct ServerState {
18    args: Args,
19    request_timestamps: Arc<Mutex<VecDeque<SystemTime>>>,
20    token_usage_timestamps: Arc<Mutex<VecDeque<(SystemTime, u32)>>>,
21}
22
23impl ServerState {
24    pub fn new(args: Args) -> Self {
25        Self {
26            args,
27            request_timestamps: Arc::new(Mutex::new(VecDeque::new())),
28            token_usage_timestamps: Arc::new(Mutex::new(VecDeque::new())),
29        }
30    }
31
32    pub fn should_return_error(&self) -> Option<u16> {
33        if let (Some(code), Some(rate)) = (self.args.error_code, self.args.error_rate) {
34            let mut rng = rand::thread_rng();
35            if rng.gen_range(0..100) < rate {
36                return Some(code);
37            }
38        }
39        None
40    }
41
42    pub fn get_response_length(&self) -> usize {
43        match &self.args.response_length {
44            Some(length_str) => {
45                if let Some(pos) = length_str.find(':') {
46                    let min: usize = length_str[..pos].parse().unwrap_or(0);
47                    let max: usize = length_str[pos + 1..].parse().unwrap_or(100);
48                    rand::thread_rng().gen_range(min..=max)
49                } else {
50                    length_str.parse().unwrap_or(0)
51                }
52            }
53            None => 0,
54        }
55    }
56
57    pub fn get_slodown_ms(&self) -> u64 {
58        match &self.args.slowdown {
59            Some(slowdown_str) => {
60                if let Some(pos) = slowdown_str.find(':') {
61                    let min: u64 = slowdown_str[..pos].parse().unwrap_or(0);
62                    let max: u64 = slowdown_str[pos + 1..].parse().unwrap_or(600000); // 10 minutes
63                    rand::thread_rng().gen_range(min..=max)
64                } else {
65                    slowdown_str.parse().unwrap_or(0)
66                }
67            }
68            None => 0, // default is zero, no slowdown
69        }
70    }
71
72    pub fn generate_lorem_content(&self, length: usize) -> String {
73        if length == 0 {
74            return String::new();
75        }
76        let word_count = length / 5;
77        let mut content = lipsum::lipsum(word_count);
78        content.truncate(length);
79        content
80    }
81
82    pub fn count_tokens(&self, text: &str) -> anyhow::Result<u32> {
83        let bpe = cl100k_base()?;
84        Ok(bpe.encode_with_special_tokens(text).len() as u32)
85    }
86
87    pub fn check_request_limit_exceeded(&self) -> bool {
88        let mut timestamps = self.request_timestamps.lock().unwrap();
89        let now = SystemTime::now();
90        let sixty_seconds_ago = now - Duration::from_secs(60);
91
92        // Prune old timestamps
93        while let Some(front) = timestamps.front() {
94            if *front < sixty_seconds_ago {
95                timestamps.pop_front();
96            } else {
97                break;
98            }
99        }
100
101        // Check limit
102        timestamps.len() as u32 >= self.args.rpm
103    }
104
105    pub fn check_token_limit_exceeded(&self, new_tokens: u32) -> bool {
106        let mut timestamps = self.token_usage_timestamps.lock().unwrap();
107        let now = SystemTime::now();
108        let sixty_seconds_ago = now - Duration::from_secs(60);
109
110        // Prune old entries
111        while let Some((t, _)) = timestamps.front() {
112            if *t < sixty_seconds_ago {
113                timestamps.pop_front();
114            } else {
115                break;
116            }
117        }
118
119        let current_token_usage: u32 = timestamps.iter().map(|(_, tokens)| tokens).sum();
120
121        // Check limit
122        (current_token_usage + new_tokens) > self.args.tpm
123    }
124
125    pub fn increment_request_count(&self) {
126        let mut timestamps = self.request_timestamps.lock().unwrap();
127        let now = SystemTime::now();
128        let sixty_seconds_ago = now - Duration::from_secs(60);
129        while let Some(front) = timestamps.front() {
130            if *front < sixty_seconds_ago {
131                timestamps.pop_front();
132            } else {
133                break;
134            }
135        }
136        timestamps.push_back(now);
137    }
138
139    pub fn add_token_usage(&self, tokens: u32) {
140        let mut timestamps = self.token_usage_timestamps.lock().unwrap();
141        let now = SystemTime::now();
142        let sixty_seconds_ago = now - Duration::from_secs(60);
143        while let Some((t, _)) = timestamps.front() {
144            if *t < sixty_seconds_ago {
145                timestamps.pop_front();
146            } else {
147                break;
148            }
149        }
150        timestamps.push_back((now, tokens));
151    }
152
153    pub fn get_rate_limit_headers(&self) -> HeaderMap {
154        let mut headers = HeaderMap::new();
155        let now = SystemTime::now();
156
157        // Requests logic
158        let mut timestamps = self.request_timestamps.lock().unwrap();
159        let sixty_seconds_ago = now - Duration::from_secs(60);
160
161        while let Some(front) = timestamps.front() {
162            if *front < sixty_seconds_ago {
163                timestamps.pop_front();
164            } else {
165                break;
166            }
167        }
168
169        let request_count = timestamps.len() as u32;
170        let limit = self.args.rpm;
171        let remaining = limit.saturating_sub(request_count);
172
173        let reset_duration = if request_count < limit {
174            Duration::ZERO
175        } else {
176            if let Some(oldest) = timestamps.front() {
177                (*oldest + Duration::from_secs(60))
178                    .duration_since(now)
179                    .unwrap_or(Duration::ZERO)
180            } else {
181                Duration::ZERO
182            }
183        };
184        let reset_duration_rounded = Duration::from_secs(reset_duration.as_secs());
185
186        headers.insert(
187            "x-ratelimit-limit-requests",
188            limit.to_string().parse().unwrap(),
189        );
190        headers.insert(
191            "x-ratelimit-remaining-requests",
192            remaining.to_string().parse().unwrap(),
193        );
194        headers.insert(
195            "x-ratelimit-reset-requests",
196            humantime::format_duration(reset_duration_rounded)
197                .to_string()
198                .parse()
199                .expect("x-ratelimit-reset-requests must be a valid header value"),
200        );
201
202        // Tokens logic
203        let mut token_timestamps = self.token_usage_timestamps.lock().unwrap();
204        while let Some((t, _)) = token_timestamps.front() {
205            if *t < sixty_seconds_ago {
206                token_timestamps.pop_front();
207            } else {
208                break;
209            }
210        }
211
212        let current_token_usage: u32 = token_timestamps.iter().map(|(_, tokens)| tokens).sum();
213        let token_limit = self.args.tpm;
214        let remaining_tokens = token_limit.saturating_sub(current_token_usage);
215
216        let token_reset_duration = if current_token_usage < token_limit {
217            Duration::ZERO
218        } else {
219            if let Some((oldest_ts, _)) = token_timestamps.front() {
220                (*oldest_ts + Duration::from_secs(60))
221                    .duration_since(now)
222                    .unwrap_or(Duration::ZERO)
223            } else {
224                Duration::ZERO
225            }
226        };
227        let token_reset_duration_rounded = Duration::from_secs(token_reset_duration.as_secs());
228
229        headers.insert(
230            "x-ratelimit-limit-tokens",
231            token_limit.to_string().parse().unwrap(),
232        );
233        headers.insert(
234            "x-ratelimit-remaining-tokens",
235            remaining_tokens.to_string().parse().unwrap(),
236        );
237        headers.insert(
238            "x-ratelimit-reset-tokens",
239            humantime::format_duration(token_reset_duration_rounded)
240                .to_string()
241                .parse()
242                .expect("x-ratelimit-reset-tokens must be a valid header value"),
243        );
244
245        headers
246    }
247}