Skip to main content

rolter_core/
config.rs

1use serde::{Deserialize, Serialize};
2use std::path::Path;
3
4use crate::error::Result;
5
6/// Root bootstrap configuration loaded from a TOML file or the database.
7#[derive(Debug, Clone, Default, Deserialize, Serialize)]
8pub struct GatewayConfig {
9    #[serde(default)]
10    pub server: ServerConfig,
11    #[serde(default)]
12    pub providers: Vec<ProviderConfig>,
13    #[serde(default)]
14    pub routes: Vec<ModelRoute>,
15    #[serde(default)]
16    pub virtual_keys: Vec<VirtualKeyConfig>,
17    #[serde(default)]
18    pub logging: LoggingConfig,
19}
20
21/// Listener configuration for a rolter process.
22#[derive(Debug, Clone, Deserialize, Serialize)]
23pub struct ServerConfig {
24    #[serde(default = "default_host")]
25    pub host: String,
26    #[serde(default = "default_port")]
27    pub port: u16,
28}
29
30impl Default for ServerConfig {
31    fn default() -> Self {
32        Self {
33            host: default_host(),
34            port: default_port(),
35        }
36    }
37}
38
39fn default_host() -> String {
40    "0.0.0.0".to_string()
41}
42
43fn default_port() -> u16 {
44    4000
45}
46
47/// The wire protocol a provider speaks.
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize)]
49#[serde(rename_all = "snake_case")]
50pub enum ProviderKind {
51    /// native openai chat/completions api
52    Openai,
53    /// native anthropic messages api
54    Anthropic,
55    /// any openai-compatible endpoint such as vllm, tgi or ollama
56    OpenaiCompatible,
57}
58
59/// An upstream provider rolter can forward to.
60#[derive(Debug, Clone, Deserialize, Serialize)]
61pub struct ProviderConfig {
62    pub name: String,
63    pub kind: ProviderKind,
64    /// base url without a trailing slash, e.g. `https://api.openai.com`
65    pub api_base: String,
66    /// inline api key; prefer `api_key_env` so secrets stay out of config files
67    #[serde(default)]
68    pub api_key: Option<String>,
69    /// name of an environment variable to read the api key from
70    #[serde(default)]
71    pub api_key_env: Option<String>,
72    /// optional outbound egress proxy url (http/https/socks5)
73    #[serde(default)]
74    pub egress_proxy: Option<String>,
75}
76
77impl ProviderConfig {
78    /// Resolve the effective api key, preferring the inline value then the env var.
79    pub fn resolve_api_key(&self) -> Option<String> {
80        if let Some(k) = &self.api_key {
81            return Some(k.clone());
82        }
83        self.api_key_env
84            .as_ref()
85            .and_then(|e| std::env::var(e).ok())
86    }
87}
88
89/// Load-balancing strategy applied to a route's targets.
90#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize, Serialize)]
91#[serde(rename_all = "snake_case")]
92pub enum BalancingStrategy {
93    #[default]
94    RoundRobin,
95    Random,
96    PowerOfTwo,
97    ConsistentHash,
98    CacheAware,
99}
100
101/// A single upstream target within a route.
102#[derive(Debug, Clone, Deserialize, Serialize)]
103pub struct Target {
104    /// name of the [`ProviderConfig`] this target forwards to
105    pub provider: String,
106    /// upstream model id; if absent the requested model name is forwarded as-is
107    #[serde(default)]
108    pub model: Option<String>,
109    #[serde(default = "default_weight")]
110    pub weight: u32,
111}
112
113fn default_weight() -> u32 {
114    1
115}
116
117/// Maps a public model name to one or more upstream targets plus a strategy.
118#[derive(Debug, Clone, Deserialize, Serialize)]
119pub struct ModelRoute {
120    /// public model name clients request, e.g. `gpt-4o`
121    pub model: String,
122    #[serde(default)]
123    pub strategy: BalancingStrategy,
124    pub targets: Vec<Target>,
125}
126
127/// A virtual api key that clients present to the gateway.
128#[derive(Debug, Clone, Deserialize, Serialize)]
129pub struct VirtualKeyConfig {
130    pub key: String,
131    #[serde(default)]
132    pub name: Option<String>,
133    /// allowed public model names; empty means all models are allowed
134    #[serde(default)]
135    pub models: Vec<String>,
136}
137
138/// Where request and cost logs are written.
139#[derive(Debug, Clone, Default, Deserialize, Serialize)]
140pub struct LoggingConfig {
141    #[serde(default)]
142    pub clickhouse_url: Option<String>,
143}
144
145impl GatewayConfig {
146    /// Parse a configuration from a TOML string.
147    pub fn from_toml_str(s: &str) -> Result<Self> {
148        Ok(toml::from_str(s)?)
149    }
150
151    /// Load a configuration from a TOML file on disk.
152    pub fn load(path: &Path) -> Result<Self> {
153        let raw = std::fs::read_to_string(path)?;
154        Self::from_toml_str(&raw)
155    }
156
157    /// Find a provider by name.
158    pub fn resolve_provider(&self, name: &str) -> Option<&ProviderConfig> {
159        self.providers.iter().find(|p| p.name == name)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn parses_minimal_config() {
169        let cfg = GatewayConfig::from_toml_str(
170            r#"
171            [[providers]]
172            name = "openai"
173            kind = "openai"
174            api_base = "https://api.openai.com"
175
176            [[routes]]
177            model = "gpt-4o"
178            strategy = "round_robin"
179            [[routes.targets]]
180            provider = "openai"
181            "#,
182        )
183        .unwrap();
184        assert_eq!(cfg.server.port, 4000);
185        assert_eq!(cfg.providers.len(), 1);
186        assert_eq!(cfg.routes[0].strategy, BalancingStrategy::RoundRobin);
187        assert_eq!(cfg.routes[0].targets[0].weight, 1);
188    }
189}