pub mod cli;
pub mod loader;
pub mod paths;
pub mod validation;
use crate::error::{ProxyError, Result};
use crate::provider::{AuthStrategy, LlmProviderBackend, LlmProviderConfig};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub auth: AuthConfig,
pub streaming: StreamingConfig,
#[serde(default)]
pub vertex: Option<VertexConfig>,
#[serde(skip)]
pub llm_provider: Option<LlmProviderConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VertexConfig {
#[serde(alias = "project_id")]
pub project: Option<String>,
#[serde(default)]
pub region: Option<String>,
#[serde(default)]
pub location: Option<String>,
#[serde(default)]
pub publisher: Option<String>,
#[serde(alias = "model_id")]
pub model: Option<String>,
#[serde(default)]
pub url: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_log_level")]
pub log_level: LogLevel,
#[serde(default = "default_enable_retries")]
pub enable_retries: bool,
#[serde(default = "default_max_retry_attempts")]
pub max_retry_attempts: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub service_account_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_account_json: Option<String>,
#[serde(skip, default = "default_auth_strategy")]
pub strategy: AuthStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamingConfig {
#[serde(default = "default_streaming_mode")]
pub mode: StreamingMode,
#[serde(default = "default_buffer_size")]
pub buffer_size: usize,
#[serde(default = "default_chunk_timeout")]
pub chunk_timeout_ms: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum StreamingMode {
Auto,
Never,
Standard,
Buffered,
Always,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum LogLevel {
#[serde(alias = "trace")]
Trace,
#[serde(alias = "debug")]
Debug,
#[serde(alias = "info")]
Info,
#[serde(alias = "warn")]
Warn,
#[serde(alias = "error")]
Error,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServiceAccountKey {
#[serde(rename = "type")]
pub account_type: String,
pub project_id: String,
pub private_key_id: String,
pub private_key: String,
pub client_email: String,
pub client_id: String,
pub auth_uri: String,
pub token_uri: String,
pub auth_provider_x509_cert_url: String,
pub client_x509_cert_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub universe_domain: Option<String>,
}
fn default_port() -> u16 {
3000
}
fn default_log_level() -> LogLevel {
LogLevel::Info
}
fn default_enable_retries() -> bool {
true
}
fn default_max_retry_attempts() -> u32 {
3
}
pub fn default_auth_strategy() -> AuthStrategy {
use crate::config::ServiceAccountKey;
let placeholder_key = ServiceAccountKey {
account_type: "service_account".to_string(),
project_id: "placeholder".to_string(),
private_key_id: "placeholder".to_string(),
private_key: "placeholder".to_string(),
client_email: "placeholder@placeholder.com".to_string(),
client_id: "placeholder".to_string(),
auth_uri: "https://accounts.google.com/o/oauth2/auth".to_string(),
token_uri: "https://oauth2.googleapis.com/token".to_string(),
auth_provider_x509_cert_url: "https://www.googleapis.com/oauth2/v1/certs".to_string(),
client_x509_cert_url: "".to_string(),
universe_domain: None,
};
AuthStrategy::GcpOAuth2(placeholder_key)
}
fn default_streaming_mode() -> StreamingMode {
StreamingMode::Auto
}
fn default_buffer_size() -> usize {
65536
}
fn default_chunk_timeout() -> u64 {
5000
}
impl Default for Config {
fn default() -> Self {
Self {
server: ServerConfig::default(),
auth: AuthConfig::default(),
streaming: StreamingConfig::default(),
vertex: None,
llm_provider: None,
}
}
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
port: default_port(),
log_level: default_log_level(),
enable_retries: default_enable_retries(),
max_retry_attempts: default_max_retry_attempts(),
}
}
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
service_account_file: None,
service_account_json: None,
strategy: default_auth_strategy(),
}
}
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
mode: default_streaming_mode(),
buffer_size: default_buffer_size(),
chunk_timeout_ms: default_chunk_timeout(),
}
}
}
impl Config {
pub fn load() -> Result<Self> {
let mut base_config = loader::ConfigLoader::new()
.with_defaults()
.with_system_config()?
.with_user_config()?
.with_env_vars()?
.build_base()?;
let service_account_key = Self::load_service_account_key_from_auth(&base_config.auth)?;
base_config.llm_provider = Some(LlmProviderConfig::from_config_or_env_with_key(
service_account_key,
base_config.vertex.as_ref(),
)?);
Ok(base_config)
}
pub fn build_predict_url(&self, is_streaming: bool) -> String {
self.llm_provider
.as_ref()
.map(|p| p.build_request_url(is_streaming))
.unwrap_or_else(|| "http://localhost:3000/unknown".to_string())
}
pub fn llm_model(&self) -> &str {
self.llm_provider.as_ref().map(|p| p.display_model_name()).unwrap_or("unknown")
}
#[allow(dead_code)]
pub fn load_service_account_key_standalone() -> Result<ServiceAccountKey> {
let auth_config =
loader::ConfigLoader::new().with_defaults().with_env_vars()?.build_base()?.auth;
Self::load_service_account_key_from_auth(&auth_config)
}
pub fn load_service_account_key_from_auth(auth: &AuthConfig) -> Result<ServiceAccountKey> {
if let Some(ref json_str) = auth.service_account_json {
serde_json::from_str(json_str).map_err(|e| {
ProxyError::Config(format!(
"Failed to parse inline service account JSON: {}\n\
\n\
The JSON appears to be malformed. Please verify:\n\
1. All required fields are present\n\
2. JSON syntax is valid\n\
3. No extra commas or missing quotes\n\
\n\
Run 'modelmux config validate' for more details.",
e
))
})
} else if let Some(ref file_path) = auth.service_account_file {
let expanded_path = paths::expand_path(file_path)?;
let file_contents = std::fs::read_to_string(&expanded_path).map_err(|e| {
ProxyError::Config(format!(
"Failed to read service account file '{}': {}\n\
\n\
To fix this:\n\
1. Verify the file exists and is readable\n\
2. Check file permissions (should be 600 or similar)\n\
3. Ensure the path is correct\n\
\n\
Example:\n\
ls -la '{}'\n\
chmod 600 '{}'",
expanded_path.display(),
e,
expanded_path.display(),
expanded_path.display()
))
})?;
serde_json::from_str(&file_contents).map_err(|e| {
ProxyError::Config(format!(
"Failed to parse service account file '{}': {}\n\
\n\
The file appears to contain invalid JSON. Please verify:\n\
1. The file was downloaded correctly from Google Cloud\n\
2. No extra characters or modifications were made\n\
3. The file is a valid service account key JSON\n\
\n\
Run 'modelmux config validate' for more details.",
expanded_path.display(),
e
))
})
} else {
Err(ProxyError::Config(
"No service account configuration found.\n\
\n\
Please configure either:\n\
1. auth.service_account_file = \"/path/to/service-account.json\"\n\
2. auth.service_account_json = \"{ ... }\" (inline JSON)\n\
\n\
Run 'modelmux config init' for interactive setup."
.to_string(),
))
}
}
pub fn validate(&self) -> Result<()> {
validation::ConfigValidator::new(self).validate()
}
pub fn load_service_account_key(&self) -> Result<ServiceAccountKey> {
Self::load_service_account_key_from_auth(&self.auth)
}
pub fn example_toml() -> &'static str {
r#"# ModelMux Configuration
# This file should be placed at:
# Linux/Unix: ~/.config/modelmux/config.toml
# macOS: ~/Library/Application Support/modelmux/config.toml
# Windows: %APPDATA%/modelmux/config.toml
[server]
# HTTP server port (default: 3000)
port = 3000
# Logging level: trace, debug, info, warn, error (default: info)
log_level = "info"
# Enable automatic retries for quota/rate limit errors (default: true)
enable_retries = true
# Maximum number of retry attempts (default: 3)
max_retry_attempts = 3
[auth]
# Path to Google Cloud service account JSON file (recommended)
# Supports tilde (~) expansion
service_account_file = "~/.config/modelmux/service-account.json"
# Alternative: Inline service account JSON (for containers)
# service_account_json = '{"type": "service_account", ...}'
[streaming]
# Streaming mode: auto, never, standard, buffered, always (default: auto)
# - auto: detect client and choose appropriate mode
# - never: disable streaming for all clients
# - standard: word-by-word streaming
# - buffered: chunk streaming for better compatibility
# - always: force streaming for all clients
mode = "auto"
# Buffer size for buffered streaming in bytes (default: 65536)
buffer_size = 65536
# Timeout for streaming chunks in milliseconds (default: 5000)
chunk_timeout_ms = 5000
# Vertex AI provider (optional - can also use env vars or .env)
[vertex]
project = "your-gcp-project"
region = "europe-west1"
location = "europe-west1"
publisher = "anthropic"
model = "claude-3-5-sonnet@20241022"
# Or use full URL override instead:
# url = "https://europe-west1-aiplatform.googleapis.com/v1/projects/MY_PROJECT/locations/europe-west1/publishers/anthropic/models/claude-3-5-sonnet@20241022"
# Alternative: use environment variables (including from .env file):
# LLM_PROVIDER=vertex
# VERTEX_PROJECT=your-gcp-project
# VERTEX_REGION=europe-west1
# VERTEX_LOCATION=europe-west1
# VERTEX_PUBLISHER=anthropic
# VERTEX_MODEL_ID=claude-3-5-sonnet@20241022
"#
}
}
impl LogLevel {
pub fn to_tracing_level(&self) -> tracing::Level {
match self {
LogLevel::Trace => tracing::Level::TRACE,
LogLevel::Debug => tracing::Level::DEBUG,
LogLevel::Info => tracing::Level::INFO,
LogLevel::Warn => tracing::Level::WARN,
LogLevel::Error => tracing::Level::ERROR,
}
}
pub fn is_trace_enabled(self) -> bool {
matches!(self, LogLevel::Trace | LogLevel::Debug)
}
pub fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"trace" => Ok(LogLevel::Trace),
"debug" => Ok(LogLevel::Debug),
"info" => Ok(LogLevel::Info),
"warn" | "warning" => Ok(LogLevel::Warn),
"error" => Ok(LogLevel::Error),
_ => Err(ProxyError::Config(format!(
"Invalid log level '{}'. Valid levels are: trace, debug, info, warn, error",
s
))),
}
}
}
impl StreamingMode {
pub fn from_str(s: &str) -> Result<Self> {
match s.to_lowercase().as_str() {
"auto" => Ok(StreamingMode::Auto),
"never" | "false" | "no" => Ok(StreamingMode::Never),
"standard" | "normal" => Ok(StreamingMode::Standard),
"buffered" | "buffer" => Ok(StreamingMode::Buffered),
"always" | "true" | "yes" => Ok(StreamingMode::Always),
_ => Err(ProxyError::Config(format!(
"Invalid streaming mode '{}'. Valid modes are: auto, never, standard, buffered, always",
s
))),
}
}
#[allow(dead_code)]
pub fn is_streaming(&self) -> bool {
!matches!(self, StreamingMode::Never)
}
#[allow(dead_code)]
pub fn is_auto_detect(&self) -> bool {
matches!(self, StreamingMode::Auto)
}
#[allow(dead_code)]
pub fn is_non_streaming(&self) -> bool {
matches!(self, StreamingMode::Never)
}
}