use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::RwLock;
use crate::error::{Error, Result};
use crate::types::CompletionRequest;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum Optimization {
#[default]
Latency,
Cost,
Reliability,
}
#[derive(Debug, Clone)]
pub struct ProviderMetrics {
pub ewma_latency_ms: f64,
pub error_rate: f64,
pub cost_per_1k_tokens: f64,
pub request_count: u64,
pub error_count: u64,
pub last_updated: SystemTime,
}
impl Default for ProviderMetrics {
fn default() -> Self {
Self {
ewma_latency_ms: 100.0, error_rate: 0.0,
cost_per_1k_tokens: 0.01, request_count: 0,
error_count: 0,
last_updated: SystemTime::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct RouterProviderConfig {
pub name: String,
pub cost_per_1k_tokens: f64,
pub reliability_weight: f64,
}
#[derive(Debug, Clone)]
pub struct RoutingDecision {
pub provider: String,
pub predicted_latency_ms: f64,
pub predicted_cost: f64,
pub fallback_chain: Vec<String>,
}
pub struct SmartRouter {
providers: HashMap<String, RouterProviderConfig>,
metrics: Arc<RwLock<HashMap<String, ProviderMetrics>>>,
optimization: Optimization,
ewma_alpha: f64,
request_counter: Arc<AtomicU64>,
}
impl SmartRouter {
pub fn builder() -> SmartRouterBuilder {
SmartRouterBuilder::default()
}
pub async fn route(&self, request: &CompletionRequest) -> Result<RoutingDecision> {
let metrics = self.metrics.read().await;
let available: Vec<_> = self
.providers
.iter()
.filter(|(name, _)| metrics.contains_key(*name))
.collect();
if available.is_empty() {
return Err(Error::Configuration(
"No providers available for routing".to_string(),
));
}
let scores: Vec<(String, f64)> = available
.iter()
.map(|(name, config)| {
let metric = &metrics[*name];
let score = self.calculate_score(metric, config);
((*name).clone(), score)
})
.collect();
let mut sorted = scores;
sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let primary = sorted[0].0.clone();
let fallback_chain: Vec<String> = sorted
.iter()
.skip(1)
.take(2)
.map(|(name, _)| name.clone())
.collect();
let primary_metric = &metrics[&primary];
let primary_config = &self.providers[&primary];
Ok(RoutingDecision {
provider: primary,
predicted_latency_ms: primary_metric.ewma_latency_ms,
predicted_cost: self.estimate_request_cost(request, primary_config),
fallback_chain,
})
}
pub async fn record_request(
&self,
provider: &str,
latency_ms: f64,
success: bool,
_tokens: u32,
) -> Result<()> {
let mut metrics = self.metrics.write().await;
let metric = metrics.entry(provider.to_string()).or_insert_with(|| {
let mut m = ProviderMetrics::default();
if let Some(config) = self.providers.get(provider) {
m.cost_per_1k_tokens = config.cost_per_1k_tokens;
}
m
});
metric.ewma_latency_ms =
(1.0 - self.ewma_alpha) * metric.ewma_latency_ms + self.ewma_alpha * latency_ms;
metric.request_count += 1;
if !success {
metric.error_count += 1;
}
metric.error_rate = metric.error_count as f64 / metric.request_count as f64;
metric.last_updated = SystemTime::now();
self.request_counter.fetch_add(1, Ordering::Relaxed);
Ok(())
}
pub async fn get_metrics(&self, provider: &str) -> Option<ProviderMetrics> {
self.metrics.read().await.get(provider).cloned()
}
pub async fn stats(&self) -> RouterStats {
let metrics = self.metrics.read().await;
let total_requests = self.request_counter.load(Ordering::Relaxed);
let avg_latency = if metrics.is_empty() {
0.0
} else {
metrics.values().map(|m| m.ewma_latency_ms).sum::<f64>() / metrics.len() as f64
};
RouterStats {
total_requests,
provider_count: metrics.len(),
average_latency_ms: avg_latency,
optimization_strategy: self.optimization,
}
}
fn calculate_score(&self, metric: &ProviderMetrics, _config: &RouterProviderConfig) -> f64 {
let latency_score = 1.0 / (1.0 + metric.ewma_latency_ms / 100.0);
let reliability_score = 1.0 - metric.error_rate;
let cost_score = 1.0 / (1.0 + metric.cost_per_1k_tokens * 100.0);
match self.optimization {
Optimization::Latency => {
latency_score * 0.7 + reliability_score * 0.3
}
Optimization::Cost => {
cost_score * 0.6 + reliability_score * 0.4
}
Optimization::Reliability => {
reliability_score * 0.7 + latency_score * 0.3
}
}
}
fn estimate_request_cost(
&self,
_request: &CompletionRequest,
config: &RouterProviderConfig,
) -> f64 {
(config.cost_per_1k_tokens / 1000.0) * 100.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct RouterStats {
pub total_requests: u64,
pub provider_count: usize,
pub average_latency_ms: f64,
pub optimization_strategy: Optimization,
}
#[derive(Default)]
pub struct SmartRouterBuilder {
providers: HashMap<String, RouterProviderConfig>,
optimization: Optimization,
ewma_alpha: f64,
}
impl SmartRouterBuilder {
pub fn add_provider(mut self, config: RouterProviderConfig) -> Self {
self.providers.insert(config.name.clone(), config);
self
}
pub fn with_providers(mut self, configs: Vec<RouterProviderConfig>) -> Self {
for config in configs {
self.providers.insert(config.name.clone(), config);
}
self
}
pub fn optimize_for(mut self, optimization: Optimization) -> Self {
self.optimization = optimization;
self
}
pub fn with_ewma_alpha(mut self, alpha: f64) -> Self {
self.ewma_alpha = alpha.clamp(0.01, 1.0);
self
}
pub fn build(self) -> SmartRouter {
SmartRouter {
providers: self.providers,
metrics: Arc::new(RwLock::new(HashMap::new())),
optimization: self.optimization,
ewma_alpha: if self.ewma_alpha > 0.0 {
self.ewma_alpha
} else {
0.1
},
request_counter: Arc::new(AtomicU64::new(0)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::Message;
fn create_test_router() -> SmartRouter {
let providers = vec![
RouterProviderConfig {
name: "openai".to_string(),
cost_per_1k_tokens: 0.01,
reliability_weight: 0.9,
},
RouterProviderConfig {
name: "anthropic".to_string(),
cost_per_1k_tokens: 0.008,
reliability_weight: 0.95,
},
RouterProviderConfig {
name: "groq".to_string(),
cost_per_1k_tokens: 0.0001,
reliability_weight: 0.85,
},
];
SmartRouter::builder()
.with_providers(providers)
.optimize_for(Optimization::Latency)
.with_ewma_alpha(0.2)
.build()
}
#[test]
fn test_router_builder() {
let router = create_test_router();
assert_eq!(router.providers.len(), 3);
assert_eq!(router.optimization, Optimization::Latency);
}
#[tokio::test]
async fn test_route_decision() {
let router = create_test_router();
let request = CompletionRequest::new("openai/gpt-4", vec![Message::user("test")]);
router
.record_request("openai", 50.0, true, 100)
.await
.unwrap();
router
.record_request("anthropic", 30.0, true, 100)
.await
.unwrap();
router
.record_request("groq", 20.0, true, 100)
.await
.unwrap();
let decision = router.route(&request).await.unwrap();
assert_eq!(decision.provider, "groq"); assert!(decision.fallback_chain.len() <= 2);
}
#[tokio::test]
async fn test_cost_optimization() {
let providers = vec![
RouterProviderConfig {
name: "expensive".to_string(),
cost_per_1k_tokens: 0.1,
reliability_weight: 0.95,
},
RouterProviderConfig {
name: "cheap".to_string(),
cost_per_1k_tokens: 0.001,
reliability_weight: 0.90,
},
];
let router = SmartRouter::builder()
.with_providers(providers)
.optimize_for(Optimization::Cost)
.build();
let request = CompletionRequest::new("openai/gpt-4", vec![Message::user("test")]);
router
.record_request("expensive", 100.0, true, 100)
.await
.unwrap();
router
.record_request("cheap", 150.0, true, 100)
.await
.unwrap();
let decision = router.route(&request).await.unwrap();
assert_eq!(decision.provider, "cheap"); }
#[tokio::test]
async fn test_reliability_optimization() {
let providers = vec![
RouterProviderConfig {
name: "stable".to_string(),
cost_per_1k_tokens: 0.01,
reliability_weight: 1.0,
},
RouterProviderConfig {
name: "flaky".to_string(),
cost_per_1k_tokens: 0.01,
reliability_weight: 0.1,
},
];
let router = SmartRouter::builder()
.with_providers(providers)
.optimize_for(Optimization::Reliability)
.build();
let request = CompletionRequest::new("openai/gpt-4", vec![Message::user("test")]);
for i in 0..5 {
router
.record_request("stable", 50.0, true, 100)
.await
.unwrap();
let success = i % 5 != 0; router
.record_request("flaky", 40.0, success, 100)
.await
.unwrap();
}
let decision = router.route(&request).await.unwrap();
assert_eq!(decision.provider, "stable"); }
#[tokio::test]
async fn test_ewma_learning() {
let router = create_test_router();
router
.record_request("openai", 80.0, true, 100)
.await
.unwrap();
let m1 = router.get_metrics("openai").await.unwrap();
assert!(m1.ewma_latency_ms < 100.0);
router
.record_request("openai", 50.0, true, 100)
.await
.unwrap();
let m2 = router.get_metrics("openai").await.unwrap();
assert!(m2.ewma_latency_ms < m1.ewma_latency_ms);
assert_eq!(m2.request_count, 2);
}
#[tokio::test]
async fn test_router_stats() {
let router = create_test_router();
router
.record_request("openai", 50.0, true, 100)
.await
.unwrap();
router
.record_request("anthropic", 60.0, true, 100)
.await
.unwrap();
let stats = router.stats().await;
assert_eq!(stats.total_requests, 2);
assert_eq!(stats.optimization_strategy, Optimization::Latency);
}
#[tokio::test]
async fn test_fallback_chain() {
let router = create_test_router();
let request = CompletionRequest::new("openai/gpt-4", vec![Message::user("test")]);
router
.record_request("openai", 100.0, true, 100)
.await
.unwrap();
router
.record_request("anthropic", 50.0, true, 100)
.await
.unwrap();
router
.record_request("groq", 30.0, true, 100)
.await
.unwrap();
let decision = router.route(&request).await.unwrap();
assert_eq!(decision.fallback_chain.len(), 2);
assert!(!decision.fallback_chain.contains(&decision.provider));
}
}