use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use cheetah_string::CheetahString;
use parking_lot::RwLock;
#[derive(Clone)]
pub struct LatencyTracker {
metrics: Arc<RwLock<HashMap<CheetahString, LatencyMetrics>>>,
}
#[derive(Debug, Clone)]
struct LatencyMetrics {
recent_latencies: Vec<u64>,
max_samples: usize,
p50: Duration,
p99: Duration,
consecutive_errors: u32,
total_requests: u64,
last_update: Instant,
}
impl LatencyMetrics {
fn new(max_samples: usize) -> Self {
Self {
recent_latencies: Vec::with_capacity(max_samples),
max_samples,
p50: Duration::from_millis(0),
p99: Duration::from_millis(0),
consecutive_errors: 0,
total_requests: 0,
last_update: Instant::now(),
}
}
fn record_success(&mut self, latency: Duration) {
self.recent_latencies.push(latency.as_micros() as u64);
if self.recent_latencies.len() > self.max_samples {
self.recent_latencies.remove(0);
}
self.update_percentiles();
self.consecutive_errors = 0;
self.total_requests += 1;
self.last_update = Instant::now();
}
fn record_error(&mut self) {
self.consecutive_errors += 1;
self.total_requests += 1;
self.last_update = Instant::now();
}
fn update_percentiles(&mut self) {
if self.recent_latencies.is_empty() {
return;
}
let mut sorted = self.recent_latencies.clone();
sorted.sort_unstable();
let len = sorted.len();
let p50_idx = len / 2;
self.p50 = Duration::from_micros(sorted[p50_idx]);
let p99_idx = (len as f64 * 0.99).ceil() as usize - 1;
self.p99 = Duration::from_micros(sorted[p99_idx.min(len - 1)]);
}
fn score(&self) -> u64 {
let latency_penalty = self.p99.as_millis() as u64;
let error_penalty = self.consecutive_errors as u64 * 100;
latency_penalty + error_penalty
}
fn is_healthy(&self) -> bool {
self.consecutive_errors < 3
}
}
impl LatencyTracker {
pub fn new() -> Self {
Self {
metrics: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn record_success(&self, addr: &CheetahString, latency: Duration) {
let mut metrics = self.metrics.write();
metrics
.entry(addr.clone())
.or_insert_with(|| LatencyMetrics::new(100))
.record_success(latency);
}
pub fn record_error(&self, addr: &CheetahString) {
let mut metrics = self.metrics.write();
metrics
.entry(addr.clone())
.or_insert_with(|| LatencyMetrics::new(100))
.record_error();
}
pub fn select_best<'a>(&self, candidates: &'a [CheetahString]) -> Option<&'a CheetahString> {
if candidates.is_empty() {
return None;
}
let metrics = self.metrics.read();
let mut best: Option<(&CheetahString, u64)> = None;
for addr in candidates {
if let Some(m) = metrics.get(addr) {
if !m.is_healthy() {
continue; }
let score = m.score();
match best {
None => best = Some((addr, score)),
Some((_, best_score)) if score < best_score => {
best = Some((addr, score));
}
_ => {}
}
}
}
best.map(|(addr, _)| addr).or_else(|| candidates.first())
}
pub fn get_p99(&self, addr: &CheetahString) -> Option<Duration> {
let metrics = self.metrics.read();
metrics.get(addr).map(|m| m.p99)
}
pub fn get_error_count(&self, addr: &CheetahString) -> u32 {
let metrics = self.metrics.read();
metrics.get(addr).map(|m| m.consecutive_errors).unwrap_or(0)
}
pub fn is_healthy(&self, addr: &CheetahString) -> bool {
let metrics = self.metrics.read();
metrics.get(addr).map(|m| m.is_healthy()).unwrap_or(true) }
pub fn clear(&self, addr: &CheetahString) {
let mut metrics = self.metrics.write();
metrics.remove(addr);
}
pub fn snapshot(&self) -> HashMap<CheetahString, (Duration, Duration, u32)> {
let metrics = self.metrics.read();
metrics
.iter()
.map(|(addr, m)| (addr.clone(), (m.p50, m.p99, m.consecutive_errors)))
.collect()
}
}
impl Default for LatencyTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_latency_tracker_basic() {
let tracker = LatencyTracker::new();
let addr = CheetahString::from_static_str("127.0.0.1:9876");
tracker.record_success(&addr, Duration::from_millis(10));
tracker.record_success(&addr, Duration::from_millis(15));
tracker.record_success(&addr, Duration::from_millis(20));
assert!(tracker.get_p99(&addr).is_some());
assert!(tracker.is_healthy(&addr));
}
#[test]
fn test_latency_tracker_error_handling() {
let tracker = LatencyTracker::new();
let addr = CheetahString::from_static_str("127.0.0.1:9876");
tracker.record_error(&addr);
tracker.record_error(&addr);
assert!(tracker.is_healthy(&addr));
tracker.record_error(&addr);
assert!(!tracker.is_healthy(&addr));
tracker.record_success(&addr, Duration::from_millis(10));
assert!(tracker.is_healthy(&addr));
}
#[test]
fn test_nameserver_selection() {
let tracker = LatencyTracker::new();
let addr1 = CheetahString::from_static_str("127.0.0.1:9876");
let addr2 = CheetahString::from_static_str("127.0.0.1:9877");
let addr3 = CheetahString::from_static_str("127.0.0.1:9878");
tracker.record_success(&addr1, Duration::from_millis(5));
tracker.record_success(&addr1, Duration::from_millis(6));
tracker.record_success(&addr2, Duration::from_millis(50));
tracker.record_success(&addr2, Duration::from_millis(60));
tracker.record_success(&addr3, Duration::from_millis(10));
tracker.record_error(&addr3);
tracker.record_error(&addr3);
tracker.record_error(&addr3);
let candidates = vec![addr1.clone(), addr2.clone(), addr3.clone()];
let best = tracker.select_best(&candidates);
assert_eq!(best, Some(&addr1));
}
}