use async_trait::async_trait;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::agent::driver::{CompletionRequest, CompletionResponse, LlmDriver};
use crate::agent::result::AgentError;
use crate::serve::backends::PrivacyTier;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RoutingStrategy {
PrimaryWithFallback,
PrimaryOnly,
FallbackOnly,
}
#[derive(Debug)]
pub struct RoutingMetrics {
primary_successes: AtomicU64,
primary_failures: AtomicU64,
spillovers: AtomicU64,
fallback_successes: AtomicU64,
fallback_failures: AtomicU64,
}
impl RoutingMetrics {
fn new() -> Self {
Self {
primary_successes: AtomicU64::new(0),
primary_failures: AtomicU64::new(0),
spillovers: AtomicU64::new(0),
fallback_successes: AtomicU64::new(0),
fallback_failures: AtomicU64::new(0),
}
}
pub fn primary_attempts(&self) -> u64 {
self.primary_successes.load(Ordering::Relaxed)
+ self.primary_failures.load(Ordering::Relaxed)
}
pub fn spillover_count(&self) -> u64 {
self.spillovers.load(Ordering::Relaxed)
}
pub fn fallback_success_rate(&self) -> f64 {
let successes = self.fallback_successes.load(Ordering::Relaxed);
let failures = self.fallback_failures.load(Ordering::Relaxed);
let total = successes + failures;
if total == 0 {
0.0
} else {
#[allow(clippy::cast_precision_loss)]
{
successes as f64 / total as f64
}
}
}
}
pub struct RoutingDriver {
primary: Box<dyn LlmDriver>,
fallback: Option<Box<dyn LlmDriver>>,
strategy: RoutingStrategy,
metrics: Arc<RoutingMetrics>,
}
impl RoutingDriver {
pub fn new(primary: Box<dyn LlmDriver>, fallback: Box<dyn LlmDriver>) -> Self {
Self {
primary,
fallback: Some(fallback),
strategy: RoutingStrategy::PrimaryWithFallback,
metrics: Arc::new(RoutingMetrics::new()),
}
}
pub fn primary_only(primary: Box<dyn LlmDriver>) -> Self {
Self {
primary,
fallback: None,
strategy: RoutingStrategy::PrimaryOnly,
metrics: Arc::new(RoutingMetrics::new()),
}
}
#[must_use]
pub fn with_strategy(mut self, strategy: RoutingStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn metrics(&self) -> &RoutingMetrics {
&self.metrics
}
fn should_fallback(error: &AgentError) -> bool {
use crate::agent::result::DriverError;
match error {
AgentError::Driver(driver_err) => {
matches!(
driver_err,
DriverError::InferenceFailed(_)
| DriverError::ModelNotFound(_)
| DriverError::Network(_)
)
}
_ => false,
}
}
fn record_primary(&self, result: &Result<CompletionResponse, AgentError>) {
match result {
Ok(_) => {
self.metrics.primary_successes.fetch_add(1, Ordering::Relaxed);
}
Err(_) => {
self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
}
}
}
fn record_fallback(&self, result: &Result<CompletionResponse, AgentError>) {
match result {
Ok(_) => {
self.metrics.fallback_successes.fetch_add(1, Ordering::Relaxed);
}
Err(_) => {
self.metrics.fallback_failures.fetch_add(1, Ordering::Relaxed);
}
}
}
async fn complete_with_fallback(
&self,
request: CompletionRequest,
) -> Result<CompletionResponse, AgentError> {
let primary_result = self.primary.complete(request.clone()).await;
match primary_result {
Ok(response) => {
self.metrics.primary_successes.fetch_add(1, Ordering::Relaxed);
Ok(response)
}
Err(ref e) if Self::should_fallback(e) && self.fallback.is_some() => {
self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
self.metrics.spillovers.fetch_add(1, Ordering::Relaxed);
self.run_fallback(request).await
}
Err(e) => {
self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
Err(e)
}
}
}
async fn run_fallback(
&self,
request: CompletionRequest,
) -> Result<CompletionResponse, AgentError> {
if let Some(ref fallback) = self.fallback {
let result = fallback.complete(request).await;
self.record_fallback(&result);
return result;
}
Err(AgentError::Driver(crate::agent::result::DriverError::InferenceFailed(
"No fallback driver configured".into(),
)))
}
}
#[async_trait]
impl LlmDriver for RoutingDriver {
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
match self.strategy {
RoutingStrategy::FallbackOnly => self.run_fallback(request).await,
RoutingStrategy::PrimaryOnly => {
let result = self.primary.complete(request).await;
self.record_primary(&result);
result
}
RoutingStrategy::PrimaryWithFallback => self.complete_with_fallback(request).await,
}
}
fn context_window(&self) -> usize {
match self.strategy {
RoutingStrategy::FallbackOnly => {
self.fallback.as_ref().map_or(self.primary.context_window(), |f| f.context_window())
}
_ => self.primary.context_window(),
}
}
fn privacy_tier(&self) -> PrivacyTier {
let primary_tier = self.primary.privacy_tier();
let fallback_tier = self.fallback.as_ref().map_or(primary_tier, |f| f.privacy_tier());
match (&primary_tier, &fallback_tier) {
(PrivacyTier::Standard, _) | (_, PrivacyTier::Standard) => PrivacyTier::Standard,
(PrivacyTier::Private, _) | (_, PrivacyTier::Private) => PrivacyTier::Private,
_ => PrivacyTier::Sovereign,
}
}
}
#[cfg(test)]
#[path = "router_tests.rs"]
mod tests;