systemprompt_models/profile/
gateway.rs1use crate::services::ai::ModelPricing;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::path::PathBuf;
7use thiserror::Error;
8
9#[derive(Debug, Error)]
10pub enum GatewayProfileError {
11 #[error("Failed to read gateway catalog {path}: {source}")]
12 CatalogRead {
13 path: PathBuf,
14 #[source]
15 source: std::io::Error,
16 },
17
18 #[error("Failed to parse gateway catalog {path}: {source}")]
19 CatalogParse {
20 path: PathBuf,
21 #[source]
22 source: serde_yaml::Error,
23 },
24
25 #[error("Invalid gateway catalog {path}: {source}")]
26 CatalogInvalid {
27 path: PathBuf,
28 #[source]
29 source: Box<Self>,
30 },
31
32 #[error("gateway catalog model has empty id")]
33 ModelEmptyId,
34
35 #[error("gateway catalog model '{model}' references unknown provider '{provider}'")]
36 UnknownProvider { model: String, provider: String },
37
38 #[error("gateway catalog provider has empty name")]
39 ProviderEmptyName,
40
41 #[error("gateway catalog provider '{name}' has empty endpoint")]
42 ProviderEmptyEndpoint { name: String },
43}
44
45pub type GatewayResult<T> = Result<T, GatewayProfileError>;
46
47#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
48#[serde(deny_unknown_fields)]
49pub struct GatewayConfig {
50 #[serde(default)]
51 pub enabled: bool,
52 #[serde(default)]
53 pub routes: Vec<GatewayRoute>,
54 #[serde(default, skip_serializing_if = "Option::is_none")]
55 pub catalog_path: Option<PathBuf>,
56 #[serde(default, skip)]
57 pub catalog: Option<GatewayCatalog>,
58 #[serde(default = "default_auth_scheme")]
59 pub auth_scheme: String,
60 #[serde(default = "default_inference_path_prefix")]
61 pub inference_path_prefix: String,
62}
63
64impl Default for GatewayConfig {
65 fn default() -> Self {
66 Self {
67 enabled: false,
68 routes: Vec::new(),
69 catalog_path: None,
70 catalog: None,
71 auth_scheme: default_auth_scheme(),
72 inference_path_prefix: default_inference_path_prefix(),
73 }
74 }
75}
76
77fn default_auth_scheme() -> String {
78 "bearer".to_string()
79}
80
81fn default_inference_path_prefix() -> String {
82 "/v1".to_string()
83}
84
85impl GatewayConfig {
86 pub fn find_route(&self, model: &str) -> Option<&GatewayRoute> {
87 self.routes.iter().find(|route| route.matches(model))
88 }
89}
90
91#[derive(Debug, Clone, Default, Serialize, Deserialize, schemars::JsonSchema)]
92#[serde(deny_unknown_fields)]
93pub struct GatewayCatalog {
94 #[serde(default)]
95 pub providers: Vec<GatewayProvider>,
96 #[serde(default)]
97 pub models: Vec<GatewayModel>,
98}
99
100impl GatewayCatalog {
101 pub fn validate(&self) -> GatewayResult<()> {
102 for model in &self.models {
103 if model.id.is_empty() {
104 return Err(GatewayProfileError::ModelEmptyId);
105 }
106 if !self.providers.iter().any(|p| p.name == model.provider) {
107 return Err(GatewayProfileError::UnknownProvider {
108 model: model.id.clone(),
109 provider: model.provider.clone(),
110 });
111 }
112 }
113 for provider in &self.providers {
114 if provider.name.is_empty() {
115 return Err(GatewayProfileError::ProviderEmptyName);
116 }
117 if provider.endpoint.is_empty() {
118 return Err(GatewayProfileError::ProviderEmptyEndpoint {
119 name: provider.name.clone(),
120 });
121 }
122 }
123 Ok(())
124 }
125
126 pub fn find_provider(&self, name: &str) -> Option<&GatewayProvider> {
127 self.providers.iter().find(|p| p.name == name)
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
132#[serde(deny_unknown_fields)]
133pub struct GatewayProvider {
134 pub name: String,
135 pub endpoint: String,
136 pub api_key_secret: String,
137 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
138 pub extra_headers: HashMap<String, String>,
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
142#[serde(deny_unknown_fields)]
143pub struct GatewayModel {
144 pub id: String,
145 pub provider: String,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 pub display_name: Option<String>,
148 #[serde(default, skip_serializing_if = "Option::is_none")]
149 pub upstream_model: Option<String>,
150 #[serde(default, skip_serializing_if = "Option::is_none")]
151 pub pricing: Option<ModelPricing>,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
155#[serde(deny_unknown_fields)]
156pub struct GatewayRoute {
157 #[serde(default)]
158 pub id: String,
159 pub model_pattern: String,
160 pub provider: String,
161 pub endpoint: String,
162 pub api_key_secret: String,
163 #[serde(default, skip_serializing_if = "Option::is_none")]
164 pub upstream_model: Option<String>,
165 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
166 pub extra_headers: HashMap<String, String>,
167 #[serde(default, skip_serializing_if = "Option::is_none")]
168 pub pricing: Option<ModelPricing>,
169}
170
171impl GatewayRoute {
172 pub fn matches(&self, model: &str) -> bool {
173 match_pattern(&self.model_pattern, model)
174 }
175
176 pub fn effective_upstream_model<'a>(&'a self, requested: &'a str) -> &'a str {
177 self.upstream_model.as_deref().unwrap_or(requested)
178 }
179
180 pub fn ensure_id(&mut self) {
181 if self.id.trim().is_empty() {
182 self.id = synthesize_route_id(&self.model_pattern, &self.provider, &self.endpoint);
183 }
184 }
185}
186
187#[must_use]
194pub fn slugify_pattern(pattern: &str) -> String {
195 let mut out = String::with_capacity(pattern.len());
196 let mut last_dash = false;
197 for ch in pattern.chars() {
198 if ch == '*' {
199 out.push_str("star");
200 last_dash = false;
201 } else if ch.is_ascii_alphanumeric() {
202 for lc in ch.to_lowercase() {
203 out.push(lc);
204 }
205 last_dash = false;
206 } else if !last_dash && !out.is_empty() {
207 out.push('-');
208 last_dash = true;
209 }
210 }
211 while out.ends_with('-') {
212 out.pop();
213 }
214 while out.starts_with('-') {
215 out.remove(0);
216 }
217 if out.is_empty() {
218 out.push_str("route");
219 }
220 out
221}
222
223#[must_use]
227pub fn synthesize_route_id(model_pattern: &str, provider: &str, endpoint: &str) -> String {
228 let mut hasher = DefaultHasher::new();
229 model_pattern.hash(&mut hasher);
230 provider.hash(&mut hasher);
231 endpoint.hash(&mut hasher);
232 let h = hasher.finish();
233 let hash6: String = format!("{h:016x}").chars().take(6).collect();
234 format!("{}-{}", slugify_pattern(model_pattern), hash6)
235}
236
237fn match_pattern(pattern: &str, model: &str) -> bool {
238 if pattern == "*" {
239 return true;
240 }
241 if let Some(prefix) = pattern.strip_suffix('*') {
242 return model.starts_with(prefix);
243 }
244 if let Some(suffix) = pattern.strip_prefix('*') {
245 return model.ends_with(suffix);
246 }
247 pattern == model
248}