sentinel_proxy/upstream/
adaptive.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use tokio::sync::RwLock;
7use tracing::{debug, info, trace, warn};
8
9use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
10use sentinel_common::errors::{SentinelError, SentinelResult};
11
12/// Configuration for adaptive load balancing
13#[derive(Debug, Clone)]
14pub struct AdaptiveConfig {
15    /// Weight adjustment interval
16    pub adjustment_interval: Duration,
17    /// Minimum weight (percentage of original)
18    pub min_weight_ratio: f64,
19    /// Maximum weight (percentage of original)
20    pub max_weight_ratio: f64,
21    /// Error rate threshold for degradation
22    pub error_threshold: f64,
23    /// Latency threshold for degradation (p99)
24    pub latency_threshold: Duration,
25    /// EWMA decay factor (0.0 to 1.0, higher = more recent weight)
26    pub ewma_decay: f64,
27    /// Recovery rate when target improves
28    pub recovery_rate: f64,
29    /// Penalty rate when target degrades
30    pub penalty_rate: f64,
31    /// Enable circuit breaker integration
32    pub circuit_breaker: bool,
33    /// Minimum requests before adjusting weights
34    pub min_requests: u64,
35}
36
37impl Default for AdaptiveConfig {
38    fn default() -> Self {
39        Self {
40            adjustment_interval: Duration::from_secs(10),
41            min_weight_ratio: 0.1, // Can go down to 10% of original weight
42            max_weight_ratio: 2.0, // Can go up to 200% of original weight
43            error_threshold: 0.05, // 5% error rate triggers penalty
44            latency_threshold: Duration::from_millis(500),
45            ewma_decay: 0.8,    // Recent data weighted at 80%
46            recovery_rate: 1.1, // 10% recovery per interval
47            penalty_rate: 0.7,  // 30% penalty per interval
48            circuit_breaker: true,
49            min_requests: 100,
50        }
51    }
52}
53
54/// Performance metrics for a target with EWMA smoothing
55#[derive(Debug, Clone)]
56struct PerformanceMetrics {
57    /// Total requests
58    total_requests: Arc<AtomicU64>,
59    /// Failed requests
60    failed_requests: Arc<AtomicU64>,
61    /// Sum of latencies in microseconds
62    total_latency_us: Arc<AtomicU64>,
63    /// Success count for latency calculation
64    success_count: Arc<AtomicU64>,
65    /// Active connections
66    active_connections: Arc<AtomicU64>,
67    /// Current effective weight
68    effective_weight: Arc<RwLock<f64>>,
69    /// EWMA error rate
70    ewma_error_rate: Arc<RwLock<f64>>,
71    /// EWMA latency in microseconds
72    ewma_latency: Arc<RwLock<f64>>,
73    /// Last adjustment time
74    last_adjustment: Arc<RwLock<Instant>>,
75    /// Consecutive successes
76    consecutive_successes: Arc<AtomicU64>,
77    /// Consecutive failures
78    consecutive_failures: Arc<AtomicU64>,
79    /// Circuit breaker state
80    circuit_open: Arc<RwLock<bool>>,
81    /// Last error time
82    last_error: Arc<RwLock<Option<Instant>>>,
83}
84
85impl PerformanceMetrics {
86    fn new(initial_weight: f64) -> Self {
87        Self {
88            total_requests: Arc::new(AtomicU64::new(0)),
89            failed_requests: Arc::new(AtomicU64::new(0)),
90            total_latency_us: Arc::new(AtomicU64::new(0)),
91            success_count: Arc::new(AtomicU64::new(0)),
92            active_connections: Arc::new(AtomicU64::new(0)),
93            effective_weight: Arc::new(RwLock::new(initial_weight)),
94            ewma_error_rate: Arc::new(RwLock::new(0.0)),
95            ewma_latency: Arc::new(RwLock::new(0.0)),
96            last_adjustment: Arc::new(RwLock::new(Instant::now())),
97            consecutive_successes: Arc::new(AtomicU64::new(0)),
98            consecutive_failures: Arc::new(AtomicU64::new(0)),
99            circuit_open: Arc::new(RwLock::new(false)),
100            last_error: Arc::new(RwLock::new(None)),
101        }
102    }
103
104    /// Update EWMA values with new sample
105    async fn update_ewma(&self, error_rate: f64, latency_us: f64, decay: f64) {
106        let mut ewma_error = self.ewma_error_rate.write().await;
107        *ewma_error = decay * error_rate + (1.0 - decay) * (*ewma_error);
108
109        let mut ewma_lat = self.ewma_latency.write().await;
110        *ewma_lat = decay * latency_us + (1.0 - decay) * (*ewma_lat);
111    }
112
113    /// Record a request result
114    async fn record_result(
115        &self,
116        success: bool,
117        latency: Option<Duration>,
118        config: &AdaptiveConfig,
119    ) {
120        self.total_requests.fetch_add(1, Ordering::Relaxed);
121
122        if success {
123            self.consecutive_successes.fetch_add(1, Ordering::Relaxed);
124            self.consecutive_failures.store(0, Ordering::Relaxed);
125
126            if let Some(lat) = latency {
127                let lat_us = lat.as_micros() as u64;
128                self.total_latency_us.fetch_add(lat_us, Ordering::Relaxed);
129                self.success_count.fetch_add(1, Ordering::Relaxed);
130            }
131
132            // Check for circuit breaker recovery
133            if config.circuit_breaker {
134                let successes = self.consecutive_successes.load(Ordering::Relaxed);
135                if successes >= 5 && *self.circuit_open.read().await {
136                    *self.circuit_open.write().await = false;
137                    info!(
138                        "Circuit breaker closed after {} consecutive successes",
139                        successes
140                    );
141                }
142            }
143        } else {
144            self.failed_requests.fetch_add(1, Ordering::Relaxed);
145            self.consecutive_failures.fetch_add(1, Ordering::Relaxed);
146            self.consecutive_successes.store(0, Ordering::Relaxed);
147            *self.last_error.write().await = Some(Instant::now());
148
149            // Check for circuit breaker trip
150            if config.circuit_breaker {
151                let failures = self.consecutive_failures.load(Ordering::Relaxed);
152                if failures >= 5 && !*self.circuit_open.read().await {
153                    *self.circuit_open.write().await = true;
154                    warn!(
155                        "Circuit breaker opened after {} consecutive failures",
156                        failures
157                    );
158                }
159            }
160        }
161    }
162
163    /// Calculate current error rate
164    fn current_error_rate(&self) -> f64 {
165        let total = self.total_requests.load(Ordering::Relaxed);
166        if total == 0 {
167            return 0.0;
168        }
169        let failed = self.failed_requests.load(Ordering::Relaxed);
170        failed as f64 / total as f64
171    }
172
173    /// Calculate average latency
174    fn average_latency(&self) -> Duration {
175        let count = self.success_count.load(Ordering::Relaxed);
176        if count == 0 {
177            return Duration::ZERO;
178        }
179        let total_us = self.total_latency_us.load(Ordering::Relaxed);
180        Duration::from_micros(total_us / count)
181    }
182
183    /// Reset interval metrics
184    fn reset_interval_metrics(&self) {
185        self.total_requests.store(0, Ordering::Relaxed);
186        self.failed_requests.store(0, Ordering::Relaxed);
187        self.total_latency_us.store(0, Ordering::Relaxed);
188        self.success_count.store(0, Ordering::Relaxed);
189    }
190}
191
192/// Score calculation for target selection
193#[derive(Debug, Clone)]
194struct TargetScore {
195    index: usize,
196    score: f64,
197    weight: f64,
198}
199
200/// Adaptive load balancer that adjusts weights based on performance
201pub struct AdaptiveBalancer {
202    /// Configuration
203    config: AdaptiveConfig,
204    /// All upstream targets
205    targets: Vec<UpstreamTarget>,
206    /// Original weights (for ratio calculation)
207    original_weights: Vec<f64>,
208    /// Performance metrics per target
209    metrics: Vec<PerformanceMetrics>,
210    /// Target health status
211    health_status: Arc<RwLock<HashMap<String, bool>>>,
212    /// Last global adjustment time
213    last_global_adjustment: Arc<RwLock<Instant>>,
214}
215
216impl AdaptiveBalancer {
217    pub fn new(targets: Vec<UpstreamTarget>, config: AdaptiveConfig) -> Self {
218        trace!(
219            target_count = targets.len(),
220            adjustment_interval_secs = config.adjustment_interval.as_secs(),
221            min_weight_ratio = config.min_weight_ratio,
222            max_weight_ratio = config.max_weight_ratio,
223            error_threshold = config.error_threshold,
224            latency_threshold_ms = config.latency_threshold.as_millis() as u64,
225            ewma_decay = config.ewma_decay,
226            circuit_breaker = config.circuit_breaker,
227            min_requests = config.min_requests,
228            "Creating adaptive balancer"
229        );
230
231        let original_weights: Vec<f64> = targets.iter().map(|t| t.weight as f64).collect();
232        let metrics = original_weights
233            .iter()
234            .map(|&w| PerformanceMetrics::new(w))
235            .collect();
236
237        debug!(
238            target_count = targets.len(),
239            total_weight = original_weights.iter().sum::<f64>(),
240            "Adaptive balancer initialized"
241        );
242
243        Self {
244            config,
245            targets,
246            original_weights,
247            metrics,
248            health_status: Arc::new(RwLock::new(HashMap::new())),
249            last_global_adjustment: Arc::new(RwLock::new(Instant::now())),
250        }
251    }
252
253    /// Adjust weights based on recent performance
254    async fn adjust_weights(&self) {
255        let mut last_adjustment = self.last_global_adjustment.write().await;
256
257        let elapsed = last_adjustment.elapsed();
258        if elapsed < self.config.adjustment_interval {
259            trace!(
260                elapsed_secs = elapsed.as_secs(),
261                interval_secs = self.config.adjustment_interval.as_secs(),
262                "Skipping weight adjustment (interval not reached)"
263            );
264            return;
265        }
266
267        debug!(
268            elapsed_secs = elapsed.as_secs(),
269            target_count = self.targets.len(),
270            "Adjusting weights based on performance metrics"
271        );
272
273        for (i, metric) in self.metrics.iter().enumerate() {
274            let requests = metric.total_requests.load(Ordering::Relaxed);
275
276            // Skip if insufficient data
277            if requests < self.config.min_requests {
278                continue;
279            }
280
281            // Calculate current metrics
282            let error_rate = metric.current_error_rate();
283            let avg_latency = metric.average_latency();
284            let latency_us = avg_latency.as_micros() as f64;
285
286            // Update EWMA
287            metric
288                .update_ewma(error_rate, latency_us, self.config.ewma_decay)
289                .await;
290
291            // Get smoothed metrics
292            let ewma_error = *metric.ewma_error_rate.read().await;
293            let ewma_latency_us = *metric.ewma_latency.read().await;
294            let ewma_latency = Duration::from_micros(ewma_latency_us as u64);
295
296            // Calculate weight adjustment factor
297            let mut adjustment = 1.0;
298
299            // Penalize high error rates
300            if ewma_error > self.config.error_threshold {
301                let error_factor =
302                    1.0 - ((ewma_error - self.config.error_threshold) * 10.0).min(0.9);
303                adjustment *= error_factor;
304                debug!(
305                    "Target {} error rate {:.2}% exceeds threshold, factor: {:.2}",
306                    i,
307                    ewma_error * 100.0,
308                    error_factor
309                );
310            }
311
312            // Penalize high latencies
313            if ewma_latency > self.config.latency_threshold {
314                let latency_ratio =
315                    self.config.latency_threshold.as_micros() as f64 / ewma_latency_us;
316                adjustment *= latency_ratio.max(0.1);
317                debug!(
318                    "Target {} latency {:?} exceeds threshold, factor: {:.2}",
319                    i, ewma_latency, latency_ratio
320                );
321            }
322
323            // Apply adjustment with damping
324            let mut current_weight = *metric.effective_weight.read().await;
325            let original = self.original_weights[i];
326
327            if adjustment < 1.0 {
328                // Degrade weight
329                current_weight *=
330                    self.config.penalty_rate + (1.0 - self.config.penalty_rate) * adjustment;
331            } else {
332                // Recover weight
333                current_weight *= self.config.recovery_rate;
334            }
335
336            // Apply bounds
337            let min_weight = original * self.config.min_weight_ratio;
338            let max_weight = original * self.config.max_weight_ratio;
339            current_weight = current_weight.max(min_weight).min(max_weight);
340
341            *metric.effective_weight.write().await = current_weight;
342
343            info!(
344                "Adjusted weight for target {}: {:.2} (original: {:.2}, error: {:.2}%, latency: {:.0}ms)",
345                i,
346                current_weight,
347                original,
348                ewma_error * 100.0,
349                ewma_latency.as_millis()
350            );
351
352            // Reset interval metrics
353            metric.reset_interval_metrics();
354        }
355
356        *last_adjustment = Instant::now();
357    }
358
359    /// Calculate scores for all healthy targets
360    async fn calculate_scores(&self) -> Vec<TargetScore> {
361        trace!(
362            target_count = self.targets.len(),
363            "Calculating scores for all targets"
364        );
365
366        let health = self.health_status.read().await;
367        let mut scores = Vec::new();
368
369        for (i, target) in self.targets.iter().enumerate() {
370            let target_id = format!("{}:{}", target.address, target.port);
371            let is_healthy = health.get(&target_id).copied().unwrap_or(true);
372            let circuit_open = *self.metrics[i].circuit_open.read().await;
373
374            // Skip unhealthy or circuit-broken targets
375            if !is_healthy || circuit_open {
376                trace!(
377                    target_index = i,
378                    target_id = %target_id,
379                    is_healthy = is_healthy,
380                    circuit_open = circuit_open,
381                    "Skipping target from scoring"
382                );
383                continue;
384            }
385
386            let weight = *self.metrics[i].effective_weight.read().await;
387            let connections = self.metrics[i].active_connections.load(Ordering::Relaxed) as f64;
388            let ewma_error = *self.metrics[i].ewma_error_rate.read().await;
389            let ewma_latency = *self.metrics[i].ewma_latency.read().await / 1000.0; // Convert to ms
390
391            // Score formula: weight / (1 + connections + error_penalty + latency_penalty)
392            let error_penalty = ewma_error * 100.0; // Scale error rate
393            let latency_penalty = (ewma_latency / 10.0).max(0.0); // Normalize latency
394            let score = weight / (1.0 + connections + error_penalty + latency_penalty);
395
396            trace!(
397                target_index = i,
398                target_id = %target_id,
399                weight = weight,
400                connections = connections,
401                ewma_error = ewma_error,
402                ewma_latency_ms = ewma_latency,
403                error_penalty = error_penalty,
404                latency_penalty = latency_penalty,
405                score = score,
406                "Calculated target score"
407            );
408
409            scores.push(TargetScore {
410                index: i,
411                score,
412                weight,
413            });
414        }
415
416        // Sort by score (highest first)
417        scores.sort_by(|a, b| {
418            b.score
419                .partial_cmp(&a.score)
420                .unwrap_or(std::cmp::Ordering::Equal)
421        });
422
423        trace!(
424            scored_count = scores.len(),
425            top_score = scores.first().map(|s| s.score).unwrap_or(0.0),
426            "Scores calculated and sorted"
427        );
428
429        scores
430    }
431
432    /// Select target using weighted random selection based on scores
433    async fn weighted_select(&self, scores: &[TargetScore]) -> Option<usize> {
434        if scores.is_empty() {
435            trace!("No scores provided for weighted selection");
436            return None;
437        }
438
439        // Calculate total score
440        let total_score: f64 = scores.iter().map(|s| s.score).sum();
441        if total_score <= 0.0 {
442            trace!(
443                fallback_index = scores[0].index,
444                "Total score is zero, using fallback"
445            );
446            return Some(scores[0].index); // Fallback to first
447        }
448
449        // Weighted random selection
450        use rand::prelude::*;
451        let mut rng = thread_rng();
452        let threshold = rng.gen::<f64>() * total_score;
453
454        trace!(
455            total_score = total_score,
456            threshold = threshold,
457            candidate_count = scores.len(),
458            "Performing weighted random selection"
459        );
460
461        let mut cumulative = 0.0;
462        for score in scores {
463            cumulative += score.score;
464            if cumulative >= threshold {
465                trace!(
466                    selected_index = score.index,
467                    selected_score = score.score,
468                    cumulative = cumulative,
469                    "Selected target via weighted random"
470                );
471                return Some(score.index);
472            }
473        }
474
475        // Fallback for floating point edge case - scores is guaranteed non-empty here
476        let fallback = scores.last().map(|s| s.index);
477        trace!(
478            fallback_index = ?fallback,
479            "Using fallback selection (floating point edge case)"
480        );
481        fallback
482    }
483}
484
485#[async_trait]
486impl LoadBalancer for AdaptiveBalancer {
487    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
488        trace!("Adaptive select started");
489
490        // Periodically adjust weights
491        self.adjust_weights().await;
492
493        // Calculate scores for all targets
494        let scores = self.calculate_scores().await;
495
496        if scores.is_empty() {
497            warn!("Adaptive: No healthy targets available");
498            return Err(SentinelError::NoHealthyUpstream);
499        }
500
501        // Select target based on scores
502        let target_index = self
503            .weighted_select(&scores)
504            .await
505            .ok_or_else(|| {
506                warn!("Adaptive: Failed to select from scores");
507                SentinelError::NoHealthyUpstream
508            })?;
509
510        let target = &self.targets[target_index];
511        let metrics = &self.metrics[target_index];
512
513        // Track connection
514        let connections = metrics.active_connections.fetch_add(1, Ordering::Relaxed) + 1;
515
516        let effective_weight = *metrics.effective_weight.read().await;
517        let ewma_error = *metrics.ewma_error_rate.read().await;
518        let ewma_latency = Duration::from_micros(*metrics.ewma_latency.read().await as u64);
519
520        let score = scores
521            .iter()
522            .find(|s| s.index == target_index)
523            .map(|s| s.score)
524            .unwrap_or(0.0);
525
526        debug!(
527            target = %format!("{}:{}", target.address, target.port),
528            target_index = target_index,
529            score = score,
530            effective_weight = effective_weight,
531            original_weight = self.original_weights[target_index],
532            error_rate = ewma_error,
533            latency_ms = ewma_latency.as_millis() as u64,
534            connections = connections,
535            "Adaptive selected target"
536        );
537
538        Ok(TargetSelection {
539            address: format!("{}:{}", target.address, target.port),
540            weight: target.weight,
541            metadata: {
542                let mut meta = HashMap::new();
543                meta.insert("algorithm".to_string(), "adaptive".to_string());
544                meta.insert("target_index".to_string(), target_index.to_string());
545                meta.insert(
546                    "effective_weight".to_string(),
547                    format!("{:.2}", effective_weight),
548                );
549                meta.insert(
550                    "original_weight".to_string(),
551                    self.original_weights[target_index].to_string(),
552                );
553                meta.insert("error_rate".to_string(), format!("{:.4}", ewma_error));
554                meta.insert(
555                    "latency_ms".to_string(),
556                    format!("{:.2}", ewma_latency.as_millis()),
557                );
558                meta.insert("connections".to_string(), connections.to_string());
559                meta
560            },
561        })
562    }
563
564    async fn report_health(&self, address: &str, healthy: bool) {
565        trace!(
566            address = %address,
567            healthy = healthy,
568            "Adaptive reporting target health"
569        );
570
571        let mut health = self.health_status.write().await;
572        let previous = health.insert(address.to_string(), healthy);
573
574        if previous != Some(healthy) {
575            info!(
576                address = %address,
577                previous = ?previous,
578                healthy = healthy,
579                "Adaptive target health changed"
580            );
581
582            // Find target index and reset its weight on health change
583            for (i, target) in self.targets.iter().enumerate() {
584                let target_id = format!("{}:{}", target.address, target.port);
585                if target_id == address {
586                    if healthy {
587                        // Reset to original weight on recovery
588                        let original = self.original_weights[i];
589                        *self.metrics[i].effective_weight.write().await = original;
590                        *self.metrics[i].circuit_open.write().await = false;
591                        self.metrics[i]
592                            .consecutive_failures
593                            .store(0, Ordering::Relaxed);
594                        info!(
595                            target_index = i,
596                            original_weight = original,
597                            "Reset target to original weight on recovery"
598                        );
599                    }
600                    break;
601                }
602            }
603        }
604    }
605
606    async fn healthy_targets(&self) -> Vec<String> {
607        let health = self.health_status.read().await;
608        let mut targets = Vec::new();
609
610        for (i, target) in self.targets.iter().enumerate() {
611            let target_id = format!("{}:{}", target.address, target.port);
612            let is_healthy = health.get(&target_id).copied().unwrap_or(true);
613            let circuit_open = *self.metrics[i].circuit_open.read().await;
614
615            if is_healthy && !circuit_open {
616                targets.push(target_id);
617            }
618        }
619
620        trace!(
621            total = self.targets.len(),
622            healthy = targets.len(),
623            "Adaptive healthy targets"
624        );
625
626        targets
627    }
628
629    async fn release(&self, selection: &TargetSelection) {
630        if let Some(index_str) = selection.metadata.get("target_index") {
631            if let Ok(index) = index_str.parse::<usize>() {
632                let connections = self.metrics[index]
633                    .active_connections
634                    .fetch_sub(1, Ordering::Relaxed) - 1;
635                trace!(
636                    target_index = index,
637                    address = %selection.address,
638                    connections = connections,
639                    "Adaptive released connection"
640                );
641            }
642        }
643    }
644
645    async fn report_result(
646        &self,
647        selection: &TargetSelection,
648        success: bool,
649        latency: Option<Duration>,
650    ) {
651        if let Some(index_str) = selection.metadata.get("target_index") {
652            if let Ok(index) = index_str.parse::<usize>() {
653                trace!(
654                    target_index = index,
655                    address = %selection.address,
656                    success = success,
657                    latency_ms = latency.map(|l| l.as_millis() as u64),
658                    "Adaptive recording result"
659                );
660                self.metrics[index]
661                    .record_result(success, latency, &self.config)
662                    .await;
663            }
664        }
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    fn create_test_targets(count: usize) -> Vec<UpstreamTarget> {
673        (0..count)
674            .map(|i| UpstreamTarget {
675                address: format!("10.0.0.{}", i + 1),
676                port: 8080,
677                weight: 100,
678            })
679            .collect()
680    }
681
682    #[tokio::test]
683    async fn test_weight_degradation() {
684        let targets = create_test_targets(3);
685        let config = AdaptiveConfig {
686            adjustment_interval: Duration::from_millis(10),
687            min_requests: 1,
688            ..Default::default()
689        };
690        let balancer = AdaptiveBalancer::new(targets, config);
691
692        // Simulate errors on target 0
693        for _ in 0..10 {
694            balancer.metrics[0]
695                .record_result(false, None, &balancer.config)
696                .await;
697        }
698        balancer.metrics[0]
699            .total_requests
700            .store(10, Ordering::Relaxed);
701
702        // Simulate success on target 1
703        for _ in 0..10 {
704            balancer.metrics[1]
705                .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
706                .await;
707        }
708        balancer.metrics[1]
709            .total_requests
710            .store(10, Ordering::Relaxed);
711
712        // Wait for adjustment interval
713        tokio::time::sleep(Duration::from_millis(15)).await;
714
715        // Trigger weight adjustment
716        balancer.adjust_weights().await;
717
718        // Check that target 0 has degraded weight
719        let weight0 = *balancer.metrics[0].effective_weight.read().await;
720        let weight1 = *balancer.metrics[1].effective_weight.read().await;
721
722        assert!(weight0 < 100.0, "Target 0 weight should be degraded");
723        assert!(weight1 >= 100.0, "Target 1 weight should not be degraded");
724    }
725
726    #[tokio::test]
727    async fn test_circuit_breaker() {
728        let targets = create_test_targets(2);
729        let config = AdaptiveConfig::default();
730        let balancer = AdaptiveBalancer::new(targets, config);
731
732        // Simulate consecutive failures
733        for _ in 0..5 {
734            balancer.metrics[0]
735                .record_result(false, None, &balancer.config)
736                .await;
737        }
738
739        // Circuit should be open
740        assert!(*balancer.metrics[0].circuit_open.read().await);
741
742        // Should not select circuit-broken target
743        let scores = balancer.calculate_scores().await;
744        assert!(!scores.iter().any(|s| s.index == 0));
745
746        // Simulate recovery
747        for _ in 0..5 {
748            balancer.metrics[0]
749                .record_result(true, Some(Duration::from_millis(10)), &balancer.config)
750                .await;
751        }
752
753        // Circuit should be closed
754        assert!(!*balancer.metrics[0].circuit_open.read().await);
755    }
756
757    #[tokio::test]
758    async fn test_latency_penalty() {
759        let targets = create_test_targets(2);
760        let config = AdaptiveConfig {
761            adjustment_interval: Duration::from_millis(10),
762            min_requests: 1,
763            latency_threshold: Duration::from_millis(100),
764            ..Default::default()
765        };
766        let balancer = AdaptiveBalancer::new(targets, config);
767
768        // Simulate high latency on target 0
769        for _ in 0..10 {
770            balancer.metrics[0]
771                .record_result(true, Some(Duration::from_millis(500)), &balancer.config)
772                .await;
773        }
774        balancer.metrics[0]
775            .total_requests
776            .store(10, Ordering::Relaxed);
777
778        // Simulate normal latency on target 1
779        for _ in 0..10 {
780            balancer.metrics[1]
781                .record_result(true, Some(Duration::from_millis(50)), &balancer.config)
782                .await;
783        }
784        balancer.metrics[1]
785            .total_requests
786            .store(10, Ordering::Relaxed);
787
788        tokio::time::sleep(Duration::from_millis(15)).await;
789        balancer.adjust_weights().await;
790
791        let weight0 = *balancer.metrics[0].effective_weight.read().await;
792        let weight1 = *balancer.metrics[1].effective_weight.read().await;
793
794        assert!(
795            weight0 < weight1,
796            "High latency target should have lower weight"
797        );
798    }
799}