Skip to main content

systemprompt_models/profile/
gateway.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, Default, Serialize, Deserialize)]
5pub struct GatewayConfig {
6    #[serde(default)]
7    pub enabled: bool,
8    #[serde(default)]
9    pub routes: Vec<GatewayRoute>,
10}
11
12impl GatewayConfig {
13    pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
14        self.routes.iter().find(|route| route.matches(model))
15    }
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct GatewayRoute {
20    pub model_pattern: String,
21    pub provider: String,
22    pub endpoint: String,
23    pub api_key_secret: String,
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub upstream_model: Option<String>,
26    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
27    pub extra_headers: HashMap<String, String>,
28}
29
30impl GatewayRoute {
31    pub fn matches(&self, model: &str) -> bool {
32        match_pattern(&self.model_pattern, model)
33    }
34
35    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
36        self.upstream_model.as_deref().unwrap_or(requested)
37    }
38}
39
40fn match_pattern(pattern: &str, model: &str) -> bool {
41    if pattern == "*" {
42        return true;
43    }
44    if let Some(prefix) = pattern.strip_suffix('*') {
45        return model.starts_with(prefix);
46    }
47    if let Some(suffix) = pattern.strip_prefix('*') {
48        return model.ends_with(suffix);
49    }
50    pattern == model
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn exact_pattern_matches() {
59        assert!(match_pattern("claude-sonnet-4-6", "claude-sonnet-4-6"));
60        assert!(!match_pattern("claude-sonnet-4-6", "claude-opus-4-7"));
61    }
62
63    #[test]
64    fn prefix_wildcard_matches() {
65        assert!(match_pattern("claude-*", "claude-sonnet-4-6"));
66        assert!(!match_pattern("claude-*", "moonshot-v1-8k"));
67    }
68
69    #[test]
70    fn catch_all_matches() {
71        assert!(match_pattern("*", "any-model-name"));
72    }
73
74    #[test]
75    fn route_finds_matching_model() {
76        let config = GatewayConfig {
77            enabled: true,
78            routes: vec![GatewayRoute {
79                model_pattern: "kimi-*".to_string(),
80                provider: "moonshot".to_string(),
81                endpoint: "https://api.moonshot.ai/v1".to_string(),
82                api_key_secret: "moonshot".to_string(),
83                upstream_model: Some("moonshot-v1-32k".to_string()),
84                extra_headers: HashMap::new(),
85            }],
86        };
87        let route = config.find_route("kimi-latest");
88        assert!(route.is_some(), "route must match");
89        let route = route.unwrap_or_else(|| unreachable!());
90        assert_eq!(route.provider, "moonshot");
91        assert_eq!(
92            route.effective_upstream_model("kimi-latest"),
93            "moonshot-v1-32k"
94        );
95    }
96}