use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct PricingConfig {
pub prompt_cost_per_million: f64,
pub completion_cost_per_million: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
#[serde(default = "default_listen")]
pub listen: String,
#[serde(default)]
pub providers: HashMap<String, ProviderConfig>,
#[serde(default)]
pub keys: Vec<KeyConfig>,
#[serde(default)]
pub extensions: Option<serde_json::Value>,
#[serde(default)]
pub storage: Option<StorageConfig>,
#[serde(default)]
pub aliases: HashMap<String, String>,
#[serde(default)]
pub models: HashMap<String, crate::ModelInfo>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub cloud_models: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub local_models: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub admin_token: Option<String>,
#[serde(default = "default_shutdown_timeout")]
pub shutdown_timeout: u64,
#[serde(default = "default_openapi")]
pub openapi: bool,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
#[serde(default, skip_serializing_if = "ProviderKind::is_default")]
pub kind: ProviderKind,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub models: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub weight: Option<u16>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_retries: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_version: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub region: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub access_key: Option<String>,
#[serde(default, skip_serializing)]
pub secret_key: Option<String>,
}
fn default_shutdown_timeout() -> u64 {
30
}
fn default_openapi() -> bool {
true
}
fn default_listen() -> String {
"127.0.0.1:5632".to_string()
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub enum ProviderKind {
#[default]
Openai,
Anthropic,
Google,
Bedrock,
Ollama,
Azure,
Custom(String),
}
impl ProviderKind {
pub fn as_str(&self) -> &str {
match self {
Self::Openai => "openai",
Self::Anthropic => "anthropic",
Self::Google => "google",
Self::Bedrock => "bedrock",
Self::Ollama => "ollama",
Self::Azure => "azure",
Self::Custom(s) => s,
}
}
pub fn is_default(&self) -> bool {
matches!(self, Self::Openai)
}
}
impl std::fmt::Display for ProviderKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
impl serde::Serialize for ProviderKind {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(self.as_str())
}
}
impl<'de> serde::Deserialize<'de> for ProviderKind {
fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
let s = String::deserialize(d)?;
Ok(match s.as_str() {
"openai" => Self::Openai,
"anthropic" => Self::Anthropic,
"google" => Self::Google,
"bedrock" => Self::Bedrock,
"ollama" => Self::Ollama,
"azure" => Self::Azure,
_ => Self::Custom(s),
})
}
}
impl ProviderConfig {
pub fn effective_kind(&self) -> ProviderKind {
if self.kind == ProviderKind::Anthropic {
return ProviderKind::Anthropic;
}
if let Some(url) = &self.base_url
&& url.contains("anthropic")
{
return ProviderKind::Anthropic;
}
self.kind.clone()
}
pub fn validate(&self, provider_name: &str) -> Result<(), String> {
if self.models.is_empty() {
return Err(format!("provider '{provider_name}' has no models"));
}
match &self.kind {
ProviderKind::Bedrock => {
if self.region.is_none() {
return Err(format!(
"provider '{provider_name}' (bedrock) requires region"
));
}
if self.access_key.is_none() {
return Err(format!(
"provider '{provider_name}' (bedrock) requires access_key"
));
}
if self.secret_key.is_none() {
return Err(format!(
"provider '{provider_name}' (bedrock) requires secret_key"
));
}
}
ProviderKind::Ollama => {
}
ProviderKind::Custom(name) => {
if self.base_url.is_none() {
return Err(format!(
"provider '{provider_name}' (custom kind '{name}') requires base_url"
));
}
}
_ => {
if self.api_key.is_none() && self.base_url.is_none() {
return Err(format!(
"provider '{provider_name}' requires api_key or base_url"
));
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct KeyRateLimit {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub requests_per_minute: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tokens_per_minute: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyConfig {
pub name: String,
pub key: String,
pub models: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rate_limit: Option<KeyRateLimit>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
#[serde(default = "StorageConfig::default_kind")]
pub kind: String,
#[serde(default)]
pub path: Option<String>,
}
impl StorageConfig {
fn default_kind() -> String {
"memory".to_string()
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct LocalModelEntry {
pub repo_id: String,
#[serde(default)]
pub size_mb: Option<u64>,
#[serde(default)]
pub vision: Option<bool>,
#[serde(default)]
pub arch: Option<String>,
}
#[cfg(feature = "gateway")]
#[derive(Deserialize)]
struct LocalModelsFile {
#[serde(default)]
models: HashMap<String, HashMap<String, HashMap<String, LocalModelEntry>>>,
}
impl GatewayConfig {
#[cfg(feature = "gateway")]
pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
let raw = std::fs::read_to_string(path)?;
let expanded = expand_env_vars(&raw);
let mut config: GatewayConfig = toml::from_str(&expanded)?;
let config_dir = path.parent().unwrap_or_else(|| std::path::Path::new("."));
config.load_cloud_models(config_dir)?;
Ok(config)
}
#[cfg(feature = "gateway")]
fn load_cloud_models(
&mut self,
config_dir: &std::path::Path,
) -> Result<(), Box<dyn std::error::Error>> {
let Some(ref path) = self.cloud_models else {
return Ok(());
};
let full = config_dir.join(path);
let raw = std::fs::read_to_string(&full)
.map_err(|e| format!("cloud_models '{}': {e}", full.display()))?;
let table: HashMap<String, crate::ModelInfo> =
toml::from_str(&raw).map_err(|e| format!("cloud_models '{}': {e}", full.display()))?;
for (model, info) in table {
self.models.entry(model).or_insert(info);
}
Ok(())
}
#[cfg(feature = "gateway")]
pub fn load_local_models(
&self,
config_dir: &std::path::Path,
) -> Result<HashMap<String, LocalModelEntry>, Box<dyn std::error::Error>> {
let Some(ref path) = self.local_models else {
return Ok(HashMap::new());
};
let full = config_dir.join(path);
let raw = std::fs::read_to_string(&full)
.map_err(|e| format!("local_models '{}': {e}", full.display()))?;
let file: LocalModelsFile =
toml::from_str(&raw).map_err(|e| format!("local_models '{}': {e}", full.display()))?;
let mut result = HashMap::new();
for (family, sizes) in file.models {
for (size, quants) in sizes {
for (quant, entry) in quants {
let alias = format!("{family}.{size}.{quant}");
result.insert(alias, entry);
}
}
}
Ok(result)
}
}
#[cfg(feature = "gateway")]
fn expand_env_vars(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let mut chars = input.chars().peekable();
while let Some(c) = chars.next() {
if c == '$' && chars.peek() == Some(&'{') {
chars.next(); let mut var_name = String::new();
for ch in chars.by_ref() {
if ch == '}' {
break;
}
var_name.push(ch);
}
if let Ok(val) = std::env::var(&var_name) {
result.push_str(&val);
}
} else {
result.push(c);
}
}
result
}