use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::routing::post;
use axum::{Json, Router};
use http_body_util::BodyExt;
use llmux::{Config, ModelConfig, PolicyConfig};
use serde_json::{Value, json};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use tokio::net::TcpListener;
use tower::ServiceExt;
struct MockHooks {
wake: String,
sleep: String,
alive: String,
}
impl MockHooks {
fn new(wake_ms: u64, sleep_ms: u64) -> Self {
Self {
wake: format!("sleep {}", wake_ms as f64 / 1000.0),
sleep: format!("sleep {}", sleep_ms as f64 / 1000.0),
alive: "true".to_string(),
}
}
}
async fn spawn_mock_backend(port: u16) -> (SocketAddr, Arc<AtomicUsize>) {
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let app = Router::new().route(
"/v1/chat/completions",
post(move |Json(body): Json<Value>| {
let c = counter_clone.fetch_add(1, Ordering::SeqCst);
let model = body
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
async move {
Json(json!({
"model": model,
"request_number": c,
"choices": [{"message": {"content": "hello"}}]
}))
}
}),
);
let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
tokio::time::sleep(Duration::from_millis(10)).await;
(addr, counter)
}
fn test_config(
model_a_port: u16,
model_b_port: u16,
hooks_a: &MockHooks,
hooks_b: &MockHooks,
) -> Config {
let mut models = HashMap::new();
models.insert(
"model-a".to_string(),
ModelConfig {
port: model_a_port,
wake: hooks_a.wake.clone(),
sleep: hooks_a.sleep.clone(),
alive: hooks_a.alive.clone(),
},
);
models.insert(
"model-b".to_string(),
ModelConfig {
port: model_b_port,
wake: hooks_b.wake.clone(),
sleep: hooks_b.sleep.clone(),
alive: hooks_b.alive.clone(),
},
);
Config {
models,
policy: PolicyConfig {
request_timeout_secs: Some(30),
drain_before_switch: true,
min_active_secs: 0,
},
port: 0,
}
}
async fn chat_request(app: &Router, model: &str) -> (StatusCode, Value) {
let body = json!({
"model": model,
"messages": [{"role": "user", "content": "hi"}]
});
let req = Request::builder()
.method("POST")
.uri("/v1/chat/completions")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&body).unwrap()))
.unwrap();
let response = app.clone().oneshot(req).await.unwrap();
let status = response.status();
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let json: Value = serde_json::from_slice(&body_bytes)
.unwrap_or(json!({"raw": String::from_utf8_lossy(&body_bytes).to_string()}));
(status, json)
}
#[tokio::test]
async fn test_single_model_request() {
let hooks_a = MockHooks::new(0, 0); let hooks_b = MockHooks::new(0, 0);
let (addr_a, counter_a) = spawn_mock_backend(0).await;
let (addr_b, _counter_b) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _switcher) = llmux::build_app(config).await.unwrap();
let (status, body) = chat_request(&app, "model-a").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body["model"], "model-a");
assert_eq!(counter_a.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_model_switch() {
let hooks_a = MockHooks::new(10, 10); let hooks_b = MockHooks::new(10, 10);
let (addr_a, counter_a) = spawn_mock_backend(0).await;
let (addr_b, counter_b) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _switcher) = llmux::build_app(config).await.unwrap();
let (status, body) = chat_request(&app, "model-a").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body["model"], "model-a");
let (status, body) = chat_request(&app, "model-b").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body["model"], "model-b");
assert_eq!(counter_a.load(Ordering::SeqCst), 1);
assert_eq!(counter_b.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_same_model_no_switch() {
let hooks_a = MockHooks::new(10, 10);
let hooks_b = MockHooks::new(10, 10);
let (addr_a, counter_a) = spawn_mock_backend(0).await;
let (addr_b, counter_b) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _switcher) = llmux::build_app(config).await.unwrap();
for _ in 0..5 {
let (status, body) = chat_request(&app, "model-a").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body["model"], "model-a");
}
assert_eq!(counter_a.load(Ordering::SeqCst), 5);
assert_eq!(counter_b.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_unknown_model() {
let hooks_a = MockHooks::new(0, 0);
let hooks_b = MockHooks::new(0, 0);
let (addr_a, _) = spawn_mock_backend(0).await;
let (addr_b, _) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _) = llmux::build_app(config).await.unwrap();
let (status, body) = chat_request(&app, "nonexistent").await;
assert_eq!(status, StatusCode::NOT_FOUND);
assert!(
body["error"]["message"]
.as_str()
.unwrap()
.contains("not found")
);
}
#[tokio::test]
async fn test_switch_timing() {
let hooks_a = MockHooks::new(100, 50); let hooks_b = MockHooks::new(100, 50);
let (addr_a, _) = spawn_mock_backend(0).await;
let (addr_b, _) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _switcher) = llmux::build_app(config).await.unwrap();
let t0 = Instant::now();
let (status, _) = chat_request(&app, "model-a").await;
let cold_start = t0.elapsed();
assert_eq!(status, StatusCode::OK);
assert!(
cold_start >= Duration::from_millis(80),
"cold start took {:?}",
cold_start
);
let t1 = Instant::now();
let (status, _) = chat_request(&app, "model-b").await;
let switch_time = t1.elapsed();
assert_eq!(status, StatusCode::OK);
assert!(
switch_time >= Duration::from_millis(120),
"switch took {:?}, expected >= 120ms (sleep + wake)",
switch_time
);
}
#[tokio::test]
async fn test_concurrent_same_model() {
let hooks_a = MockHooks::new(50, 10);
let hooks_b = MockHooks::new(50, 10);
let (addr_a, counter_a) = spawn_mock_backend(0).await;
let (addr_b, _) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _) = llmux::build_app(config).await.unwrap();
let mut handles = Vec::new();
for _ in 0..10 {
let app = app.clone();
handles.push(tokio::spawn(
async move { chat_request(&app, "model-a").await },
));
}
for handle in handles {
let (status, body) = handle.await.unwrap();
assert_eq!(status, StatusCode::OK);
assert_eq!(body["model"], "model-a");
}
assert_eq!(counter_a.load(Ordering::SeqCst), 10);
}
#[tokio::test]
async fn test_concurrent_different_models() {
let hooks_a = MockHooks::new(50, 10);
let hooks_b = MockHooks::new(50, 10);
let (addr_a, counter_a) = spawn_mock_backend(0).await;
let (addr_b, counter_b) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _) = llmux::build_app(config).await.unwrap();
let mut handles = Vec::new();
for i in 0..6 {
let app = app.clone();
let model = if i % 2 == 0 { "model-a" } else { "model-b" };
handles.push(tokio::spawn(async move { chat_request(&app, model).await }));
}
let mut statuses = Vec::new();
for handle in handles {
let (status, _) = handle.await.unwrap();
statuses.push(status);
}
assert!(
statuses.iter().all(|s| *s == StatusCode::OK),
"Some requests failed: {:?}",
statuses
);
let total = counter_a.load(Ordering::SeqCst) + counter_b.load(Ordering::SeqCst);
assert_eq!(total, 6);
}
#[tokio::test]
async fn test_list_models() {
let hooks_a = MockHooks::new(0, 0);
let hooks_b = MockHooks::new(0, 0);
let (addr_a, _) = spawn_mock_backend(0).await;
let (addr_b, _) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, _) = llmux::build_app(config).await.unwrap();
let req = Request::builder()
.method("GET")
.uri("/v1/models")
.body(Body::empty())
.unwrap();
let response = app.clone().oneshot(req).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body: Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["object"], "list");
let data = body["data"].as_array().unwrap();
assert_eq!(data.len(), 2);
let ids: Vec<&str> = data.iter().map(|m| m["id"].as_str().unwrap()).collect();
assert_eq!(ids, vec!["model-a", "model-b"]);
for model in data {
assert_eq!(model["object"], "model");
assert_eq!(model["owned_by"], "llmux");
}
}
#[tokio::test]
async fn test_switch_cost_tracking() {
let hooks_a = MockHooks::new(50, 20); let hooks_b = MockHooks::new(80, 20);
let (addr_a, _) = spawn_mock_backend(0).await;
let (addr_b, _) = spawn_mock_backend(0).await;
let config = test_config(addr_a.port(), addr_b.port(), &hooks_a, &hooks_b);
let (app, switcher) = llmux::build_app(config).await.unwrap();
assert!(switcher.estimated_switch_cost(None, "model-a").is_none());
assert!(switcher.estimated_switch_cost(None, "model-b").is_none());
let (status, _) = chat_request(&app, "model-a").await;
assert_eq!(status, StatusCode::OK);
let cold_a = switcher.estimated_switch_cost(None, "model-a");
assert!(cold_a.is_some(), "cold start cost for model-a not recorded");
assert!(
cold_a.unwrap() >= Duration::from_millis(30),
"cold start cost {:?} too low",
cold_a
);
let (status, _) = chat_request(&app, "model-b").await;
assert_eq!(status, StatusCode::OK);
let a_to_b = switcher.estimated_switch_cost(Some("model-a"), "model-b");
assert!(a_to_b.is_some(), "a→b switch cost not recorded");
assert!(
a_to_b.unwrap() >= Duration::from_millis(70),
"a→b cost {:?} too low",
a_to_b
);
assert!(
switcher
.estimated_switch_cost(Some("model-b"), "model-a")
.is_none()
);
let (status, _) = chat_request(&app, "model-a").await;
assert_eq!(status, StatusCode::OK);
let b_to_a = switcher.estimated_switch_cost(Some("model-b"), "model-a");
assert!(b_to_a.is_some(), "b→a switch cost not recorded");
assert!(
a_to_b.unwrap() > b_to_a.unwrap(),
"expected a→b ({:?}) > b→a ({:?}) due to asymmetric wake times",
a_to_b,
b_to_a
);
}