use simple_agent_type::prelude::{
CompletionChunk, CompletionRequest, CompletionResponse, Provider, ProviderHealth, Result,
SimpleAgentsError,
};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct LatencyRouterConfig {
pub alpha: f64,
pub slow_threshold: Duration,
}
impl Default for LatencyRouterConfig {
fn default() -> Self {
Self {
alpha: 0.2,
slow_threshold: Duration::from_secs(2),
}
}
}
#[derive(Clone, Copy, Debug)]
struct LatencyStats {
avg_latency_ms: f64,
samples: u64,
health: ProviderHealth,
}
impl LatencyStats {
fn new() -> Self {
Self {
avg_latency_ms: 0.0,
samples: 0,
health: ProviderHealth::Healthy,
}
}
fn record(&mut self, latency: Duration, alpha: f64, slow_threshold: Duration) {
let latency_ms = latency.as_secs_f64() * 1000.0;
if self.samples == 0 {
self.avg_latency_ms = latency_ms;
} else {
let previous = self.avg_latency_ms;
self.avg_latency_ms = (alpha * latency_ms) + ((1.0 - alpha) * previous);
}
self.samples = self.samples.saturating_add(1);
let threshold_ms = slow_threshold.as_secs_f64() * 1000.0;
self.health = if self.avg_latency_ms >= threshold_ms {
ProviderHealth::Degraded
} else {
ProviderHealth::Healthy
};
}
}
pub struct LatencyRouter {
providers: Vec<Arc<dyn Provider>>,
stats: Mutex<Vec<LatencyStats>>,
counter: AtomicUsize,
config: LatencyRouterConfig,
}
impl LatencyRouter {
pub fn new(providers: Vec<Arc<dyn Provider>>) -> Result<Self> {
Self::with_config(providers, LatencyRouterConfig::default())
}
pub fn with_config(
providers: Vec<Arc<dyn Provider>>,
config: LatencyRouterConfig,
) -> Result<Self> {
if providers.is_empty() {
return Err(SimpleAgentsError::Routing(
"no providers configured".to_string(),
));
}
let stats = vec![LatencyStats::new(); providers.len()];
Ok(Self {
providers,
stats: Mutex::new(stats),
counter: AtomicUsize::new(0),
config,
})
}
pub fn provider_count(&self) -> usize {
self.providers.len()
}
pub async fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse> {
let index = self.select_provider_index()?;
let provider = &self.providers[index];
let start = Instant::now();
let provider_request = provider.transform_request(request)?;
let provider_response = provider.execute(provider_request).await?;
let response = provider.transform_response(provider_response)?;
self.record_latency(index, start.elapsed());
Ok(response)
}
pub async fn stream(
&self,
request: &CompletionRequest,
) -> Result<Box<dyn futures_core::Stream<Item = Result<CompletionChunk>> + Send + Unpin>> {
let index = self.select_provider_index()?;
let provider = &self.providers[index];
let provider_request = provider.transform_request(request)?;
provider.execute_stream(provider_request).await
}
fn select_provider_index(&self) -> Result<usize> {
let len = self.providers.len();
if len == 0 {
return Err(SimpleAgentsError::Routing(
"no providers configured".to_string(),
));
}
let stats = self
.stats
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
let mut best_index: Option<usize> = None;
let mut best_latency = f64::MAX;
let mut has_samples = false;
let mut has_healthy = false;
for stat in stats.iter() {
if stat.samples == 0 {
continue;
}
has_samples = true;
if stat.health == ProviderHealth::Healthy {
has_healthy = true;
}
}
if has_samples {
for (index, stat) in stats.iter().enumerate() {
if stat.samples == 0 {
continue;
}
if has_healthy && stat.health != ProviderHealth::Healthy {
continue;
}
if stat.avg_latency_ms < best_latency {
best_latency = stat.avg_latency_ms;
best_index = Some(index);
}
}
}
if let Some(index) = best_index {
return Ok(index);
}
let index = self.counter.fetch_add(1, Ordering::Relaxed);
Ok(index % len)
}
fn record_latency(&self, index: usize, latency: Duration) {
let mut stats = self
.stats
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
if let Some(stat) = stats.get_mut(index) {
stat.record(latency, self.config.alpha, self.config.slow_threshold);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use simple_agent_type::prelude::*;
struct MockProvider {
name: &'static str,
}
impl MockProvider {
fn new(name: &'static str) -> Self {
Self { name }
}
}
#[async_trait]
impl Provider for MockProvider {
fn name(&self) -> &str {
self.name
}
fn transform_request(&self, _req: &CompletionRequest) -> Result<ProviderRequest> {
Ok(ProviderRequest::new("http://example.com"))
}
async fn execute(&self, _req: ProviderRequest) -> Result<ProviderResponse> {
Ok(ProviderResponse::new(200, serde_json::Value::Null))
}
fn transform_response(&self, _resp: ProviderResponse) -> Result<CompletionResponse> {
Ok(CompletionResponse {
id: "resp_test".to_string(),
model: "test-model".to_string(),
choices: vec![CompletionChoice {
index: 0,
message: Message::assistant("ok"),
finish_reason: FinishReason::Stop,
logprobs: None,
}],
usage: Usage::new(1, 1),
created: None,
provider: Some(self.name().to_string()),
healing_metadata: None,
})
}
}
fn build_request() -> CompletionRequest {
CompletionRequest::builder()
.model("test-model")
.message(Message::user("hello"))
.build()
.unwrap()
}
#[test]
fn empty_router_returns_error() {
let result = LatencyRouter::new(Vec::new());
match result {
Ok(_) => panic!("expected error, got Ok"),
Err(SimpleAgentsError::Routing(message)) => {
assert_eq!(message, "no providers configured");
}
Err(_) => panic!("unexpected error type"),
}
}
#[test]
fn selects_lowest_latency_provider() {
let router = LatencyRouter::new(vec![
Arc::new(MockProvider::new("p1")),
Arc::new(MockProvider::new("p2")),
])
.unwrap();
router.record_latency(0, Duration::from_millis(250));
router.record_latency(1, Duration::from_millis(50));
let index = router.select_provider_index().unwrap();
assert_eq!(index, 1);
}
#[test]
fn prefers_healthy_over_degraded() {
let config = LatencyRouterConfig {
alpha: 1.0,
slow_threshold: Duration::from_millis(100),
};
let router = LatencyRouter::with_config(
vec![
Arc::new(MockProvider::new("p1")),
Arc::new(MockProvider::new("p2")),
],
config,
)
.unwrap();
router.record_latency(0, Duration::from_millis(400));
router.record_latency(1, Duration::from_millis(80));
let index = router.select_provider_index().unwrap();
assert_eq!(index, 1);
}
#[test]
fn round_robin_when_no_metrics() {
let router = LatencyRouter::new(vec![
Arc::new(MockProvider::new("p1")),
Arc::new(MockProvider::new("p2")),
])
.unwrap();
let first = router.select_provider_index().unwrap();
let second = router.select_provider_index().unwrap();
assert_eq!(first, 0);
assert_eq!(second, 1);
}
#[tokio::test]
async fn records_latency_on_success() {
let router = LatencyRouter::new(vec![Arc::new(MockProvider::new("p1"))]).unwrap();
let request = build_request();
let _ = router.complete(&request).await.unwrap();
let stats = router.stats.lock().expect("latency stats lock poisoned");
assert_eq!(stats[0].samples, 1);
}
}