Skip to main content

hyperinfer_client/
router.rs

1use hyperinfer_core::types::{Config, Provider};
2use tracing::warn;
3
4pub struct Router {
5    #[allow(dead_code)]
6    rules: Vec<hyperinfer_core::types::RoutingRule>,
7    model_aliases: std::collections::HashMap<String, (String, Option<Provider>)>,
8    default_provider: Option<Provider>,
9}
10
11impl Router {
12    pub fn new(rules: Vec<hyperinfer_core::types::RoutingRule>) -> Self {
13        Self {
14            rules,
15            model_aliases: std::collections::HashMap::new(),
16            default_provider: None,
17        }
18    }
19
20    pub fn with_aliases(mut self, aliases: std::collections::HashMap<String, String>) -> Self {
21        self.model_aliases = aliases
22            .into_iter()
23            .filter_map(|(alias, target)| match Self::parse_target_model(&target) {
24                Ok((model, provider)) => Some((alias, (model, provider))),
25                Err(err) => {
26                    warn!("Invalid alias '{}': {}", alias, err);
27                    None
28                }
29            })
30            .collect();
31        self
32    }
33
34    pub fn with_default_provider(mut self, provider: Option<Provider>) -> Self {
35        self.default_provider = provider;
36        self
37    }
38
39    fn parse_target_model(target: &str) -> Result<(String, Option<Provider>), String> {
40        if let Some(slash_pos) = target.find('/') {
41            let provider_str = &target[..slash_pos];
42            let model = target[slash_pos + 1..].to_string();
43            let provider = match provider_str.to_lowercase().as_str() {
44                "openai" => Some(Provider::OpenAI),
45                "anthropic" => Some(Provider::Anthropic),
46                unknown => return Err(format!("Unknown provider: '{}'", unknown)),
47            };
48            Ok((model, provider))
49        } else {
50            Ok((target.to_string(), None))
51        }
52    }
53
54    fn infer_provider(model: &str) -> Option<Provider> {
55        if model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") {
56            Some(Provider::OpenAI)
57        } else if model.starts_with("claude-") {
58            Some(Provider::Anthropic)
59        } else {
60            None
61        }
62    }
63
64    fn resolve_provider(&self, explicit: Option<Provider>, model: &str) -> Option<Provider> {
65        if let Some(provider) = explicit {
66            return Some(provider);
67        }
68        Self::infer_provider(model).or(self.default_provider.clone())
69    }
70
71    pub fn resolve(&self, model: &str, _config: &Config) -> Option<(String, Provider)> {
72        if let Some((target_model, explicit_provider)) = self.model_aliases.get(model) {
73            let provider = self.resolve_provider(explicit_provider.clone(), target_model)?;
74            return Some((target_model.clone(), provider));
75        }
76
77        let provider = self.resolve_provider(None, model)?;
78        Some((model.to_string(), provider))
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use std::collections::HashMap;
86
87    fn create_test_config() -> Config {
88        Config {
89            api_keys: HashMap::new(),
90            routing_rules: vec![],
91            quotas: HashMap::new(),
92            model_aliases: HashMap::new(),
93            default_provider: None,
94        }
95    }
96
97    #[test]
98    fn test_router_new() {
99        let router = Router::new(vec![]);
100        assert_eq!(router.model_aliases.len(), 0);
101        assert_eq!(router.default_provider, None);
102    }
103
104    #[test]
105    fn test_router_with_default_provider() {
106        let router = Router::new(vec![]).with_default_provider(Some(Provider::OpenAI));
107        assert_eq!(router.default_provider, Some(Provider::OpenAI));
108    }
109
110    #[test]
111    fn test_parse_target_model_with_provider() {
112        let result = Router::parse_target_model("openai/gpt-4").unwrap();
113        assert_eq!(result.0, "gpt-4");
114        assert_eq!(result.1, Some(Provider::OpenAI));
115
116        let result = Router::parse_target_model("anthropic/claude-3").unwrap();
117        assert_eq!(result.0, "claude-3");
118        assert_eq!(result.1, Some(Provider::Anthropic));
119    }
120
121    #[test]
122    fn test_parse_target_model_without_provider() {
123        let result = Router::parse_target_model("gpt-4").unwrap();
124        assert_eq!(result.0, "gpt-4");
125        assert_eq!(result.1, None);
126    }
127
128    #[test]
129    fn test_parse_target_model_unknown_provider() {
130        let result = Router::parse_target_model("unknown/model");
131        assert!(result.is_err());
132        assert!(result.unwrap_err().contains("Unknown provider"));
133    }
134
135    #[test]
136    fn test_infer_provider_gpt() {
137        assert_eq!(Router::infer_provider("gpt-4"), Some(Provider::OpenAI));
138        assert_eq!(
139            Router::infer_provider("gpt-3.5-turbo"),
140            Some(Provider::OpenAI)
141        );
142    }
143
144    #[test]
145    fn test_infer_provider_o1() {
146        assert_eq!(Router::infer_provider("o1-preview"), Some(Provider::OpenAI));
147        assert_eq!(Router::infer_provider("o1-mini"), Some(Provider::OpenAI));
148    }
149
150    #[test]
151    fn test_infer_provider_o3() {
152        assert_eq!(Router::infer_provider("o3-mini"), Some(Provider::OpenAI));
153    }
154
155    #[test]
156    fn test_infer_provider_claude() {
157        assert_eq!(
158            Router::infer_provider("claude-3-opus"),
159            Some(Provider::Anthropic)
160        );
161        assert_eq!(
162            Router::infer_provider("claude-2"),
163            Some(Provider::Anthropic)
164        );
165    }
166
167    #[test]
168    fn test_infer_provider_unknown() {
169        assert_eq!(Router::infer_provider("unknown-model"), None);
170        assert_eq!(Router::infer_provider("llama-2"), None);
171    }
172
173    #[test]
174    fn test_with_aliases_valid() {
175        let mut aliases = HashMap::new();
176        aliases.insert("my-gpt".to_string(), "openai/gpt-4".to_string());
177        aliases.insert("my-claude".to_string(), "anthropic/claude-3".to_string());
178
179        let router = Router::new(vec![]).with_aliases(aliases);
180        assert_eq!(router.model_aliases.len(), 2);
181    }
182
183    #[test]
184    fn test_with_aliases_invalid_skipped() {
185        let mut aliases = HashMap::new();
186        aliases.insert("valid".to_string(), "openai/gpt-4".to_string());
187        aliases.insert("invalid".to_string(), "unknown/model".to_string());
188
189        let router = Router::new(vec![]).with_aliases(aliases);
190        assert_eq!(router.model_aliases.len(), 1);
191        assert!(router.model_aliases.contains_key("valid"));
192        assert!(!router.model_aliases.contains_key("invalid"));
193    }
194
195    #[test]
196    fn test_resolve_with_alias() {
197        let mut aliases = HashMap::new();
198        aliases.insert("my-model".to_string(), "openai/gpt-4".to_string());
199
200        let router = Router::new(vec![]).with_aliases(aliases);
201        let config = create_test_config();
202
203        let result = router.resolve("my-model", &config);
204        assert!(result.is_some());
205        let (model, provider) = result.unwrap();
206        assert_eq!(model, "gpt-4");
207        assert_eq!(provider, Provider::OpenAI);
208    }
209
210    #[test]
211    fn test_resolve_with_inference() {
212        let router = Router::new(vec![]);
213        let config = create_test_config();
214
215        let result = router.resolve("gpt-4", &config);
216        assert!(result.is_some());
217        let (model, provider) = result.unwrap();
218        assert_eq!(model, "gpt-4");
219        assert_eq!(provider, Provider::OpenAI);
220
221        let result = router.resolve("claude-3", &config);
222        assert!(result.is_some());
223        let (model, provider) = result.unwrap();
224        assert_eq!(model, "claude-3");
225        assert_eq!(provider, Provider::Anthropic);
226    }
227
228    #[test]
229    fn test_resolve_with_default_provider() {
230        let router = Router::new(vec![]).with_default_provider(Some(Provider::OpenAI));
231        let config = create_test_config();
232
233        let result = router.resolve("unknown-model", &config);
234        assert!(result.is_some());
235        let (model, provider) = result.unwrap();
236        assert_eq!(model, "unknown-model");
237        assert_eq!(provider, Provider::OpenAI);
238    }
239
240    #[test]
241    fn test_resolve_no_match() {
242        let router = Router::new(vec![]);
243        let config = create_test_config();
244
245        let result = router.resolve("unknown-model", &config);
246        assert!(result.is_none());
247    }
248
249    #[test]
250    fn test_resolve_alias_without_explicit_provider() {
251        let mut aliases = HashMap::new();
252        aliases.insert("my-gpt".to_string(), "gpt-4".to_string());
253
254        let router = Router::new(vec![]).with_aliases(aliases);
255        let config = create_test_config();
256
257        let result = router.resolve("my-gpt", &config);
258        assert!(result.is_some());
259        let (model, provider) = result.unwrap();
260        assert_eq!(model, "gpt-4");
261        assert_eq!(provider, Provider::OpenAI);
262    }
263
264    #[test]
265    fn test_resolve_alias_with_default_provider() {
266        let mut aliases = HashMap::new();
267        aliases.insert("my-model".to_string(), "custom-model".to_string());
268
269        let router = Router::new(vec![])
270            .with_aliases(aliases)
271            .with_default_provider(Some(Provider::Anthropic));
272        let config = create_test_config();
273
274        let result = router.resolve("my-model", &config);
275        assert!(result.is_some());
276        let (model, provider) = result.unwrap();
277        assert_eq!(model, "custom-model");
278        assert_eq!(provider, Provider::Anthropic);
279    }
280
281    #[test]
282    fn test_resolve_priority_explicit_over_inference() {
283        let mut aliases = HashMap::new();
284        // Map a gpt-like name to anthropic explicitly
285        aliases.insert("gpt-custom".to_string(), "anthropic/claude-3".to_string());
286
287        let router = Router::new(vec![]).with_aliases(aliases);
288        let config = create_test_config();
289
290        let result = router.resolve("gpt-custom", &config);
291        assert!(result.is_some());
292        let (model, provider) = result.unwrap();
293        assert_eq!(model, "claude-3");
294        assert_eq!(provider, Provider::Anthropic);
295    }
296}