use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PricingConfig {
pub prompt_cost_per_million: f64,
pub completion_cost_per_million: f64,
}
pub fn cost(pricing: &PricingConfig, prompt_tokens: u32, completion_tokens: u32) -> f64 {
(prompt_tokens as f64 * pricing.prompt_cost_per_million
+ completion_tokens as f64 * pricing.completion_cost_per_million)
/ 1_000_000.0
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
pub listen: String,
pub providers: HashMap<String, ProviderConfig>,
#[serde(default)]
pub keys: Vec<KeyConfig>,
#[serde(default)]
pub extensions: Option<serde_json::Value>,
#[serde(default)]
pub storage: Option<StorageConfig>,
#[serde(default)]
pub aliases: HashMap<String, String>,
#[serde(default)]
pub pricing: HashMap<String, PricingConfig>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub admin_token: Option<String>,
#[serde(default = "default_shutdown_timeout")]
pub shutdown_timeout: u64,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
#[serde(
default,
alias = "standard",
skip_serializing_if = "ProviderKind::is_default"
)]
pub kind: ProviderKind,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub base_url: Option<String>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub models: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub weight: Option<u16>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_retries: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub api_version: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub region: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub access_key: Option<String>,
#[serde(default, skip_serializing)]
pub secret_key: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_path: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub n_gpu_layers: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub n_ctx: Option<u32>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub n_threads: Option<u32>,
}
fn default_shutdown_timeout() -> u64 {
30
}
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ProviderKind {
#[default]
Openai,
Anthropic,
Google,
Bedrock,
Ollama,
Azure,
#[serde(alias = "llama_cpp")]
LlamaCpp,
}
impl ProviderKind {
pub fn is_default(&self) -> bool {
*self == Self::Openai
}
}
impl ProviderConfig {
pub fn effective_kind(&self) -> ProviderKind {
if self.kind == ProviderKind::Anthropic {
return ProviderKind::Anthropic;
}
if let Some(url) = &self.base_url
&& url.contains("anthropic")
{
return ProviderKind::Anthropic;
}
self.kind
}
pub fn validate(&self, provider_name: &str) -> Result<(), String> {
if self.models.is_empty() {
return Err(format!("provider '{provider_name}' has no models"));
}
match self.kind {
ProviderKind::Bedrock => {
if self.region.is_none() {
return Err(format!(
"provider '{provider_name}' (bedrock) requires region"
));
}
if self.access_key.is_none() {
return Err(format!(
"provider '{provider_name}' (bedrock) requires access_key"
));
}
if self.secret_key.is_none() {
return Err(format!(
"provider '{provider_name}' (bedrock) requires secret_key"
));
}
}
ProviderKind::Ollama => {
}
ProviderKind::LlamaCpp => match &self.model_path {
None => {
return Err(format!(
"provider '{provider_name}' (llamacpp) requires model_path"
));
}
Some(path) => {
if !std::path::Path::new(path).exists() {
return Err(format!(
"provider '{provider_name}' (llamacpp): model_path '{path}' does not exist"
));
}
}
},
_ => {
if self.api_key.is_none() && self.base_url.is_none() {
return Err(format!(
"provider '{provider_name}' requires api_key or base_url"
));
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyConfig {
pub name: String,
pub key: String,
pub models: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageConfig {
#[serde(default = "StorageConfig::default_kind")]
pub kind: String,
#[serde(default)]
pub path: Option<String>,
}
impl StorageConfig {
fn default_kind() -> String {
"memory".to_string()
}
}
impl GatewayConfig {
#[cfg(feature = "gateway")]
pub fn from_file(path: &std::path::Path) -> Result<Self, Box<dyn std::error::Error>> {
let raw = std::fs::read_to_string(path)?;
let expanded = expand_env_vars(&raw);
let config: GatewayConfig = toml::from_str(&expanded)?;
Ok(config)
}
}
#[cfg(feature = "gateway")]
fn expand_env_vars(input: &str) -> String {
let mut result = String::with_capacity(input.len());
let mut chars = input.chars().peekable();
while let Some(c) = chars.next() {
if c == '$' && chars.peek() == Some(&'{') {
chars.next(); let mut var_name = String::new();
for ch in chars.by_ref() {
if ch == '}' {
break;
}
var_name.push(ch);
}
if let Ok(val) = std::env::var(&var_name) {
result.push_str(&val);
}
} else {
result.push(c);
}
}
result
}