1use hyperinfer_core::types::{Config, Provider};
2use tracing::warn;
3
4pub struct Router {
5 #[allow(dead_code)]
6 rules: Vec<hyperinfer_core::types::RoutingRule>,
7 model_aliases: std::collections::HashMap<String, (String, Option<Provider>)>,
8 default_provider: Option<Provider>,
9}
10
11impl Router {
12 pub fn new(rules: Vec<hyperinfer_core::types::RoutingRule>) -> Self {
13 Self {
14 rules,
15 model_aliases: std::collections::HashMap::new(),
16 default_provider: None,
17 }
18 }
19
20 pub fn with_aliases(mut self, aliases: std::collections::HashMap<String, String>) -> Self {
21 self.model_aliases = aliases
22 .into_iter()
23 .filter_map(|(alias, target)| match Self::parse_target_model(&target) {
24 Ok((model, provider)) => Some((alias, (model, provider))),
25 Err(err) => {
26 warn!("Invalid alias '{}': {}", alias, err);
27 None
28 }
29 })
30 .collect();
31 self
32 }
33
34 pub fn with_default_provider(mut self, provider: Option<Provider>) -> Self {
35 self.default_provider = provider;
36 self
37 }
38
39 fn parse_target_model(target: &str) -> Result<(String, Option<Provider>), String> {
40 if let Some(slash_pos) = target.find('/') {
41 let provider_str = &target[..slash_pos];
42 let model = target[slash_pos + 1..].to_string();
43 let provider = match provider_str.to_lowercase().as_str() {
44 "openai" => Some(Provider::OpenAI),
45 "anthropic" => Some(Provider::Anthropic),
46 unknown => return Err(format!("Unknown provider: '{}'", unknown)),
47 };
48 Ok((model, provider))
49 } else {
50 Ok((target.to_string(), None))
51 }
52 }
53
54 fn infer_provider(model: &str) -> Option<Provider> {
55 if model.starts_with("gpt-") || model.starts_with("o1-") || model.starts_with("o3-") {
56 Some(Provider::OpenAI)
57 } else if model.starts_with("claude-") {
58 Some(Provider::Anthropic)
59 } else {
60 None
61 }
62 }
63
64 fn resolve_provider(&self, explicit: Option<Provider>, model: &str) -> Option<Provider> {
65 if let Some(provider) = explicit {
66 return Some(provider);
67 }
68 Self::infer_provider(model).or(self.default_provider.clone())
69 }
70
71 pub fn resolve(&self, model: &str, _config: &Config) -> Option<(String, Provider)> {
72 if let Some((target_model, explicit_provider)) = self.model_aliases.get(model) {
73 let provider = self.resolve_provider(explicit_provider.clone(), target_model)?;
74 return Some((target_model.clone(), provider));
75 }
76
77 let provider = self.resolve_provider(None, model)?;
78 Some((model.to_string(), provider))
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use std::collections::HashMap;
86
87 fn create_test_config() -> Config {
88 Config {
89 api_keys: HashMap::new(),
90 routing_rules: vec![],
91 quotas: HashMap::new(),
92 model_aliases: HashMap::new(),
93 default_provider: None,
94 }
95 }
96
97 #[test]
98 fn test_router_new() {
99 let router = Router::new(vec![]);
100 assert_eq!(router.model_aliases.len(), 0);
101 assert_eq!(router.default_provider, None);
102 }
103
104 #[test]
105 fn test_router_with_default_provider() {
106 let router = Router::new(vec![]).with_default_provider(Some(Provider::OpenAI));
107 assert_eq!(router.default_provider, Some(Provider::OpenAI));
108 }
109
110 #[test]
111 fn test_parse_target_model_with_provider() {
112 let result = Router::parse_target_model("openai/gpt-4").unwrap();
113 assert_eq!(result.0, "gpt-4");
114 assert_eq!(result.1, Some(Provider::OpenAI));
115
116 let result = Router::parse_target_model("anthropic/claude-3").unwrap();
117 assert_eq!(result.0, "claude-3");
118 assert_eq!(result.1, Some(Provider::Anthropic));
119 }
120
121 #[test]
122 fn test_parse_target_model_without_provider() {
123 let result = Router::parse_target_model("gpt-4").unwrap();
124 assert_eq!(result.0, "gpt-4");
125 assert_eq!(result.1, None);
126 }
127
128 #[test]
129 fn test_parse_target_model_unknown_provider() {
130 let result = Router::parse_target_model("unknown/model");
131 assert!(result.is_err());
132 assert!(result.unwrap_err().contains("Unknown provider"));
133 }
134
135 #[test]
136 fn test_infer_provider_gpt() {
137 assert_eq!(Router::infer_provider("gpt-4"), Some(Provider::OpenAI));
138 assert_eq!(
139 Router::infer_provider("gpt-3.5-turbo"),
140 Some(Provider::OpenAI)
141 );
142 }
143
144 #[test]
145 fn test_infer_provider_o1() {
146 assert_eq!(Router::infer_provider("o1-preview"), Some(Provider::OpenAI));
147 assert_eq!(Router::infer_provider("o1-mini"), Some(Provider::OpenAI));
148 }
149
150 #[test]
151 fn test_infer_provider_o3() {
152 assert_eq!(Router::infer_provider("o3-mini"), Some(Provider::OpenAI));
153 }
154
155 #[test]
156 fn test_infer_provider_claude() {
157 assert_eq!(
158 Router::infer_provider("claude-3-opus"),
159 Some(Provider::Anthropic)
160 );
161 assert_eq!(
162 Router::infer_provider("claude-2"),
163 Some(Provider::Anthropic)
164 );
165 }
166
167 #[test]
168 fn test_infer_provider_unknown() {
169 assert_eq!(Router::infer_provider("unknown-model"), None);
170 assert_eq!(Router::infer_provider("llama-2"), None);
171 }
172
173 #[test]
174 fn test_with_aliases_valid() {
175 let mut aliases = HashMap::new();
176 aliases.insert("my-gpt".to_string(), "openai/gpt-4".to_string());
177 aliases.insert("my-claude".to_string(), "anthropic/claude-3".to_string());
178
179 let router = Router::new(vec![]).with_aliases(aliases);
180 assert_eq!(router.model_aliases.len(), 2);
181 }
182
183 #[test]
184 fn test_with_aliases_invalid_skipped() {
185 let mut aliases = HashMap::new();
186 aliases.insert("valid".to_string(), "openai/gpt-4".to_string());
187 aliases.insert("invalid".to_string(), "unknown/model".to_string());
188
189 let router = Router::new(vec![]).with_aliases(aliases);
190 assert_eq!(router.model_aliases.len(), 1);
191 assert!(router.model_aliases.contains_key("valid"));
192 assert!(!router.model_aliases.contains_key("invalid"));
193 }
194
195 #[test]
196 fn test_resolve_with_alias() {
197 let mut aliases = HashMap::new();
198 aliases.insert("my-model".to_string(), "openai/gpt-4".to_string());
199
200 let router = Router::new(vec![]).with_aliases(aliases);
201 let config = create_test_config();
202
203 let result = router.resolve("my-model", &config);
204 assert!(result.is_some());
205 let (model, provider) = result.unwrap();
206 assert_eq!(model, "gpt-4");
207 assert_eq!(provider, Provider::OpenAI);
208 }
209
210 #[test]
211 fn test_resolve_with_inference() {
212 let router = Router::new(vec![]);
213 let config = create_test_config();
214
215 let result = router.resolve("gpt-4", &config);
216 assert!(result.is_some());
217 let (model, provider) = result.unwrap();
218 assert_eq!(model, "gpt-4");
219 assert_eq!(provider, Provider::OpenAI);
220
221 let result = router.resolve("claude-3", &config);
222 assert!(result.is_some());
223 let (model, provider) = result.unwrap();
224 assert_eq!(model, "claude-3");
225 assert_eq!(provider, Provider::Anthropic);
226 }
227
228 #[test]
229 fn test_resolve_with_default_provider() {
230 let router = Router::new(vec![]).with_default_provider(Some(Provider::OpenAI));
231 let config = create_test_config();
232
233 let result = router.resolve("unknown-model", &config);
234 assert!(result.is_some());
235 let (model, provider) = result.unwrap();
236 assert_eq!(model, "unknown-model");
237 assert_eq!(provider, Provider::OpenAI);
238 }
239
240 #[test]
241 fn test_resolve_no_match() {
242 let router = Router::new(vec![]);
243 let config = create_test_config();
244
245 let result = router.resolve("unknown-model", &config);
246 assert!(result.is_none());
247 }
248
249 #[test]
250 fn test_resolve_alias_without_explicit_provider() {
251 let mut aliases = HashMap::new();
252 aliases.insert("my-gpt".to_string(), "gpt-4".to_string());
253
254 let router = Router::new(vec![]).with_aliases(aliases);
255 let config = create_test_config();
256
257 let result = router.resolve("my-gpt", &config);
258 assert!(result.is_some());
259 let (model, provider) = result.unwrap();
260 assert_eq!(model, "gpt-4");
261 assert_eq!(provider, Provider::OpenAI);
262 }
263
264 #[test]
265 fn test_resolve_alias_with_default_provider() {
266 let mut aliases = HashMap::new();
267 aliases.insert("my-model".to_string(), "custom-model".to_string());
268
269 let router = Router::new(vec![])
270 .with_aliases(aliases)
271 .with_default_provider(Some(Provider::Anthropic));
272 let config = create_test_config();
273
274 let result = router.resolve("my-model", &config);
275 assert!(result.is_some());
276 let (model, provider) = result.unwrap();
277 assert_eq!(model, "custom-model");
278 assert_eq!(provider, Provider::Anthropic);
279 }
280
281 #[test]
282 fn test_resolve_priority_explicit_over_inference() {
283 let mut aliases = HashMap::new();
284 aliases.insert("gpt-custom".to_string(), "anthropic/claude-3".to_string());
286
287 let router = Router::new(vec![]).with_aliases(aliases);
288 let config = create_test_config();
289
290 let result = router.resolve("gpt-custom", &config);
291 assert!(result.is_some());
292 let (model, provider) = result.unwrap();
293 assert_eq!(model, "claude-3");
294 assert_eq!(provider, Provider::Anthropic);
295 }
296}