1use std::sync::{
26 Arc,
27 atomic::{AtomicUsize, Ordering},
28};
29
30use async_trait::async_trait;
31use bob_core::{
32 error::LlmError,
33 ports::LlmPort,
34 resilience::{CircuitBreaker, CircuitState},
35 types::{LlmCapabilities, LlmRequest, LlmResponse, LlmStream},
36};
37
38#[derive(Debug, Clone)]
42pub enum RoutingStrategy {
43 Priority,
45 RoundRobin,
47}
48
49pub struct ProviderEntry {
53 pub name: String,
55 pub adapter: Arc<dyn LlmPort>,
57 pub circuit_breaker: Option<Arc<CircuitBreaker>>,
59}
60
61impl std::fmt::Debug for ProviderEntry {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("ProviderEntry")
64 .field("name", &self.name)
65 .field("has_circuit_breaker", &self.circuit_breaker.is_some())
66 .finish_non_exhaustive()
67 }
68}
69
70impl ProviderEntry {
71 #[must_use]
73 pub fn new(name: impl Into<String>, adapter: Arc<dyn LlmPort>) -> Self {
74 Self { name: name.into(), adapter, circuit_breaker: None }
75 }
76
77 #[must_use]
79 pub fn with_circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
80 self.circuit_breaker = Some(cb);
81 self
82 }
83
84 fn is_available(&self) -> bool {
86 match &self.circuit_breaker {
87 Some(cb) => cb.state() != CircuitState::Open,
88 None => true,
89 }
90 }
91}
92
93pub struct ProviderRouter {
97 strategy: RoutingStrategy,
98 providers: Vec<ProviderEntry>,
99 round_robin_index: AtomicUsize,
100}
101
102impl std::fmt::Debug for ProviderRouter {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("ProviderRouter")
105 .field("strategy", &self.strategy)
106 .field("provider_count", &self.providers.len())
107 .finish_non_exhaustive()
108 }
109}
110
111impl ProviderRouter {
112 #[must_use]
114 pub fn new(strategy: RoutingStrategy) -> Self {
115 Self { strategy, providers: Vec::new(), round_robin_index: AtomicUsize::new(0) }
116 }
117
118 #[must_use]
120 pub fn with_provider(mut self, entry: ProviderEntry) -> Self {
121 self.providers.push(entry);
122 self
123 }
124
125 #[must_use]
127 pub fn provider_count(&self) -> usize {
128 self.providers.len()
129 }
130
131 async fn route_request<F, Fut>(&self, make_call: F) -> Result<LlmResponse, LlmError>
133 where
134 F: Fn(Arc<dyn LlmPort>) -> Fut,
135 Fut: std::future::Future<Output = Result<LlmResponse, LlmError>>,
136 {
137 match &self.strategy {
138 RoutingStrategy::Priority => self.route_priority(&make_call).await,
139 RoutingStrategy::RoundRobin => self.route_round_robin(&make_call).await,
140 }
141 }
142
143 async fn route_priority<F, Fut>(&self, make_call: &F) -> Result<LlmResponse, LlmError>
144 where
145 F: Fn(Arc<dyn LlmPort>) -> Fut,
146 Fut: std::future::Future<Output = Result<LlmResponse, LlmError>>,
147 {
148 let mut last_error = None;
149
150 for entry in &self.providers {
151 if !entry.is_available() {
152 continue;
153 }
154
155 let result = if let Some(cb) = &entry.circuit_breaker {
156 cb.call(|| make_call(entry.adapter.clone())).await.map_err(|cb_err| match cb_err {
157 bob_core::resilience::CircuitBreakerError::CircuitOpen => {
158 LlmError::Provider(format!("{}: circuit open", entry.name))
159 }
160 bob_core::resilience::CircuitBreakerError::Inner(e) => e,
161 })
162 } else {
163 make_call(entry.adapter.clone()).await
164 };
165
166 match result {
167 Ok(resp) => return Ok(resp),
168 Err(err) => {
169 tracing::warn!(provider = %entry.name, error = %err, "provider failed, trying next");
170 last_error = Some(err);
171 }
172 }
173 }
174
175 Err(last_error.unwrap_or_else(|| LlmError::Provider("no providers available".into())))
176 }
177
178 async fn route_round_robin<F, Fut>(&self, make_call: &F) -> Result<LlmResponse, LlmError>
179 where
180 F: Fn(Arc<dyn LlmPort>) -> Fut,
181 Fut: std::future::Future<Output = Result<LlmResponse, LlmError>>,
182 {
183 let healthy: Vec<&ProviderEntry> =
184 self.providers.iter().filter(|p| p.is_available()).collect();
185
186 if healthy.is_empty() {
187 return Err(LlmError::Provider("no healthy providers available".into()));
188 }
189
190 let index = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % healthy.len();
191
192 let mut last_error = None;
194 for offset in 0..healthy.len() {
195 let entry = healthy[(index + offset) % healthy.len()];
196
197 let result = if let Some(cb) = &entry.circuit_breaker {
198 cb.call(|| make_call(entry.adapter.clone())).await.map_err(|cb_err| match cb_err {
199 bob_core::resilience::CircuitBreakerError::CircuitOpen => {
200 LlmError::Provider(format!("{}: circuit open", entry.name))
201 }
202 bob_core::resilience::CircuitBreakerError::Inner(e) => e,
203 })
204 } else {
205 make_call(entry.adapter.clone()).await
206 };
207
208 match result {
209 Ok(resp) => return Ok(resp),
210 Err(err) => {
211 tracing::warn!(provider = %entry.name, error = %err, "provider failed in round-robin");
212 last_error = Some(err);
213 }
214 }
215 }
216
217 Err(last_error.unwrap_or_else(|| LlmError::Provider("all providers failed".into())))
218 }
219}
220
221#[async_trait]
224impl LlmPort for ProviderRouter {
225 fn capabilities(&self) -> LlmCapabilities {
226 let mut native_tool_calling = false;
228 let mut streaming = false;
229 for entry in &self.providers {
230 let caps = entry.adapter.capabilities();
231 native_tool_calling |= caps.native_tool_calling;
232 streaming |= caps.streaming;
233 }
234 LlmCapabilities { native_tool_calling, streaming }
235 }
236
237 async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
238 let req = Arc::new(req);
239 self.route_request(|adapter| {
240 let req = Arc::clone(&req);
241 async move { adapter.complete((*req).clone()).await }
242 })
243 .await
244 }
245
246 async fn complete_stream(&self, req: LlmRequest) -> Result<LlmStream, LlmError> {
247 for entry in &self.providers {
249 if !entry.is_available() {
250 continue;
251 }
252 match entry.adapter.complete_stream(req.clone()).await {
253 Ok(stream) => return Ok(stream),
254 Err(err) => {
255 tracing::warn!(provider = %entry.name, error = %err, "stream provider failed, trying next");
256 }
257 }
258 }
259 Err(LlmError::Provider("no provider available for streaming".into()))
260 }
261}
262
263#[cfg(test)]
266mod tests {
267 use std::sync::Mutex;
268
269 use super::*;
270
271 struct MockLlm {
272 _name: &'static str,
273 responses: Mutex<Vec<Result<LlmResponse, LlmError>>>,
274 }
275
276 impl MockLlm {
277 fn succeeds(name: &'static str, content: &'static str) -> Self {
278 Self {
279 _name: name,
280 responses: Mutex::new(vec![Ok(LlmResponse {
281 content: content.into(),
282 usage: bob_core::types::TokenUsage::default(),
283 finish_reason: bob_core::types::FinishReason::Stop,
284 tool_calls: vec![],
285 })]),
286 }
287 }
288
289 fn always_fails(name: &'static str) -> Self {
290 Self {
291 _name: name,
292 responses: Mutex::new(vec![Err(LlmError::Provider(format!(
293 "{name}: simulated failure"
294 )))]),
295 }
296 }
297 }
298
299 #[async_trait]
300 impl LlmPort for MockLlm {
301 fn capabilities(&self) -> LlmCapabilities {
302 LlmCapabilities { native_tool_calling: false, streaming: false }
303 }
304
305 async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
306 let mut responses = match self.responses.lock() {
307 Ok(guard) => guard,
308 Err(poisoned) => poisoned.into_inner(),
309 };
310 if responses.is_empty() {
311 return Err(LlmError::Provider("no more mock responses".into()));
312 }
313 responses.remove(0)
314 }
315
316 async fn complete_stream(&self, _req: LlmRequest) -> Result<LlmStream, LlmError> {
317 Err(LlmError::Provider("streaming not supported in mock".into()))
318 }
319 }
320
321 fn test_request() -> LlmRequest {
322 LlmRequest {
323 model: "test-model".into(),
324 messages: vec![bob_core::types::Message::text(bob_core::types::Role::User, "hello")],
325 tools: vec![],
326 output_schema: None,
327 }
328 }
329
330 #[tokio::test]
331 async fn priority_routes_to_first_available() {
332 let router = ProviderRouter::new(RoutingStrategy::Priority)
333 .with_provider(ProviderEntry::new(
334 "primary",
335 Arc::new(MockLlm::succeeds("primary", "ok")),
336 ))
337 .with_provider(ProviderEntry::new(
338 "backup",
339 Arc::new(MockLlm::succeeds("backup", "fallback")),
340 ));
341
342 let resp = router.complete(test_request()).await.expect("should succeed");
343 assert_eq!(resp.content, "ok");
344 }
345
346 #[tokio::test]
347 async fn priority_falls_back_on_failure() {
348 let router = ProviderRouter::new(RoutingStrategy::Priority)
349 .with_provider(ProviderEntry::new(
350 "primary",
351 Arc::new(MockLlm::always_fails("primary")),
352 ))
353 .with_provider(ProviderEntry::new(
354 "backup",
355 Arc::new(MockLlm::succeeds("backup", "fallback")),
356 ));
357
358 let resp = router.complete(test_request()).await.expect("should succeed via fallback");
359 assert_eq!(resp.content, "fallback");
360 }
361
362 #[tokio::test]
363 async fn priority_fails_when_all_providers_fail() {
364 let router = ProviderRouter::new(RoutingStrategy::Priority)
365 .with_provider(ProviderEntry::new("p1", Arc::new(MockLlm::always_fails("p1"))))
366 .with_provider(ProviderEntry::new("p2", Arc::new(MockLlm::always_fails("p2"))));
367
368 let result = router.complete(test_request()).await;
369 assert!(result.is_err());
370 }
371
372 #[tokio::test]
373 async fn round_robin_distributes_requests() {
374 let router = ProviderRouter::new(RoutingStrategy::RoundRobin)
375 .with_provider(ProviderEntry::new("a", Arc::new(MockLlm::succeeds("a", "from-a"))))
376 .with_provider(ProviderEntry::new("b", Arc::new(MockLlm::succeeds("b", "from-b"))));
377
378 let _ = router.complete(test_request()).await.expect("first call should succeed");
380 let _ = router.complete(test_request()).await.expect("second call should succeed");
381 }
382
383 #[tokio::test]
384 async fn circuit_breaker_skips_open_provider() {
385 let cb = Arc::new(CircuitBreaker::new(bob_core::resilience::CircuitBreakerConfig {
386 failure_threshold: 1,
387 success_threshold: 1,
388 cooldown: std::time::Duration::from_secs(60),
389 }));
390
391 let _ = cb.call(|| async { Err::<(), _>("fail") }).await;
393 assert_eq!(cb.state(), CircuitState::Open);
394
395 let router = ProviderRouter::new(RoutingStrategy::Priority)
396 .with_provider(
397 ProviderEntry::new("primary", Arc::new(MockLlm::succeeds("primary", "ok")))
398 .with_circuit_breaker(cb),
399 )
400 .with_provider(ProviderEntry::new(
401 "backup",
402 Arc::new(MockLlm::succeeds("backup", "fallback")),
403 ));
404
405 let resp = router.complete(test_request()).await.expect("should fall back to backup");
406 assert_eq!(resp.content, "fallback");
407 }
408
409 #[tokio::test]
410 async fn no_providers_returns_error() {
411 let router = ProviderRouter::new(RoutingStrategy::Priority);
412 let result = router.complete(test_request()).await;
413 assert!(result.is_err());
414 }
415}