use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::{Path, State};
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::Json;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use stolas::{
ContextItem, Document, Embedder, InMemoryStore, MockEmbedder, RagPipeline, RetrievalConfig,
VectorStore,
};
use crate::error_response::{api_error, ErrorCode};
pub struct RagState {
pub pipeline: Option<RagPipeline>,
pub documents: HashMap<String, DocumentMeta>,
pub initialized: bool,
}
impl Default for RagState {
fn default() -> Self {
Self::new()
}
}
impl RagState {
pub fn new() -> Self {
Self {
pipeline: None,
documents: HashMap::new(),
initialized: false,
}
}
pub fn initialize(&mut self) {
if self.initialized {
return;
}
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder::new(384));
let store: Arc<dyn VectorStore> = Arc::new(InMemoryStore::new());
let config = RetrievalConfig::default();
self.pipeline = Some(RagPipeline::new(embedder, store, config));
self.initialized = true;
}
pub fn pipeline(&self) -> Option<&RagPipeline> {
self.pipeline.as_ref()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentMeta {
pub id: String,
pub name: String,
pub chunk_count: usize,
pub indexed_at: u64,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RagHealthResponse {
pub document_count: usize,
pub chunk_count: usize,
pub embedding_model: Option<String>,
pub last_updated: Option<String>,
pub initialized: bool,
}
#[derive(Debug, Deserialize)]
pub struct IndexDocumentRequest {
pub name: String,
pub content: String,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Serialize)]
pub struct DocumentListResponse {
pub documents: Vec<DocumentMeta>,
}
#[derive(Debug, Serialize)]
pub struct DocumentCountResponse {
pub count: usize,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SearchRequest {
pub query: String,
#[serde(default = "default_top_k")]
pub top_k: usize,
#[serde(default)]
pub min_score: Option<f32>,
#[serde(default)]
pub rerank: bool,
}
fn default_top_k() -> usize {
5
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct SearchResultItem {
pub content: String,
pub source_id: String,
pub chunk_index: usize,
pub score: f32,
#[serde(default)]
pub metadata: HashMap<String, serde_json::Value>,
}
impl From<ContextItem> for SearchResultItem {
fn from(item: ContextItem) -> Self {
Self {
content: item.content,
source_id: item.source_id,
chunk_index: item.chunk_index,
score: item.score,
metadata: item.metadata,
}
}
}
#[derive(Debug, Serialize)]
pub struct SearchResponse {
pub results: Vec<SearchResultItem>,
pub total: usize,
}
#[derive(Debug, Serialize)]
pub struct DeleteResponse {
pub deleted: usize,
}
pub async fn rag_health(State(rag): State<Arc<RwLock<RagState>>>) -> impl IntoResponse {
let state = rag.read().await;
let (document_count, chunk_count) = if state.initialized {
(
state.documents.len(),
state.documents.values().map(|d| d.chunk_count).sum(),
)
} else {
(0, 0)
};
let last_updated = state
.documents
.values()
.map(|d| d.indexed_at)
.max()
.map(|ts| {
chrono::DateTime::from_timestamp_millis(ts as i64)
.map(|dt| dt.to_rfc3339())
.unwrap_or_default()
});
Json(RagHealthResponse {
document_count,
chunk_count,
embedding_model: if state.initialized {
Some("mock-embedder-384".to_string())
} else {
None
},
last_updated,
initialized: state.initialized,
})
}
pub async fn list_documents(State(rag): State<Arc<RwLock<RagState>>>) -> impl IntoResponse {
let state = rag.read().await;
if !state.initialized {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(api_error(
ErrorCode::ServiceUnavailable,
"RAG not initialized",
)),
)
.into_response();
}
let documents: Vec<DocumentMeta> = state.documents.values().cloned().collect();
Json(DocumentListResponse { documents }).into_response()
}
pub async fn document_count(State(rag): State<Arc<RwLock<RagState>>>) -> impl IntoResponse {
let state = rag.read().await;
Json(DocumentCountResponse {
count: state.documents.len(),
})
}
pub async fn index_document(
State(rag): State<Arc<RwLock<RagState>>>,
Json(req): Json<IndexDocumentRequest>,
) -> impl IntoResponse {
if req.name.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(api_error(
ErrorCode::InvalidRequest,
"Document name is required",
)),
)
.into_response();
}
if req.content.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(api_error(
ErrorCode::InvalidRequest,
"Document content is required",
)),
)
.into_response();
}
let mut state = rag.write().await;
if !state.initialized {
state.initialize();
}
let pipeline = match state.pipeline.as_ref() {
Some(p) => p,
None => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(api_error(
ErrorCode::ServiceUnavailable,
"RAG pipeline not available",
)),
)
.into_response();
},
};
let doc_id = format!("doc_{}", uuid::Uuid::new_v4().simple());
let mut doc = Document::new(&doc_id, &req.content);
for (key, value) in &req.metadata {
doc = doc.with_metadata(key, value.clone());
}
doc = doc.with_metadata("name", serde_json::json!(req.name));
let chunk_count = match pipeline.ingest(doc).await {
Ok(count) => count,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(api_error(
ErrorCode::InternalError,
&format!("Failed to index document: {}", e),
)),
)
.into_response();
},
};
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let meta = DocumentMeta {
id: doc_id.clone(),
name: req.name,
chunk_count,
indexed_at: now,
metadata: req.metadata,
};
state.documents.insert(doc_id, meta.clone());
(StatusCode::CREATED, Json(meta)).into_response()
}
pub async fn delete_document(
State(rag): State<Arc<RwLock<RagState>>>,
Path(doc_id): Path<String>,
) -> impl IntoResponse {
let mut state = rag.write().await;
if !state.initialized {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(api_error(
ErrorCode::ServiceUnavailable,
"RAG not initialized",
)),
)
.into_response();
}
let meta = match state.documents.remove(&doc_id) {
Some(m) => m,
None => {
return (
StatusCode::NOT_FOUND,
Json(api_error(ErrorCode::NotFound, "Document not found")),
)
.into_response();
},
};
Json(DeleteResponse {
deleted: meta.chunk_count,
})
.into_response()
}
pub async fn search(
State(rag): State<Arc<RwLock<RagState>>>,
Json(req): Json<SearchRequest>,
) -> impl IntoResponse {
if req.query.is_empty() {
return (
StatusCode::BAD_REQUEST,
Json(api_error(ErrorCode::InvalidRequest, "Query is required")),
)
.into_response();
}
let state = rag.read().await;
if !state.initialized {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(api_error(
ErrorCode::ServiceUnavailable,
"RAG not initialized",
)),
)
.into_response();
}
let pipeline = match state.pipeline.as_ref() {
Some(p) => p,
None => {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(api_error(
ErrorCode::ServiceUnavailable,
"RAG pipeline not available",
)),
)
.into_response();
},
};
let results = match pipeline.retrieve(&req.query).await {
Ok(items) => items,
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(api_error(
ErrorCode::InternalError,
&format!("Search failed: {}", e),
)),
)
.into_response();
},
};
let results: Vec<SearchResultItem> = results
.into_iter()
.take(req.top_k)
.filter(|r| req.min_score.map(|m| r.score >= m).unwrap_or(true))
.map(SearchResultItem::from)
.collect();
let total = results.len();
Json(SearchResponse { results, total }).into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rag_state_new() {
let state = RagState::new();
assert!(!state.initialized);
assert!(state.pipeline.is_none());
assert!(state.documents.is_empty());
}
#[test]
fn test_rag_state_initialize() {
let mut state = RagState::new();
state.initialize();
assert!(state.initialized);
assert!(state.pipeline.is_some());
}
#[test]
fn test_rag_state_double_initialize() {
let mut state = RagState::new();
state.initialize();
state.initialize(); assert!(state.initialized);
}
#[test]
fn test_document_meta_serialization() {
let meta = DocumentMeta {
id: "doc_123".to_string(),
name: "test.txt".to_string(),
chunk_count: 5,
indexed_at: 1234567890000,
metadata: HashMap::new(),
};
let json = serde_json::to_string(&meta).unwrap();
assert!(json.contains("doc_123"));
assert!(json.contains("test.txt"));
}
#[test]
fn test_search_request_defaults() {
let json = r#"{"query": "test"}"#;
let req: SearchRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.query, "test");
assert_eq!(req.top_k, 5);
assert!(!req.rerank);
}
#[test]
fn test_context_item_to_search_result() {
let item = ContextItem {
content: "test content".to_string(),
source_id: "doc_1".to_string(),
chunk_index: 2,
score: 0.85,
metadata: HashMap::new(),
};
let result: SearchResultItem = item.into();
assert_eq!(result.content, "test content");
assert_eq!(result.source_id, "doc_1");
assert_eq!(result.chunk_index, 2);
assert!((result.score - 0.85).abs() < 0.001);
}
}