use axum::middleware;
use axum::routing::{any, delete, get, post};
use axum::Router;
use clap::{Parser, Subcommand};
use llmtrace_core::{ProxyConfig, SecurityAnalyzer};
use llmtrace_proxy::alerts::AlertEngine;
use llmtrace_proxy::circuit_breaker::CircuitBreaker;
use llmtrace_proxy::config;
use llmtrace_proxy::cost::CostEstimator;
use llmtrace_proxy::cost_caps::CostTracker;
use llmtrace_proxy::openapi::ApiDoc;
use llmtrace_proxy::proxy::{health_handler, proxy_handler, AppState, MlModelStatus};
use llmtrace_proxy::shutdown::{self, ShutdownCoordinator};
use llmtrace_security::RegexSecurityAnalyzer;
use llmtrace_storage::StorageProfile;
use std::path::PathBuf;
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tracing::info;
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
#[derive(Parser)]
#[command(name = "llmtrace-proxy", version, about, long_about = None)]
struct Cli {
#[arg(short, long, global = true, env = "LLMTRACE_CONFIG")]
config: Option<PathBuf>,
#[arg(long, global = true, env = "LLMTRACE_LOG_LEVEL")]
log_level: Option<String>,
#[arg(long, global = true, env = "LLMTRACE_LOG_FORMAT")]
log_format: Option<String>,
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
Validate,
Migrate,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
let config = load_and_merge_config(&cli)?;
match cli.command {
Some(Commands::Validate) => run_validate(&config),
Some(Commands::Migrate) => {
init_logging(&config)?;
run_migrate(&config).await
}
None => {
init_logging(&config)?;
config::validate_config(&config)?;
run_proxy(config).await
}
}
}
fn load_and_merge_config(cli: &Cli) -> anyhow::Result<ProxyConfig> {
let mut config = match &cli.config {
Some(path) => {
eprintln!("Loading configuration from {}", path.display());
config::load_config(path)?
}
None => {
eprintln!("No config file specified, using defaults");
ProxyConfig::default()
}
};
config::apply_env_overrides(&mut config);
if let Some(ref level) = cli.log_level {
config.logging.level.clone_from(level);
}
if let Some(ref format) = cli.log_format {
config.logging.format.clone_from(format);
}
Ok(config)
}
fn run_validate(config: &ProxyConfig) -> anyhow::Result<()> {
config::validate_config(config)?;
println!("✓ Configuration is valid.\n");
println!("Resolved configuration:");
println!("{}", serde_yaml::to_string(config)?);
Ok(())
}
async fn run_migrate(config: &ProxyConfig) -> anyhow::Result<()> {
info!(
profile = %config.storage.profile,
"Running database migrations"
);
match config.storage.profile.as_str() {
"lite" => {
let url = format!("sqlite:{}", config.storage.database_path);
let pool = llmtrace_storage::migration::open_sqlite_pool(&url).await?;
llmtrace_storage::migration::run_sqlite_migrations(&pool).await?;
info!("SQLite migrations complete");
}
"memory" => {
info!("Memory profile uses no persistent storage — nothing to migrate");
}
"production" => {
if let Some(ref pg_url) = config.storage.postgres_url {
let pool = llmtrace_storage::migration::open_pg_pool(pg_url).await?;
llmtrace_storage::migration::run_pg_migrations(&pool).await?;
info!("PostgreSQL migrations complete");
} else {
anyhow::bail!("storage.postgres_url is required for production profile migrations");
}
}
other => {
anyhow::bail!("Unknown storage profile: {other}");
}
}
println!("✓ All migrations applied successfully.");
Ok(())
}
async fn run_proxy(config: ProxyConfig) -> anyhow::Result<()> {
info!(
listen_addr = %config.listen_addr,
upstream_url = %config.upstream_url,
storage_profile = %config.storage.profile,
shutdown_timeout_seconds = config.shutdown.timeout_seconds,
"Starting LLMTrace proxy server"
);
let listen_addr = config.listen_addr.clone();
let state = build_app_state(config).await?;
let coordinator = state.shutdown.clone();
if state.config.grpc.enabled {
let grpc_state = Arc::clone(&state);
tokio::spawn(async move {
if let Err(e) = llmtrace_proxy::run_grpc_server(grpc_state).await {
tracing::error!("gRPC server exited with error: {e}");
}
});
}
let app = build_router(state);
let listener = tokio::net::TcpListener::bind(&listen_addr).await?;
info!(%listen_addr, "Proxy server listening");
let shutdown_coord = coordinator.clone();
axum::serve(listener, app)
.with_graceful_shutdown(shutdown::shutdown_signal(shutdown_coord))
.await?;
let n = coordinator.in_flight_count();
if n > 0 {
info!(
in_flight_tasks = n,
"Waiting for in-flight background tasks to complete"
);
}
coordinator.wait_for_tasks().await;
info!("Shutdown complete");
Ok(())
}
fn init_logging(config: &ProxyConfig) -> anyhow::Result<()> {
let filter = tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(&config.logging.level));
match config.logging.format.as_str() {
"json" => {
tracing_subscriber::fmt()
.json()
.with_env_filter(filter)
.with_target(true)
.with_thread_ids(true)
.init();
}
_ => {
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(true)
.init();
}
}
Ok(())
}
async fn build_app_state(config: ProxyConfig) -> anyhow::Result<Arc<AppState>> {
let client = reqwest::Client::builder()
.connect_timeout(std::time::Duration::from_millis(
config.connection_timeout_ms,
))
.timeout(std::time::Duration::from_millis(config.timeout_ms))
.build()?;
let profile =
match config.storage.profile.as_str() {
"memory" => StorageProfile::Memory,
"production" => StorageProfile::Production {
clickhouse_url: config
.storage
.clickhouse_url
.clone()
.unwrap_or_else(|| "http://localhost:8123".to_string()),
clickhouse_database: config
.storage
.clickhouse_database
.clone()
.unwrap_or_else(|| "llmtrace".to_string()),
postgres_url: config.storage.postgres_url.clone().unwrap_or_else(|| {
"postgres://llmtrace:llmtrace@localhost/llmtrace".to_string()
}),
redis_url: config
.storage
.redis_url
.clone()
.unwrap_or_else(|| "redis://127.0.0.1:6379".to_string()),
},
_ => StorageProfile::Lite {
database_path: config.storage.database_path.clone(),
},
};
info!(
profile = %config.storage.profile,
database_path = %config.storage.database_path,
"Initializing storage"
);
let storage = profile
.build()
.await
.map_err(|e| anyhow::anyhow!("Failed to initialize storage: {}", e))?;
let (security, ml_status) = build_security_analyzer(&config).await?;
let fast_analyzer: Arc<dyn llmtrace_core::SecurityAnalyzer> = Arc::new(
llmtrace_security::RegexSecurityAnalyzer::new()
.map_err(|e| anyhow::anyhow!("Failed to build fast analyzer: {e}"))?,
);
let storage_breaker = Arc::new(CircuitBreaker::from_config(&config.circuit_breaker));
let security_breaker = Arc::new(CircuitBreaker::from_config(&config.circuit_breaker));
let cost_estimator = CostEstimator::new(&config.cost_estimation);
let alert_engine = AlertEngine::from_config(&config.alerts, client.clone());
let cost_tracker = CostTracker::new(&config.cost_caps, Arc::clone(&storage.cache));
let anomaly_detector = llmtrace_proxy::anomaly::AnomalyDetector::new(
&config.anomaly_detection,
Arc::clone(&storage.cache),
);
if let Some(ref engine) = alert_engine {
info!(
channels = engine.channel_count(),
"Alert engine enabled with {} channel(s)",
engine.channel_count(),
);
}
if cost_tracker.is_some() {
info!("Cost cap enforcement enabled");
}
if anomaly_detector.is_some() {
info!("Anomaly detection enabled");
}
let report_store = llmtrace_proxy::compliance::new_report_store();
let shutdown = ShutdownCoordinator::new(config.shutdown.timeout_seconds);
let metrics = llmtrace_proxy::metrics::Metrics::new();
let ready = Arc::new(std::sync::atomic::AtomicBool::new(false));
let rate_limiter = if config.rate_limiting.enabled {
info!(
rps = config.rate_limiting.requests_per_second,
burst = config.rate_limiting.burst_size,
overrides = config.rate_limiting.tenant_overrides.len(),
"Per-tenant rate limiting enabled"
);
Some(llmtrace_proxy::RateLimiter::new(
&config.rate_limiting,
Arc::clone(&storage.cache),
))
} else {
None
};
Ok(Arc::new(AppState {
config,
client,
storage,
security,
fast_analyzer,
storage_breaker,
security_breaker,
cost_estimator,
alert_engine,
cost_tracker,
anomaly_detector,
report_store,
rate_limiter,
ml_status,
shutdown,
metrics,
ready,
}))
}
async fn build_security_analyzer(
config: &ProxyConfig,
) -> anyhow::Result<(Arc<dyn SecurityAnalyzer>, MlModelStatus)> {
let mut ml_enabled = config.security_analysis.ml_enabled;
let mut ml_preload = config.security_analysis.ml_preload;
if let Ok(v) = std::env::var("LLMTRACE_ML_ENABLED") {
let v = v.to_lowercase();
ml_enabled = !(v == "0" || v == "false" || v == "no");
if !ml_enabled {
tracing::info!("LLMTRACE_ML_ENABLED disabled ML security analysis");
}
}
if let Ok(v) = std::env::var("LLMTRACE_ML_PRELOAD") {
let v = v.to_lowercase();
ml_preload = !(v == "0" || v == "false" || v == "no");
if !ml_preload {
tracing::info!("LLMTRACE_ML_PRELOAD disabled ML model preloading");
}
}
#[cfg(feature = "ml")]
{
if ml_enabled && ml_preload {
info!("Pre-loading ML models at startup (ml_preload=true)");
let load_start = std::time::Instant::now();
let ml_config = llmtrace_security::MLSecurityConfig {
model_id: config.security_analysis.ml_model.clone(),
threshold: config.security_analysis.ml_threshold,
cache_dir: Some(config.security_analysis.ml_cache_dir.clone()),
};
let ner_config = if config.security_analysis.ner_enabled {
Some(llmtrace_security::NerConfig {
model_id: config.security_analysis.ner_model.clone(),
cache_dir: Some(config.security_analysis.ml_cache_dir.clone()),
})
} else {
None
};
let ig_config = if config.security_analysis.injecguard_enabled {
Some(llmtrace_security::InjecGuardConfig {
model_id: config.security_analysis.injecguard_model.clone(),
threshold: config.security_analysis.injecguard_threshold,
cache_dir: Some(config.security_analysis.ml_cache_dir.clone()),
})
} else {
None
};
let pg_config = if config.security_analysis.piguard_enabled {
Some(llmtrace_security::PIGuardConfig {
model_id: config.security_analysis.piguard_model.clone(),
threshold: config.security_analysis.piguard_threshold,
cache_dir: Some(config.security_analysis.ml_cache_dir.clone()),
})
} else {
None
};
let timeout = std::time::Duration::from_secs(
config.security_analysis.ml_download_timeout_seconds,
);
match tokio::time::timeout(
timeout,
llmtrace_security::EnsembleSecurityAnalyzer::with_piguard(
&ml_config,
ner_config.as_ref(),
ig_config.as_ref(),
pg_config.as_ref(),
),
)
.await
{
Ok(Ok(ensemble)) => {
let load_time_ms = load_start.elapsed().as_millis() as u64;
let pi_loaded = ensemble.is_ml_active();
let ner_loaded = ensemble.is_ner_active();
let ig_loaded = ensemble.is_injecguard_active();
let pg_loaded = ensemble.is_piguard_active();
let deberta_count = [pi_loaded, ig_loaded, pg_loaded]
.iter()
.filter(|&&v| v)
.count();
let ner_mb = if ner_loaded { 400 } else { 0 };
let estimated_mb = deberta_count * 500 + ner_mb;
let total_models = deberta_count + (ner_loaded as usize);
info!(
prompt_injection_loaded = pi_loaded,
ner_loaded = ner_loaded,
injecguard_loaded = ig_loaded,
piguard_loaded = pg_loaded,
load_time_ms = load_time_ms,
estimated_memory_mb = estimated_mb,
"ML models pre-loaded at startup (~{estimated_mb} MB for {total_models} model(s))"
);
let status = MlModelStatus::Loaded {
prompt_injection: pi_loaded,
ner: ner_loaded,
injecguard: ig_loaded,
piguard: pg_loaded,
load_time_ms,
};
let op = match config.security_analysis.operating_point {
llmtrace_core::OperatingPoint::HighRecall => {
llmtrace_security::OperatingPoint::HighRecall
}
llmtrace_core::OperatingPoint::HighPrecision => {
llmtrace_security::OperatingPoint::HighPrecision
}
llmtrace_core::OperatingPoint::Balanced => {
llmtrace_security::OperatingPoint::Balanced
}
};
let ensemble = ensemble
.with_operating_point(op)
.with_over_defence(config.security_analysis.over_defence);
Ok((Arc::new(ensemble) as Arc<dyn SecurityAnalyzer>, status))
}
Ok(Err(e)) => {
let err_msg = format!("{e}");
tracing::warn!(
error = %e,
"ML model loading failed at startup — falling back to regex analyzer"
);
let regex = RegexSecurityAnalyzer::new().map_err(|e| {
anyhow::anyhow!("Failed to initialize security analyzer: {}", e)
})?;
Ok((
Arc::new(regex) as Arc<dyn SecurityAnalyzer>,
MlModelStatus::Failed { error: err_msg },
))
}
Err(_) => {
let err_msg = format!(
"ML model download timed out after {}s",
config.security_analysis.ml_download_timeout_seconds
);
tracing::warn!(
timeout_seconds = config.security_analysis.ml_download_timeout_seconds,
"ML model download timed out — falling back to regex analyzer"
);
let regex = RegexSecurityAnalyzer::new().map_err(|e| {
anyhow::anyhow!("Failed to initialize security analyzer: {}", e)
})?;
Ok((
Arc::new(regex) as Arc<dyn SecurityAnalyzer>,
MlModelStatus::Failed { error: err_msg },
))
}
}
} else if ml_enabled {
info!("ML enabled but ml_preload=false — models will load on first request");
let regex = RegexSecurityAnalyzer::new()
.map_err(|e| anyhow::anyhow!("Failed to initialize security analyzer: {}", e))?;
Ok((
Arc::new(regex) as Arc<dyn SecurityAnalyzer>,
MlModelStatus::Disabled,
))
} else {
let regex = RegexSecurityAnalyzer::new()
.map_err(|e| anyhow::anyhow!("Failed to initialize security analyzer: {}", e))?;
Ok((
Arc::new(regex) as Arc<dyn SecurityAnalyzer>,
MlModelStatus::Disabled,
))
}
}
#[cfg(not(feature = "ml"))]
{
let _ = &config.security_analysis; let regex = RegexSecurityAnalyzer::new()
.map_err(|e| anyhow::anyhow!("Failed to initialize security analyzer: {}", e))?;
Ok((
Arc::new(regex) as Arc<dyn SecurityAnalyzer>,
MlModelStatus::Disabled,
))
}
}
fn build_router(state: Arc<AppState>) -> Router {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-doc/openapi.json", ApiDoc::openapi()))
.route("/health", get(health_handler))
.route("/metrics", get(llmtrace_proxy::metrics::metrics_handler))
.route(
"/api/v1/auth/keys",
post(llmtrace_proxy::auth::create_api_key)
.get(llmtrace_proxy::auth::list_api_keys),
)
.route(
"/api/v1/auth/keys/:id",
delete(llmtrace_proxy::auth::revoke_api_key),
)
.route("/api/v1/traces", get(llmtrace_proxy::api::list_traces))
.route(
"/api/v1/traces/:trace_id",
get(llmtrace_proxy::api::get_trace),
)
.route("/api/v1/spans", get(llmtrace_proxy::api::list_spans))
.route(
"/api/v1/spans/:span_id",
get(llmtrace_proxy::api::get_span),
)
.route("/api/v1/stats", get(llmtrace_proxy::api::get_stats))
.route(
"/api/v1/security/findings",
get(llmtrace_proxy::api::list_security_findings),
)
.route(
"/api/v1/costs/current",
get(llmtrace_proxy::api::get_current_costs),
)
.route(
"/api/v1/traces/:trace_id/actions",
post(llmtrace_proxy::api::report_action),
)
.route(
"/api/v1/actions/summary",
get(llmtrace_proxy::api::actions_summary),
)
.route(
"/api/v1/tenants",
post(llmtrace_proxy::tenant_api::create_tenant)
.get(llmtrace_proxy::tenant_api::list_tenants),
)
.route(
"/api/v1/tenants/current/token",
get(llmtrace_proxy::tenant_api::get_current_tenant_token),
)
.route(
"/api/v1/tenants/:id/token",
get(llmtrace_proxy::tenant_api::get_tenant_token),
)
.route(
"/api/v1/tenants/:id",
get(llmtrace_proxy::tenant_api::get_tenant)
.put(llmtrace_proxy::tenant_api::update_tenant)
.delete(llmtrace_proxy::tenant_api::delete_tenant),
)
.route(
"/api/v1/tenants/:id/token/reset",
post(llmtrace_proxy::tenant_api::reset_tenant_token),
)
.route(
"/api/v1/reports/generate",
post(llmtrace_proxy::compliance::generate_report),
)
.route(
"/api/v1/reports",
get(llmtrace_proxy::compliance::list_reports),
)
.route(
"/api/v1/reports/:id",
get(llmtrace_proxy::compliance::get_report),
)
.route(
"/v1/traces",
post(llmtrace_proxy::otel::ingest_traces),
)
.fallback(any(proxy_handler))
.layer(middleware::from_fn_with_state(
Arc::clone(&state),
llmtrace_proxy::auth::auth_middleware,
))
.layer(cors)
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use llmtrace_core::StorageConfig;
use tower::ServiceExt;
fn memory_config() -> ProxyConfig {
ProxyConfig {
storage: StorageConfig {
profile: "memory".to_string(),
database_path: String::new(),
..StorageConfig::default()
},
..ProxyConfig::default()
}
}
async fn test_app() -> Router {
let state = build_app_state(memory_config()).await.unwrap();
build_router(state)
}
#[tokio::test]
async fn test_health_endpoint() {
let app = test_app().await;
let req = Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), 1024 * 1024)
.await
.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["status"], "healthy");
assert!(json["storage"]["traces"]["healthy"].as_bool().unwrap());
assert!(json["security"]["healthy"].as_bool().unwrap());
}
#[tokio::test]
async fn test_proxy_returns_bad_gateway_when_upstream_unreachable() {
let config = ProxyConfig {
upstream_url: "http://127.0.0.1:1".to_string(), connection_timeout_ms: 100,
timeout_ms: 500,
storage: StorageConfig {
profile: "memory".to_string(),
database_path: String::new(),
..StorageConfig::default()
},
..ProxyConfig::default()
};
let state = build_app_state(config).await.unwrap();
let app = build_router(state);
let body = serde_json::json!({
"model": "gpt-4",
"messages": [{"role": "user", "content": "Hello"}]
});
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.header("authorization", "Bearer sk-test")
.body(Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_GATEWAY);
}
#[tokio::test]
async fn test_build_app_state_succeeds() {
let state = build_app_state(memory_config()).await;
assert!(state.is_ok());
}
#[test]
fn test_load_and_merge_config_defaults() {
let cli = Cli {
config: None,
log_level: None,
log_format: None,
command: None,
};
let config = load_and_merge_config(&cli).unwrap();
assert_eq!(config.listen_addr, "0.0.0.0:8080");
assert_eq!(config.logging.level, "info");
assert_eq!(config.logging.format, "text");
}
#[test]
fn test_load_and_merge_config_cli_overrides() {
let cli = Cli {
config: None,
log_level: Some("debug".to_string()),
log_format: Some("json".to_string()),
command: None,
};
let config = load_and_merge_config(&cli).unwrap();
assert_eq!(config.logging.level, "debug");
assert_eq!(config.logging.format, "json");
}
#[tokio::test]
async fn test_metrics_endpoint_returns_prometheus_format() {
let app = test_app().await;
let req = Request::builder()
.uri("/metrics")
.body(Body::empty())
.unwrap();
let response = app.oneshot(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let ct = response
.headers()
.get("content-type")
.unwrap()
.to_str()
.unwrap();
assert!(
ct.contains("text/plain"),
"Expected text/plain content type for Prometheus, got: {ct}"
);
let body = axum::body::to_bytes(response.into_body(), 1024 * 1024)
.await
.unwrap();
let text = String::from_utf8_lossy(&body);
assert!(
text.contains("# HELP llmtrace_active_connections"),
"Missing active_connections metric"
);
assert!(
text.contains("# TYPE llmtrace_active_connections gauge"),
"Missing active_connections type"
);
assert!(
text.contains("# HELP llmtrace_circuit_breaker_state"),
"Missing circuit_breaker_state metric"
);
}
#[test]
fn test_load_and_merge_config_from_file() {
use std::io::Write;
let yaml = r#"
listen_addr: "127.0.0.1:9999"
upstream_url: "http://localhost:11434"
timeout_ms: 60000
connection_timeout_ms: 5000
max_connections: 500
enable_tls: false
enable_security_analysis: true
enable_trace_storage: true
enable_streaming: true
max_request_size_bytes: 52428800
security_analysis_timeout_ms: 5000
trace_storage_timeout_ms: 10000
rate_limiting:
enabled: true
requests_per_second: 100
burst_size: 200
window_seconds: 60
circuit_breaker:
enabled: true
failure_threshold: 10
recovery_timeout_ms: 30000
half_open_max_calls: 3
health_check:
enabled: true
path: "/health"
interval_seconds: 10
timeout_ms: 5000
retries: 3
logging:
level: "warn"
format: "json"
"#;
let mut f = tempfile::NamedTempFile::new().unwrap();
f.write_all(yaml.as_bytes()).unwrap();
let cli = Cli {
config: Some(f.path().to_path_buf()),
log_level: None,
log_format: None,
command: None,
};
let config = load_and_merge_config(&cli).unwrap();
assert_eq!(config.listen_addr, "127.0.0.1:9999");
assert_eq!(config.logging.level, "warn");
assert_eq!(config.logging.format, "json");
}
}