batuta/agent/driver/
router.rs1use async_trait::async_trait;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::sync::Arc;
16
17use crate::agent::driver::{CompletionRequest, CompletionResponse, LlmDriver};
18use crate::agent::result::AgentError;
19use crate::serve::backends::PrivacyTier;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum RoutingStrategy {
24 PrimaryWithFallback,
26 PrimaryOnly,
30 FallbackOnly,
33}
34
35#[derive(Debug)]
37pub struct RoutingMetrics {
38 primary_successes: AtomicU64,
40 primary_failures: AtomicU64,
42 spillovers: AtomicU64,
44 fallback_successes: AtomicU64,
46 fallback_failures: AtomicU64,
48}
49
50impl RoutingMetrics {
51 fn new() -> Self {
52 Self {
53 primary_successes: AtomicU64::new(0),
54 primary_failures: AtomicU64::new(0),
55 spillovers: AtomicU64::new(0),
56 fallback_successes: AtomicU64::new(0),
57 fallback_failures: AtomicU64::new(0),
58 }
59 }
60
61 pub fn primary_attempts(&self) -> u64 {
63 self.primary_successes.load(Ordering::Relaxed)
64 + self.primary_failures.load(Ordering::Relaxed)
65 }
66
67 pub fn spillover_count(&self) -> u64 {
69 self.spillovers.load(Ordering::Relaxed)
70 }
71
72 pub fn fallback_success_rate(&self) -> f64 {
74 let successes = self.fallback_successes.load(Ordering::Relaxed);
75 let failures = self.fallback_failures.load(Ordering::Relaxed);
76 let total = successes + failures;
77 if total == 0 {
78 0.0
79 } else {
80 #[allow(clippy::cast_precision_loss)]
82 {
83 successes as f64 / total as f64
84 }
85 }
86 }
87}
88
89pub struct RoutingDriver {
100 primary: Box<dyn LlmDriver>,
101 fallback: Option<Box<dyn LlmDriver>>,
102 strategy: RoutingStrategy,
103 metrics: Arc<RoutingMetrics>,
104}
105
106impl RoutingDriver {
107 pub fn new(primary: Box<dyn LlmDriver>, fallback: Box<dyn LlmDriver>) -> Self {
109 Self {
110 primary,
111 fallback: Some(fallback),
112 strategy: RoutingStrategy::PrimaryWithFallback,
113 metrics: Arc::new(RoutingMetrics::new()),
114 }
115 }
116
117 pub fn primary_only(primary: Box<dyn LlmDriver>) -> Self {
119 Self {
120 primary,
121 fallback: None,
122 strategy: RoutingStrategy::PrimaryOnly,
123 metrics: Arc::new(RoutingMetrics::new()),
124 }
125 }
126
127 #[must_use]
129 pub fn with_strategy(mut self, strategy: RoutingStrategy) -> Self {
130 self.strategy = strategy;
131 self
132 }
133
134 pub fn metrics(&self) -> &RoutingMetrics {
136 &self.metrics
137 }
138
139 fn should_fallback(error: &AgentError) -> bool {
141 use crate::agent::result::DriverError;
142 match error {
143 AgentError::Driver(driver_err) => {
144 matches!(
145 driver_err,
146 DriverError::InferenceFailed(_)
147 | DriverError::ModelNotFound(_)
148 | DriverError::Network(_)
149 )
150 }
151 _ => false,
152 }
153 }
154
155 fn record_primary(&self, result: &Result<CompletionResponse, AgentError>) {
157 match result {
158 Ok(_) => {
159 self.metrics.primary_successes.fetch_add(1, Ordering::Relaxed);
160 }
161 Err(_) => {
162 self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
163 }
164 }
165 }
166
167 fn record_fallback(&self, result: &Result<CompletionResponse, AgentError>) {
169 match result {
170 Ok(_) => {
171 self.metrics.fallback_successes.fetch_add(1, Ordering::Relaxed);
172 }
173 Err(_) => {
174 self.metrics.fallback_failures.fetch_add(1, Ordering::Relaxed);
175 }
176 }
177 }
178
179 async fn complete_with_fallback(
181 &self,
182 request: CompletionRequest,
183 ) -> Result<CompletionResponse, AgentError> {
184 let primary_result = self.primary.complete(request.clone()).await;
185
186 match primary_result {
187 Ok(response) => {
188 self.metrics.primary_successes.fetch_add(1, Ordering::Relaxed);
189 Ok(response)
190 }
191 Err(ref e) if Self::should_fallback(e) && self.fallback.is_some() => {
192 self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
193 self.metrics.spillovers.fetch_add(1, Ordering::Relaxed);
194 self.run_fallback(request).await
195 }
196 Err(e) => {
197 self.metrics.primary_failures.fetch_add(1, Ordering::Relaxed);
198 Err(e)
199 }
200 }
201 }
202
203 async fn run_fallback(
205 &self,
206 request: CompletionRequest,
207 ) -> Result<CompletionResponse, AgentError> {
208 if let Some(ref fallback) = self.fallback {
209 let result = fallback.complete(request).await;
210 self.record_fallback(&result);
211 return result;
212 }
213 Err(AgentError::Driver(crate::agent::result::DriverError::InferenceFailed(
214 "No fallback driver configured".into(),
215 )))
216 }
217}
218
219#[async_trait]
220impl LlmDriver for RoutingDriver {
221 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, AgentError> {
222 match self.strategy {
223 RoutingStrategy::FallbackOnly => self.run_fallback(request).await,
224 RoutingStrategy::PrimaryOnly => {
225 let result = self.primary.complete(request).await;
226 self.record_primary(&result);
227 result
228 }
229 RoutingStrategy::PrimaryWithFallback => self.complete_with_fallback(request).await,
230 }
231 }
232
233 fn context_window(&self) -> usize {
234 match self.strategy {
235 RoutingStrategy::FallbackOnly => {
236 self.fallback.as_ref().map_or(self.primary.context_window(), |f| f.context_window())
237 }
238 _ => self.primary.context_window(),
239 }
240 }
241
242 fn privacy_tier(&self) -> PrivacyTier {
243 let primary_tier = self.primary.privacy_tier();
244 let fallback_tier = self.fallback.as_ref().map_or(primary_tier, |f| f.privacy_tier());
245
246 match (&primary_tier, &fallback_tier) {
248 (PrivacyTier::Standard, _) | (_, PrivacyTier::Standard) => PrivacyTier::Standard,
249 (PrivacyTier::Private, _) | (_, PrivacyTier::Private) => PrivacyTier::Private,
250 _ => PrivacyTier::Sovereign,
251 }
252 }
253}
254
255#[cfg(test)]
256#[path = "router_tests.rs"]
257mod tests;