use axum::{
Json, Router,
extract::{Query, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
};
use clap::{Parser, Subcommand};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tracing::{info, warn};
#[derive(Parser, Debug)]
#[command(name = "mock-vllm")]
#[command(about = "Mock vLLM server for testing")]
struct Args {
#[command(subcommand)]
command: Option<Commands>,
#[arg(short, long, default_value = "8001", global = true)]
port: u16,
#[arg(short, long, default_value = "test-model")]
model: String,
#[arg(long, default_value = "50", global = true)]
latency_ms: u64,
#[arg(long, default_value = "0", global = true)]
startup_delay_ms: u64,
}
#[derive(Subcommand, Debug)]
enum Commands {
Serve {
model: String,
#[arg(long)]
port: Option<u16>,
#[arg(long, default_value = "0.9")]
gpu_memory_utilization: f32,
#[arg(long, default_value = "1")]
tensor_parallel_size: usize,
#[arg(long, default_value = "auto")]
dtype: String,
#[arg(long)]
enable_sleep_mode: bool,
#[arg(long)]
max_model_len: Option<usize>,
},
}
#[derive(Debug)]
struct MockState {
model: String,
sleeping: RwLock<bool>,
sleep_level: RwLock<u8>,
latency: RwLock<Duration>,
request_count: RwLock<u64>,
fail_sleep: RwLock<bool>,
fail_wake: RwLock<bool>,
sleep_delay_ms: RwLock<u64>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter("mock_vllm=debug,tower_http=debug")
.init();
let args = Args::parse();
let (model, port) = match args.command {
Some(Commands::Serve {
model,
port: serve_port,
..
}) => {
let port = serve_port.unwrap_or(args.port);
(model, port)
}
None => {
(args.model, args.port)
}
};
if args.startup_delay_ms > 0 {
info!(delay_ms = args.startup_delay_ms, "Simulating startup delay");
tokio::time::sleep(Duration::from_millis(args.startup_delay_ms)).await;
}
let state = Arc::new(MockState {
model: model.clone(),
sleeping: RwLock::new(false),
sleep_level: RwLock::new(0),
latency: RwLock::new(Duration::from_millis(args.latency_ms)),
request_count: RwLock::new(0),
fail_sleep: RwLock::new(false),
fail_wake: RwLock::new(false),
sleep_delay_ms: RwLock::new(0),
});
let app = Router::new()
.route("/health", get(health))
.route("/sleep", post(sleep))
.route("/wake_up", post(wake_up))
.route("/collective_rpc", post(collective_rpc))
.route("/reset_prefix_cache", post(reset_prefix_cache))
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.route("/stats", get(stats))
.route("/control/fail-sleep", post(control_fail_sleep))
.route("/control/fail-wake", post(control_fail_wake))
.route("/control/sleep-delay", post(control_sleep_delay))
.route("/control/latency", post(control_latency))
.with_state(state);
let addr = format!("0.0.0.0:{}", port);
let listener = TcpListener::bind(&addr).await?;
let actual_port = listener.local_addr()?.port();
info!(
model = %model,
port = actual_port,
"Mock vLLM server listening"
);
println!("READY {}", actual_port);
axum::serve(listener, app).await?;
Ok(())
}
async fn health(State(state): State<Arc<MockState>>) -> impl IntoResponse {
let sleeping = *state.sleeping.read().await;
if sleeping {
info!("Health check: sleeping");
}
StatusCode::OK
}
#[derive(Deserialize)]
struct SleepQuery {
level: Option<u8>,
}
async fn sleep(
State(state): State<Arc<MockState>>,
Query(query): Query<SleepQuery>,
) -> impl IntoResponse {
let level = query.level.unwrap_or(1);
info!(level = level, "Putting model to sleep");
if *state.fail_sleep.read().await {
warn!("Sleep forced to fail via /control/fail-sleep");
return StatusCode::INTERNAL_SERVER_ERROR;
}
let delay = *state.sleep_delay_ms.read().await;
if delay > 0 {
info!(delay_ms = delay, "Applying artificial sleep delay");
tokio::time::sleep(Duration::from_millis(delay)).await;
}
*state.sleeping.write().await = true;
*state.sleep_level.write().await = level;
StatusCode::OK
}
async fn wake_up(State(state): State<Arc<MockState>>) -> impl IntoResponse {
if *state.fail_wake.read().await {
warn!("Wake forced to fail via /control/fail-wake");
return StatusCode::INTERNAL_SERVER_ERROR;
}
info!("Waking up model");
*state.sleeping.write().await = false;
StatusCode::OK
}
#[derive(Deserialize)]
struct CollectiveRpcRequest {
method: String,
}
async fn collective_rpc(
State(_state): State<Arc<MockState>>,
Json(request): Json<CollectiveRpcRequest>,
) -> impl IntoResponse {
info!(method = %request.method, "Collective RPC call");
if request.method == "reload_weights" {
tokio::time::sleep(Duration::from_millis(100)).await;
}
StatusCode::OK
}
async fn reset_prefix_cache() -> impl IntoResponse {
info!("Resetting prefix cache");
StatusCode::OK
}
#[derive(Deserialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<Message>,
#[serde(default)]
stream: bool,
#[serde(default = "default_max_tokens")]
#[allow(dead_code)] max_tokens: u32,
#[serde(default)]
temperature: Option<f64>,
#[serde(default)]
seed: Option<u64>,
}
fn default_max_tokens() -> u32 {
100
}
#[derive(Deserialize, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Serialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<Choice>,
usage: Usage,
}
#[derive(Serialize)]
struct Choice {
index: u32,
message: Message,
finish_reason: String,
}
#[derive(Serialize)]
struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
async fn chat_completions(
State(state): State<Arc<MockState>>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
if *state.sleeping.read().await {
warn!(model = %request.model, "Request received while model is sleeping");
return Err((
StatusCode::SERVICE_UNAVAILABLE,
"Model is sleeping".to_string(),
));
}
tokio::time::sleep(*state.latency.read().await).await;
{
let mut count = state.request_count.write().await;
*count += 1;
}
let count = *state.request_count.read().await;
info!(
model = %request.model,
messages = request.messages.len(),
stream = request.stream,
request_num = count,
"Processing chat completion"
);
if request.stream {
warn!("Streaming requested but returning non-streaming response");
}
let deterministic = request.temperature == Some(0.0) && request.seed.is_some();
let response_content = if deterministic {
"4".to_string()
} else {
format!(
"Mock response from {} (request #{}): You said \"{}\"",
state.model,
count,
request
.messages
.last()
.map(|m| m.content.as_str())
.unwrap_or("")
)
};
let response = ChatCompletionResponse {
id: format!("chatcmpl-mock-{}", count),
object: "chat.completion".to_string(),
created: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
model: state.model.clone(),
choices: vec![Choice {
index: 0,
message: Message {
role: "assistant".to_string(),
content: response_content,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens: 10,
completion_tokens: 20,
total_tokens: 30,
},
};
Ok(Json(response))
}
#[derive(Serialize)]
struct ModelsResponse {
object: String,
data: Vec<ModelInfo>,
}
#[derive(Serialize)]
struct ModelInfo {
id: String,
object: String,
owned_by: String,
}
async fn list_models(State(state): State<Arc<MockState>>) -> impl IntoResponse {
let response = ModelsResponse {
object: "list".to_string(),
data: vec![ModelInfo {
id: state.model.clone(),
object: "model".to_string(),
owned_by: "mock-vllm".to_string(),
}],
};
Json(response)
}
#[derive(Serialize)]
struct StatsResponse {
model: String,
sleeping: bool,
sleep_level: u8,
request_count: u64,
}
async fn stats(State(state): State<Arc<MockState>>) -> impl IntoResponse {
let response = StatsResponse {
model: state.model.clone(),
sleeping: *state.sleeping.read().await,
sleep_level: *state.sleep_level.read().await,
request_count: *state.request_count.read().await,
};
Json(response)
}
#[derive(Deserialize)]
struct ControlFailSleep {
enabled: bool,
}
async fn control_fail_sleep(
State(state): State<Arc<MockState>>,
Json(body): Json<ControlFailSleep>,
) -> impl IntoResponse {
info!(enabled = body.enabled, "Setting fail_sleep");
*state.fail_sleep.write().await = body.enabled;
StatusCode::OK
}
#[derive(Deserialize)]
struct ControlSleepDelay {
delay_ms: u64,
}
#[derive(Deserialize)]
struct ControlFailWake {
enabled: bool,
}
async fn control_fail_wake(
State(state): State<Arc<MockState>>,
Json(body): Json<ControlFailWake>,
) -> impl IntoResponse {
info!(enabled = body.enabled, "Setting fail_wake");
*state.fail_wake.write().await = body.enabled;
StatusCode::OK
}
async fn control_sleep_delay(
State(state): State<Arc<MockState>>,
Json(body): Json<ControlSleepDelay>,
) -> impl IntoResponse {
info!(delay_ms = body.delay_ms, "Setting sleep_delay_ms");
*state.sleep_delay_ms.write().await = body.delay_ms;
StatusCode::OK
}
#[derive(Deserialize)]
struct ControlLatency {
latency_ms: u64,
}
async fn control_latency(
State(state): State<Arc<MockState>>,
Json(body): Json<ControlLatency>,
) -> impl IntoResponse {
info!(latency_ms = body.latency_ms, "Setting latency");
*state.latency.write().await = Duration::from_millis(body.latency_ms);
StatusCode::OK
}