1use axum::http::HeaderMap;
5use humantime;
6use rand::Rng;
7use std::{
8 collections::VecDeque,
9 sync::{Arc, Mutex},
10 time::{Duration, SystemTime},
11};
12use tiktoken_rs::cl100k_base;
13
14use crate::Args;
15
16#[derive(Clone)]
17pub struct ServerState {
18 args: Args,
19 request_timestamps: Arc<Mutex<VecDeque<SystemTime>>>,
20 token_usage_timestamps: Arc<Mutex<VecDeque<(SystemTime, u32)>>>,
21}
22
23impl ServerState {
24 pub fn new(args: Args) -> Self {
25 Self {
26 args,
27 request_timestamps: Arc::new(Mutex::new(VecDeque::new())),
28 token_usage_timestamps: Arc::new(Mutex::new(VecDeque::new())),
29 }
30 }
31
32 pub fn should_return_error(&self) -> Option<u16> {
33 if let (Some(code), Some(rate)) = (self.args.error_code, self.args.error_rate) {
34 let mut rng = rand::thread_rng();
35 if rng.gen_range(0..100) < rate {
36 return Some(code);
37 }
38 }
39 None
40 }
41
42 pub fn get_response_length(&self) -> usize {
43 match &self.args.response_length {
44 Some(length_str) => {
45 if let Some(pos) = length_str.find(':') {
46 let min: usize = length_str[..pos].parse().unwrap_or(0);
47 let max: usize = length_str[pos + 1..].parse().unwrap_or(100);
48 rand::thread_rng().gen_range(min..=max)
49 } else {
50 length_str.parse().unwrap_or(0)
51 }
52 }
53 None => 0,
54 }
55 }
56
57 pub fn get_slodown_ms(&self) -> u64 {
58 match &self.args.slowdown {
59 Some(slowdown_str) => {
60 if let Some(pos) = slowdown_str.find(':') {
61 let min: u64 = slowdown_str[..pos].parse().unwrap_or(0);
62 let max: u64 = slowdown_str[pos + 1..].parse().unwrap_or(600000); rand::thread_rng().gen_range(min..=max)
64 } else {
65 slowdown_str.parse().unwrap_or(0)
66 }
67 }
68 None => 0, }
70 }
71
72 pub fn generate_lorem_content(&self, length: usize) -> String {
73 if length == 0 {
74 return String::new();
75 }
76 let word_count = length / 5;
77 let mut content = lipsum::lipsum(word_count);
78 content.truncate(length);
79 content
80 }
81
82 pub fn count_tokens(&self, text: &str) -> anyhow::Result<u32> {
83 let bpe = cl100k_base()?;
84 Ok(bpe.encode_with_special_tokens(text).len() as u32)
85 }
86
87 pub fn check_request_limit_exceeded(&self) -> bool {
88 let mut timestamps = self.request_timestamps.lock().unwrap();
89 let now = SystemTime::now();
90 let sixty_seconds_ago = now - Duration::from_secs(60);
91
92 while let Some(front) = timestamps.front() {
94 if *front < sixty_seconds_ago {
95 timestamps.pop_front();
96 } else {
97 break;
98 }
99 }
100
101 timestamps.len() as u32 >= self.args.rpm
103 }
104
105 pub fn check_token_limit_exceeded(&self, new_tokens: u32) -> bool {
106 let mut timestamps = self.token_usage_timestamps.lock().unwrap();
107 let now = SystemTime::now();
108 let sixty_seconds_ago = now - Duration::from_secs(60);
109
110 while let Some((t, _)) = timestamps.front() {
112 if *t < sixty_seconds_ago {
113 timestamps.pop_front();
114 } else {
115 break;
116 }
117 }
118
119 let current_token_usage: u32 = timestamps.iter().map(|(_, tokens)| tokens).sum();
120
121 (current_token_usage + new_tokens) > self.args.tpm
123 }
124
125 pub fn increment_request_count(&self) {
126 let mut timestamps = self.request_timestamps.lock().unwrap();
127 let now = SystemTime::now();
128 let sixty_seconds_ago = now - Duration::from_secs(60);
129 while let Some(front) = timestamps.front() {
130 if *front < sixty_seconds_ago {
131 timestamps.pop_front();
132 } else {
133 break;
134 }
135 }
136 timestamps.push_back(now);
137 }
138
139 pub fn add_token_usage(&self, tokens: u32) {
140 let mut timestamps = self.token_usage_timestamps.lock().unwrap();
141 let now = SystemTime::now();
142 let sixty_seconds_ago = now - Duration::from_secs(60);
143 while let Some((t, _)) = timestamps.front() {
144 if *t < sixty_seconds_ago {
145 timestamps.pop_front();
146 } else {
147 break;
148 }
149 }
150 timestamps.push_back((now, tokens));
151 }
152
153 pub fn get_rate_limit_headers(&self) -> HeaderMap {
154 let mut headers = HeaderMap::new();
155 let now = SystemTime::now();
156
157 let mut timestamps = self.request_timestamps.lock().unwrap();
159 let sixty_seconds_ago = now - Duration::from_secs(60);
160
161 while let Some(front) = timestamps.front() {
162 if *front < sixty_seconds_ago {
163 timestamps.pop_front();
164 } else {
165 break;
166 }
167 }
168
169 let request_count = timestamps.len() as u32;
170 let limit = self.args.rpm;
171 let remaining = limit.saturating_sub(request_count);
172
173 let reset_duration = if request_count < limit {
174 Duration::ZERO
175 } else {
176 if let Some(oldest) = timestamps.front() {
177 (*oldest + Duration::from_secs(60))
178 .duration_since(now)
179 .unwrap_or(Duration::ZERO)
180 } else {
181 Duration::ZERO
182 }
183 };
184 let reset_duration_rounded = Duration::from_secs(reset_duration.as_secs());
185
186 headers.insert(
187 "x-ratelimit-limit-requests",
188 limit.to_string().parse().unwrap(),
189 );
190 headers.insert(
191 "x-ratelimit-remaining-requests",
192 remaining.to_string().parse().unwrap(),
193 );
194 headers.insert(
195 "x-ratelimit-reset-requests",
196 humantime::format_duration(reset_duration_rounded)
197 .to_string()
198 .parse()
199 .expect("x-ratelimit-reset-requests must be a valid header value"),
200 );
201
202 let mut token_timestamps = self.token_usage_timestamps.lock().unwrap();
204 while let Some((t, _)) = token_timestamps.front() {
205 if *t < sixty_seconds_ago {
206 token_timestamps.pop_front();
207 } else {
208 break;
209 }
210 }
211
212 let current_token_usage: u32 = token_timestamps.iter().map(|(_, tokens)| tokens).sum();
213 let token_limit = self.args.tpm;
214 let remaining_tokens = token_limit.saturating_sub(current_token_usage);
215
216 let token_reset_duration = if current_token_usage < token_limit {
217 Duration::ZERO
218 } else {
219 if let Some((oldest_ts, _)) = token_timestamps.front() {
220 (*oldest_ts + Duration::from_secs(60))
221 .duration_since(now)
222 .unwrap_or(Duration::ZERO)
223 } else {
224 Duration::ZERO
225 }
226 };
227 let token_reset_duration_rounded = Duration::from_secs(token_reset_duration.as_secs());
228
229 headers.insert(
230 "x-ratelimit-limit-tokens",
231 token_limit.to_string().parse().unwrap(),
232 );
233 headers.insert(
234 "x-ratelimit-remaining-tokens",
235 remaining_tokens.to_string().parse().unwrap(),
236 );
237 headers.insert(
238 "x-ratelimit-reset-tokens",
239 humantime::format_duration(token_reset_duration_rounded)
240 .to_string()
241 .parse()
242 .expect("x-ratelimit-reset-tokens must be a valid header value"),
243 );
244
245 headers
246 }
247}