llmux 0.7.7

Zero-reload model switching for vLLM - manages multiple models on shared GPU
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
//! Mock vLLM server for testing llmux
//!
//! Supports two modes:
//! 1. Direct: `mock-vllm --port 8001 --model test-model`
//! 2. vLLM-compatible: `mock-vllm serve model-name --port 8001 --gpu-memory-utilization 0.9 ...`
//!
//! Simulates vLLM sleep mode API endpoints for integration testing.

use axum::{
    Json, Router,
    extract::{Query, State},
    http::StatusCode,
    response::IntoResponse,
    routing::{get, post},
};
use clap::{Parser, Subcommand};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tracing::{info, warn};

#[derive(Parser, Debug)]
#[command(name = "mock-vllm")]
#[command(about = "Mock vLLM server for testing")]
struct Args {
    #[command(subcommand)]
    command: Option<Commands>,

    /// Port to listen on (direct mode)
    #[arg(short, long, default_value = "8001", global = true)]
    port: u16,

    /// Model name to serve (direct mode)
    #[arg(short, long, default_value = "test-model")]
    model: String,

    /// Artificial latency for responses (ms)
    #[arg(long, default_value = "50", global = true)]
    latency_ms: u64,

    /// Artificial startup delay (ms)
    #[arg(long, default_value = "0", global = true)]
    startup_delay_ms: u64,
}

#[derive(Subcommand, Debug)]
enum Commands {
    /// vLLM-compatible serve mode.
    /// Accepts and ignores vLLM-specific flags so it can stand in for real vLLM.
    Serve {
        /// Model to serve (positional, like vllm)
        model: String,

        /// Port (vLLM-style)
        #[arg(long)]
        port: Option<u16>,

        /// GPU memory utilization (ignored in mock, accepted for compatibility)
        #[arg(long, default_value = "0.9")]
        gpu_memory_utilization: f32,

        /// Tensor parallel size (ignored in mock)
        #[arg(long, default_value = "1")]
        tensor_parallel_size: usize,

        /// Data type (ignored in mock)
        #[arg(long, default_value = "auto")]
        dtype: String,

        /// Enable sleep mode (ignored in mock, always enabled)
        #[arg(long)]
        enable_sleep_mode: bool,

        /// Max model length (ignored in mock)
        #[arg(long)]
        max_model_len: Option<usize>,
    },
}

/// Server state
#[derive(Debug)]
struct MockState {
    model: String,
    sleeping: RwLock<bool>,
    sleep_level: RwLock<u8>,
    latency: RwLock<Duration>,
    request_count: RwLock<u64>,
    /// When true, /sleep returns 500 (for testing L3 fallback)
    fail_sleep: RwLock<bool>,
    /// When true, /wake_up returns 500 (for testing wake failure cleanup)
    fail_wake: RwLock<bool>,
    /// Artificial sleep delay in milliseconds (for testing timeouts)
    sleep_delay_ms: RwLock<u64>,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::fmt()
        .with_env_filter("mock_vllm=debug,tower_http=debug")
        .init();

    let args = Args::parse();

    // Determine model and port based on mode
    let (model, port) = match args.command {
        Some(Commands::Serve {
            model,
            port: serve_port,
            ..
        }) => {
            // vLLM-compatible mode: use serve subcommand's model and port
            let port = serve_port.unwrap_or(args.port);
            (model, port)
        }
        None => {
            // Direct mode: use top-level args
            (args.model, args.port)
        }
    };

    // Simulate startup delay
    if args.startup_delay_ms > 0 {
        info!(delay_ms = args.startup_delay_ms, "Simulating startup delay");
        tokio::time::sleep(Duration::from_millis(args.startup_delay_ms)).await;
    }

    let state = Arc::new(MockState {
        model: model.clone(),
        sleeping: RwLock::new(false),
        sleep_level: RwLock::new(0),
        latency: RwLock::new(Duration::from_millis(args.latency_ms)),
        request_count: RwLock::new(0),
        fail_sleep: RwLock::new(false),
        fail_wake: RwLock::new(false),
        sleep_delay_ms: RwLock::new(0),
    });

    let app = Router::new()
        .route("/health", get(health))
        .route("/sleep", post(sleep))
        .route("/wake_up", post(wake_up))
        .route("/collective_rpc", post(collective_rpc))
        .route("/reset_prefix_cache", post(reset_prefix_cache))
        .route("/v1/chat/completions", post(chat_completions))
        .route("/v1/models", get(list_models))
        .route("/stats", get(stats))
        .route("/control/fail-sleep", post(control_fail_sleep))
        .route("/control/fail-wake", post(control_fail_wake))
        .route("/control/sleep-delay", post(control_sleep_delay))
        .route("/control/latency", post(control_latency))
        .with_state(state);

