use anyhow::{Context, Result};
use clap::Parser;
use llmux::Config;
use std::path::PathBuf;
use tokio::net::TcpListener;
use tracing::info;
use tracing_subscriber::EnvFilter;
#[derive(Parser, Debug)]
#[command(name = "llmux")]
#[command(about = "Zero-reload model switching for vLLM")]
struct Args {
#[arg(short, long, default_value = "config.json")]
config: PathBuf,
#[arg(short, long)]
port: Option<u16>,
#[arg(short, long)]
verbose: bool,
#[arg(long, value_name = "MODEL", conflicts_with_all = ["checkpoint", "restore_detached"])]
validate: Option<String>,
#[arg(
long,
value_name = "POLICIES",
value_delimiter = ',',
requires = "validate"
)]
policies: Vec<String>,
#[arg(long, value_name = "MODEL", conflicts_with_all = ["validate", "restore_detached"])]
checkpoint: Option<String>,
#[arg(long, value_name = "MODEL", conflicts_with_all = ["validate", "checkpoint"])]
restore_detached: Option<String>,
#[arg(long, value_name = "POLICY", requires = "checkpoint")]
eviction: Option<String>,
#[arg(long, requires = "checkpoint")]
no_warmup: bool,
}
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
let filter = if args.verbose {
EnvFilter::new("llmux=debug,onwards=debug,tower_http=debug,vllm=debug")
} else {
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"))
};
tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(true)
.init();
info!("Starting llmux");
let mut config = Config::from_file(&args.config)
.await
.with_context(|| format!("Failed to load config from {}", args.config.display()))?;
if let Some(port) = args.port {
config.port = port;
}
config.validate();
info!(
models = ?config.models.keys().collect::<Vec<_>>(),
port = config.port,
"Configuration loaded"
);
if let Some(model_name) = args.validate {
let policies = if args.policies.is_empty() {
None
} else {
Some(args.policies)
};
let success =
llmux::validate::run_validation(&config, &model_name, policies.as_deref()).await?;
std::process::exit(if success { 0 } else { 1 });
}
if let Some(model_name) = args.checkpoint {
let eviction = args.eviction.as_deref().unwrap_or("discard+checkpoint");
let success = llmux::validate::run_checkpoint(
&config,
&model_name,
eviction,
!args.no_warmup,
)
.await?;
std::process::exit(if success { 0 } else { 1 });
}
if let Some(model_name) = args.restore_detached {
let success = llmux::validate::run_restore(&config, &model_name).await?;
std::process::exit(if success { 0 } else { 1 });
}
let (app, metrics_router, control_router, switcher) = llmux::build_app(config.clone())
.await
.context("Failed to build application")?;
if config.warmup {
llmux::run_warmup(&switcher)
.await
.context("Warmup phase failed")?;
}
if let Some(metrics_router) = metrics_router {
let metrics_addr = format!("0.0.0.0:{}", config.metrics_port);
let metrics_listener = TcpListener::bind(&metrics_addr)
.await
.with_context(|| format!("Failed to bind metrics to {}", metrics_addr))?;
info!(addr = %metrics_addr, "Serving metrics");
tokio::spawn(async move {
if let Err(e) = axum::serve(metrics_listener, metrics_router).await {
tracing::error!(error = %e, "Metrics server error");
}
});
}
if let Some(admin_port) = config.admin_port {
let admin_addr = format!("0.0.0.0:{}", admin_port);
let admin_listener = TcpListener::bind(&admin_addr)
.await
.with_context(|| format!("Failed to bind admin API to {}", admin_addr))?;
info!(addr = %admin_addr, "Serving control API");
tokio::spawn(async move {
if let Err(e) = axum::serve(admin_listener, control_router).await {
tracing::error!(error = %e, "Admin server error");
}
});
}
let addr = format!("0.0.0.0:{}", config.port);
let listener = TcpListener::bind(&addr)
.await
.with_context(|| format!("Failed to bind to {}", addr))?;
info!(addr = %addr, "Listening for requests");
axum::serve(listener, app).await.context("Server error")?;
Ok(())
}