Skip to main content

bitrouter_runtime/
router.rs

1use std::collections::HashMap;
2
3use bitrouter_config::{ApiProtocol, ProviderConfig};
4use bitrouter_core::{
5    errors::{BitrouterError, Result},
6    models::language::language_model::DynLanguageModel,
7    routers::{model_router::LanguageModelRouter, routing_table::RoutingTarget},
8};
9use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
10
11/// A model router backed by `reqwest` that instantiates concrete provider
12/// model objects on demand from [`ProviderConfig`] entries.
13pub struct Router {
14    client: reqwest::Client,
15    providers: HashMap<String, ProviderConfig>,
16}
17
18impl Router {
19    pub fn new(client: reqwest::Client, providers: HashMap<String, ProviderConfig>) -> Self {
20        Self { client, providers }
21    }
22
23    fn build_openai_config(&self, provider: &ProviderConfig) -> Result<OpenAiConfig> {
24        let api_key = provider.api_key.clone().unwrap_or_default();
25        let base_url = provider
26            .api_base
27            .clone()
28            .unwrap_or_else(|| "https://api.openai.com/v1".into());
29
30        let default_headers = parse_headers(provider.default_headers.as_ref())?;
31
32        Ok(OpenAiConfig {
33            api_key,
34            base_url,
35            organization: None,
36            project: None,
37            default_headers,
38        })
39    }
40
41    fn build_anthropic_config(&self, provider: &ProviderConfig) -> Result<AnthropicConfig> {
42        let api_key = provider.api_key.clone().unwrap_or_default();
43        let base_url = provider
44            .api_base
45            .clone()
46            .unwrap_or_else(|| "https://api.anthropic.com".into());
47
48        let default_headers = parse_headers(provider.default_headers.as_ref())?;
49
50        Ok(AnthropicConfig {
51            api_key,
52            base_url,
53            api_version: "2023-06-01".into(),
54            default_headers,
55        })
56    }
57
58    fn build_google_config(&self, provider: &ProviderConfig) -> Result<GoogleConfig> {
59        let api_key = provider.api_key.clone().unwrap_or_default();
60        let base_url = provider
61            .api_base
62            .clone()
63            .unwrap_or_else(|| "https://generativelanguage.googleapis.com".into());
64
65        let default_headers = parse_headers(provider.default_headers.as_ref())?;
66
67        Ok(GoogleConfig {
68            api_key,
69            base_url,
70            default_headers,
71        })
72    }
73}
74
75impl LanguageModelRouter for Router {
76    async fn route_model(&self, target: RoutingTarget) -> Result<Box<DynLanguageModel<'static>>> {
77        let provider = self.providers.get(&target.provider_name).ok_or_else(|| {
78            BitrouterError::invalid_request(
79                None,
80                format!("unknown provider: {}", target.provider_name),
81                None,
82            )
83        })?;
84
85        let protocol = provider.api_protocol.as_ref().ok_or_else(|| {
86            BitrouterError::invalid_request(
87                Some(&target.provider_name),
88                format!(
89                    "provider '{}' has no api_protocol configured",
90                    target.provider_name
91                ),
92                None,
93            )
94        })?;
95
96        match protocol {
97            ApiProtocol::Openai => {
98                let config = self.build_openai_config(provider)?;
99                let model = OpenAiChatCompletionsModel::with_client(
100                    target.model_id,
101                    self.client.clone(),
102                    config,
103                );
104                Ok(DynLanguageModel::new_box(model))
105            }
106            ApiProtocol::Anthropic => {
107                let config = self.build_anthropic_config(provider)?;
108                let model = AnthropicMessagesModel::with_client(
109                    target.model_id,
110                    self.client.clone(),
111                    config,
112                );
113                Ok(DynLanguageModel::new_box(model))
114            }
115            ApiProtocol::Google => {
116                let config = self.build_google_config(provider)?;
117                let model = GoogleGenerativeAiModel::with_client(
118                    target.model_id,
119                    self.client.clone(),
120                    config,
121                );
122                Ok(DynLanguageModel::new_box(model))
123            }
124        }
125    }
126}
127
128fn parse_headers(headers: Option<&HashMap<String, String>>) -> Result<HeaderMap> {
129    let mut map = HeaderMap::new();
130    if let Some(h) = headers {
131        for (k, v) in h {
132            let name = HeaderName::from_bytes(k.as_bytes()).map_err(|e| {
133                BitrouterError::invalid_request(
134                    None,
135                    format!("invalid header name '{k}': {e}"),
136                    None,
137                )
138            })?;
139            let value = HeaderValue::from_str(v).map_err(|e| {
140                BitrouterError::invalid_request(
141                    None,
142                    format!("invalid header value for '{k}': {e}"),
143                    None,
144                )
145            })?;
146            map.insert(name, value);
147        }
148    }
149    Ok(map)
150}
151
152// Re-export provider types under short aliases for readability.
153use bitrouter_anthropic::messages::provider::{AnthropicConfig, AnthropicMessagesModel};
154use bitrouter_google::generate_content::provider::{GoogleConfig, GoogleGenerativeAiModel};
155use bitrouter_openai::chat::provider::{OpenAiChatCompletionsModel, OpenAiConfig};