1use grapsus_config::{InferenceProvider, ModelRoutingConfig, ModelUpstreamMapping};
7
8#[derive(Debug, Clone)]
10pub struct ModelRoutingResult {
11 pub upstream: String,
13 pub provider: Option<InferenceProvider>,
15 pub is_default: bool,
17}
18
19pub fn find_upstream_for_model(
31 config: &ModelRoutingConfig,
32 model: &str,
33) -> Option<ModelRoutingResult> {
34 for mapping in &config.mappings {
36 if matches_model_pattern(&mapping.model_pattern, model) {
37 return Some(ModelRoutingResult {
38 upstream: mapping.upstream.clone(),
39 provider: mapping.provider,
40 is_default: false,
41 });
42 }
43 }
44
45 config
47 .default_upstream
48 .as_ref()
49 .map(|upstream| ModelRoutingResult {
50 upstream: upstream.clone(),
51 provider: None,
52 is_default: true,
53 })
54}
55
56fn matches_model_pattern(pattern: &str, model: &str) -> bool {
62 if pattern == model {
64 return true;
65 }
66
67 glob_match(pattern, model)
69}
70
71fn glob_match(pattern: &str, text: &str) -> bool {
82 let pattern_chars: Vec<char> = pattern.chars().collect();
83 let text_chars: Vec<char> = text.chars().collect();
84
85 glob_match_recursive(&pattern_chars, &text_chars, 0, 0)
86}
87
88fn glob_match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
89 if p_idx >= pattern.len() {
91 return t_idx >= text.len();
92 }
93
94 if pattern[p_idx] == '*' {
96 for i in t_idx..=text.len() {
98 if glob_match_recursive(pattern, text, p_idx + 1, i) {
99 return true;
100 }
101 }
102 return false;
103 }
104
105 if t_idx < text.len() && pattern[p_idx] == text[t_idx] {
107 return glob_match_recursive(pattern, text, p_idx + 1, t_idx + 1);
108 }
109
110 false
111}
112
113pub fn extract_model_from_headers(headers: &http::HeaderMap) -> Option<String> {
122 let header_names = ["x-model", "x-model-id"];
124
125 for name in header_names {
126 if let Some(value) = headers.get(name) {
127 if let Ok(model) = value.to_str() {
128 let model = model.trim();
129 if !model.is_empty() {
130 return Some(model.to_string());
131 }
132 }
133 }
134 }
135
136 None
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142
143 fn create_test_config() -> ModelRoutingConfig {
144 ModelRoutingConfig {
145 mappings: vec![
146 ModelUpstreamMapping {
147 model_pattern: "gpt-4".to_string(),
148 upstream: "openai-gpt4".to_string(),
149 provider: Some(InferenceProvider::OpenAi),
150 },
151 ModelUpstreamMapping {
152 model_pattern: "gpt-4*".to_string(),
153 upstream: "openai-primary".to_string(),
154 provider: Some(InferenceProvider::OpenAi),
155 },
156 ModelUpstreamMapping {
157 model_pattern: "gpt-3.5*".to_string(),
158 upstream: "openai-secondary".to_string(),
159 provider: Some(InferenceProvider::OpenAi),
160 },
161 ModelUpstreamMapping {
162 model_pattern: "claude-*".to_string(),
163 upstream: "anthropic-backend".to_string(),
164 provider: Some(InferenceProvider::Anthropic),
165 },
166 ModelUpstreamMapping {
167 model_pattern: "llama-*".to_string(),
168 upstream: "local-gpu".to_string(),
169 provider: Some(InferenceProvider::Generic),
170 },
171 ],
172 default_upstream: Some("openai-primary".to_string()),
173 }
174 }
175
176 #[test]
177 fn test_exact_match() {
178 let config = create_test_config();
179
180 let result = find_upstream_for_model(&config, "gpt-4").unwrap();
182 assert_eq!(result.upstream, "openai-gpt4");
183 assert!(!result.is_default);
184 }
185
186 #[test]
187 fn test_glob_suffix_match() {
188 let config = create_test_config();
189
190 let result = find_upstream_for_model(&config, "gpt-4-turbo").unwrap();
192 assert_eq!(result.upstream, "openai-primary");
193 assert!(!result.is_default);
194
195 let result = find_upstream_for_model(&config, "gpt-4o").unwrap();
197 assert_eq!(result.upstream, "openai-primary");
198 }
199
200 #[test]
201 fn test_claude_models() {
202 let config = create_test_config();
203
204 let result = find_upstream_for_model(&config, "claude-3-opus").unwrap();
206 assert_eq!(result.upstream, "anthropic-backend");
207 assert_eq!(result.provider, Some(InferenceProvider::Anthropic));
208
209 let result = find_upstream_for_model(&config, "claude-3.5-sonnet").unwrap();
210 assert_eq!(result.upstream, "anthropic-backend");
211 }
212
213 #[test]
214 fn test_default_upstream() {
215 let config = create_test_config();
216
217 let result = find_upstream_for_model(&config, "unknown-model").unwrap();
219 assert_eq!(result.upstream, "openai-primary");
220 assert!(result.is_default);
221 assert!(result.provider.is_none());
222 }
223
224 #[test]
225 fn test_no_match_no_default() {
226 let config = ModelRoutingConfig {
227 mappings: vec![ModelUpstreamMapping {
228 model_pattern: "gpt-4".to_string(),
229 upstream: "openai".to_string(),
230 provider: None,
231 }],
232 default_upstream: None,
233 };
234
235 let result = find_upstream_for_model(&config, "claude-3-opus");
237 assert!(result.is_none());
238 }
239
240 #[test]
241 fn test_first_match_wins() {
242 let config = create_test_config();
243
244 let result = find_upstream_for_model(&config, "gpt-4").unwrap();
246 assert_eq!(result.upstream, "openai-gpt4");
247 }
248
249 #[test]
250 fn test_glob_match_patterns() {
251 assert!(glob_match("gpt-4*", "gpt-4"));
253 assert!(glob_match("gpt-4*", "gpt-4-turbo"));
254 assert!(glob_match("gpt-4*", "gpt-4o"));
255 assert!(!glob_match("gpt-4*", "gpt-3.5-turbo"));
256
257 assert!(glob_match("*-turbo", "gpt-4-turbo"));
258 assert!(glob_match("*-turbo", "gpt-3.5-turbo"));
259 assert!(!glob_match("*-turbo", "gpt-4"));
260
261 assert!(glob_match("claude-*-sonnet", "claude-3-sonnet"));
262 assert!(glob_match("claude-*-sonnet", "claude-3.5-sonnet"));
263 assert!(!glob_match("claude-*-sonnet", "claude-3-opus"));
264
265 assert!(glob_match("*", "anything"));
266 assert!(glob_match("*", ""));
267 }
268
269 #[test]
270 fn test_extract_model_from_headers() {
271 let mut headers = http::HeaderMap::new();
272
273 assert!(extract_model_from_headers(&headers).is_none());
275
276 headers.insert("x-model", "gpt-4".parse().unwrap());
278 assert_eq!(
279 extract_model_from_headers(&headers),
280 Some("gpt-4".to_string())
281 );
282
283 headers.clear();
285 headers.insert("x-model-id", "claude-3-opus".parse().unwrap());
286 assert_eq!(
287 extract_model_from_headers(&headers),
288 Some("claude-3-opus".to_string())
289 );
290
291 headers.insert("x-model", "gpt-4".parse().unwrap());
293 assert_eq!(
294 extract_model_from_headers(&headers),
295 Some("gpt-4".to_string())
296 );
297
298 headers.clear();
300 headers.insert("x-model", "".parse().unwrap());
301 assert!(extract_model_from_headers(&headers).is_none());
302
303 headers.clear();
305 headers.insert("x-model", " ".parse().unwrap());
306 assert!(extract_model_from_headers(&headers).is_none());
307 }
308}