#[cfg(feature = "inference")]
mod e2e {
use crate::commands::serve::routes::create_router;
use crate::commands::serve::types::*;
use axum_test::TestServer;
use std::io::Write;
use std::sync::Arc;
use tempfile::NamedTempFile;
fn test_server() -> (TestServer, NamedTempFile) {
let mut file = NamedTempFile::with_suffix(".apr").expect("apr'");
file.write_all(b"fake model data for testing").expect("write_all(b'fake model data fo");
let state =
ServerState::new(file.path().to_path_buf(), ServerConfig::default()).expect("to_path_buf(), ServerConfig::d");
state.set_ready();
let router = create_router(Arc::new(state));
let server = TestServer::new(router).expect("TestServer::new(router)");
(server, file)
}
fn test_server_not_ready() -> (TestServer, NamedTempFile) {
let mut file = NamedTempFile::with_suffix(".apr").expect("apr'");
file.write_all(b"fake model data for testing").expect("write_all(b'fake model data fo");
let state =
ServerState::new(file.path().to_path_buf(), ServerConfig::default()).expect("to_path_buf(), ServerConfig::d");
let router = create_router(Arc::new(state));
let server = TestServer::new(router).expect("TestServer::new(router)");
(server, file)
}
#[tokio::test]
async fn e2e_root_returns_server_info() {
let (server, _f) = test_server();
let resp = server.get("/").await;
resp.assert_status_ok();
let info: ServerInfo = resp.json();
assert_eq!(info.name, "apr-serve");
assert!(!info.version.is_empty());
assert!(!info.model_id.is_empty());
}
#[tokio::test]
async fn e2e_root_version_is_semver() {
let (server, _f) = test_server();
let info: ServerInfo = server.get("/").await.json();
let parts: Vec<&str> = info.version.split('.').collect();
assert!(parts.len() >= 2, "version should be semver: {}", info.version);
assert!(parts[0].parse::<u32>().is_ok());
}
#[tokio::test]
async fn e2e_health_returns_200_when_ready() {
let (server, _f) = test_server();
let resp = server.get("/health").await;
resp.assert_status_ok();
let health: HealthResponse = resp.json();
assert_eq!(health.status, HealthStatus::Healthy);
assert!(!health.model_id.is_empty());
assert!(!health.version.is_empty());
}
#[tokio::test]
async fn e2e_health_returns_503_when_not_ready() {
let (server, _f) = test_server_not_ready();
let resp = server.get("/health").await;
resp.assert_status(axum::http::StatusCode::SERVICE_UNAVAILABLE);
let health: HealthResponse = resp.json();
assert_eq!(health.status, HealthStatus::Unhealthy);
}
#[tokio::test]
async fn e2e_health_includes_uptime() {
let (server, _f) = test_server();
let health: HealthResponse = server.get("/health").await.json();
assert!(health.uptime_seconds < 60, "test uptime should be < 60s");
}
#[tokio::test]
async fn e2e_metrics_returns_prometheus_format() {
let (server, _f) = test_server();
let resp = server.get("/metrics").await;
resp.assert_status_ok();
let body = resp.text();
assert!(body.contains("# HELP apr_requests_total"));
assert!(body.contains("# TYPE apr_requests_total counter"));
assert!(body.contains("apr_uptime_seconds"));
}
#[tokio::test]
async fn e2e_metrics_content_type_is_text() {
let (server, _f) = test_server();
let resp = server.get("/metrics").await;
let ct = resp
.header("content-type")
.to_str()
.expect("to_str(")
.to_string();
assert!(
ct.contains("text/plain"),
"metrics content-type should be text/plain, got: {ct}"
);
}
#[tokio::test]
async fn e2e_predict_valid_request() {
let (server, _f) = test_server();
let resp = server
.post("/predict")
.json(&serde_json::json!({"inputs": {"text": "hello"}}))
.await;
resp.assert_status_ok();
let body: serde_json::Value = resp.json();
assert!(body.get("outputs").is_some());
assert!(body.get("latency_ms").is_some());
}
#[tokio::test]
async fn e2e_predict_missing_inputs_returns_400() {
let (server, _f) = test_server();
let resp = server
.post("/predict")
.json(&serde_json::json!({"text": "no inputs field"}))
.await;
resp.assert_status(axum::http::StatusCode::BAD_REQUEST);
let err: ErrorResponse = resp.json();
assert_eq!(err.error, "missing_field");
}
#[tokio::test]
async fn e2e_predict_invalid_json_returns_400() {
let (server, _f) = test_server();
let resp = server
.post("/predict")
.text("not json")
.await;
resp.assert_status(axum::http::StatusCode::BAD_REQUEST);
let err: ErrorResponse = resp.json();
assert_eq!(err.error, "invalid_json");
}
#[tokio::test]
async fn e2e_generate_non_streaming() {
let (server, _f) = test_server();
let resp = server
.post("/generate")
.json(&serde_json::json!({
"prompt": "Hello world",
"max_tokens": 10,
"stream": false
}))
.await;
resp.assert_status_ok();
let gen: GenerateResponse = resp.json();
assert_eq!(gen.finish_reason, "stop");
}
#[tokio::test]
async fn e2e_generate_streaming_returns_sse() {
let (server, _f) = test_server();
let resp = server
.post("/generate")
.json(&serde_json::json!({
"prompt": "Hello world",
"max_tokens": 10,
"stream": true
}))
.await;
resp.assert_status_ok();
let ct = resp
.header("content-type")
.to_str()
.expect("to_str(")
.to_string();
assert!(
ct.contains("text/event-stream"),
"streaming should return SSE content-type, got: {ct}"
);
}
#[tokio::test]
async fn e2e_generate_empty_prompt_returns_400() {
let (server, _f) = test_server();
let resp = server
.post("/generate")
.json(&serde_json::json!({
"prompt": "",
"max_tokens": 10
}))
.await;
resp.assert_status(axum::http::StatusCode::BAD_REQUEST);
let err: ErrorResponse = resp.json();
assert_eq!(err.error, "empty_prompt");
}
#[tokio::test]
async fn e2e_generate_invalid_json_returns_400() {
let (server, _f) = test_server();
let resp = server
.post("/generate")
.text("garbage")
.await;
resp.assert_status(axum::http::StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn e2e_transcribe_returns_response() {
let (server, _f) = test_server();
let resp = server
.post("/transcribe")
.text("fake audio bytes")
.await;
resp.assert_status_ok();
let tr: TranscribeResponse = resp.json();
assert_eq!(tr.language, "en");
}
#[tokio::test]
async fn e2e_get_predict_returns_405() {
let (server, _f) = test_server();
let resp = server.get("/predict").await;
resp.assert_status(axum::http::StatusCode::METHOD_NOT_ALLOWED);
let err: ErrorResponse = resp.json();
assert_eq!(err.error, "method_not_allowed");
}
#[tokio::test]
async fn e2e_get_generate_returns_405() {
let (server, _f) = test_server();
let resp = server.get("/generate").await;
resp.assert_status(axum::http::StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn e2e_get_transcribe_returns_405() {
let (server, _f) = test_server();
let resp = server.get("/transcribe").await;
resp.assert_status(axum::http::StatusCode::METHOD_NOT_ALLOWED);
}
#[tokio::test]
async fn e2e_unknown_endpoint_returns_404() {
let (server, _f) = test_server();
let resp = server.get("/nonexistent").await;
resp.assert_status(axum::http::StatusCode::NOT_FOUND);
let err: ErrorResponse = resp.json();
assert_eq!(err.error, "not_found");
}
#[tokio::test]
async fn e2e_post_unknown_endpoint_returns_404() {
let (server, _f) = test_server();
let resp = server
.post("/v99/imaginary")
.json(&serde_json::json!({}))
.await;
resp.assert_status(axum::http::StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn e2e_metrics_increment_after_predict() {
let (server, _f) = test_server();
server
.post("/predict")
.json(&serde_json::json!({"inputs": {"text": "test"}}))
.await
.assert_status_ok();
let body = server.get("/metrics").await.text();
assert!(
body.contains("apr_requests_total 1"),
"should have 1 total request, got:\n{body}"
);
assert!(
body.contains("apr_requests_success 1"),
"should have 1 success, got:\n{body}"
);
}
#[tokio::test]
async fn e2e_metrics_increment_client_errors() {
let (server, _f) = test_server();
server
.post("/predict")
.json(&serde_json::json!({"bad": true}))
.await;
let body = server.get("/metrics").await.text();
assert!(
body.contains("apr_requests_client_error 1"),
"should have 1 client error, got:\n{body}"
);
}
#[tokio::test]
async fn e2e_oversized_content_length_rejected() {
let (server, _f) = test_server();
let resp = server
.post("/predict")
.add_header(
axum::http::header::CONTENT_LENGTH,
"20000000",
)
.text("small body")
.await;
resp.assert_status(axum::http::StatusCode::PAYLOAD_TOO_LARGE);
}
}