Skip to main content

systemprompt_models/profile/
gateway.rs

1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum GatewayProfileError {
11    #[error("Failed to read gateway catalog {path}: {source}")]
12    CatalogRead {
13        path: PathBuf,
14        #[source]
15        source: std::io::Error,
16    },
17
18    #[error("Failed to parse gateway catalog {path}: {source}")]
19    CatalogParse {
20        path: PathBuf,
21        #[source]
22        source: serde_yaml::Error,
23    },
24
25    #[error("Invalid gateway catalog {path}: {source}")]
26    CatalogInvalid {
27        path: PathBuf,
28        #[source]
29        source: Box<Self>,
30    },
31
32    #[error("gateway catalog model has empty id")]
33    ModelEmptyId,
34
35    #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
36    UnknownProvider { model: String, provider: String },
37
38    #[error("gateway catalog provider has empty name")]
39    ProviderEmptyName,
40
41    #[error("gateway catalog provider '{name}' has empty endpoint")]
42    ProviderEmptyEndpoint { name: String },
43}
44
45pub type GatewayResult<T> = Result<T, GatewayProfileError>;
46
47#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
48#[serde(deny_unknown_fields)]
49pub struct GatewayConfig {
50    #[serde(default)]
51    pub enabled: bool,
52    #[serde(default)]
53    pub routes: Vec<GatewayRoute>,
54    #[serde(default, skip_serializing_if = "Option::is_none")]
55    pub catalog_path: Option<PathBuf>,
56    #[serde(default, skip)]
57    pub catalog: Option<GatewayCatalog>,
58    #[serde(default = "default_auth_scheme")]
59    pub auth_scheme: String,
60    #[serde(default = "default_inference_path_prefix")]
61    pub inference_path_prefix: String,
62}
63
64impl Default for GatewayConfig {
65    fn default() -> Self {
66        Self {
67            enabled: false,
68            routes: Vec::new(),
69            catalog_path: None,
70            catalog: None,
71            auth_scheme: default_auth_scheme(),
72            inference_path_prefix: default_inference_path_prefix(),
73        }
74    }
75}
76
77fn default_auth_scheme() -> String {
78    "bearer".to_string()
79}
80
81fn default_inference_path_prefix() -> String {
82    "/v1".to_string()
83}
84
85impl GatewayConfig {
86    pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
87        self.routes.iter().find(|route| route.matches(model))
88    }
89}
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
92#[serde(deny_unknown_fields)]
93pub struct GatewayCatalog {
94    #[serde(default)]
95    pub providers: Vec<GatewayProvider>,
96    #[serde(default)]
97    pub models: Vec<GatewayModel>,
98}
99
100impl GatewayCatalog {
101    pub fn validate(&self) -> GatewayResult<()> {
102        for model in &self.models {
103            if model.id.is_empty() {
104                return Err(GatewayProfileError::ModelEmptyId);
105            }
106            if !self.providers.iter().any(|p| p.name == model.provider) {
107                return Err(GatewayProfileError::UnknownProvider {
108                    model: model.id.clone(),
109                    provider: model.provider.clone(),
110                });
111            }
112        }
113        for provider in &self.providers {
114            if provider.name.is_empty() {
115                return Err(GatewayProfileError::ProviderEmptyName);
116            }
117            if provider.endpoint.is_empty() {
118                return Err(GatewayProfileError::ProviderEmptyEndpoint {
119                    name: provider.name.clone(),
120                });
121            }
122        }
123        Ok(())
124    }
125
126    pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
127        self.providers.iter().find(|p| p.name == name)
128    }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
132#[serde(deny_unknown_fields)]
133pub struct GatewayProvider {
134    pub name: String,
135    pub endpoint: String,
136    pub api_key_secret: String,
137    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
138    pub extra_headers: HashMap<String, String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
142#[serde(deny_unknown_fields)]
143pub struct GatewayModel {
144    pub id: String,
145    pub provider: String,
146    #[serde(default, skip_serializing_if = "Option::is_none")]
147    pub display_name: Option<String>,
148    #[serde(default, skip_serializing_if = "Option::is_none")]
149    pub upstream_model: Option<String>,
150    #[serde(default, skip_serializing_if = "Option::is_none")]
151    pub pricing: Option<ModelPricing>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
155#[serde(deny_unknown_fields)]
156pub struct GatewayRoute {
157    #[serde(default)]
158    pub id: String,
159    pub model_pattern: String,
160    pub provider: String,
161    pub endpoint: String,
162    pub api_key_secret: String,
163    #[serde(default, skip_serializing_if = "Option::is_none")]
164    pub upstream_model: Option<String>,
165    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
166    pub extra_headers: HashMap<String, String>,
167    #[serde(default, skip_serializing_if = "Option::is_none")]
168    pub pricing: Option<ModelPricing>,
169}
170
171impl GatewayRoute {
172    pub fn matches(&self, model: &str) -> bool {
173        match_pattern(&self.model_pattern, model)
174    }
175
176    pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
177        self.upstream_model.as_deref().unwrap_or(requested)
178    }
179
180    pub fn ensure_id(&mut self) {
181        if self.id.trim().is_empty() {
182            self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
183        }
184    }
185}
186
187/// Slugify a model pattern for use in a stable id.
188///
189/// Mirrors the template's historical implementation in
190/// `extensions/web/admin/.../gateway.rs`: `*` becomes `star`,
191/// non-alphanumeric runs collapse to a single `-`, leading/trailing `-`
192/// are trimmed, and an empty result becomes `route`.
193#[must_use]
194pub fn slugify_pattern(pattern: &str) -> String {
195    let mut out = String::with_capacity(pattern.len());
196    let mut last_dash = false;
197    for ch in pattern.chars() {
198        if ch == '*' {
199            out.push_str("star");
200            last_dash = false;
201        } else if ch.is_ascii_alphanumeric() {
202            for lc in ch.to_lowercase() {
203                out.push(lc);
204            }
205            last_dash = false;
206        } else if !last_dash && !out.is_empty() {
207            out.push('-');
208            last_dash = true;
209        }
210    }
211    while out.ends_with('-') {
212        out.pop();
213    }
214    while out.starts_with('-') {
215        out.remove(0);
216    }
217    if out.is_empty() {
218        out.push_str("route");
219    }
220    out
221}
222
223// Format: <slug>-<6 hex chars> where the hex digest is the first 6 chars of
224// DefaultHasher over (model_pattern, provider, endpoint). Mirrors the template
225// logic so ids stay identical across the core/template seam.
226#[must_use]
227pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
228    let mut hasher = DefaultHasher::new();
229    model_pattern.hash(&mut hasher);
230    provider.hash(&mut hasher);
231    endpoint.hash(&mut hasher);
232    let h = hasher.finish();
233    let hash6: String = format!("{h:016x}").chars().take(6).collect();
234    format!("{}-{}", slugify_pattern(model_pattern), hash6)
235}
236
237fn match_pattern(pattern: &str, model: &str) -> bool {
238    if pattern == "*" {
239        return true;
240    }
241    if let Some(prefix) = pattern.strip_suffix('*') {
242        return model.starts_with(prefix);
243    }
244    if let Some(suffix) = pattern.strip_prefix('*') {
245        return model.ends_with(suffix);
246    }
247    pattern == model
248}