1use axum::http::HeaderMap;
5use humantime;
6use rand::Rng;
7use std::{
8 collections::VecDeque,
9 sync::{Arc, Mutex},
10 time::{Duration, SystemTime},
11};
12use tiktoken_rs::cl100k_base;
13
14use crate::Args;
15
16#[derive(Clone)]
17pub struct ServerState {
18 args: Args,
19 request_timestamps: Arc<Mutex<VecDeque<SystemTime>>>,
20 token_usage_timestamps: Arc<Mutex<VecDeque<(SystemTime, u32)>>>,
21}
22
23impl ServerState {
24 pub fn new(args: Args) -> Self {
25 Self {
26 args,
27 request_timestamps: Arc::new(Mutex::new(VecDeque::new())),
28 token_usage_timestamps: Arc::new(Mutex::new(VecDeque::new())),
29 }
30 }
31
32 pub fn should_return_error(&self) -> Option<u16> {
33 if let (Some(code), Some(rate)) = (self.args.error_code, self.args.error_rate) {
34 let mut rng = rand::thread_rng();
35 if rng.gen_range(0..100) < rate {
36 return Some(code);
37 }
38 }
39 None
40 }
41
42 pub fn get_response_length(&self) -> usize {
43 match &self.args.response_length {
44 Some(length_str) => {
45 if let Some(pos) = length_str.find(':') {
46 let min: usize = length_str[..pos].parse().unwrap_or(0);
47 let max: usize = length_str[pos + 1..].parse().unwrap_or(100);
48 rand::thread_rng().gen_range(min..=max)
49 } else {
50 length_str.parse().unwrap_or(0)
51 }
52 }
53 None => 0,
54 }
55 }
56
57 pub fn generate_lorem_content(&self, length: usize) -> String {
58 if length == 0 {
59 return String::new();
60 }
61 let word_count = length / 5;
62 let mut content = lipsum::lipsum(word_count);
63 content.truncate(length);
64 content
65 }
66
67 pub fn count_tokens(&self, text: &str) -> anyhow::Result<u32> {
68 let bpe = cl100k_base()?;
69 Ok(bpe.encode_with_special_tokens(text).len() as u32)
70 }
71
72 pub fn check_request_limit_exceeded(&self) -> bool {
73 let mut timestamps = self.request_timestamps.lock().unwrap();
74 let now = SystemTime::now();
75 let sixty_seconds_ago = now - Duration::from_secs(60);
76
77 while let Some(front) = timestamps.front() {
79 if *front < sixty_seconds_ago {
80 timestamps.pop_front();
81 } else {
82 break;
83 }
84 }
85
86 timestamps.len() as u32 >= self.args.rpm
88 }
89
90 pub fn check_token_limit_exceeded(&self, new_tokens: u32) -> bool {
91 let mut timestamps = self.token_usage_timestamps.lock().unwrap();
92 let now = SystemTime::now();
93 let sixty_seconds_ago = now - Duration::from_secs(60);
94
95 while let Some((t, _)) = timestamps.front() {
97 if *t < sixty_seconds_ago {
98 timestamps.pop_front();
99 } else {
100 break;
101 }
102 }
103
104 let current_token_usage: u32 = timestamps.iter().map(|(_, tokens)| tokens).sum();
105
106 (current_token_usage + new_tokens) > self.args.tpm
108 }
109
110 pub fn increment_request_count(&self) {
111 let mut timestamps = self.request_timestamps.lock().unwrap();
112 let now = SystemTime::now();
113 let sixty_seconds_ago = now - Duration::from_secs(60);
114 while let Some(front) = timestamps.front() {
115 if *front < sixty_seconds_ago {
116 timestamps.pop_front();
117 } else {
118 break;
119 }
120 }
121 timestamps.push_back(now);
122 }
123
124 pub fn add_token_usage(&self, tokens: u32) {
125 let mut timestamps = self.token_usage_timestamps.lock().unwrap();
126 let now = SystemTime::now();
127 let sixty_seconds_ago = now - Duration::from_secs(60);
128 while let Some((t, _)) = timestamps.front() {
129 if *t < sixty_seconds_ago {
130 timestamps.pop_front();
131 } else {
132 break;
133 }
134 }
135 timestamps.push_back((now, tokens));
136 }
137
138 pub fn get_rate_limit_headers(&self) -> HeaderMap {
139 let mut headers = HeaderMap::new();
140 let now = SystemTime::now();
141
142 let mut timestamps = self.request_timestamps.lock().unwrap();
144 let sixty_seconds_ago = now - Duration::from_secs(60);
145
146 while let Some(front) = timestamps.front() {
147 if *front < sixty_seconds_ago {
148 timestamps.pop_front();
149 } else {
150 break;
151 }
152 }
153
154 let request_count = timestamps.len() as u32;
155 let limit = self.args.rpm;
156 let remaining = limit.saturating_sub(request_count);
157
158 let reset_duration = if request_count < limit {
159 Duration::ZERO
160 } else {
161 if let Some(oldest) = timestamps.front() {
162 (*oldest + Duration::from_secs(60))
163 .duration_since(now)
164 .unwrap_or(Duration::ZERO)
165 } else {
166 Duration::ZERO
167 }
168 };
169 let reset_duration_rounded = Duration::from_secs(reset_duration.as_secs());
170
171 headers.insert(
172 "x-ratelimit-limit-requests",
173 limit.to_string().parse().unwrap(),
174 );
175 headers.insert(
176 "x-ratelimit-remaining-requests",
177 remaining.to_string().parse().unwrap(),
178 );
179 headers.insert(
180 "x-ratelimit-reset-requests",
181 humantime::format_duration(reset_duration_rounded)
182 .to_string()
183 .parse()
184 .expect("x-ratelimit-reset-requests must be a valid header value"),
185 );
186
187 let mut token_timestamps = self.token_usage_timestamps.lock().unwrap();
189 while let Some((t, _)) = token_timestamps.front() {
190 if *t < sixty_seconds_ago {
191 token_timestamps.pop_front();
192 } else {
193 break;
194 }
195 }
196
197 let current_token_usage: u32 = token_timestamps.iter().map(|(_, tokens)| tokens).sum();
198 let token_limit = self.args.tpm;
199 let remaining_tokens = token_limit.saturating_sub(current_token_usage);
200
201 let token_reset_duration = if current_token_usage < token_limit {
202 Duration::ZERO
203 } else {
204 if let Some((oldest_ts, _)) = token_timestamps.front() {
205 (*oldest_ts + Duration::from_secs(60))
206 .duration_since(now)
207 .unwrap_or(Duration::ZERO)
208 } else {
209 Duration::ZERO
210 }
211 };
212 let token_reset_duration_rounded = Duration::from_secs(token_reset_duration.as_secs());
213
214 headers.insert(
215 "x-ratelimit-limit-tokens",
216 token_limit.to_string().parse().unwrap(),
217 );
218 headers.insert(
219 "x-ratelimit-remaining-tokens",
220 remaining_tokens.to_string().parse().unwrap(),
221 );
222 headers.insert(
223 "x-ratelimit-reset-tokens",
224 humantime::format_duration(token_reset_duration_rounded)
225 .to_string()
226 .parse()
227 .expect("x-ratelimit-reset-tokens must be a valid header value"),
228 );
229
230 headers
231 }
232}