    let addr = format!("0.0.0.0:{}", port);
    let listener = TcpListener::bind(&addr).await?;

    // Get the actual port (important when port=0 for dynamic allocation)
    let actual_port = listener.local_addr()?.port();

    info!(
        model = %model,
        port = actual_port,
        "Mock vLLM server listening"
    );

    // Signal readiness to stdout for test harness
    // Format: "READY <port>" on its own line
    println!("READY {}", actual_port);

    axum::serve(listener, app).await?;
    Ok(())
}

/// Health check endpoint
async fn health(State(state): State<Arc<MockState>>) -> impl IntoResponse {
    let sleeping = *state.sleeping.read().await;
    if sleeping {
        info!("Health check: sleeping");
        // vLLM still returns healthy when sleeping
    }
    StatusCode::OK
}

#[derive(Deserialize)]
struct SleepQuery {
    level: Option<u8>,
}

/// Sleep endpoint - PUT model to sleep
async fn sleep(
    State(state): State<Arc<MockState>>,
    Query(query): Query<SleepQuery>,
) -> impl IntoResponse {
    let level = query.level.unwrap_or(1);
    info!(level = level, "Putting model to sleep");

    // Check if sleep should fail (for testing L3 fallback)
    if *state.fail_sleep.read().await {
        warn!("Sleep forced to fail via /control/fail-sleep");
        return StatusCode::INTERNAL_SERVER_ERROR;
    }

    // Apply artificial sleep delay (for testing timeouts)
    let delay = *state.sleep_delay_ms.read().await;
    if delay > 0 {
        info!(delay_ms = delay, "Applying artificial sleep delay");
        tokio::time::sleep(Duration::from_millis(delay)).await;
    }

    *state.sleeping.write().await = true;
    *state.sleep_level.write().await = level;

    StatusCode::OK
}

/// Wake up endpoint
async fn wake_up(State(state): State<Arc<MockState>>) -> impl IntoResponse {
    // Check if wake should fail (for testing wake failure cleanup)
    if *state.fail_wake.read().await {
        warn!("Wake forced to fail via /control/fail-wake");
        return StatusCode::INTERNAL_SERVER_ERROR;
    }

    info!("Waking up model");
    *state.sleeping.write().await = false;
    StatusCode::OK
}

#[derive(Deserialize)]
struct CollectiveRpcRequest {
    method: String,
}

/// Collective RPC endpoint (for weight reloading)
async fn collective_rpc(
    State(_state): State<Arc<MockState>>,
    Json(request): Json<CollectiveRpcRequest>,
) -> impl IntoResponse {
    info!(method = %request.method, "Collective RPC call");

    if request.method == "reload_weights" {
        // Simulate weight reload time
        tokio::time::sleep(Duration::from_millis(100)).await;
    }

    StatusCode::OK
}

/// Reset prefix cache endpoint
async fn reset_prefix_cache() -> impl IntoResponse {
    info!("Resetting prefix cache");
    StatusCode::OK
}

#[derive(Deserialize)]
struct ChatCompletionRequest {
    model: String,
    messages: Vec<Message>,
    #[serde(default)]
    stream: bool,
    #[serde(default = "default_max_tokens")]
    #[allow(dead_code)] // Parsed but not used in mock response
    max_tokens: u32,
    #[serde(default)]
    temperature: Option<f64>,
    #[serde(default)]
    seed: Option<u64>,
}

fn default_max_tokens() -> u32 {
    100
}

#[derive(Deserialize, Serialize)]
struct Message {
    role: String,
    content: String,
}

#[derive(Serialize)]
struct ChatCompletionResponse {
    id: String,
    object: String,
    created: u64,
    model: String,
    choices: Vec<Choice>,
    usage: Usage,
}

#[derive(Serialize)]
struct Choice {
    index: u32,
    message: Message,
    finish_reason: String,
}

#[derive(Serialize)]
struct Usage {
    prompt_tokens: u32,
    completion_tokens: u32,
    total_tokens: u32,
}

