systemprompt_models/profile/
gateway.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::collections::hash_map::DefaultHasher;
4use std::hash::{Hash, Hasher};
5use std::path::PathBuf;
6use thiserror::Error;
7
8#[derive(Debug, Error)]
9pub enum GatewayProfileError {
10 #[error("Failed to read gateway catalog {path}: {source}")]
11 CatalogRead {
12 path: PathBuf,
13 #[source]
14 source: std::io::Error,
15 },
16
17 #[error("Failed to parse gateway catalog {path}: {source}")]
18 CatalogParse {
19 path: PathBuf,
20 #[source]
21 source: serde_yaml::Error,
22 },
23
24 #[error("Invalid gateway catalog {path}: {source}")]
25 CatalogInvalid {
26 path: PathBuf,
27 #[source]
28 source: Box<Self>,
29 },
30
31 #[error("gateway catalog model has empty id")]
32 ModelEmptyId,
33
34 #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
35 UnknownProvider { model: String, provider: String },
36
37 #[error("gateway catalog provider has empty name")]
38 ProviderEmptyName,
39
40 #[error("gateway catalog provider '{name}' has empty endpoint")]
41 ProviderEmptyEndpoint { name: String },
42}
43
44pub type GatewayResult<T> = Result<T, GatewayProfileError>;
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct GatewayConfig {
48 #[serde(default)]
49 pub enabled: bool,
50 #[serde(default)]
51 pub routes: Vec<GatewayRoute>,
52 #[serde(default, skip_serializing_if = "Option::is_none")]
53 pub catalog_path: Option<PathBuf>,
54 #[serde(default, skip)]
55 pub catalog: Option<GatewayCatalog>,
56 #[serde(default = "default_auth_scheme")]
57 pub auth_scheme: String,
58 #[serde(default = "default_inference_path_prefix")]
59 pub inference_path_prefix: String,
60}
61
62impl Default for GatewayConfig {
63 fn default() -> Self {
64 Self {
65 enabled: false,
66 routes: Vec::new(),
67 catalog_path: None,
68 catalog: None,
69 auth_scheme: default_auth_scheme(),
70 inference_path_prefix: default_inference_path_prefix(),
71 }
72 }
73}
74
75fn default_auth_scheme() -> String {
76 "bearer".to_string()
77}
78
79fn default_inference_path_prefix() -> String {
80 "/v1".to_string()
81}
82
83impl GatewayConfig {
84 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
85 self.routes.iter().find(|route| route.matches(model))
86 }
87}
88
89#[derive(Debug, Clone, Default, Serialize, Deserialize)]
90pub struct GatewayCatalog {
91 #[serde(default)]
92 pub providers: Vec<GatewayProvider>,
93 #[serde(default)]
94 pub models: Vec<GatewayModel>,
95}
96
97impl GatewayCatalog {
98 pub fn validate(&self) -> GatewayResult<()> {
99 for model in &self.models {
100 if model.id.is_empty() {
101 return Err(GatewayProfileError::ModelEmptyId);
102 }
103 if !self.providers.iter().any(|p| p.name == model.provider) {
104 return Err(GatewayProfileError::UnknownProvider {
105 model: model.id.clone(),
106 provider: model.provider.clone(),
107 });
108 }
109 }
110 for provider in &self.providers {
111 if provider.name.is_empty() {
112 return Err(GatewayProfileError::ProviderEmptyName);
113 }
114 if provider.endpoint.is_empty() {
115 return Err(GatewayProfileError::ProviderEmptyEndpoint {
116 name: provider.name.clone(),
117 });
118 }
119 }
120 Ok(())
121 }
122
123 pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
124 self.providers.iter().find(|p| p.name == name)
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct GatewayProvider {
130 pub name: String,
131 pub endpoint: String,
132 pub api_key_secret: String,
133 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
134 pub extra_headers: HashMap<String, String>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct GatewayModel {
139 pub id: String,
140 pub provider: String,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub display_name: Option<String>,
143 #[serde(default, skip_serializing_if = "Option::is_none")]
144 pub upstream_model: Option<String>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct GatewayRoute {
149 #[serde(default)]
150 pub id: String,
151 pub model_pattern: String,
152 pub provider: String,
153 pub endpoint: String,
154 pub api_key_secret: String,
155 #[serde(default, skip_serializing_if = "Option::is_none")]
156 pub upstream_model: Option<String>,
157 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
158 pub extra_headers: HashMap<String, String>,
159}
160
161impl GatewayRoute {
162 pub fn matches(&self, model: &str) -> bool {
163 match_pattern(&self.model_pattern, model)
164 }
165
166 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
167 self.upstream_model.as_deref().unwrap_or(requested)
168 }
169
170 pub fn ensure_id(&mut self) {
171 if self.id.trim().is_empty() {
172 self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
173 }
174 }
175}
176
177#[must_use]
184pub fn slugify_pattern(pattern: &str) -> String {
185 let mut out = String::with_capacity(pattern.len());
186 let mut last_dash = false;
187 for ch in pattern.chars() {
188 if ch == '*' {
189 out.push_str("star");
190 last_dash = false;
191 } else if ch.is_ascii_alphanumeric() {
192 for lc in ch.to_lowercase() {
193 out.push(lc);
194 }
195 last_dash = false;
196 } else if !last_dash && !out.is_empty() {
197 out.push('-');
198 last_dash = true;
199 }
200 }
201 while out.ends_with('-') {
202 out.pop();
203 }
204 while out.starts_with('-') {
205 out.remove(0);
206 }
207 if out.is_empty() {
208 out.push_str("route");
209 }
210 out
211}
212
213#[must_use]
219pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
220 let mut hasher = DefaultHasher::new();
221 model_pattern.hash(&mut hasher);
222 provider.hash(&mut hasher);
223 endpoint.hash(&mut hasher);
224 let h = hasher.finish();
225 let hash6: String = format!("{h:016x}").chars().take(6).collect();
226 format!("{}-{}", slugify_pattern(model_pattern), hash6)
227}
228
229fn match_pattern(pattern: &str, model: &str) -> bool {
230 if pattern == "*" {
231 return true;
232 }
233 if let Some(prefix) = pattern.strip_suffix('*') {
234 return model.starts_with(prefix);
235 }
236 if let Some(suffix) = pattern.strip_prefix('*') {
237 return model.ends_with(suffix);
238 }
239 pattern == model
240}