Skip to main content

grapsus_proxy/proxy/
model_routing.rs

1//! Model-based routing for inference requests.
2//!
3//! Routes inference requests to different upstreams based on the model name.
4//! Supports glob patterns for flexible model matching (e.g., `gpt-4*`, `claude-*`).
5
6use grapsus_config::{InferenceProvider, ModelRoutingConfig, ModelUpstreamMapping};
7
8/// Result of model-based routing lookup.
9#[derive(Debug, Clone)]
10pub struct ModelRoutingResult {
11    /// Target upstream for this model
12    pub upstream: String,
13    /// Provider override if specified (for cross-provider routing)
14    pub provider: Option<InferenceProvider>,
15    /// Whether this was a default routing (no specific mapping matched)
16    pub is_default: bool,
17}
18
19/// Find the upstream for a given model name.
20///
21/// Checks mappings in order (first match wins). If no mapping matches,
22/// returns the default upstream if configured, otherwise None.
23///
24/// # Arguments
25/// * `config` - Model routing configuration
26/// * `model` - Model name to route
27///
28/// # Returns
29/// `Some(ModelRoutingResult)` if a matching upstream was found, `None` otherwise.
30pub fn find_upstream_for_model(
31    config: &ModelRoutingConfig,
32    model: &str,
33) -> Option<ModelRoutingResult> {
34    // Check mappings in order (first match wins)
35    for mapping in &config.mappings {
36        if matches_model_pattern(&mapping.model_pattern, model) {
37            return Some(ModelRoutingResult {
38                upstream: mapping.upstream.clone(),
39                provider: mapping.provider,
40                is_default: false,
41            });
42        }
43    }
44
45    // No mapping matched - use default if configured
46    config
47        .default_upstream
48        .as_ref()
49        .map(|upstream| ModelRoutingResult {
50            upstream: upstream.clone(),
51            provider: None,
52            is_default: true,
53        })
54}
55
56/// Check if a model name matches a pattern.
57///
58/// Supports:
59/// - Exact match: `"gpt-4"` matches `"gpt-4"`
60/// - Glob patterns with `*` wildcard: `"gpt-4*"` matches `"gpt-4"`, `"gpt-4-turbo"`, `"gpt-4o"`
61fn matches_model_pattern(pattern: &str, model: &str) -> bool {
62    // Exact match (fast path)
63    if pattern == model {
64        return true;
65    }
66
67    // Glob pattern matching
68    glob_match(pattern, model)
69}
70
71/// Simple glob pattern matching for model names.
72///
73/// Supports:
74/// - `*` matches any sequence of characters (including empty)
75/// - All other characters match literally
76///
77/// # Examples
78/// - `gpt-4*` matches `gpt-4`, `gpt-4-turbo`, `gpt-4o`
79/// - `claude-*-sonnet` matches `claude-3-sonnet`, `claude-3.5-sonnet`
80/// - `*-turbo` matches `gpt-4-turbo`, `gpt-3.5-turbo`
81fn glob_match(pattern: &str, text: &str) -> bool {
82    let pattern_chars: Vec<char> = pattern.chars().collect();
83    let text_chars: Vec<char> = text.chars().collect();
84
85    glob_match_recursive(&pattern_chars, &text_chars, 0, 0)
86}
87
88fn glob_match_recursive(pattern: &[char], text: &[char], p_idx: usize, t_idx: usize) -> bool {
89    // End of pattern
90    if p_idx >= pattern.len() {
91        return t_idx >= text.len();
92    }
93
94    // Wildcard match
95    if pattern[p_idx] == '*' {
96        // Try matching zero or more characters
97        for i in t_idx..=text.len() {
98            if glob_match_recursive(pattern, text, p_idx + 1, i) {
99                return true;
100            }
101        }
102        return false;
103    }
104
105    // Exact character match
106    if t_idx < text.len() && pattern[p_idx] == text[t_idx] {
107        return glob_match_recursive(pattern, text, p_idx + 1, t_idx + 1);
108    }
109
110    false
111}
112
113/// Extract model name from request headers.
114///
115/// Checks common model headers in order of precedence:
116/// 1. `x-model` - Explicit model header
117/// 2. `x-model-id` - Alternative model header
118///
119/// # Returns
120/// `Some(model_name)` if found in headers, `None` otherwise.
121pub fn extract_model_from_headers(headers: &http::HeaderMap) -> Option<String> {
122    // Check common model headers
123    let header_names = ["x-model", "x-model-id"];
124
125    for name in header_names {
126        if let Some(value) = headers.get(name) {
127            if let Ok(model) = value.to_str() {
128                let model = model.trim();
129                if !model.is_empty() {
130                    return Some(model.to_string());
131                }
132            }
133        }
134    }
135
136    None
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    fn create_test_config() -> ModelRoutingConfig {
144        ModelRoutingConfig {
145            mappings: vec![
146                ModelUpstreamMapping {
147                    model_pattern: "gpt-4".to_string(),
148                    upstream: "openai-gpt4".to_string(),
149                    provider: Some(InferenceProvider::OpenAi),
150                },
151                ModelUpstreamMapping {
152                    model_pattern: "gpt-4*".to_string(),
153                    upstream: "openai-primary".to_string(),
154                    provider: Some(InferenceProvider::OpenAi),
155                },
156                ModelUpstreamMapping {
157                    model_pattern: "gpt-3.5*".to_string(),
158                    upstream: "openai-secondary".to_string(),
159                    provider: Some(InferenceProvider::OpenAi),
160                },
161                ModelUpstreamMapping {
162                    model_pattern: "claude-*".to_string(),
163                    upstream: "anthropic-backend".to_string(),
164                    provider: Some(InferenceProvider::Anthropic),
165                },
166                ModelUpstreamMapping {
167                    model_pattern: "llama-*".to_string(),
168                    upstream: "local-gpu".to_string(),
169                    provider: Some(InferenceProvider::Generic),
170                },
171            ],
172            default_upstream: Some("openai-primary".to_string()),
173        }
174    }
175
176    #[test]
177    fn test_exact_match() {
178        let config = create_test_config();
179
180        // Exact match for "gpt-4" should match first (more specific)
181        let result = find_upstream_for_model(&config, "gpt-4").unwrap();
182        assert_eq!(result.upstream, "openai-gpt4");
183        assert!(!result.is_default);
184    }
185
186    #[test]
187    fn test_glob_suffix_match() {
188        let config = create_test_config();
189
190        // "gpt-4-turbo" should match "gpt-4*" pattern
191        let result = find_upstream_for_model(&config, "gpt-4-turbo").unwrap();
192        assert_eq!(result.upstream, "openai-primary");
193        assert!(!result.is_default);
194
195        // "gpt-4o" should match "gpt-4*" pattern
196        let result = find_upstream_for_model(&config, "gpt-4o").unwrap();
197        assert_eq!(result.upstream, "openai-primary");
198    }
199
200    #[test]
201    fn test_claude_models() {
202        let config = create_test_config();
203
204        // All claude models should route to anthropic
205        let result = find_upstream_for_model(&config, "claude-3-opus").unwrap();
206        assert_eq!(result.upstream, "anthropic-backend");
207        assert_eq!(result.provider, Some(InferenceProvider::Anthropic));
208
209        let result = find_upstream_for_model(&config, "claude-3.5-sonnet").unwrap();
210        assert_eq!(result.upstream, "anthropic-backend");
211    }
212
213    #[test]
214    fn test_default_upstream() {
215        let config = create_test_config();
216
217        // Unknown model should fall back to default
218        let result = find_upstream_for_model(&config, "unknown-model").unwrap();
219        assert_eq!(result.upstream, "openai-primary");
220        assert!(result.is_default);
221        assert!(result.provider.is_none());
222    }
223
224    #[test]
225    fn test_no_match_no_default() {
226        let config = ModelRoutingConfig {
227            mappings: vec![ModelUpstreamMapping {
228                model_pattern: "gpt-4".to_string(),
229                upstream: "openai".to_string(),
230                provider: None,
231            }],
232            default_upstream: None,
233        };
234
235        // No match and no default should return None
236        let result = find_upstream_for_model(&config, "claude-3-opus");
237        assert!(result.is_none());
238    }
239
240    #[test]
241    fn test_first_match_wins() {
242        let config = create_test_config();
243
244        // "gpt-4" exact match should win over "gpt-4*" glob
245        let result = find_upstream_for_model(&config, "gpt-4").unwrap();
246        assert_eq!(result.upstream, "openai-gpt4");
247    }
248
249    #[test]
250    fn test_glob_match_patterns() {
251        // Test various glob patterns
252        assert!(glob_match("gpt-4*", "gpt-4"));
253        assert!(glob_match("gpt-4*", "gpt-4-turbo"));
254        assert!(glob_match("gpt-4*", "gpt-4o"));
255        assert!(!glob_match("gpt-4*", "gpt-3.5-turbo"));
256
257        assert!(glob_match("*-turbo", "gpt-4-turbo"));
258        assert!(glob_match("*-turbo", "gpt-3.5-turbo"));
259        assert!(!glob_match("*-turbo", "gpt-4"));
260
261        assert!(glob_match("claude-*-sonnet", "claude-3-sonnet"));
262        assert!(glob_match("claude-*-sonnet", "claude-3.5-sonnet"));
263        assert!(!glob_match("claude-*-sonnet", "claude-3-opus"));
264
265        assert!(glob_match("*", "anything"));
266        assert!(glob_match("*", ""));
267    }
268
269    #[test]
270    fn test_extract_model_from_headers() {
271        let mut headers = http::HeaderMap::new();
272
273        // No headers
274        assert!(extract_model_from_headers(&headers).is_none());
275
276        // x-model header
277        headers.insert("x-model", "gpt-4".parse().unwrap());
278        assert_eq!(
279            extract_model_from_headers(&headers),
280            Some("gpt-4".to_string())
281        );
282
283        // x-model-id header (lower precedence)
284        headers.clear();
285        headers.insert("x-model-id", "claude-3-opus".parse().unwrap());
286        assert_eq!(
287            extract_model_from_headers(&headers),
288            Some("claude-3-opus".to_string())
289        );
290
291        // Both headers - x-model takes precedence
292        headers.insert("x-model", "gpt-4".parse().unwrap());
293        assert_eq!(
294            extract_model_from_headers(&headers),
295            Some("gpt-4".to_string())
296        );
297
298        // Empty header value
299        headers.clear();
300        headers.insert("x-model", "".parse().unwrap());
301        assert!(extract_model_from_headers(&headers).is_none());
302
303        // Whitespace-only header value
304        headers.clear();
305        headers.insert("x-model", "   ".parse().unwrap());
306        assert!(extract_model_from_headers(&headers).is_none());
307    }
308}