roy_cli/
server_state.rs

1// Copyright 2025 Massimiliano Pippi
2// SPDX-License-Identifier: MIT
3
4use axum::http::HeaderMap;
5use humantime;
6use rand::Rng;
7use std::{
8    collections::VecDeque,
9    sync::{Arc, Mutex},
10    time::{Duration, SystemTime},
11};
12use tiktoken_rs::cl100k_base;
13
14use crate::Args;
15
16#[derive(Clone)]
17pub struct ServerState {
18    args: Args,
19    request_timestamps: Arc<Mutex<VecDeque<SystemTime>>>,
20    token_usage_timestamps: Arc<Mutex<VecDeque<(SystemTime, u32)>>>,
21}
22
23impl ServerState {
24    pub fn new(args: Args) -> Self {
25        Self {
26            args,
27            request_timestamps: Arc::new(Mutex::new(VecDeque::new())),
28            token_usage_timestamps: Arc::new(Mutex::new(VecDeque::new())),
29        }
30    }
31
32    pub fn should_return_error(&self) -> Option<u16> {
33        if let (Some(code), Some(rate)) = (self.args.error_code, self.args.error_rate) {
34            let mut rng = rand::thread_rng();
35            if rng.gen_range(0..100) < rate {
36                return Some(code);
37            }
38        }
39        None
40    }
41
42    pub fn get_response_length(&self) -> usize {
43        match &self.args.response_length {
44            Some(length_str) => {
45                if let Some(pos) = length_str.find(':') {
46                    let min: usize = length_str[..pos].parse().unwrap_or(0);
47                    let max: usize = length_str[pos + 1..].parse().unwrap_or(100);
48                    rand::thread_rng().gen_range(min..=max)
49                } else {
50                    length_str.parse().unwrap_or(0)
51                }
52            }
53            None => 0,
54        }
55    }
56
57    pub fn generate_lorem_content(&self, length: usize) -> String {
58        if length == 0 {
59            return String::new();
60        }
61        let word_count = length / 5;
62        let mut content = lipsum::lipsum(word_count);
63        content.truncate(length);
64        content
65    }
66
67    pub fn count_tokens(&self, text: &str) -> anyhow::Result<u32> {
68        let bpe = cl100k_base()?;
69        Ok(bpe.encode_with_special_tokens(text).len() as u32)
70    }
71
72    pub fn check_request_limit_exceeded(&self) -> bool {
73        let mut timestamps = self.request_timestamps.lock().unwrap();
74        let now = SystemTime::now();
75        let sixty_seconds_ago = now - Duration::from_secs(60);
76
77        // Prune old timestamps
78        while let Some(front) = timestamps.front() {
79            if *front < sixty_seconds_ago {
80                timestamps.pop_front();
81            } else {
82                break;
83            }
84        }
85
86        // Check limit
87        timestamps.len() as u32 >= self.args.rpm
88    }
89
90    pub fn check_token_limit_exceeded(&self, new_tokens: u32) -> bool {
91        let mut timestamps = self.token_usage_timestamps.lock().unwrap();
92        let now = SystemTime::now();
93        let sixty_seconds_ago = now - Duration::from_secs(60);
94
95        // Prune old entries
96        while let Some((t, _)) = timestamps.front() {
97            if *t < sixty_seconds_ago {
98                timestamps.pop_front();
99            } else {
100                break;
101            }
102        }
103
104        let current_token_usage: u32 = timestamps.iter().map(|(_, tokens)| tokens).sum();
105
106        // Check limit
107        (current_token_usage + new_tokens) > self.args.tpm
108    }
109
110    pub fn increment_request_count(&self) {
111        let mut timestamps = self.request_timestamps.lock().unwrap();
112        let now = SystemTime::now();
113        let sixty_seconds_ago = now - Duration::from_secs(60);
114        while let Some(front) = timestamps.front() {
115            if *front < sixty_seconds_ago {
116                timestamps.pop_front();
117            } else {
118                break;
119            }
120        }
121        timestamps.push_back(now);
122    }
123
124    pub fn add_token_usage(&self, tokens: u32) {
125        let mut timestamps = self.token_usage_timestamps.lock().unwrap();
126        let now = SystemTime::now();
127        let sixty_seconds_ago = now - Duration::from_secs(60);
128        while let Some((t, _)) = timestamps.front() {
129            if *t < sixty_seconds_ago {
130                timestamps.pop_front();
131            } else {
132                break;
133            }
134        }
135        timestamps.push_back((now, tokens));
136    }
137
138    pub fn get_rate_limit_headers(&self) -> HeaderMap {
139        let mut headers = HeaderMap::new();
140        let now = SystemTime::now();
141
142        // Requests logic
143        let mut timestamps = self.request_timestamps.lock().unwrap();
144        let sixty_seconds_ago = now - Duration::from_secs(60);
145
146        while let Some(front) = timestamps.front() {
147            if *front < sixty_seconds_ago {
148                timestamps.pop_front();
149            } else {
150                break;
151            }
152        }
153
154        let request_count = timestamps.len() as u32;
155        let limit = self.args.rpm;
156        let remaining = limit.saturating_sub(request_count);
157
158        let reset_duration = if request_count < limit {
159            Duration::ZERO
160        } else {
161            if let Some(oldest) = timestamps.front() {
162                (*oldest + Duration::from_secs(60))
163                    .duration_since(now)
164                    .unwrap_or(Duration::ZERO)
165            } else {
166                Duration::ZERO
167            }
168        };
169        let reset_duration_rounded = Duration::from_secs(reset_duration.as_secs());
170
171        headers.insert(
172            "x-ratelimit-limit-requests",
173            limit.to_string().parse().unwrap(),
174        );
175        headers.insert(
176            "x-ratelimit-remaining-requests",
177            remaining.to_string().parse().unwrap(),
178        );
179        headers.insert(
180            "x-ratelimit-reset-requests",
181            humantime::format_duration(reset_duration_rounded)
182                .to_string()
183                .parse()
184                .expect("x-ratelimit-reset-requests must be a valid header value"),
185        );
186
187        // Tokens logic
188        let mut token_timestamps = self.token_usage_timestamps.lock().unwrap();
189        while let Some((t, _)) = token_timestamps.front() {
190            if *t < sixty_seconds_ago {
191                token_timestamps.pop_front();
192            } else {
193                break;
194            }
195        }
196
197        let current_token_usage: u32 = token_timestamps.iter().map(|(_, tokens)| tokens).sum();
198        let token_limit = self.args.tpm;
199        let remaining_tokens = token_limit.saturating_sub(current_token_usage);
200
201        let token_reset_duration = if current_token_usage < token_limit {
202            Duration::ZERO
203        } else {
204            if let Some((oldest_ts, _)) = token_timestamps.front() {
205                (*oldest_ts + Duration::from_secs(60))
206                    .duration_since(now)
207                    .unwrap_or(Duration::ZERO)
208            } else {
209                Duration::ZERO
210            }
211        };
212        let token_reset_duration_rounded = Duration::from_secs(token_reset_duration.as_secs());
213
214        headers.insert(
215            "x-ratelimit-limit-tokens",
216            token_limit.to_string().parse().unwrap(),
217        );
218        headers.insert(
219            "x-ratelimit-remaining-tokens",
220            remaining_tokens.to_string().parse().unwrap(),
221        );
222        headers.insert(
223            "x-ratelimit-reset-tokens",
224            humantime::format_duration(token_reset_duration_rounded)
225                .to_string()
226                .parse()
227                .expect("x-ratelimit-reset-tokens must be a valid header value"),
228        );
229
230        headers
231    }
232}