devsper_providers/
router.rs1use devsper_core::{LlmProvider, LlmRequest, LlmResponse};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use std::sync::Arc;
5use tracing::debug;
6
7pub struct ModelRouter {
14 providers: Vec<Arc<dyn LlmProvider>>,
15}
16
17impl ModelRouter {
18 pub fn new() -> Self {
19 Self { providers: vec![] }
20 }
21
22 pub fn with_provider(mut self, provider: Arc<dyn LlmProvider>) -> Self {
23 self.providers.push(provider);
24 self
25 }
26
27 pub fn add_provider(&mut self, provider: Arc<dyn LlmProvider>) {
28 self.providers.push(provider);
29 }
30
31 fn route(&self, model: &str) -> Option<&Arc<dyn LlmProvider>> {
32 self.providers.iter().find(|p| p.supports_model(model))
33 }
34}
35
36impl Default for ModelRouter {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42#[async_trait]
43impl LlmProvider for ModelRouter {
44 async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
45 let provider = self
46 .route(&req.model)
47 .ok_or_else(|| anyhow!("No provider found for model: {}", req.model))?;
48 debug!(model = %req.model, provider = %provider.name(), "Routing request");
49 provider.generate(req).await
50 }
51
52 fn name(&self) -> &str {
53 "router"
54 }
55
56 fn supports_model(&self, model: &str) -> bool {
57 self.route(model).is_some()
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use crate::mock::MockProvider;
65 use devsper_core::{LlmMessage, LlmRole};
66
67 fn make_req(model: &str) -> LlmRequest {
68 LlmRequest {
69 model: model.to_string(),
70 messages: vec![LlmMessage {
71 role: LlmRole::User,
72 content: "test".to_string(),
73 }],
74 tools: vec![],
75 max_tokens: None,
76 temperature: None,
77 system: None,
78 }
79 }
80
81 #[tokio::test]
82 async fn routes_to_mock() {
83 let router = ModelRouter::new().with_provider(Arc::new(MockProvider::new("mocked")));
84
85 let res = router.generate(make_req("mock")).await.unwrap();
86 assert_eq!(res.content, "mocked");
87 }
88
89 #[tokio::test]
90 async fn unknown_model_returns_error() {
91 let router = ModelRouter::new();
92 let result = router.generate(make_req("unknown-model")).await;
93 assert!(result.is_err());
94 }
95
96 #[tokio::test]
97 async fn router_supports_model_delegates() {
98 let router = ModelRouter::new().with_provider(Arc::new(MockProvider::new("")));
99 assert!(router.supports_model("mock"));
100 assert!(!router.supports_model("claude-opus-4-6"));
101 }
102}