Skip to main content

devsper_providers/
router.rs

1use devsper_core::{LlmProvider, LlmRequest, LlmResponse};
2use anyhow::{anyhow, Result};
3use async_trait::async_trait;
4use std::sync::Arc;
5use tracing::debug;
6
7/// Routes LLM requests to the correct provider based on model prefix.
8/// claude-*      → Anthropic
9/// gpt-*, o1*, o3* → OpenAI
10/// ollama:*      → Ollama
11/// zai:*, glm-*  → ZAI
12/// mock*         → Mock
13pub struct ModelRouter {
14    providers: Vec<Arc<dyn LlmProvider>>,
15}
16
17impl ModelRouter {
18    pub fn new() -> Self {
19        Self { providers: vec![] }
20    }
21
22    pub fn with_provider(mut self, provider: Arc<dyn LlmProvider>) -> Self {
23        self.providers.push(provider);
24        self
25    }
26
27    pub fn add_provider(&mut self, provider: Arc<dyn LlmProvider>) {
28        self.providers.push(provider);
29    }
30
31    fn route(&self, model: &str) -> Option<&Arc<dyn LlmProvider>> {
32        self.providers.iter().find(|p| p.supports_model(model))
33    }
34}
35
36impl Default for ModelRouter {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42#[async_trait]
43impl LlmProvider for ModelRouter {
44    async fn generate(&self, req: LlmRequest) -> Result<LlmResponse> {
45        let provider = self
46            .route(&req.model)
47            .ok_or_else(|| anyhow!("No provider found for model: {}", req.model))?;
48        debug!(model = %req.model, provider = %provider.name(), "Routing request");
49        provider.generate(req).await
50    }
51
52    fn name(&self) -> &str {
53        "router"
54    }
55
56    fn supports_model(&self, model: &str) -> bool {
57        self.route(model).is_some()
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use crate::mock::MockProvider;
65    use devsper_core::{LlmMessage, LlmRole};
66
67    fn make_req(model: &str) -> LlmRequest {
68        LlmRequest {
69            model: model.to_string(),
70            messages: vec![LlmMessage {
71                role: LlmRole::User,
72                content: "test".to_string(),
73            }],
74            tools: vec![],
75            max_tokens: None,
76            temperature: None,
77            system: None,
78        }
79    }
80
81    #[tokio::test]
82    async fn routes_to_mock() {
83        let router = ModelRouter::new().with_provider(Arc::new(MockProvider::new("mocked")));
84
85        let res = router.generate(make_req("mock")).await.unwrap();
86        assert_eq!(res.content, "mocked");
87    }
88
89    #[tokio::test]
90    async fn unknown_model_returns_error() {
91        let router = ModelRouter::new();
92        let result = router.generate(make_req("unknown-model")).await;
93        assert!(result.is_err());
94    }
95
96    #[tokio::test]
97    async fn router_supports_model_delegates() {
98        let router = ModelRouter::new().with_provider(Arc::new(MockProvider::new("")));
99        assert!(router.supports_model("mock"));
100        assert!(!router.supports_model("claude-opus-4-6"));
101    }
102}