1use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13use tracing::{debug, trace};
14
15use grapsus_common::errors::{GrapsusError, GrapsusResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19#[derive(Debug, Clone)]
21pub struct LeastTokensQueuedConfig {
22 pub ewma_alpha: f64,
25 pub default_tps: f64,
27 pub min_tps: f64,
29}
30
31impl Default for LeastTokensQueuedConfig {
32 fn default() -> Self {
33 Self {
34 ewma_alpha: 0.3,
35 default_tps: 100.0, min_tps: 1.0,
37 }
38 }
39}
40
41struct TargetMetrics {
43 queued_tokens: AtomicU64,
45 queued_requests: AtomicU64,
47 tps_ewma: parking_lot::Mutex<f64>,
49 total_tokens: AtomicU64,
51 total_requests: AtomicU64,
53}
54
55impl TargetMetrics {
56 fn new(default_tps: f64) -> Self {
57 Self {
58 queued_tokens: AtomicU64::new(0),
59 queued_requests: AtomicU64::new(0),
60 tps_ewma: parking_lot::Mutex::new(default_tps),
61 total_tokens: AtomicU64::new(0),
62 total_requests: AtomicU64::new(0),
63 }
64 }
65
66 fn estimated_queue_time(&self, min_tps: f64) -> f64 {
68 let queued = self.queued_tokens.load(Ordering::Relaxed) as f64;
69 let tps = (*self.tps_ewma.lock()).max(min_tps);
70 queued / tps
71 }
72
73 fn enqueue(&self, tokens: u64) {
75 self.queued_tokens.fetch_add(tokens, Ordering::AcqRel);
76 self.queued_requests.fetch_add(1, Ordering::AcqRel);
77 }
78
79 fn dequeue(&self, tokens: u64, duration: Duration, ewma_alpha: f64) {
81 self.queued_tokens.fetch_saturating_sub(tokens);
83 self.queued_requests.fetch_saturating_sub(1);
84
85 self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
87 self.total_requests.fetch_add(1, Ordering::Relaxed);
88
89 if duration.as_secs_f64() > 0.0 {
91 let measured_tps = tokens as f64 / duration.as_secs_f64();
92 let mut tps = self.tps_ewma.lock();
93 *tps = ewma_alpha * measured_tps + (1.0 - ewma_alpha) * *tps;
94 }
95 }
96}
97
98trait AtomicSaturatingSub {
100 fn fetch_saturating_sub(&self, val: u64);
101}
102
103impl AtomicSaturatingSub for AtomicU64 {
104 fn fetch_saturating_sub(&self, val: u64) {
105 loop {
106 let current = self.load(Ordering::Acquire);
107 let new = current.saturating_sub(val);
108 if self
109 .compare_exchange(current, new, Ordering::AcqRel, Ordering::Relaxed)
110 .is_ok()
111 {
112 break;
113 }
114 }
115 }
116}
117
118pub struct LeastTokensQueuedBalancer {
123 targets: Vec<UpstreamTarget>,
124 metrics: Arc<HashMap<String, TargetMetrics>>,
125 health_status: Arc<RwLock<HashMap<String, bool>>>,
126 config: LeastTokensQueuedConfig,
127}
128
129impl LeastTokensQueuedBalancer {
130 pub fn new(targets: Vec<UpstreamTarget>, config: LeastTokensQueuedConfig) -> Self {
132 let mut metrics = HashMap::new();
133 let mut health_status = HashMap::new();
134
135 for target in &targets {
136 let addr = target.full_address();
137 metrics.insert(addr.clone(), TargetMetrics::new(config.default_tps));
138 health_status.insert(addr, true);
139 }
140
141 Self {
142 targets,
143 metrics: Arc::new(metrics),
144 health_status: Arc::new(RwLock::new(health_status)),
145 config,
146 }
147 }
148
149 pub fn enqueue_tokens(&self, address: &str, estimated_tokens: u64) {
151 if let Some(metrics) = self.metrics.get(address) {
152 metrics.enqueue(estimated_tokens);
153 trace!(
154 target = address,
155 tokens = estimated_tokens,
156 queued = metrics.queued_tokens.load(Ordering::Relaxed),
157 "Enqueued tokens for target"
158 );
159 }
160 }
161
162 pub fn dequeue_tokens(&self, address: &str, actual_tokens: u64, duration: Duration) {
164 if let Some(metrics) = self.metrics.get(address) {
165 metrics.dequeue(actual_tokens, duration, self.config.ewma_alpha);
166 debug!(
167 target = address,
168 tokens = actual_tokens,
169 duration_ms = duration.as_millis() as u64,
170 queued = metrics.queued_tokens.load(Ordering::Relaxed),
171 tps = *metrics.tps_ewma.lock(),
172 "Dequeued tokens for target"
173 );
174 }
175 }
176
177 pub fn target_metrics(&self, address: &str) -> Option<LeastTokensQueuedTargetStats> {
179 self.metrics
180 .get(address)
181 .map(|m| LeastTokensQueuedTargetStats {
182 queued_tokens: m.queued_tokens.load(Ordering::Relaxed),
183 queued_requests: m.queued_requests.load(Ordering::Relaxed),
184 tokens_per_second: *m.tps_ewma.lock(),
185 total_tokens: m.total_tokens.load(Ordering::Relaxed),
186 total_requests: m.total_requests.load(Ordering::Relaxed),
187 })
188 }
189
190 pub async fn queue_times(&self) -> Vec<(String, f64)> {
192 let health = self.health_status.read().await;
193 self.targets
194 .iter()
195 .filter_map(|t| {
196 let addr = t.full_address();
197 if *health.get(&addr).unwrap_or(&true) {
198 self.metrics
199 .get(&addr)
200 .map(|m| (addr, m.estimated_queue_time(self.config.min_tps)))
201 } else {
202 None
203 }
204 })
205 .collect()
206 }
207}
208
209#[derive(Debug, Clone)]
211pub struct LeastTokensQueuedTargetStats {
212 pub queued_tokens: u64,
213 pub queued_requests: u64,
214 pub tokens_per_second: f64,
215 pub total_tokens: u64,
216 pub total_requests: u64,
217}
218
219#[async_trait]
220impl LoadBalancer for LeastTokensQueuedBalancer {
221 async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
222 trace!(
223 total_targets = self.targets.len(),
224 algorithm = "least_tokens_queued",
225 "Selecting upstream target"
226 );
227
228 let health = self.health_status.read().await;
229
230 let mut best_target = None;
231 let mut min_queue_time = f64::MAX;
232
233 for target in &self.targets {
234 let addr = target.full_address();
235
236 if !*health.get(&addr).unwrap_or(&true) {
238 trace!(
239 target = %addr,
240 algorithm = "least_tokens_queued",
241 "Skipping unhealthy target"
242 );
243 continue;
244 }
245
246 let queue_time = self
248 .metrics
249 .get(&addr)
250 .map(|m| m.estimated_queue_time(self.config.min_tps))
251 .unwrap_or(0.0);
252
253 trace!(
254 target = %addr,
255 queue_time_secs = queue_time,
256 "Evaluating target queue time"
257 );
258
259 if queue_time < min_queue_time {
260 min_queue_time = queue_time;
261 best_target = Some(target);
262 }
263 }
264
265 match best_target {
266 Some(target) => {
267 debug!(
268 selected_target = %target.full_address(),
269 queue_time_secs = min_queue_time,
270 algorithm = "least_tokens_queued",
271 "Selected target with lowest queue time"
272 );
273 Ok(TargetSelection {
274 address: target.full_address(),
275 weight: target.weight,
276 metadata: HashMap::new(),
277 })
278 }
279 None => {
280 tracing::warn!(
281 total_targets = self.targets.len(),
282 algorithm = "least_tokens_queued",
283 "No healthy upstream targets available"
284 );
285 Err(GrapsusError::NoHealthyUpstream)
286 }
287 }
288 }
289
290 async fn report_health(&self, address: &str, healthy: bool) {
291 trace!(
292 target = %address,
293 healthy = healthy,
294 algorithm = "least_tokens_queued",
295 "Updating target health status"
296 );
297 self.health_status
298 .write()
299 .await
300 .insert(address.to_string(), healthy);
301 }
302
303 async fn healthy_targets(&self) -> Vec<String> {
304 self.health_status
305 .read()
306 .await
307 .iter()
308 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
309 .collect()
310 }
311
312 async fn report_result(
313 &self,
314 selection: &TargetSelection,
315 success: bool,
316 latency: Option<Duration>,
317 ) {
318 self.report_health(&selection.address, success).await;
320
321 }
324
325 async fn report_result_with_latency(
326 &self,
327 address: &str,
328 success: bool,
329 latency: Option<Duration>,
330 ) {
331 self.report_health(address, success).await;
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 fn test_targets() -> Vec<UpstreamTarget> {
340 vec![
341 UpstreamTarget::new("server1", 8080, 100),
342 UpstreamTarget::new("server2", 8080, 100),
343 UpstreamTarget::new("server3", 8080, 100),
344 ]
345 }
346
347 #[tokio::test]
348 async fn test_basic_selection() {
349 let balancer =
350 LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
351
352 let selection = balancer.select(None).await.unwrap();
354 assert!(!selection.address.is_empty());
355 }
356
357 #[tokio::test]
358 async fn test_selects_least_queued() {
359 let balancer =
360 LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
361
362 balancer.enqueue_tokens("server1:8080", 1000);
364 balancer.enqueue_tokens("server2:8080", 500);
365 let selection = balancer.select(None).await.unwrap();
368 assert_eq!(selection.address, "server3:8080");
369 }
370
371 #[tokio::test]
372 async fn test_dequeue_updates_tps() {
373 let balancer =
374 LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
375
376 balancer.enqueue_tokens("server1:8080", 1000);
378 balancer.dequeue_tokens("server1:8080", 1000, Duration::from_secs(1));
379
380 let stats = balancer.target_metrics("server1:8080").unwrap();
382 assert!(stats.total_tokens == 1000);
383 assert!(stats.total_requests == 1);
384 }
385
386 #[tokio::test]
387 async fn test_unhealthy_target_skipped() {
388 let balancer =
389 LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
390
391 balancer.report_health("server3:8080", false).await;
393
394 balancer.enqueue_tokens("server1:8080", 1000);
396
397 let selection = balancer.select(None).await.unwrap();
399 assert_eq!(selection.address, "server2:8080");
400 }
401}