use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Json, Response};
use axum::Router;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use tokio::sync::Mutex as TokioMutex;
use oxibonsai_rag::embedding::TfIdfEmbedder;
use oxibonsai_rag::pipeline::{RagConfig, RagPipeline};
use crate::engine::InferenceEngine;
const BOOTSTRAP_CORPUS: &[&str] = &[
"The quick brown fox jumps over the lazy dog.",
"Artificial intelligence and machine learning are transforming software.",
"Rust is a systems programming language focused on safety performance and concurrency.",
"Retrieval-augmented generation combines search with language model generation.",
"Vector embeddings represent semantic meaning in high-dimensional space.",
];
const DEFAULT_MAX_FEATURES: usize = 512;
#[derive(Debug, Deserialize)]
pub struct IndexDocumentRequest {
pub documents: Vec<String>,
pub chunk_size: Option<usize>,
pub chunk_overlap: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct IndexDocumentResponse {
pub indexed: usize,
pub chunks: usize,
pub document_ids: Vec<usize>,
}
#[derive(Debug, Deserialize)]
pub struct RagQueryRequest {
pub query: String,
pub max_tokens: Option<usize>,
pub top_k: Option<usize>,
pub temperature: Option<f32>,
pub include_context: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct RagQueryResponse {
pub answer: String,
pub retrieved_chunks: Option<Vec<String>>,
pub prompt_used: String,
pub usage: RagUsage,
}
#[derive(Debug, Serialize)]
pub struct RagUsage {
pub documents_searched: usize,
pub chunks_retrieved: usize,
pub prompt_tokens: usize,
pub completion_tokens: usize,
}
#[derive(Debug, Serialize)]
pub struct RagStatsResponse {
pub documents_indexed: usize,
pub chunks_indexed: usize,
pub embedding_dim: usize,
pub store_memory_bytes: usize,
pub store_memory_human: String,
}
fn error_response(status: StatusCode, message: impl Into<String>) -> Response {
let body = serde_json::json!({ "error": message.into() });
(status, Json(body)).into_response()
}
fn human_bytes(bytes: usize) -> String {
const KB: usize = 1024;
const MB: usize = 1024 * KB;
const GB: usize = 1024 * MB;
if bytes >= GB {
format!("{:.2} GiB", bytes as f64 / GB as f64)
} else if bytes >= MB {
format!("{:.2} MiB", bytes as f64 / MB as f64)
} else if bytes >= KB {
format!("{:.2} KiB", bytes as f64 / KB as f64)
} else {
format!("{bytes} B")
}
}
fn rough_token_count(text: &str) -> usize {
text.split_whitespace().count()
}
pub struct RagState {
pipeline: Mutex<RagPipeline<TfIdfEmbedder>>,
engine: Arc<TokioMutex<InferenceEngine<'static>>>,
}
impl RagState {
pub fn new(engine: Arc<TokioMutex<InferenceEngine<'static>>>) -> Self {
let embedder = TfIdfEmbedder::fit(BOOTSTRAP_CORPUS, DEFAULT_MAX_FEATURES);
let pipeline = RagPipeline::new(embedder, RagConfig::default());
Self {
pipeline: Mutex::new(pipeline),
engine,
}
}
}
pub async fn index_documents(
State(state): State<Arc<RagState>>,
Json(req): Json<IndexDocumentRequest>,
) -> impl IntoResponse {
if req.documents.is_empty() {
return error_response(StatusCode::BAD_REQUEST, "documents list must not be empty");
}
let doc_refs: Vec<&str> = req.documents.iter().map(String::as_str).collect();
let embedder = TfIdfEmbedder::fit(&doc_refs, DEFAULT_MAX_FEATURES);
let mut chunk_config = oxibonsai_rag::chunker::ChunkConfig::default();
if let Some(size) = req.chunk_size {
chunk_config.chunk_size = size;
}
if let Some(overlap) = req.chunk_overlap {
chunk_config.overlap = overlap;
}
let rag_config = RagConfig {
chunk_config,
..Default::default()
};
let mut new_pipeline = RagPipeline::new(embedder, rag_config);
let mut document_ids: Vec<usize> = Vec::with_capacity(req.documents.len());
let mut total_chunks = 0usize;
for (doc_idx, doc) in req.documents.iter().enumerate() {
match new_pipeline.index_document(doc) {
Ok(chunk_count) => {
document_ids.push(doc_idx);
total_chunks += chunk_count;
}
Err(e) => {
return error_response(
StatusCode::BAD_REQUEST,
format!("failed to index document {doc_idx}: {e}"),
);
}
}
}
let indexed = document_ids.len();
match state.pipeline.lock() {
Ok(mut guard) => {
*guard = new_pipeline;
}
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("pipeline lock poisoned: {e}"),
);
}
}
let resp = IndexDocumentResponse {
indexed,
chunks: total_chunks,
document_ids,
};
(StatusCode::OK, Json(resp)).into_response()
}
pub async fn rag_query(
State(state): State<Arc<RagState>>,
Json(req): Json<RagQueryRequest>,
) -> impl IntoResponse {
if req.query.trim().is_empty() {
return error_response(StatusCode::BAD_REQUEST, "query must not be empty");
}
let max_tokens = req.max_tokens.unwrap_or(256);
let top_k = req.top_k.unwrap_or(3);
let include_context = req.include_context.unwrap_or(false);
let (prompt, retrieved_chunks, docs_searched, chunks_retrieved) = {
let pipeline_guard = match state.pipeline.lock() {
Ok(g) => g,
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("pipeline lock poisoned: {e}"),
);
}
};
let stats = pipeline_guard.stats();
let docs_searched = stats.documents_indexed;
let retrieved_texts: Vec<String> = if stats.chunks_indexed == 0 {
Vec::new()
} else {
match pipeline_guard.retriever().retrieve_text(&req.query) {
Ok(texts) => texts.into_iter().take(top_k).collect(),
Err(oxibonsai_rag::RagError::NoDocumentsIndexed) => Vec::new(),
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("retrieval failed: {e}"),
);
}
}
};
let chunks_retrieved = retrieved_texts.len();
let prompt = match pipeline_guard.build_prompt(&req.query) {
Ok(p) => p,
Err(e) => {
return error_response(
StatusCode::BAD_REQUEST,
format!("prompt build failed: {e}"),
);
}
};
(prompt, retrieved_texts, docs_searched, chunks_retrieved)
};
let prompt_tokens_count = rough_token_count(&prompt);
let input_tokens: Vec<u32> = vec![151644];
let output_tokens = {
let mut engine = state.engine.lock().await;
match engine.generate(&input_tokens, max_tokens) {
Ok(tokens) => tokens,
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("generation failed: {e}"),
);
}
}
};
let completion_tokens = output_tokens.len();
let answer = output_tokens
.iter()
.map(|t| t.to_string())
.collect::<Vec<_>>()
.join(" ");
let resp = RagQueryResponse {
answer,
retrieved_chunks: if include_context {
Some(retrieved_chunks)
} else {
None
},
prompt_used: prompt,
usage: RagUsage {
documents_searched: docs_searched,
chunks_retrieved,
prompt_tokens: prompt_tokens_count,
completion_tokens,
},
};
(StatusCode::OK, Json(resp)).into_response()
}
pub async fn rag_stats(State(state): State<Arc<RagState>>) -> impl IntoResponse {
let stats = match state.pipeline.lock() {
Ok(guard) => guard.stats(),
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("pipeline lock poisoned: {e}"),
)
.into_response();
}
};
let resp = RagStatsResponse {
documents_indexed: stats.documents_indexed,
chunks_indexed: stats.chunks_indexed,
embedding_dim: stats.embedding_dim,
store_memory_bytes: stats.store_memory_bytes,
store_memory_human: human_bytes(stats.store_memory_bytes),
};
(StatusCode::OK, Json(resp)).into_response()
}
pub async fn clear_index(State(state): State<Arc<RagState>>) -> impl IntoResponse {
let embedder = TfIdfEmbedder::fit(BOOTSTRAP_CORPUS, DEFAULT_MAX_FEATURES);
let fresh_pipeline = RagPipeline::new(embedder, RagConfig::default());
match state.pipeline.lock() {
Ok(mut guard) => {
*guard = fresh_pipeline;
}
Err(e) => {
return error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("pipeline lock poisoned: {e}"),
);
}
}
let body = serde_json::json!({ "status": "cleared" });
(StatusCode::OK, Json(body)).into_response()
}
pub fn create_rag_router(engine: InferenceEngine<'static>) -> Router {
let engine_arc = Arc::new(TokioMutex::new(engine));
let state = Arc::new(RagState::new(engine_arc));
Router::new()
.route("/rag/index", axum::routing::post(index_documents))
.route("/rag/index", axum::routing::delete(clear_index))
.route("/rag/query", axum::routing::post(rag_query))
.route("/rag/stats", axum::routing::get(rag_stats))
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn human_bytes_formatting() {
assert_eq!(human_bytes(0), "0 B");
assert_eq!(human_bytes(512), "512 B");
assert_eq!(human_bytes(1024), "1.00 KiB");
assert_eq!(human_bytes(1024 * 1024), "1.00 MiB");
assert_eq!(human_bytes(1024 * 1024 * 1024), "1.00 GiB");
}
#[test]
fn rough_token_count_basic() {
assert_eq!(rough_token_count(""), 0);
assert_eq!(rough_token_count("one two three"), 3);
assert_eq!(rough_token_count(" spaces everywhere "), 2);
}
#[test]
fn rag_state_creates_without_panic() {
use crate::sampling::SamplingParams;
use oxibonsai_core::config::Qwen3Config;
let config = Qwen3Config::tiny_test();
let engine = InferenceEngine::new(config, SamplingParams::default(), 42);
let engine_arc = Arc::new(TokioMutex::new(engine));
let _state = RagState::new(engine_arc);
}
}