Skip to main content

bob_adapters/
provider_router.rs

1//! # Provider Router
2//!
3//! Routes LLM requests across multiple providers with fallback,
4//! priority ordering, and circuit-breaker-aware health checks.
5//!
6//! ## Strategies
7//!
8//! - [`RoutingStrategy::Priority`]: Try providers in order; fall back to the next on failure.
9//! - [`RoutingStrategy::RoundRobin`]: Distribute requests across healthy providers in round-robin
10//!   order.
11//!
12//! ## Example
13//!
14//! ```rust,ignore
15//! use bob_adapters::provider_router::{ProviderRouter, RoutingStrategy, ProviderEntry};
16//! use std::sync::Arc;
17//!
18//! let router = ProviderRouter::new(RoutingStrategy::Priority)
19//!     .with_provider(ProviderEntry::new("openai", Arc::new(openai_adapter)))
20//!     .with_provider(ProviderEntry::new("anthropic", Arc::new(anthropic_adapter)));
21//!
22//! let response = router.complete(request).await?;
23//! ```
24
25use std::sync::{
26    Arc,
27    atomic::{AtomicUsize, Ordering},
28};
29
30use async_trait::async_trait;
31use bob_core::{
32    error::LlmError,
33    ports::LlmPort,
34    resilience::{CircuitBreaker, CircuitState},
35    types::{LlmCapabilities, LlmRequest, LlmResponse, LlmStream},
36};
37
38// ── Routing Strategy ─────────────────────────────────────────────────
39
40/// How the router selects which provider to use.
41#[derive(Debug, Clone)]
42pub enum RoutingStrategy {
43    /// Try each provider in the order they were added; stop on first success.
44    Priority,
45    /// Distribute across healthy providers in round-robin fashion.
46    RoundRobin,
47}
48
49// ── Provider Entry ───────────────────────────────────────────────────
50
51/// A named LLM provider with an optional circuit breaker.
52pub struct ProviderEntry {
53    /// Human-readable name (e.g. `"openai"`, `"anthropic"`).
54    pub name: String,
55    /// The underlying LLM adapter.
56    pub adapter: Arc<dyn LlmPort>,
57    /// Optional circuit breaker for this provider.
58    pub circuit_breaker: Option<Arc<CircuitBreaker>>,
59}
60
61impl std::fmt::Debug for ProviderEntry {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("ProviderEntry")
64            .field("name", &self.name)
65            .field("has_circuit_breaker", &self.circuit_breaker.is_some())
66            .finish_non_exhaustive()
67    }
68}
69
70impl ProviderEntry {
71    /// Create a new provider entry without a circuit breaker.
72    #[must_use]
73    pub fn new(name: impl Into<String>, adapter: Arc<dyn LlmPort>) -> Self {
74        Self { name: name.into(), adapter, circuit_breaker: None }
75    }
76
77    /// Attach a circuit breaker to this provider entry.
78    #[must_use]
79    pub fn with_circuit_breaker(mut self, cb: Arc<CircuitBreaker>) -> Self {
80        self.circuit_breaker = Some(cb);
81        self
82    }
83
84    /// Returns `true` if the circuit breaker allows a call.
85    fn is_available(&self) -> bool {
86        match &self.circuit_breaker {
87            Some(cb) => cb.state() != CircuitState::Open,
88            None => true,
89        }
90    }
91}
92
93// ── Provider Router ──────────────────────────────────────────────────
94
95/// Routes LLM requests across multiple providers.
96pub struct ProviderRouter {
97    strategy: RoutingStrategy,
98    providers: Vec<ProviderEntry>,
99    round_robin_index: AtomicUsize,
100}
101
102impl std::fmt::Debug for ProviderRouter {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("ProviderRouter")
105            .field("strategy", &self.strategy)
106            .field("provider_count", &self.providers.len())
107            .finish_non_exhaustive()
108    }
109}
110
111impl ProviderRouter {
112    /// Create a new router with the given strategy.
113    #[must_use]
114    pub fn new(strategy: RoutingStrategy) -> Self {
115        Self { strategy, providers: Vec::new(), round_robin_index: AtomicUsize::new(0) }
116    }
117
118    /// Add a provider to the router.
119    #[must_use]
120    pub fn with_provider(mut self, entry: ProviderEntry) -> Self {
121        self.providers.push(entry);
122        self
123    }
124
125    /// Returns the number of registered providers.
126    #[must_use]
127    pub fn provider_count(&self) -> usize {
128        self.providers.len()
129    }
130
131    /// Execute a request using the configured routing strategy.
132    async fn route_request<F, Fut>(&self, make_call: F) -> Result<LlmResponse, LlmError>
133    where
134        F: Fn(Arc<dyn LlmPort>) -> Fut,
135        Fut: std::future::Future<Output = Result<LlmResponse, LlmError>>,
136    {
137        match &self.strategy {
138            RoutingStrategy::Priority => self.route_priority(&make_call).await,
139            RoutingStrategy::RoundRobin => self.route_round_robin(&make_call).await,
140        }
141    }
142
143    async fn route_priority<F, Fut>(&self, make_call: &F) -> Result<LlmResponse, LlmError>
144    where
145        F: Fn(Arc<dyn LlmPort>) -> Fut,
146        Fut: std::future::Future<Output = Result<LlmResponse, LlmError>>,
147    {
148        let mut last_error = None;
149
150        for entry in &self.providers {
151            if !entry.is_available() {
152                continue;
153            }
154
155            let result = if let Some(cb) = &entry.circuit_breaker {
156                cb.call(|| make_call(entry.adapter.clone())).await.map_err(|cb_err| match cb_err {
157                    bob_core::resilience::CircuitBreakerError::CircuitOpen => {
158                        LlmError::Provider(format!("{}: circuit open", entry.name))
159                    }
160                    bob_core::resilience::CircuitBreakerError::Inner(e) => e,
161                })
162            } else {
163                make_call(entry.adapter.clone()).await
164            };
165
166            match result {
167                Ok(resp) => return Ok(resp),
168                Err(err) => {
169                    tracing::warn!(provider = %entry.name, error = %err, "provider failed, trying next");
170                    last_error = Some(err);
171                }
172            }
173        }
174
175        Err(last_error.unwrap_or_else(|| LlmError::Provider("no providers available".into())))
176    }
177
178    async fn route_round_robin<F, Fut>(&self, make_call: &F) -> Result<LlmResponse, LlmError>
179    where
180        F: Fn(Arc<dyn LlmPort>) -> Fut,
181        Fut: std::future::Future<Output = Result<LlmResponse, LlmError>>,
182    {
183        let healthy: Vec<&ProviderEntry> =
184            self.providers.iter().filter(|p| p.is_available()).collect();
185
186        if healthy.is_empty() {
187            return Err(LlmError::Provider("no healthy providers available".into()));
188        }
189
190        let index = self.round_robin_index.fetch_add(1, Ordering::Relaxed) % healthy.len();
191
192        // Try starting from the round-robin index, then wrap around.
193        let mut last_error = None;
194        for offset in 0..healthy.len() {
195            let entry = healthy[(index + offset) % healthy.len()];
196
197            let result = if let Some(cb) = &entry.circuit_breaker {
198                cb.call(|| make_call(entry.adapter.clone())).await.map_err(|cb_err| match cb_err {
199                    bob_core::resilience::CircuitBreakerError::CircuitOpen => {
200                        LlmError::Provider(format!("{}: circuit open", entry.name))
201                    }
202                    bob_core::resilience::CircuitBreakerError::Inner(e) => e,
203                })
204            } else {
205                make_call(entry.adapter.clone()).await
206            };
207
208            match result {
209                Ok(resp) => return Ok(resp),
210                Err(err) => {
211                    tracing::warn!(provider = %entry.name, error = %err, "provider failed in round-robin");
212                    last_error = Some(err);
213                }
214            }
215        }
216
217        Err(last_error.unwrap_or_else(|| LlmError::Provider("all providers failed".into())))
218    }
219}
220
221// ── LlmPort Implementation ───────────────────────────────────────────
222
223#[async_trait]
224impl LlmPort for ProviderRouter {
225    fn capabilities(&self) -> LlmCapabilities {
226        // Union of all provider capabilities.
227        let mut native_tool_calling = false;
228        let mut streaming = false;
229        for entry in &self.providers {
230            let caps = entry.adapter.capabilities();
231            native_tool_calling |= caps.native_tool_calling;
232            streaming |= caps.streaming;
233        }
234        LlmCapabilities { native_tool_calling, streaming }
235    }
236
237    async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
238        let req = Arc::new(req);
239        self.route_request(|adapter| {
240            let req = Arc::clone(&req);
241            async move { adapter.complete((*req).clone()).await }
242        })
243        .await
244    }
245
246    async fn complete_stream(&self, req: LlmRequest) -> Result<LlmStream, LlmError> {
247        // For streaming, try each provider in priority order.
248        for entry in &self.providers {
249            if !entry.is_available() {
250                continue;
251            }
252            match entry.adapter.complete_stream(req.clone()).await {
253                Ok(stream) => return Ok(stream),
254                Err(err) => {
255                    tracing::warn!(provider = %entry.name, error = %err, "stream provider failed, trying next");
256                }
257            }
258        }
259        Err(LlmError::Provider("no provider available for streaming".into()))
260    }
261}
262
263// ── Tests ────────────────────────────────────────────────────────────
264
265#[cfg(test)]
266mod tests {
267    use std::sync::Mutex;
268
269    use super::*;
270
271    struct MockLlm {
272        _name: &'static str,
273        responses: Mutex<Vec<Result<LlmResponse, LlmError>>>,
274    }
275
276    impl MockLlm {
277        fn succeeds(name: &'static str, content: &'static str) -> Self {
278            Self {
279                _name: name,
280                responses: Mutex::new(vec![Ok(LlmResponse {
281                    content: content.into(),
282                    usage: bob_core::types::TokenUsage::default(),
283                    finish_reason: bob_core::types::FinishReason::Stop,
284                    tool_calls: vec![],
285                })]),
286            }
287        }
288
289        fn always_fails(name: &'static str) -> Self {
290            Self {
291                _name: name,
292                responses: Mutex::new(vec![Err(LlmError::Provider(format!(
293                    "{name}: simulated failure"
294                )))]),
295            }
296        }
297    }
298
299    #[async_trait]
300    impl LlmPort for MockLlm {
301        fn capabilities(&self) -> LlmCapabilities {
302            LlmCapabilities { native_tool_calling: false, streaming: false }
303        }
304
305        async fn complete(&self, _req: LlmRequest) -> Result<LlmResponse, LlmError> {
306            let mut responses = match self.responses.lock() {
307                Ok(guard) => guard,
308                Err(poisoned) => poisoned.into_inner(),
309            };
310            if responses.is_empty() {
311                return Err(LlmError::Provider("no more mock responses".into()));
312            }
313            responses.remove(0)
314        }
315
316        async fn complete_stream(&self, _req: LlmRequest) -> Result<LlmStream, LlmError> {
317            Err(LlmError::Provider("streaming not supported in mock".into()))
318        }
319    }
320
321    fn test_request() -> LlmRequest {
322        LlmRequest {
323            model: "test-model".into(),
324            messages: vec![bob_core::types::Message::text(bob_core::types::Role::User, "hello")],
325            tools: vec![],
326            output_schema: None,
327        }
328    }
329
330    #[tokio::test]
331    async fn priority_routes_to_first_available() {
332        let router = ProviderRouter::new(RoutingStrategy::Priority)
333            .with_provider(ProviderEntry::new(
334                "primary",
335                Arc::new(MockLlm::succeeds("primary", "ok")),
336            ))
337            .with_provider(ProviderEntry::new(
338                "backup",
339                Arc::new(MockLlm::succeeds("backup", "fallback")),
340            ));
341
342        let resp = router.complete(test_request()).await.expect("should succeed");
343        assert_eq!(resp.content, "ok");
344    }
345
346    #[tokio::test]
347    async fn priority_falls_back_on_failure() {
348        let router = ProviderRouter::new(RoutingStrategy::Priority)
349            .with_provider(ProviderEntry::new(
350                "primary",
351                Arc::new(MockLlm::always_fails("primary")),
352            ))
353            .with_provider(ProviderEntry::new(
354                "backup",
355                Arc::new(MockLlm::succeeds("backup", "fallback")),
356            ));
357
358        let resp = router.complete(test_request()).await.expect("should succeed via fallback");
359        assert_eq!(resp.content, "fallback");
360    }
361
362    #[tokio::test]
363    async fn priority_fails_when_all_providers_fail() {
364        let router = ProviderRouter::new(RoutingStrategy::Priority)
365            .with_provider(ProviderEntry::new("p1", Arc::new(MockLlm::always_fails("p1"))))
366            .with_provider(ProviderEntry::new("p2", Arc::new(MockLlm::always_fails("p2"))));
367
368        let result = router.complete(test_request()).await;
369        assert!(result.is_err());
370    }
371
372    #[tokio::test]
373    async fn round_robin_distributes_requests() {
374        let router = ProviderRouter::new(RoutingStrategy::RoundRobin)
375            .with_provider(ProviderEntry::new("a", Arc::new(MockLlm::succeeds("a", "from-a"))))
376            .with_provider(ProviderEntry::new("b", Arc::new(MockLlm::succeeds("b", "from-b"))));
377
378        // Both should succeed; the order depends on the round-robin index.
379        let _ = router.complete(test_request()).await.expect("first call should succeed");
380        let _ = router.complete(test_request()).await.expect("second call should succeed");
381    }
382
383    #[tokio::test]
384    async fn circuit_breaker_skips_open_provider() {
385        let cb = Arc::new(CircuitBreaker::new(bob_core::resilience::CircuitBreakerConfig {
386            failure_threshold: 1,
387            success_threshold: 1,
388            cooldown: std::time::Duration::from_secs(60),
389        }));
390
391        // Trip the circuit breaker by calling it directly.
392        let _ = cb.call(|| async { Err::<(), _>("fail") }).await;
393        assert_eq!(cb.state(), CircuitState::Open);
394
395        let router = ProviderRouter::new(RoutingStrategy::Priority)
396            .with_provider(
397                ProviderEntry::new("primary", Arc::new(MockLlm::succeeds("primary", "ok")))
398                    .with_circuit_breaker(cb),
399            )
400            .with_provider(ProviderEntry::new(
401                "backup",
402                Arc::new(MockLlm::succeeds("backup", "fallback")),
403            ));
404
405        let resp = router.complete(test_request()).await.expect("should fall back to backup");
406        assert_eq!(resp.content, "fallback");
407    }
408
409    #[tokio::test]
410    async fn no_providers_returns_error() {
411        let router = ProviderRouter::new(RoutingStrategy::Priority);
412        let result = router.complete(test_request()).await;
413        assert!(result.is_err());
414    }
415}