/// Chat completions endpoint
async fn chat_completions(
    State(state): State<Arc<MockState>>,
    Json(request): Json<ChatCompletionRequest>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
    // Check if sleeping
    if *state.sleeping.read().await {
        warn!(model = %request.model, "Request received while model is sleeping");
        return Err((
            StatusCode::SERVICE_UNAVAILABLE,
            "Model is sleeping".to_string(),
        ));
    }

    // Simulate latency
    tokio::time::sleep(*state.latency.read().await).await;

    // Increment request count
    {
        let mut count = state.request_count.write().await;
        *count += 1;
    }

    let count = *state.request_count.read().await;

    info!(
        model = %request.model,
        messages = request.messages.len(),
        stream = request.stream,
        request_num = count,
        "Processing chat completion"
    );

    if request.stream {
        // For now, return non-streaming response even if stream requested
        // A full implementation would return SSE
        warn!("Streaming requested but returning non-streaming response");
    }

    // Generate mock response
    // Deterministic mode: when temperature == 0.0 and seed is set, always return "4"
    let deterministic = request.temperature == Some(0.0) && request.seed.is_some();
    let response_content = if deterministic {
        "4".to_string()
    } else {
        format!(
            "Mock response from {} (request #{}): You said \"{}\"",
            state.model,
            count,
            request
                .messages
                .last()
                .map(|m| m.content.as_str())
                .unwrap_or("")
        )
    };

    let response = ChatCompletionResponse {
        id: format!("chatcmpl-mock-{}", count),
        object: "chat.completion".to_string(),
        created: std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs(),
        model: state.model.clone(),
        choices: vec![Choice {
            index: 0,
            message: Message {
                role: "assistant".to_string(),
                content: response_content,
            },
            finish_reason: "stop".to_string(),
        }],
        usage: Usage {
            prompt_tokens: 10,
            completion_tokens: 20,
            total_tokens: 30,
        },
    };

    Ok(Json(response))
}

#[derive(Serialize)]
struct ModelsResponse {
    object: String,
    data: Vec<ModelInfo>,
}

#[derive(Serialize)]
struct ModelInfo {
    id: String,
    object: String,
    owned_by: String,
}

/// List models endpoint
async fn list_models(State(state): State<Arc<MockState>>) -> impl IntoResponse {
    let response = ModelsResponse {
        object: "list".to_string(),
        data: vec![ModelInfo {
            id: state.model.clone(),
            object: "model".to_string(),
            owned_by: "mock-vllm".to_string(),
        }],
    };

    Json(response)
}

#[derive(Serialize)]
struct StatsResponse {
    model: String,
    sleeping: bool,
    sleep_level: u8,
    request_count: u64,
}

/// Stats endpoint for testing inspection
async fn stats(State(state): State<Arc<MockState>>) -> impl IntoResponse {
    let response = StatsResponse {
        model: state.model.clone(),
        sleeping: *state.sleeping.read().await,
        sleep_level: *state.sleep_level.read().await,
        request_count: *state.request_count.read().await,
    };

    Json(response)
}

#[derive(Deserialize)]
struct ControlFailSleep {
    enabled: bool,
}

/// Control endpoint: make /sleep return 500
async fn control_fail_sleep(
    State(state): State<Arc<MockState>>,
    Json(body): Json<ControlFailSleep>,
) -> impl IntoResponse {
    info!(enabled = body.enabled, "Setting fail_sleep");
    *state.fail_sleep.write().await = body.enabled;
    StatusCode::OK
}

#[derive(Deserialize)]
struct ControlSleepDelay {
    delay_ms: u64,
}

#[derive(Deserialize)]
struct ControlFailWake {
    enabled: bool,
}

/// Control endpoint: make /wake_up return 500
async fn control_fail_wake(
    State(state): State<Arc<MockState>>,
    Json(body): Json<ControlFailWake>,
) -> impl IntoResponse {
    info!(enabled = body.enabled, "Setting fail_wake");
    *state.fail_wake.write().await = body.enabled;
    StatusCode::OK
}

/// Control endpoint: set artificial sleep delay
async fn control_sleep_delay(
    State(state): State<Arc<MockState>>,
    Json(body): Json<ControlSleepDelay>,
) -> impl IntoResponse {
    info!(delay_ms = body.delay_ms, "Setting sleep_delay_ms");
    *state.sleep_delay_ms.write().await = body.delay_ms;
    StatusCode::OK
}

#[derive(Deserialize)]
struct ControlLatency {
    latency_ms: u64,
}

/// Control endpoint: set request latency
async fn control_latency(
    State(state): State<Arc<MockState>>,
    Json(body): Json<ControlLatency>,
) -> impl IntoResponse {
    info!(latency_ms = body.latency_ms, "Setting latency");
    *state.latency.write().await = Duration::from_millis(body.latency_ms);
    StatusCode::OK
}