use axum::body::Body;
use axum::http::{Request, StatusCode};
use std::sync::Arc;
use tower::ServiceExt;
use oxibonsai_core::config::Qwen3Config;
use oxibonsai_runtime::engine::InferenceEngine;
use oxibonsai_runtime::metrics::InferenceMetrics;
use oxibonsai_runtime::sampling::SamplingParams;
use oxibonsai_runtime::server::{create_router, create_router_with_metrics};
fn test_router() -> axum::Router {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
create_router(engine, None)
}
#[tokio::test]
async fn health_returns_200() {
let app = test_router();
let req = Request::get("/health")
.body(Body::empty())
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
assert_eq!(&body[..], b"ok");
}
#[tokio::test]
async fn models_returns_json_with_model_list() {
let app = test_router();
let req = Request::get("/v1/models")
.body(Body::empty())
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let json: serde_json::Value = serde_json::from_slice(&body).expect("parse json");
assert_eq!(json["object"], "list");
let data = json["data"].as_array().expect("data should be array");
assert!(!data.is_empty(), "should have at least one model");
assert_eq!(data[0]["id"], "bonsai-8b");
assert_eq!(data[0]["object"], "model");
assert_eq!(data[0]["owned_by"], "oxibonsai");
}
#[tokio::test]
async fn chat_completions_valid_request() {
let app = test_router();
let body = serde_json::json!({
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 5,
"temperature": 0.0
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let json: serde_json::Value = serde_json::from_slice(&bytes).expect("parse json");
assert!(json["id"].is_string(), "should have id");
assert_eq!(json["object"], "chat.completion");
assert!(json["choices"].is_array(), "should have choices");
assert!(json["usage"].is_object(), "should have usage");
}
#[tokio::test]
async fn chat_completions_response_has_usage_field() {
let app = test_router();
let body = serde_json::json!({
"messages": [
{"role": "user", "content": "Hi"}
],
"max_tokens": 3
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let json: serde_json::Value = serde_json::from_slice(&bytes).expect("parse json");
let usage = &json["usage"];
assert!(
usage["prompt_tokens"].is_number(),
"should have prompt_tokens"
);
assert!(
usage["completion_tokens"].is_number(),
"should have completion_tokens"
);
assert!(
usage["total_tokens"].is_number(),
"should have total_tokens"
);
}
#[tokio::test]
async fn chat_completions_response_has_id_format() {
let app = test_router();
let body = serde_json::json!({
"messages": [
{"role": "user", "content": "Test"}
],
"max_tokens": 1
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let json: serde_json::Value = serde_json::from_slice(&bytes).expect("parse json");
let id = json["id"].as_str().expect("id should be string");
assert!(
id.starts_with("chatcmpl-"),
"id should start with chatcmpl-, got: {id}"
);
assert!(id.len() > 9, "id should have content after prefix");
}
#[tokio::test]
async fn chat_completions_choice_has_message() {
let app = test_router();
let body = serde_json::json!({
"messages": [
{"role": "user", "content": "Test"}
],
"max_tokens": 1
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let json: serde_json::Value = serde_json::from_slice(&bytes).expect("parse json");
let choices = json["choices"].as_array().expect("choices");
assert!(!choices.is_empty(), "should have at least one choice");
let choice = &choices[0];
assert_eq!(choice["index"], 0);
assert_eq!(choice["message"]["role"], "assistant");
assert!(choice["message"]["content"].is_string());
assert_eq!(choice["finish_reason"], "stop");
}
#[tokio::test]
async fn multi_turn_conversation_via_api() {
let app = test_router();
let body = serde_json::json!({
"messages": [
{"role": "system", "content": "You are a calculator."},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "And 3+3?"}
],
"max_tokens": 1
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn default_max_tokens_applied() {
let raw = serde_json::json!({
"messages": [
{"role": "user", "content": "Hello"}
]
});
let parsed: oxibonsai_runtime::server::ChatCompletionRequest =
serde_json::from_value(raw).expect("deserialize request");
assert_eq!(parsed.max_tokens, 256, "default max_tokens should be 256");
let app = test_router();
let body = serde_json::json!({
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 1
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
}
#[tokio::test]
async fn default_temperature_applied() {
let app = test_router();
let body = serde_json::json!({
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 1
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
}
#[test]
fn create_router_without_tokenizer() {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
let _router = create_router(engine, None);
}
#[tokio::test]
async fn metrics_endpoint_returns_prometheus_format() {
let app = test_router();
let req = Request::get("/metrics")
.body(Body::empty())
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let text = std::str::from_utf8(&body).expect("utf8");
assert!(
text.contains("# HELP"),
"metrics body should contain '# HELP' lines, got: {text}"
);
assert!(
text.contains("# TYPE"),
"metrics body should contain '# TYPE' lines, got: {text}"
);
assert!(
text.contains("oxibonsai_requests_total"),
"metrics should include requests_total"
);
assert!(
text.contains("oxibonsai_tokens_generated_total"),
"metrics should include tokens_generated_total"
);
}
#[tokio::test]
async fn metrics_content_type_is_prometheus() {
let app = test_router();
let req = Request::get("/metrics")
.body(Body::empty())
.expect("request");
let resp = app.oneshot(req).await.expect("response");
let ct = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
ct.contains("text/plain"),
"content-type should be text/plain, got: {ct}"
);
}
#[tokio::test]
async fn metrics_track_request_count() {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
let metrics = Arc::new(InferenceMetrics::new());
let initial_count = metrics.requests_total.get();
let app = create_router_with_metrics(engine, None, Arc::clone(&metrics));
let body = serde_json::json!({
"messages": [{"role": "user", "content": "ping"}],
"max_tokens": 1
});
let req = Request::post("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(serde_json::to_string(&body).expect("serialize")))
.expect("request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
assert!(
metrics.requests_total.get() > initial_count,
"requests_total should have incremented from {initial_count}, got {}",
metrics.requests_total.get()
);
}
#[tokio::test]
async fn metrics_endpoint_shows_incremented_counter_after_request() {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
let metrics = Arc::new(InferenceMetrics::new());
let app = create_router_with_metrics(engine, None, Arc::clone(&metrics));
let _chat_body = serde_json::json!({
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 1
});
metrics.requests_total.inc_by(7);
metrics.tokens_generated_total.inc_by(42);
let req = Request::get("/metrics")
.body(Body::empty())
.expect("request");
let resp = app.oneshot(req).await.expect("response");
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("body");
let text = std::str::from_utf8(&body).expect("utf8");
assert!(
text.contains("oxibonsai_requests_total 7"),
"should see incremented counter in Prometheus output, got:\n{text}"
);
assert!(
text.contains("oxibonsai_tokens_generated_total 42"),
"should see token counter in Prometheus output"
);
}
#[test]
fn create_router_delegates_to_metrics_variant() {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
let _router = create_router(engine, None);
}