Skip to main content

grapsus_proxy/upstream/
least_tokens.rs

1//! Least Tokens Queued load balancer for inference workloads
2//!
3//! This load balancer selects upstreams based on the estimated number of tokens
4//! currently being processed, optimized for LLM/AI inference traffic where
5//! request processing time correlates strongly with token count.
6
7use async_trait::async_trait;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13use tracing::{debug, trace};
14
15use grapsus_common::errors::{GrapsusError, GrapsusResult};
16
17use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
18
19/// Configuration for the least tokens queued balancer
20#[derive(Debug, Clone)]
21pub struct LeastTokensQueuedConfig {
22    /// Smoothing factor for tokens-per-second EWMA (0.0-1.0)
23    /// Higher values = more responsive to recent measurements
24    pub ewma_alpha: f64,
25    /// Default tokens-per-second estimate for new targets
26    pub default_tps: f64,
27    /// Minimum tokens-per-second to avoid division issues
28    pub min_tps: f64,
29}
30
31impl Default for LeastTokensQueuedConfig {
32    fn default() -> Self {
33        Self {
34            ewma_alpha: 0.3,
35            default_tps: 100.0, // Conservative default
36            min_tps: 1.0,
37        }
38    }
39}
40
41/// Per-target metrics for token-aware load balancing
42struct TargetMetrics {
43    /// Currently queued tokens (estimated)
44    queued_tokens: AtomicU64,
45    /// Currently queued requests
46    queued_requests: AtomicU64,
47    /// Exponentially weighted moving average of tokens per second
48    tps_ewma: parking_lot::Mutex<f64>,
49    /// Total tokens processed (for debugging/metrics)
50    total_tokens: AtomicU64,
51    /// Total requests processed
52    total_requests: AtomicU64,
53}
54
55impl TargetMetrics {
56    fn new(default_tps: f64) -> Self {
57        Self {
58            queued_tokens: AtomicU64::new(0),
59            queued_requests: AtomicU64::new(0),
60            tps_ewma: parking_lot::Mutex::new(default_tps),
61            total_tokens: AtomicU64::new(0),
62            total_requests: AtomicU64::new(0),
63        }
64    }
65
66    /// Get the estimated queue time: queued_tokens / tokens_per_second
67    fn estimated_queue_time(&self, min_tps: f64) -> f64 {
68        let queued = self.queued_tokens.load(Ordering::Relaxed) as f64;
69        let tps = (*self.tps_ewma.lock()).max(min_tps);
70        queued / tps
71    }
72
73    /// Add tokens to the queue (when request starts)
74    fn enqueue(&self, tokens: u64) {
75        self.queued_tokens.fetch_add(tokens, Ordering::AcqRel);
76        self.queued_requests.fetch_add(1, Ordering::AcqRel);
77    }
78
79    /// Remove tokens from queue and update TPS (when request completes)
80    fn dequeue(&self, tokens: u64, duration: Duration, ewma_alpha: f64) {
81        // Remove from queue
82        self.queued_tokens.fetch_saturating_sub(tokens);
83        self.queued_requests.fetch_saturating_sub(1);
84
85        // Update totals
86        self.total_tokens.fetch_add(tokens, Ordering::Relaxed);
87        self.total_requests.fetch_add(1, Ordering::Relaxed);
88
89        // Update TPS EWMA
90        if duration.as_secs_f64() > 0.0 {
91            let measured_tps = tokens as f64 / duration.as_secs_f64();
92            let mut tps = self.tps_ewma.lock();
93            *tps = ewma_alpha * measured_tps + (1.0 - ewma_alpha) * *tps;
94        }
95    }
96}
97
98/// Extension trait for AtomicU64 to add saturating_sub
99trait AtomicSaturatingSub {
100    fn fetch_saturating_sub(&self, val: u64);
101}
102
103impl AtomicSaturatingSub for AtomicU64 {
104    fn fetch_saturating_sub(&self, val: u64) {
105        loop {
106            let current = self.load(Ordering::Acquire);
107            let new = current.saturating_sub(val);
108            if self
109                .compare_exchange(current, new, Ordering::AcqRel, Ordering::Relaxed)
110                .is_ok()
111            {
112                break;
113            }
114        }
115    }
116}
117
118/// Least Tokens Queued load balancer
119///
120/// Selects the upstream with the lowest estimated queue time,
121/// calculated as: queued_tokens / tokens_per_second
122pub struct LeastTokensQueuedBalancer {
123    targets: Vec<UpstreamTarget>,
124    metrics: Arc<HashMap<String, TargetMetrics>>,
125    health_status: Arc<RwLock<HashMap<String, bool>>>,
126    config: LeastTokensQueuedConfig,
127}
128
129impl LeastTokensQueuedBalancer {
130    /// Create a new least tokens queued balancer
131    pub fn new(targets: Vec<UpstreamTarget>, config: LeastTokensQueuedConfig) -> Self {
132        let mut metrics = HashMap::new();
133        let mut health_status = HashMap::new();
134
135        for target in &targets {
136            let addr = target.full_address();
137            metrics.insert(addr.clone(), TargetMetrics::new(config.default_tps));
138            health_status.insert(addr, true);
139        }
140
141        Self {
142            targets,
143            metrics: Arc::new(metrics),
144            health_status: Arc::new(RwLock::new(health_status)),
145            config,
146        }
147    }
148
149    /// Enqueue tokens for a target (call when request starts)
150    pub fn enqueue_tokens(&self, address: &str, estimated_tokens: u64) {
151        if let Some(metrics) = self.metrics.get(address) {
152            metrics.enqueue(estimated_tokens);
153            trace!(
154                target = address,
155                tokens = estimated_tokens,
156                queued = metrics.queued_tokens.load(Ordering::Relaxed),
157                "Enqueued tokens for target"
158            );
159        }
160    }
161
162    /// Dequeue tokens for a target (call when request completes)
163    pub fn dequeue_tokens(&self, address: &str, actual_tokens: u64, duration: Duration) {
164        if let Some(metrics) = self.metrics.get(address) {
165            metrics.dequeue(actual_tokens, duration, self.config.ewma_alpha);
166            debug!(
167                target = address,
168                tokens = actual_tokens,
169                duration_ms = duration.as_millis() as u64,
170                queued = metrics.queued_tokens.load(Ordering::Relaxed),
171                tps = *metrics.tps_ewma.lock(),
172                "Dequeued tokens for target"
173            );
174        }
175    }
176
177    /// Get current metrics for a target (for debugging/observability)
178    pub fn target_metrics(&self, address: &str) -> Option<LeastTokensQueuedTargetStats> {
179        self.metrics
180            .get(address)
181            .map(|m| LeastTokensQueuedTargetStats {
182                queued_tokens: m.queued_tokens.load(Ordering::Relaxed),
183                queued_requests: m.queued_requests.load(Ordering::Relaxed),
184                tokens_per_second: *m.tps_ewma.lock(),
185                total_tokens: m.total_tokens.load(Ordering::Relaxed),
186                total_requests: m.total_requests.load(Ordering::Relaxed),
187            })
188    }
189
190    /// Get all targets' current queue times for debugging
191    pub async fn queue_times(&self) -> Vec<(String, f64)> {
192        let health = self.health_status.read().await;
193        self.targets
194            .iter()
195            .filter_map(|t| {
196                let addr = t.full_address();
197                if *health.get(&addr).unwrap_or(&true) {
198                    self.metrics
199                        .get(&addr)
200                        .map(|m| (addr, m.estimated_queue_time(self.config.min_tps)))
201                } else {
202                    None
203                }
204            })
205            .collect()
206    }
207}
208
209/// Target statistics for observability
210#[derive(Debug, Clone)]
211pub struct LeastTokensQueuedTargetStats {
212    pub queued_tokens: u64,
213    pub queued_requests: u64,
214    pub tokens_per_second: f64,
215    pub total_tokens: u64,
216    pub total_requests: u64,
217}
218
219#[async_trait]
220impl LoadBalancer for LeastTokensQueuedBalancer {
221    async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
222        trace!(
223            total_targets = self.targets.len(),
224            algorithm = "least_tokens_queued",
225            "Selecting upstream target"
226        );
227
228        let health = self.health_status.read().await;
229
230        let mut best_target = None;
231        let mut min_queue_time = f64::MAX;
232
233        for target in &self.targets {
234            let addr = target.full_address();
235
236            // Skip unhealthy targets
237            if !*health.get(&addr).unwrap_or(&true) {
238                trace!(
239                    target = %addr,
240                    algorithm = "least_tokens_queued",
241                    "Skipping unhealthy target"
242                );
243                continue;
244            }
245
246            // Calculate estimated queue time
247            let queue_time = self
248                .metrics
249                .get(&addr)
250                .map(|m| m.estimated_queue_time(self.config.min_tps))
251                .unwrap_or(0.0);
252
253            trace!(
254                target = %addr,
255                queue_time_secs = queue_time,
256                "Evaluating target queue time"
257            );
258
259            if queue_time < min_queue_time {
260                min_queue_time = queue_time;
261                best_target = Some(target);
262            }
263        }
264
265        match best_target {
266            Some(target) => {
267                debug!(
268                    selected_target = %target.full_address(),
269                    queue_time_secs = min_queue_time,
270                    algorithm = "least_tokens_queued",
271                    "Selected target with lowest queue time"
272                );
273                Ok(TargetSelection {
274                    address: target.full_address(),
275                    weight: target.weight,
276                    metadata: HashMap::new(),
277                })
278            }
279            None => {
280                tracing::warn!(
281                    total_targets = self.targets.len(),
282                    algorithm = "least_tokens_queued",
283                    "No healthy upstream targets available"
284                );
285                Err(GrapsusError::NoHealthyUpstream)
286            }
287        }
288    }
289
290    async fn report_health(&self, address: &str, healthy: bool) {
291        trace!(
292            target = %address,
293            healthy = healthy,
294            algorithm = "least_tokens_queued",
295            "Updating target health status"
296        );
297        self.health_status
298            .write()
299            .await
300            .insert(address.to_string(), healthy);
301    }
302
303    async fn healthy_targets(&self) -> Vec<String> {
304        self.health_status
305            .read()
306            .await
307            .iter()
308            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
309            .collect()
310    }
311
312    async fn report_result(
313        &self,
314        selection: &TargetSelection,
315        success: bool,
316        latency: Option<Duration>,
317    ) {
318        // Update health based on success
319        self.report_health(&selection.address, success).await;
320
321        // Note: Token dequeuing should be done explicitly via dequeue_tokens()
322        // when the actual token count is known from the response
323    }
324
325    async fn report_result_with_latency(
326        &self,
327        address: &str,
328        success: bool,
329        latency: Option<Duration>,
330    ) {
331        self.report_health(address, success).await;
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    fn test_targets() -> Vec<UpstreamTarget> {
340        vec![
341            UpstreamTarget::new("server1", 8080, 100),
342            UpstreamTarget::new("server2", 8080, 100),
343            UpstreamTarget::new("server3", 8080, 100),
344        ]
345    }
346
347    #[tokio::test]
348    async fn test_basic_selection() {
349        let balancer =
350            LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
351
352        // All targets start with 0 queued tokens, so selection should work
353        let selection = balancer.select(None).await.unwrap();
354        assert!(!selection.address.is_empty());
355    }
356
357    #[tokio::test]
358    async fn test_selects_least_queued() {
359        let balancer =
360            LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
361
362        // Add tokens to server1 and server2
363        balancer.enqueue_tokens("server1:8080", 1000);
364        balancer.enqueue_tokens("server2:8080", 500);
365        // server3 has 0 tokens
366
367        let selection = balancer.select(None).await.unwrap();
368        assert_eq!(selection.address, "server3:8080");
369    }
370
371    #[tokio::test]
372    async fn test_dequeue_updates_tps() {
373        let balancer =
374            LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
375
376        // Enqueue and then dequeue with timing
377        balancer.enqueue_tokens("server1:8080", 1000);
378        balancer.dequeue_tokens("server1:8080", 1000, Duration::from_secs(1));
379
380        // Check that TPS was updated
381        let stats = balancer.target_metrics("server1:8080").unwrap();
382        assert!(stats.total_tokens == 1000);
383        assert!(stats.total_requests == 1);
384    }
385
386    #[tokio::test]
387    async fn test_unhealthy_target_skipped() {
388        let balancer =
389            LeastTokensQueuedBalancer::new(test_targets(), LeastTokensQueuedConfig::default());
390
391        // Mark server3 as unhealthy
392        balancer.report_health("server3:8080", false).await;
393
394        // Add tokens to server1
395        balancer.enqueue_tokens("server1:8080", 1000);
396
397        // Should select server2 (healthy and lowest queue)
398        let selection = balancer.select(None).await.unwrap();
399        assert_eq!(selection.address, "server2:8080");
400    }
401}