#[cfg(feature = "rag")]
mod tests {
use axum::body::Body;
use axum::http::{Method, Request, StatusCode};
use tower::ServiceExt;
use oxibonsai_core::config::Qwen3Config;
use oxibonsai_runtime::engine::InferenceEngine;
use oxibonsai_runtime::rag_server::create_rag_router;
use oxibonsai_runtime::sampling::SamplingParams;
fn test_rag_router() -> axum::Router {
let config = Qwen3Config::tiny_test();
let params = SamplingParams::default();
let engine = InferenceEngine::new(config, params, 42);
create_rag_router(engine)
}
fn json_body(value: serde_json::Value) -> Body {
Body::from(serde_json::to_string(&value).expect("serialize"))
}
fn json_request(method: Method, path: &str, body: serde_json::Value) -> Request<Body> {
Request::builder()
.method(method)
.uri(path)
.header("content-type", "application/json")
.body(json_body(body))
.expect("build request")
}
async fn body_json(resp: axum::response::Response) -> serde_json::Value {
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.expect("read body");
serde_json::from_slice(&bytes).expect("parse JSON")
}
#[tokio::test]
async fn test_index_documents_returns_200() {
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/index",
serde_json::json!({
"documents": [
"Rust is a systems programming language emphasising memory safety.",
"Axum is an ergonomic and modular web framework built on Tokio.",
"RAG combines retrieval with language model generation."
]
}),
);
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK, "expected 200 OK");
let json = body_json(resp).await;
assert_eq!(
json["indexed"], 3,
"three documents should be indexed; got {json}"
);
assert!(
json["chunks"].as_u64().unwrap_or(0) >= 1,
"should produce at least one chunk; got {json}"
);
let ids = json["document_ids"].as_array().expect("document_ids array");
assert_eq!(ids.len(), 3, "should have one id per document");
}
#[tokio::test]
async fn test_index_empty_documents_handled() {
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/index",
serde_json::json!({ "documents": [] }),
);
let resp = app.oneshot(req).await.expect("response");
assert_eq!(
resp.status(),
StatusCode::BAD_REQUEST,
"empty documents list should return 400"
);
let json = body_json(resp).await;
assert!(
json["error"].is_string(),
"response should contain an error field; got {json}"
);
}
#[tokio::test]
async fn test_rag_stats_returns_json() {
let app = test_rag_router();
let req = Request::get("/rag/stats")
.body(Body::empty())
.expect("build request");
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
let json = body_json(resp).await;
assert!(
json["documents_indexed"].is_number(),
"stats must contain documents_indexed; got {json}"
);
assert!(
json["chunks_indexed"].is_number(),
"stats must contain chunks_indexed; got {json}"
);
assert!(
json["embedding_dim"].is_number(),
"stats must contain embedding_dim; got {json}"
);
assert!(
json["store_memory_bytes"].is_number(),
"stats must contain store_memory_bytes; got {json}"
);
assert!(
json["store_memory_human"].is_string(),
"stats must contain store_memory_human; got {json}"
);
}
#[tokio::test]
async fn test_rag_query_returns_answer() {
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/query",
serde_json::json!({
"query": "What is Rust?",
"max_tokens": 5
}),
);
let resp = app.oneshot(req).await.expect("response");
assert_eq!(
resp.status(),
StatusCode::OK,
"rag query should return 200 even with empty index"
);
let json = body_json(resp).await;
assert!(
json["answer"].is_string(),
"response must have answer field; got {json}"
);
assert!(
json["prompt_used"].is_string(),
"response must have prompt_used; got {json}"
);
assert!(
json["usage"].is_object(),
"response must have usage; got {json}"
);
let usage = &json["usage"];
assert!(usage["completion_tokens"].is_number());
assert!(usage["prompt_tokens"].is_number());
assert!(usage["chunks_retrieved"].is_number());
assert!(usage["documents_searched"].is_number());
}
#[tokio::test]
async fn test_clear_index_resets_stats() {
let config = Qwen3Config::tiny_test();
let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
let router = create_rag_router(engine);
let index_req = json_request(
Method::POST,
"/rag/index",
serde_json::json!({
"documents": [
"Rust memory safety ensures freedom from data races.",
"Axum is built on top of Hyper and Tokio."
]
}),
);
let resp = router
.clone()
.oneshot(index_req)
.await
.expect("index response");
assert_eq!(resp.status(), StatusCode::OK, "indexing should succeed");
let stats_resp = router
.clone()
.oneshot(Request::get("/rag/stats").body(Body::empty()).expect("req"))
.await
.expect("stats response");
let stats = body_json(stats_resp).await;
assert_eq!(
stats["documents_indexed"], 2,
"after indexing 2 docs, stats should show 2; got {stats}"
);
let clear_req = Request::builder()
.method(Method::DELETE)
.uri("/rag/index")
.body(Body::empty())
.expect("delete request");
let clear_resp = router
.clone()
.oneshot(clear_req)
.await
.expect("clear response");
assert_eq!(clear_resp.status(), StatusCode::OK);
let clear_json = body_json(clear_resp).await;
assert_eq!(
clear_json["status"], "cleared",
"clear response should contain status:cleared; got {clear_json}"
);
let stats_after_resp = router
.clone()
.oneshot(Request::get("/rag/stats").body(Body::empty()).expect("req"))
.await
.expect("stats after clear");
let stats_after = body_json(stats_after_resp).await;
assert_eq!(
stats_after["documents_indexed"], 0,
"after clear, documents_indexed should be 0; got {stats_after}"
);
assert_eq!(
stats_after["chunks_indexed"], 0,
"after clear, chunks_indexed should be 0; got {stats_after}"
);
}
#[tokio::test]
async fn test_rag_query_includes_context_when_requested() {
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/query",
serde_json::json!({
"query": "Tell me about memory safety.",
"max_tokens": 3,
"include_context": true
}),
);
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
let json = body_json(resp).await;
assert!(
json["retrieved_chunks"].is_array(),
"retrieved_chunks should be an array when include_context=true; got {json}"
);
}
#[tokio::test]
async fn test_rag_query_omits_context_by_default() {
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/query",
serde_json::json!({
"query": "What is Axum?",
"max_tokens": 3
}),
);
let resp = app.oneshot(req).await.expect("response");
assert_eq!(resp.status(), StatusCode::OK);
let json = body_json(resp).await;
assert!(
json["retrieved_chunks"].is_null(),
"retrieved_chunks should be absent when include_context is not set; got {json}"
);
}
#[tokio::test]
async fn test_rag_router_wires_all_routes() {
{
let app = test_rag_router();
let req = Request::get("/rag/stats").body(Body::empty()).expect("req");
let resp = app.oneshot(req).await.expect("resp");
assert_ne!(
resp.status(),
StatusCode::NOT_FOUND,
"GET /rag/stats should be registered"
);
}
{
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/index",
serde_json::json!({ "documents": ["hello world this is a test document"] }),
);
let resp = app.oneshot(req).await.expect("resp");
assert_ne!(
resp.status(),
StatusCode::NOT_FOUND,
"POST /rag/index should be registered"
);
}
{
let app = test_rag_router();
let req = Request::builder()
.method(Method::DELETE)
.uri("/rag/index")
.body(Body::empty())
.expect("req");
let resp = app.oneshot(req).await.expect("resp");
assert_ne!(
resp.status(),
StatusCode::NOT_FOUND,
"DELETE /rag/index should be registered"
);
assert_eq!(
resp.status(),
StatusCode::OK,
"DELETE /rag/index should return 200"
);
}
{
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/query",
serde_json::json!({ "query": "test query", "max_tokens": 1 }),
);
let resp = app.oneshot(req).await.expect("resp");
assert_ne!(
resp.status(),
StatusCode::NOT_FOUND,
"POST /rag/query should be registered"
);
}
}
#[tokio::test]
async fn test_rag_query_rejects_empty_query() {
let app = test_rag_router();
let req = json_request(
Method::POST,
"/rag/query",
serde_json::json!({ "query": " " }),
);
let resp = app.oneshot(req).await.expect("response");
assert_eq!(
resp.status(),
StatusCode::BAD_REQUEST,
"blank query should return 400"
);
let json = body_json(resp).await;
assert!(
json["error"].is_string(),
"error field should be present; got {json}"
);
}
}