use crate::config::ProviderConfig;
use crate::error::Result;
use crate::request::CompletionRequest;
use async_trait::async_trait;
use std::time::Duration;
#[async_trait]
pub trait RoutingStrategy: Send + Sync {
async fn select_provider(
&self,
providers: &[ProviderConfig],
request: &CompletionRequest,
) -> Result<usize>;
async fn report_success(&self, provider_index: usize, latency: Duration) {
let _ = (provider_index, latency);
}
async fn report_failure(&self, provider_index: usize) {
let _ = provider_index;
}
fn name(&self) -> &str {
"routing-strategy"
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RoutingMode {
Priority,
RoundRobin,
LatencyBased,
Random,
}
impl RoutingMode {
pub fn description(&self) -> &str {
match self {
Self::Priority => "Try providers in priority order",
Self::RoundRobin => "Distribute requests evenly across providers",
Self::LatencyBased => "Route to provider with lowest average latency",
Self::Random => "Randomly select provider",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ProviderHealth {
Healthy,
Degraded,
Unavailable,
}
impl ProviderHealth {
pub fn is_available(&self) -> bool {
matches!(self, Self::Healthy | Self::Degraded)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct ProviderMetrics {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub avg_latency: Duration,
pub health: ProviderHealth,
}
impl Default for ProviderMetrics {
fn default() -> Self {
Self {
total_requests: 0,
successful_requests: 0,
failed_requests: 0,
avg_latency: Duration::from_millis(0),
health: ProviderHealth::Healthy,
}
}
}
impl ProviderMetrics {
pub fn success_rate(&self) -> f32 {
if self.total_requests == 0 {
return 1.0;
}
self.successful_requests as f32 / self.total_requests as f32
}
pub fn failure_rate(&self) -> f32 {
1.0 - self.success_rate()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_routing_mode_description() {
assert!(!RoutingMode::Priority.description().is_empty());
assert!(!RoutingMode::RoundRobin.description().is_empty());
assert!(!RoutingMode::LatencyBased.description().is_empty());
assert!(!RoutingMode::Random.description().is_empty());
}
#[test]
fn test_provider_health_is_available() {
assert!(ProviderHealth::Healthy.is_available());
assert!(ProviderHealth::Degraded.is_available());
assert!(!ProviderHealth::Unavailable.is_available());
}
#[test]
fn test_provider_metrics_default() {
let metrics = ProviderMetrics::default();
assert_eq!(metrics.total_requests, 0);
assert_eq!(metrics.successful_requests, 0);
assert_eq!(metrics.failed_requests, 0);
assert_eq!(metrics.success_rate(), 1.0);
assert_eq!(metrics.failure_rate(), 0.0);
}
#[test]
fn test_provider_metrics_success_rate() {
let metrics = ProviderMetrics {
total_requests: 100,
successful_requests: 95,
failed_requests: 5,
avg_latency: Duration::from_millis(200),
health: ProviderHealth::Healthy,
};
assert!((metrics.success_rate() - 0.95).abs() < 0.001);
assert!((metrics.failure_rate() - 0.05).abs() < 0.001);
}
#[test]
fn test_provider_metrics_zero_requests() {
let metrics = ProviderMetrics {
total_requests: 0,
successful_requests: 0,
failed_requests: 0,
avg_latency: Duration::from_millis(0),
health: ProviderHealth::Healthy,
};
assert_eq!(metrics.success_rate(), 1.0);
assert_eq!(metrics.failure_rate(), 0.0);
}
#[test]
fn test_routing_strategy_object_safety() {
fn _assert_object_safe(_: &dyn RoutingStrategy) {}
}
}