Skip to main content

batuta/agent/driver/
router.rs

1//! Routing driver for local-first, remote fallback inference.
2//!
3//! Wraps a primary (typically sovereign/local) and fallback
4//! (typically remote/cloud) `LlmDriver`. Attempts the primary
5//! driver first; on failure, spills over to the fallback.
6//!
7//! Phase 2: Implements `RoutingDriver` from the agent spec.
8//!
9//! Privacy tier: inherits the more permissive of the two
10//! underlying drivers (if fallback is Standard, routing is
11//! Standard — data *may* leave the machine on spillover).
12
13use async_trait::async_trait;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17use crate::agent::driver::{CompletionRequest, CompletionResponse, LlmDriver};
18use crate::agent::result::AgentError;
19use crate::serve::backends::PrivacyTier;
20
21/// Strategy for selecting which driver to use.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum RoutingStrategy {
24    /// Try primary first, fallback on any error.
25    PrimaryWithFallback,
26    /// Use primary only (no fallback). Equivalent to
27    /// using the primary driver directly, but keeps the
28    /// `RoutingDriver` interface for config uniformity.
29    PrimaryOnly,
30    /// Use fallback only (no primary). Useful for testing
31    /// or when local inference is unavailable.
32    FallbackOnly,
33}
34
35/// Metrics for routing decisions.
36#[derive(Debug)]
37pub struct RoutingMetrics {
38    /// Number of successful primary completions.
39    primary_successes: AtomicU64,
40    /// Number of primary failures (all kinds).
41    primary_failures: AtomicU64,
42    /// Number of times fallback was actually attempted.
43    spillovers: AtomicU64,
44    /// Number of successful fallback completions.
45    fallback_successes: AtomicU64,
46    /// Number of fallback failures.
47    fallback_failures: AtomicU64,
48}
49
50impl RoutingMetrics {
51    fn new() -> Self {
52        Self {
53            primary_successes: AtomicU64::new(0),
54            primary_failures: AtomicU64::new(0),
55            spillovers: AtomicU64::new(0),
56            fallback_successes: AtomicU64::new(0),
57            fallback_failures: AtomicU64::new(0),
58        }
59    }
60
61    /// Total primary attempts.
62    pub fn primary_attempts(&self) -> u64 {
63        self.primary_successes.load(Ordering::Relaxed)
64            + self.primary_failures.load(Ordering::Relaxed)
65    }
66
67    /// Total spillover count (primary failures → fallback).
68    pub fn spillover_count(&self) -> u64 {
69        self.spillovers.load(Ordering::Relaxed)
70    }
71
72    /// Fallback success rate (0.0–1.0).
73    pub fn fallback_success_rate(&self) -> f64 {
74        let successes = self.fallback_successes.load(Ordering::Relaxed);
75        let failures = self.fallback_failures.load(Ordering::Relaxed);
76        let total = successes + failures;
77        if total == 0 {
78            0.0
79        } else {
80            // Precision loss acceptable for metrics display
81            #[allow(clippy::cast_precision_loss)]
82            {
83                successes as f64 / total as f64
84            }
85        }
86    }
87}
88
89/// Routing driver: local-first with remote fallback.
90///
91/// Wraps two `LlmDriver` implementations. The primary driver
92/// (typically `RealizarDriver` for sovereign inference) is tried
93/// first. On failure, the fallback (typically `RemoteDriver`)
94/// handles the request.
95///
96/// Privacy tier is the more permissive of the two drivers — if
97/// the fallback is `Standard`, data may leave the machine on
98/// spillover, so the routing driver reports `Standard`.
99pub struct RoutingDriver {
100    primary: Box<dyn LlmDriver>,
101    fallback: Option<Box<dyn LlmDriver>>,
102    strategy: RoutingStrategy,
103    metrics: Arc<RoutingMetrics>,
104}
105
106impl RoutingDriver {
107    /// Create a new routing driver with primary and fallback.
108    pub fn new(primary: Box<dyn LlmDriver>, fallback: Box<dyn LlmDriver>) -> Self {
109        Self {
110            primary,
111            fallback: Some(fallback),
112            strategy: RoutingStrategy::PrimaryWithFallback,
113            metrics: Arc::new(RoutingMetrics::new()),
114        }
115    }
116
117    /// Create a routing driver with primary only (no fallback).
118    pub fn primary_only(primary: Box<dyn LlmDriver>) -> Self {
119        Self {
120            primary,
121            fallback: None,
122            strategy: RoutingStrategy::PrimaryOnly,
123            metrics: Arc::new(RoutingMetrics::new()),
124        }
125    }
126
127    /// Set the routing strategy.
128    #[must_use]
129    pub fn with_strategy(mut self, strategy: RoutingStrategy) -> Self {
130        self.strategy = strategy;
131        self
132    }
133
134    /// Get routing metrics.
135    pub fn metrics(&self) -> &RoutingMetrics {
136        &self.metrics
137    }
138
139    /// Check if the error is retryable (should trigger fallback).
140    fn should_fallback(error: &AgentError) -> bool {
141        use crate::agent::result::DriverError;
142        match error {
143            AgentError::Driver(driver_err) => {
144                matches!(
145                    driver_err,
146                    DriverError::InferenceFailed(_)
147                        | DriverError::ModelNotFound(_)
148                        | DriverError::Network(_)
149                )
150            }
151            _ => false,
152        }
153    }
154
155    /// Record primary result in metrics.
156    fn record_primary(&self, result: &Result<CompletionResponse, AgentError>) {
157        match result {
158            Ok(_) => {
159                self.metrics.primary_successes.fetch_add(1, Ordering::Relaxed);
160            }
161            Err(_) => {
162                self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
163            }
164        }
165    }
166
167    /// Record fallback result in metrics.
168    fn record_fallback(&self, result: &Result<CompletionResponse, AgentError>) {
169        match result {
170            Ok(_) => {
171                self.metrics.fallback_successes.fetch_add(1, Ordering::Relaxed);
172            }
173            Err(_) => {
174                self.metrics.fallback_failures.fetch_add(1, Ordering::Relaxed);
175            }
176        }
177    }
178
179    /// Try primary, spillover to fallback on retryable error.
180    async fn complete_with_fallback(
181        &self,
182        request: CompletionRequest,
183    ) -> Result<CompletionResponse, AgentError> {
184        let primary_result = self.primary.complete(request.clone()).await;
185
186        match primary_result {
187            Ok(response) => {
188                self.metrics.primary_successes.fetch_add(1, Ordering::Relaxed);
189                Ok(response)
190            }
191            Err(ref e) if Self::should_fallback(e) && self.fallback.is_some() => {
192                self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
193                self.metrics.spillovers.fetch_add(1, Ordering::Relaxed);
194                self.run_fallback(request).await
195            }
196            Err(e) => {
197                self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
198                Err(e)
199            }
200        }
201    }
202
203    /// Execute on fallback driver and record metrics.
204    async fn run_fallback(
205        &self,
206        request: CompletionRequest,
207    ) -> Result<CompletionResponse, AgentError> {
208        if let Some(ref fallback) = self.fallback {
209            let result = fallback.complete(request).await;
210            self.record_fallback(&result);
211            return result;
212        }
213        Err(AgentError::Driver(crate::agent::result::DriverError::InferenceFailed(
214            "No fallback driver configured".into(),
215        )))
216    }
217}
218
219#[async_trait]
220impl LlmDriver for RoutingDriver {
221    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
222        match self.strategy {
223            RoutingStrategy::FallbackOnly => self.run_fallback(request).await,
224            RoutingStrategy::PrimaryOnly => {
225                let result = self.primary.complete(request).await;
226                self.record_primary(&result);
227                result
228            }
229            RoutingStrategy::PrimaryWithFallback => self.complete_with_fallback(request).await,
230        }
231    }
232
233    fn context_window(&self) -> usize {
234        match self.strategy {
235            RoutingStrategy::FallbackOnly => {
236                self.fallback.as_ref().map_or(self.primary.context_window(), |f| f.context_window())
237            }
238            _ => self.primary.context_window(),
239        }
240    }
241
242    fn privacy_tier(&self) -> PrivacyTier {
243        let primary_tier = self.primary.privacy_tier();
244        let fallback_tier = self.fallback.as_ref().map_or(primary_tier, |f| f.privacy_tier());
245
246        // Most permissive tier wins (Standard > Private > Sovereign)
247        match (&primary_tier, &fallback_tier) {
248            (PrivacyTier::Standard, _) | (_, PrivacyTier::Standard) => PrivacyTier::Standard,
249            (PrivacyTier::Private, _) | (_, PrivacyTier::Private) => PrivacyTier::Private,
250            _ => PrivacyTier::Sovereign,
251        }
252    }
253}
254
255#[cfg(test)]
256#[path = "router_tests.rs"]
257mod tests;