use arc_swap::ArcSwap;
use bytes::Bytes;
use clap::{Parser, Subcommand};
use crabllm_core::{
AudioSpeechRequest, BoxStream, ChatCompletionChunk, ChatCompletionRequest,
ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse, Error, Extension, GatewayConfig,
ImageRequest, MultipartField, Provider, Storage,
};
use crabllm_provider::{ProviderRegistry, RemoteProvider};
use crabllm_proxy::{
AppState,
ext::{
audit::AuditLogger, budget::Budget, cache::Cache, logging::RequestLogger,
rate_limit::RateLimit, usage::UsageTracker,
},
storage::MemoryStorage,
};
use std::{
collections::HashMap,
path::PathBuf,
sync::{Arc, RwLock},
time::Duration,
};
#[derive(Parser)]
#[command(name = "crabllm", about = "High-performance LLM API gateway")]
struct Cli {
#[arg(short, long, action = clap::ArgAction::Count, global = true)]
verbose: u8,
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
Serve {
#[arg(short, long, default_value = "crabllm.toml")]
config: PathBuf,
#[arg(short, long)]
bind: Option<String>,
},
Init {
#[arg(short, long, default_value = "crabllm.toml")]
out: PathBuf,
#[arg(short, long)]
force: bool,
},
#[cfg(feature = "openapi")]
Openapi {
#[arg(long, value_enum, default_value = "json")]
format: OpenapiFormat,
},
}
#[cfg(feature = "openapi")]
#[derive(Clone, Copy, clap::ValueEnum)]
enum OpenapiFormat {
Json,
Html,
}
#[tokio::main]
async fn main() {
let cli = Cli::parse();
init_tracing(cli.verbose);
match cli.command {
Some(Commands::Serve { config, bind }) => serve(config, bind).await,
Some(Commands::Init { out, force }) => init(out, force),
#[cfg(feature = "openapi")]
Some(Commands::Openapi { format }) => dump_openapi(format),
None => serve(PathBuf::from("crabllm.toml"), None).await,
}
}
#[cfg(feature = "openapi")]
fn dump_openapi(format: OpenapiFormat) {
let spec = crabllm_proxy::openapi::spec();
match format {
OpenapiFormat::Json => {
let out = serde_json::to_string_pretty(&spec).expect("serialize openapi spec");
println!("{out}");
}
OpenapiFormat::Html => {
let html = utoipa_scalar::Scalar::new(spec).to_html();
print!("{html}");
}
}
}
fn init_tracing(verbose: u8) {
use tracing_subscriber::EnvFilter;
let filter = match verbose {
0 => EnvFilter::try_from_default_env().unwrap_or_else(|_| {
EnvFilter::new("crabllm=info,crabllm_proxy=info,crabllm_provider=info,warn")
}),
1 => EnvFilter::new("info"),
2 => EnvFilter::new("debug"),
_ => EnvFilter::new("trace"),
};
tracing_subscriber::fmt().with_env_filter(filter).init();
}
fn init(out: PathBuf, force: bool) {
if out.exists() && !force {
eprintln!(
"error: {} already exists; pass --force to overwrite",
out.display()
);
std::process::exit(1);
}
let admin_token = crabllm_proxy::admin::generate_key();
let default_key = crabllm_proxy::admin::generate_key();
let contents = format!(
r#"# crabllm gateway config — generated by `crabllm init`.
# Safe to edit; `crabllm serve` reads this file on startup.
# Address the gateway listens on. Use 0.0.0.0 to accept external traffic.
listen = "127.0.0.1:5632"
# Admin bearer token for the /v1/admin/* API surface. Anyone with this
# token can manage keys, providers, and inspect usage — keep it secret.
admin_token = "{admin_token}"
# OpenAPI docs at /docs (Scalar UI) and /openapi.json. Set to false to
# disable, or rebuild with `--no-default-features` to strip the feature.
# openapi = true
# Default user-facing API key. Clients send this as the
# `Authorization: Bearer <key>` header. Add more via:
# crabctl keys create <name> --models m1,m2
[[keys]]
name = "default"
key = "{default_key}"
models = ["*"]
# ── Storage backend (optional) ──
# Without this block, storage is in-memory — dynamic keys and providers
# created via the admin API will be lost on restart.
#
# [storage]
# kind = "sqlite" # or "redis" / "memory"
# path = "crabllm.db" # file for sqlite, url for redis
# ── Providers (optional) ──
# Configure upstream LLM providers here, or add them at runtime via
# `crabctl providers create`. Examples:
#
# [providers.openai]
# kind = "openai"
# api_key = "${{OPENAI_API_KEY}}" # env var interpolation supported
# models = ["gpt-4o", "gpt-4o-mini"]
#
# [providers.anthropic]
# kind = "anthropic"
# api_key = "${{ANTHROPIC_API_KEY}}"
# models = ["claude-sonnet-4-20250514"]
# ── Extensions (optional) ──
# Opt-in features. Enable by uncommenting and configuring.
#
# [extensions]
# rate_limit = {{ requests_per_minute = 60 }}
# usage = {{}}
# audit = {{}}
# budget = {{ default_budget = 10.0 }}
# cache = {{ ttl_seconds = 300 }}
"#,
);
if let Some(parent) = out.parent()
&& !parent.as_os_str().is_empty()
&& let Err(e) = std::fs::create_dir_all(parent)
{
eprintln!("error: failed to create directory: {e}");
std::process::exit(1);
}
if let Err(e) = std::fs::write(&out, contents) {
eprintln!("error: failed to write {}: {e}", out.display());
std::process::exit(1);
}
eprintln!("wrote {}", out.display());
eprintln!();
eprintln!("admin token: {admin_token}");
eprintln!("default key: {default_key}");
eprintln!();
eprintln!("these values are also in the config file above; keep them safe.");
}
enum Dispatch {
Remote(RemoteProvider),
}
impl Provider for Dispatch {
async fn chat_completion(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse, Error> {
match self {
Self::Remote(p) => p.chat_completion(request).await,
}
}
async fn chat_completion_stream(
&self,
request: &ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<ChatCompletionChunk, Error>>, Error> {
match self {
Self::Remote(p) => p.chat_completion_stream(request).await,
}
}
async fn embedding(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse, Error> {
match self {
Self::Remote(p) => p.embedding(request).await,
}
}
async fn image_generation(&self, request: &ImageRequest) -> Result<(Bytes, String), Error> {
match self {
Self::Remote(p) => p.image_generation(request).await,
}
}
async fn audio_speech(&self, request: &AudioSpeechRequest) -> Result<(Bytes, String), Error> {
match self {
Self::Remote(p) => p.audio_speech(request).await,
}
}
async fn audio_transcription(
&self,
model: &str,
fields: &[MultipartField],
) -> Result<(Bytes, String), Error> {
match self {
Self::Remote(p) => p.audio_transcription(model, fields).await,
}
}
fn is_openai_compat(&self) -> bool {
match self {
Self::Remote(p) => p.is_openai_compat(),
}
}
fn is_anthropic_compat(&self) -> bool {
match self {
Self::Remote(p) => p.is_anthropic_compat(),
}
}
async fn chat_completion_raw(&self, model: &str, raw_body: Bytes) -> Result<Bytes, Error> {
match self {
Self::Remote(p) => p.chat_completion_raw(model, raw_body).await,
}
}
async fn anthropic_messages_raw(&self, raw_body: Bytes) -> Result<Bytes, Error> {
match self {
Self::Remote(p) => p.anthropic_messages_raw(raw_body).await,
}
}
}
async fn serve(config_path: PathBuf, bind: Option<String>) {
if !config_path.exists() {
tracing::info!(
path = %config_path.display(),
"config not found — generating a starter config"
);
init(config_path.clone(), false);
}
let mut config = match GatewayConfig::from_file(&config_path) {
Ok(c) => c,
Err(e) => {
tracing::error!("failed to load config: {e}");
std::process::exit(1);
}
};
if let Some(bind) = bind {
config.listen = bind;
}
let storage_kind = config
.storage
.as_ref()
.map(|s| s.kind.as_str())
.unwrap_or("memory");
match storage_kind {
#[cfg(feature = "storage-redis")]
"redis" => {
let url = config
.storage
.as_ref()
.and_then(|s| s.path.as_deref())
.unwrap_or("redis://127.0.0.1:6379");
let storage = match crabllm_proxy::storage::RedisStorage::open(url).await {
Ok(s) => Arc::new(s),
Err(e) => {
tracing::error!("failed to open redis storage: {e}");
std::process::exit(1);
}
};
run(config, config_path.clone(), storage).await;
}
#[cfg(not(feature = "storage-redis"))]
"redis" => {
tracing::error!("redis storage requires the 'storage-redis' feature");
std::process::exit(1);
}
#[cfg(feature = "storage-sqlite")]
"sqlite" => {
let path = config
.storage
.as_ref()
.and_then(|s| s.path.as_deref())
.unwrap_or("crabllm.db");
let url = format!("sqlite:{path}?mode=rwc");
let storage = match crabllm_proxy::storage::SqliteStorage::open(&url).await {
Ok(s) => Arc::new(s),
Err(e) => {
tracing::error!("failed to open sqlite storage: {e}");
std::process::exit(1);
}
};
run(config, config_path.clone(), storage).await;
}
#[cfg(not(feature = "storage-sqlite"))]
"sqlite" => {
tracing::error!("sqlite storage requires the 'storage-sqlite' feature");
std::process::exit(1);
}
_ => {
let storage = Arc::new(MemoryStorage::new());
run(config, config_path.clone(), storage).await;
}
}
}
async fn run<S: Storage + 'static>(
mut config: GatewayConfig,
config_path: PathBuf,
storage: Arc<S>,
) {
crabllm_proxy::admin_providers::merge_stored_providers(
storage.as_ref() as &dyn Storage,
&mut config,
)
.await;
let registry: ProviderRegistry<Dispatch> =
match ProviderRegistry::from_config(&config, Dispatch::Remote) {
Ok(r) => r,
Err(e) => {
tracing::error!("failed to build provider registry: {e}");
std::process::exit(1);
}
};
let (extensions, mut admin_routes) =
match build_extensions(&config, storage.clone() as Arc<dyn Storage>) {
Ok(result) => result,
Err(e) => {
tracing::error!("failed to build extensions: {e}");
std::process::exit(1);
}
};
let handle = metrics_exporter_prometheus::PrometheusBuilder::new()
.install_recorder()
.expect("failed to install metrics recorder");
admin_routes.push(axum::Router::new().route(
"/metrics",
axum::routing::get(move || async move { handle.render() }),
));
let ext_count = extensions.len();
let addr = config.listen.clone();
let model_count = registry.model_names().count();
let provider_count = registry.provider_count();
let shutdown_timeout = Duration::from_secs(config.shutdown_timeout);
let registry = Arc::new(ArcSwap::from_pointee(registry));
let key_map: HashMap<String, String> = config
.keys
.iter()
.map(|k| (k.key.clone(), k.name.clone()))
.collect();
let key_map = Arc::new(RwLock::new(key_map));
for kc in &config.keys {
let skey = crabllm_core::storage_key(&crabllm_proxy::PREFIX_KEYS, kc.name.as_bytes());
if let Ok(value) = serde_json::to_vec(kc) {
let _ = storage.set(&skey, value).await;
}
}
crabllm_proxy::admin::load_stored_keys(
storage.as_ref() as &dyn crabllm_core::Storage,
&config.keys,
&key_map,
)
.await;
if let Some(ref admin_token) = config.admin_token {
admin_routes.push(crabllm_proxy::admin::key_admin_routes(
storage.clone() as Arc<dyn crabllm_core::Storage>,
key_map.clone(),
admin_token.clone(),
config.keys.clone(),
));
let rebuilder: crabllm_proxy::admin_providers::Rebuilder<Dispatch> =
Arc::new(|config: &GatewayConfig| {
ProviderRegistry::from_config(config, Dispatch::Remote)
});
admin_routes.push(crabllm_proxy::admin_providers::provider_admin_routes(
registry.clone(),
config_path,
admin_token.clone(),
rebuilder,
storage.clone() as Arc<dyn crabllm_core::Storage>,
));
}
#[cfg(feature = "openapi")]
let enable_openapi = config.openapi;
let state: AppState<S, Dispatch> = AppState {
registry,
config,
extensions: Arc::new(extensions),
storage,
key_map,
usage_events: None,
};
#[allow(unused_mut)]
let mut app = crabllm_proxy::router(state, admin_routes);
#[cfg(feature = "openapi")]
if enable_openapi {
use utoipa_scalar::Servable;
let spec = crabllm_proxy::openapi::spec();
app = app
.merge(utoipa_scalar::Scalar::with_url("/docs", spec.clone()))
.route(
"/openapi.json",
axum::routing::get(move || async { axum::Json(spec) }),
);
tracing::info!("openapi docs enabled at /docs");
}
let app = app.layer(axum::middleware::from_fn(crabllm_proxy::log_request));
let listener = match tokio::net::TcpListener::bind(&addr).await {
Ok(l) => l,
Err(e) => {
tracing::error!("failed to bind to {addr}: {e}");
std::process::exit(1);
}
};
tracing::info!(
addr = %addr,
models = model_count,
providers = provider_count,
extensions = ext_count,
"crabllm listening"
);
let server = axum::serve(NoDelayListener(listener), app)
.with_graceful_shutdown(shutdown_signal(shutdown_timeout));
if let Err(e) = server.await {
tracing::error!("server failed: {e}");
}
}
async fn shutdown_signal(drain_timeout: Duration) {
let ctrl_c = tokio::signal::ctrl_c();
#[cfg(unix)]
{
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("failed to install SIGTERM handler");
tokio::select! {
_ = ctrl_c => {}
_ = sigterm.recv() => {}
}
}
#[cfg(not(unix))]
ctrl_c.await.ok();
tracing::info!(
timeout_secs = drain_timeout.as_secs(),
"shutdown signal received, draining connections"
);
tokio::spawn(async move {
tokio::time::sleep(drain_timeout).await;
tracing::warn!("drain timeout exceeded, forcing exit");
std::process::exit(0);
});
}
struct NoDelayListener(tokio::net::TcpListener);
impl axum::serve::Listener for NoDelayListener {
type Io = tokio::net::TcpStream;
type Addr = std::net::SocketAddr;
async fn accept(&mut self) -> (Self::Io, Self::Addr) {
loop {
match self.0.accept().await {
Ok((stream, addr)) => {
let _ = stream.set_nodelay(true);
return (stream, addr);
}
Err(_) => continue,
}
}
}
fn local_addr(&self) -> std::io::Result<Self::Addr> {
self.0.local_addr()
}
}
type Extensions = (Vec<Box<dyn Extension>>, Vec<axum::Router>);
fn build_extensions(
config: &GatewayConfig,
storage: Arc<dyn Storage>,
) -> Result<Extensions, String> {
let mut extensions: Vec<Box<dyn Extension>> = Vec::new();
let mut admin_routes: Vec<axum::Router> = Vec::new();
let ext_table = match &config.extensions {
Some(serde_json::Value::Object(t)) => t,
Some(_) => return Err("[extensions] must be a table".to_string()),
None => return Ok((extensions, admin_routes)),
};
for (name, value) in ext_table {
match name.as_str() {
"rate_limit" => {
let ext = RateLimit::new(value, storage.clone())?;
extensions.push(Box::new(ext));
}
"usage" => {
let ext = UsageTracker::new(value, storage.clone())?;
admin_routes.push(ext.admin_routes());
extensions.push(Box::new(ext));
}
"cache" => {
let ext = Cache::new(value, storage.clone())?;
admin_routes.push(ext.admin_routes());
extensions.push(Box::new(ext));
}
"budget" => {
let ext = Budget::new(value, storage.clone(), config.models.clone())?;
admin_routes.push(ext.admin_routes());
extensions.push(Box::new(ext));
}
"logging" => {
let ext = RequestLogger::new(value)?;
extensions.push(Box::new(ext));
}
"audit" => {
let ext = AuditLogger::new(value, storage.clone(), config.models.clone())?;
admin_routes.push(ext.admin_routes());
extensions.push(Box::new(ext));
}
unknown => {
return Err(format!(
"unknown extension '{unknown}'. valid extensions: rate_limit, usage, cache, budget, logging, audit"
));
}
}
}
Ok((extensions, admin_routes))
}