use crate::{
thinktool::{LlmClient, LlmRequest, UnifiedLlmClient},
Document, Error, Result, RetrievalConfig,
};
pub mod chunking;
pub mod hyde;
pub mod performance;
use crate::rag::hyde::HyDEExpander;
#[cfg(feature = "memory")]
use reasonkit_mem::{
indexing::IndexManager,
retrieval::{HybridResult, HybridRetriever, RetrievalStats},
storage::Storage,
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagConfig {
pub top_k: usize,
pub min_score: f32,
pub max_context_tokens: usize,
pub include_sources: bool,
pub system_prompt: String,
pub sparse_only: bool,
pub hybrid_alpha: f32,
pub hyde_enabled: bool,
}
impl Default for RagConfig {
fn default() -> Self {
Self {
top_k: 5,
min_score: 0.1,
max_context_tokens: 2000,
include_sources: true,
system_prompt: DEFAULT_RAG_PROMPT.to_string(),
sparse_only: true, hybrid_alpha: 0.3,
hyde_enabled: false,
}
}
}
impl RagConfig {
pub fn quick() -> Self {
Self {
top_k: 3,
min_score: 0.2,
max_context_tokens: 1000,
include_sources: false,
sparse_only: true,
..Default::default()
}
}
pub fn thorough() -> Self {
Self {
top_k: 10,
min_score: 0.05,
max_context_tokens: 4000,
include_sources: true,
sparse_only: false,
hybrid_alpha: 0.5,
hyde_enabled: true,
..Default::default()
}
}
}
const DEFAULT_RAG_PROMPT: &str = r#"You are a helpful assistant answering questions based on the provided context.
INSTRUCTIONS:
1. Answer the question using ONLY the provided context
2. If the context doesn't contain the answer, say "I don't have enough information to answer this"
3. Be concise but comprehensive
4. When citing information, reference the source section
CONTEXT:
{context}
Answer the question based on the context above. Be accurate and cite your sources."#;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagResponse {
pub answer: String,
pub sources: Vec<RagSource>,
pub retrieval_stats: RagRetrievalStats,
pub tokens_used: Option<u32>,
pub query: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RagSource {
pub doc_id: Uuid,
pub chunk_id: Uuid,
pub text: String,
pub score: f32,
pub section: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RagRetrievalStats {
pub chunks_retrieved: usize,
pub chunks_used: usize,
pub context_tokens: usize,
pub retrieval_time_ms: u64,
}
pub struct RagEngine {
retriever: HybridRetriever,
llm_client: Option<UnifiedLlmClient>,
config: RagConfig,
}
impl RagEngine {
pub fn in_memory() -> Result<Self> {
Ok(Self {
retriever: HybridRetriever::in_memory()?,
llm_client: None,
config: RagConfig::default(),
})
}
pub async fn persistent(base_path: PathBuf) -> Result<Self> {
let storage_path = base_path.join("storage");
let index_path = base_path.join("index");
std::fs::create_dir_all(&storage_path)
.map_err(|e| Error::io(format!("Failed to create storage dir: {}", e)))?;
std::fs::create_dir_all(&index_path)
.map_err(|e| Error::io(format!("Failed to create index dir: {}", e)))?;
let storage = Storage::file(storage_path).await?;
let index = IndexManager::open(index_path)?;
Ok(Self {
retriever: HybridRetriever::new(storage, index),
llm_client: None,
config: RagConfig::default(),
})
}
pub fn with_llm(mut self, client: UnifiedLlmClient) -> Self {
self.llm_client = Some(client);
self
}
pub fn with_config(mut self, config: RagConfig) -> Self {
self.config = config;
self
}
pub async fn add_document(&self, doc: &Document) -> Result<()> {
let mem_doc: reasonkit_mem::Document = doc.clone().into();
self.retriever.add_document(&mem_doc).await?;
Ok(())
}
pub async fn add_documents(&self, docs: &[Document]) -> Result<usize> {
let mut count = 0;
for doc in docs {
let mem_doc: reasonkit_mem::Document = doc.clone().into();
self.retriever.add_document(&mem_doc).await?;
count += 1;
}
Ok(count)
}
pub async fn query(&self, query: &str) -> Result<RagResponse> {
let effective_query = if self.config.hyde_enabled {
if let Some(ref client) = self.llm_client {
let expander = HyDEExpander::new(client.clone());
expander.expand_query(query).await?
} else {
query.to_string()
}
} else {
query.to_string()
};
let start = std::time::Instant::now();
let results = if self.config.sparse_only {
self.retriever
.search_sparse(&effective_query, self.config.top_k)
.await?
} else {
let retrieval_config = RetrievalConfig {
top_k: self.config.top_k,
alpha: self.config.hybrid_alpha,
..Default::default()
};
self.retriever
.search_hybrid(&effective_query, None, &retrieval_config)
.await?
};
let retrieval_time_ms = start.elapsed().as_millis() as u64;
let filtered_results: Vec<_> = results
.into_iter()
.filter(|r| r.score >= self.config.min_score)
.collect();
let (context, context_tokens) = self.build_context(&filtered_results);
let sources: Vec<RagSource> = filtered_results
.iter()
.map(|r| RagSource {
doc_id: r.doc_id,
chunk_id: r.chunk_id,
text: truncate_text(&r.text, 200),
score: r.score,
section: None, })
.collect();
let retrieval_stats = RagRetrievalStats {
chunks_retrieved: self.config.top_k,
chunks_used: filtered_results.len(),
context_tokens,
retrieval_time_ms,
};
let (answer, tokens_used) = if let Some(ref client) = self.llm_client {
let system_prompt = self.config.system_prompt.replace("{context}", &context);
let request = LlmRequest::new(query)
.with_system(&system_prompt)
.with_max_tokens(1000);
let response = client
.complete(request)
.await
.map_err(|e| Error::network(format!("LLM generation failed: {}", e)))?;
let tokens = Some(response.usage.total_tokens);
(response.content, tokens)
} else {
let answer = format!(
"Retrieved {} relevant chunks for query: \"{}\"\n\nTop results:\n{}",
filtered_results.len(),
query,
filtered_results
.iter()
.take(3)
.enumerate()
.map(|(i, r)| format!(
"{}. [score: {:.3}] {}",
i + 1,
r.score,
truncate_text(&r.text, 150)
))
.collect::<Vec<_>>()
.join("\n")
);
(answer, None)
};
Ok(RagResponse {
answer,
sources,
retrieval_stats,
tokens_used,
query: query.to_string(),
})
}
pub async fn retrieve(&self, query: &str, top_k: usize) -> Result<Vec<HybridResult>> {
self.retriever
.search_sparse(query, top_k)
.await
.map_err(Error::from)
}
pub async fn stats(&self) -> Result<RetrievalStats> {
self.retriever.stats().await.map_err(Error::from)
}
pub async fn delete_document(&self, doc_id: &Uuid) -> Result<()> {
self.retriever
.delete_document(doc_id)
.await
.map_err(Error::from)
}
fn build_context(&self, results: &[HybridResult]) -> (String, usize) {
let mut context_parts = Vec::new();
let mut total_tokens = 0;
for (i, result) in results.iter().enumerate() {
let chunk_tokens = result.text.len() / 4;
if total_tokens + chunk_tokens > self.config.max_context_tokens {
break;
}
context_parts.push(format!(
"[Source {}] (relevance: {:.2})\n{}",
i + 1,
result.score,
result.text
));
total_tokens += chunk_tokens;
}
(context_parts.join("\n\n---\n\n"), total_tokens)
}
}
fn truncate_text(text: &str, max_len: usize) -> String {
if text.len() <= max_len {
text.to_string()
} else {
format!("{}...", &text[..max_len.saturating_sub(3)])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Chunk, DocumentType, EmbeddingIds, Source, SourceType};
use chrono::Utc;
fn create_test_document(content: &str, title: &str) -> Document {
let source = Source {
source_type: SourceType::Local,
url: None,
path: Some(format!("/test/{}.md", title)),
arxiv_id: None,
github_repo: None,
retrieved_at: Utc::now(),
version: None,
};
let mut doc = Document::new(DocumentType::Note, source).with_content(content.to_string());
doc.chunks = vec![Chunk {
id: Uuid::new_v4(),
text: content.to_string(),
index: 0,
start_char: 0,
end_char: content.len(),
token_count: Some(content.len() / 4),
section: Some(title.to_string()),
page: None,
embedding_ids: EmbeddingIds::default(),
}];
doc
}
fn create_multi_chunk_document(chunks: &[&str], title: &str) -> Document {
let source = Source {
source_type: SourceType::Local,
url: None,
path: Some(format!("/test/{}.md", title)),
arxiv_id: None,
github_repo: None,
retrieved_at: Utc::now(),
version: None,
};
let full_content = chunks.join("\n\n");
let mut doc = Document::new(DocumentType::Note, source).with_content(full_content.clone());
let mut char_offset = 0;
doc.chunks = chunks
.iter()
.enumerate()
.map(|(i, text)| {
let chunk = Chunk {
id: Uuid::new_v4(),
text: text.to_string(),
index: i,
start_char: char_offset,
end_char: char_offset + text.len(),
token_count: Some(text.len() / 4),
section: Some(format!("Section {}", i + 1)),
page: Some(i / 2 + 1),
embedding_ids: EmbeddingIds::default(),
};
char_offset += text.len() + 2; chunk
})
.collect();
doc
}
#[test]
fn test_rag_config_default() {
let config = RagConfig::default();
assert_eq!(config.top_k, 5);
assert_eq!(config.min_score, 0.1);
assert_eq!(config.max_context_tokens, 2000);
assert!(config.include_sources);
assert!(config.sparse_only);
assert_eq!(config.hybrid_alpha, 0.3);
assert!(config.system_prompt.contains("CONTEXT"));
}
#[test]
fn test_rag_config_quick() {
let config = RagConfig::quick();
assert_eq!(config.top_k, 3);
assert_eq!(config.min_score, 0.2);
assert_eq!(config.max_context_tokens, 1000);
assert!(!config.include_sources);
assert!(config.sparse_only);
}
#[test]
fn test_rag_config_thorough() {
let config = RagConfig::thorough();
assert_eq!(config.top_k, 10);
assert_eq!(config.min_score, 0.05);
assert_eq!(config.max_context_tokens, 4000);
assert!(config.include_sources);
assert!(!config.sparse_only);
assert_eq!(config.hybrid_alpha, 0.5);
}
#[test]
fn test_rag_config_serialization() {
let config = RagConfig::default();
let json = serde_json::to_string(&config).expect("Serialization failed");
let deserialized: RagConfig = serde_json::from_str(&json).expect("Deserialization failed");
assert_eq!(config.top_k, deserialized.top_k);
assert_eq!(config.min_score, deserialized.min_score);
assert_eq!(config.max_context_tokens, deserialized.max_context_tokens);
}
#[test]
fn test_truncate_text_short() {
let text = "Short text";
let result = truncate_text(text, 50);
assert_eq!(result, "Short text");
}
#[test]
fn test_truncate_text_exact_length() {
let text = "Exactly ten";
let result = truncate_text(text, 11);
assert_eq!(result, "Exactly ten");
}
#[test]
fn test_truncate_text_long() {
let text = "This is a very long text that needs to be truncated";
let result = truncate_text(text, 20);
assert_eq!(result.len(), 20);
assert!(result.ends_with("..."));
assert_eq!(result, "This is a very lo...");
}
#[test]
fn test_truncate_text_empty() {
let text = "";
let result = truncate_text(text, 10);
assert_eq!(result, "");
}
#[test]
fn test_truncate_text_zero_max() {
let text = "Some text";
let result = truncate_text(text, 0);
assert_eq!(result, "...");
}
#[test]
fn test_truncate_text_very_small_max() {
let text = "Hello world";
let result = truncate_text(text, 3);
assert_eq!(result, "...");
}
#[test]
fn test_single_chunk_document() {
let doc = create_test_document("Simple content", "simple");
assert_eq!(doc.chunks.len(), 1);
assert_eq!(doc.chunks[0].text, "Simple content");
assert_eq!(doc.chunks[0].index, 0);
assert_eq!(doc.chunks[0].start_char, 0);
assert_eq!(doc.chunks[0].end_char, 14);
}
#[test]
fn test_multi_chunk_document() {
let chunks = [
"First paragraph about machine learning.",
"Second paragraph about neural networks.",
"Third paragraph about deep learning.",
];
let doc = create_multi_chunk_document(&chunks, "ml-doc");
assert_eq!(doc.chunks.len(), 3);
for (i, chunk) in doc.chunks.iter().enumerate() {
assert_eq!(chunk.index, i);
assert!(chunk
.section
.as_ref()
.unwrap()
.contains(&format!("{}", i + 1)));
}
let mut prev_end = 0;
for chunk in &doc.chunks {
assert!(chunk.start_char >= prev_end);
assert!(chunk.end_char > chunk.start_char);
prev_end = chunk.end_char;
}
}
#[test]
fn test_chunk_token_count_estimation() {
let content = "This is exactly twenty characters."; let doc = create_test_document(content, "token-test");
let expected_tokens = content.len() / 4;
assert_eq!(doc.chunks[0].token_count, Some(expected_tokens));
}
#[test]
fn test_chunk_page_assignment() {
let chunks = [
"Page 1 content A",
"Page 1 content B",
"Page 2 content A",
"Page 2 content B",
];
let doc = create_multi_chunk_document(&chunks, "paged-doc");
assert_eq!(doc.chunks[0].page, Some(1));
assert_eq!(doc.chunks[1].page, Some(1));
assert_eq!(doc.chunks[2].page, Some(2));
assert_eq!(doc.chunks[3].page, Some(2));
}
#[tokio::test]
async fn test_bm25_search_basic() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document(
"Machine learning algorithms process data to make predictions.",
"ml-basics",
);
engine.add_document(&doc).await.expect("Failed to add doc");
let results = engine
.retrieve("machine learning predictions", 5)
.await
.expect("Retrieval failed");
assert!(!results.is_empty());
assert!(results[0].text.contains("Machine learning"));
}
#[tokio::test]
async fn test_bm25_search_multiple_documents() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let docs = vec![
create_test_document(
"Python is a popular programming language for data science.",
"python",
),
create_test_document(
"Rust provides memory safety without garbage collection.",
"rust",
),
create_test_document("JavaScript runs in web browsers and Node.js.", "javascript"),
];
for doc in &docs {
engine.add_document(doc).await.expect("Failed to add doc");
}
let results = engine
.retrieve("memory safety rust", 5)
.await
.expect("Retrieval failed");
assert!(!results.is_empty());
assert!(results[0].text.contains("Rust"));
}
#[tokio::test]
async fn test_bm25_search_no_match() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Cats and dogs are common pets.", "pets");
engine.add_document(&doc).await.expect("Failed to add doc");
let results = engine
.retrieve("quantum physics relativity", 5)
.await
.expect("Retrieval failed");
if !results.is_empty() {
assert!(results[0].score < 5.0);
}
}
#[tokio::test]
async fn test_bm25_search_ranking() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc1 = create_test_document(
"Neural networks and deep learning are subsets of machine learning.",
"high-relevance",
);
let doc2 = create_test_document(
"The weather today is sunny with clear skies.",
"low-relevance",
);
let doc3 = create_test_document(
"Machine learning uses algorithms to learn from data.",
"medium-relevance",
);
engine.add_document(&doc1).await.unwrap();
engine.add_document(&doc2).await.unwrap();
engine.add_document(&doc3).await.unwrap();
let results = engine
.retrieve("machine learning neural networks", 5)
.await
.expect("Retrieval failed");
assert!(results.len() >= 2);
assert!(results[0].score >= results[results.len() - 1].score);
}
#[tokio::test]
async fn test_rag_engine_basic() {
let engine = RagEngine::in_memory().expect("Failed to create in-memory engine");
let doc1 = create_test_document(
"Chain-of-thought prompting enables complex reasoning by breaking problems into steps.",
"cot-basics",
);
let doc2 = create_test_document(
"Self-consistency improves reasoning by sampling multiple paths and selecting the most common answer.",
"self-consistency",
);
engine
.add_document(&doc1)
.await
.expect("Failed to add doc1");
engine
.add_document(&doc2)
.await
.expect("Failed to add doc2");
let response = engine
.query("How does chain of thought work?")
.await
.expect("Query failed");
assert!(!response.sources.is_empty());
assert!(response.answer.contains("Retrieved"));
assert!(response.retrieval_stats.chunks_used > 0);
}
#[tokio::test]
async fn test_rag_engine_with_custom_config() {
let config = RagConfig {
top_k: 2,
min_score: 0.0,
max_context_tokens: 500,
include_sources: true,
sparse_only: true,
..Default::default()
};
let engine = RagEngine::in_memory()
.expect("Failed to create engine")
.with_config(config);
let doc = create_test_document("Test content for RAG engine.", "test");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("test content").await.expect("Query failed");
assert!(response.retrieval_stats.chunks_retrieved <= 2);
}
#[tokio::test]
async fn test_rag_engine_add_multiple_documents() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let docs = vec![
create_test_document("Document one content.", "doc1"),
create_test_document("Document two content.", "doc2"),
create_test_document("Document three content.", "doc3"),
];
let count = engine
.add_documents(&docs)
.await
.expect("Failed to add docs");
assert_eq!(count, 3);
let stats = engine.stats().await.expect("Failed to get stats");
assert_eq!(stats.document_count, 3);
}
#[tokio::test]
async fn test_rag_engine_stats() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let stats = engine.stats().await.expect("Failed to get stats");
assert_eq!(stats.document_count, 0);
let doc = create_multi_chunk_document(&["Chunk 1", "Chunk 2", "Chunk 3"], "multi-chunk");
engine.add_document(&doc).await.expect("Failed to add doc");
let stats = engine.stats().await.expect("Failed to get stats");
assert_eq!(stats.document_count, 1);
assert_eq!(stats.chunk_count, 3);
}
#[tokio::test]
async fn test_rag_engine_delete_document() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Content to delete.", "delete-me");
let doc_id = doc.id;
engine.add_document(&doc).await.expect("Failed to add doc");
let stats = engine.stats().await.expect("Failed to get stats");
assert_eq!(stats.document_count, 1);
engine
.delete_document(&doc_id)
.await
.expect("Failed to delete doc");
let stats = engine.stats().await.expect("Failed to get stats");
assert_eq!(stats.document_count, 0);
}
#[tokio::test]
async fn test_min_score_filtering() {
let config = RagConfig {
min_score: 5.0, ..Default::default()
};
let engine = RagEngine::in_memory()
.expect("Failed to create engine")
.with_config(config);
let doc = create_test_document("Some content about cats and dogs.", "pets");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine
.query("quantum computing algorithms")
.await
.expect("Query failed");
for source in &response.sources {
assert!(source.score >= 5.0);
}
}
#[tokio::test]
async fn test_min_score_zero() {
let config = RagConfig {
min_score: 0.0,
..Default::default()
};
let engine = RagEngine::in_memory()
.expect("Failed to create engine")
.with_config(config);
let doc = create_test_document("Any content here.", "test");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("test query").await.expect("Query failed");
assert!(response.retrieval_stats.chunks_used >= 0);
}
#[tokio::test]
async fn test_context_token_limit() {
let config = RagConfig {
max_context_tokens: 10, min_score: 0.0,
..Default::default()
};
let engine = RagEngine::in_memory()
.expect("Failed to create engine")
.with_config(config);
let doc = create_test_document(
"This is a very long document that contains many words and should exceed the token limit when assembled into context.",
"long-doc",
);
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("document").await.expect("Query failed");
assert!(response.retrieval_stats.context_tokens <= 10);
}
#[tokio::test]
async fn test_context_assembly_format() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Test content for context assembly.", "test");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("test content").await.expect("Query failed");
assert!(response.answer.contains("Retrieved"));
assert!(response.answer.contains("score:"));
}
#[tokio::test]
async fn test_empty_query() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Some content here.", "test");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("").await.expect("Query failed");
assert!(!response.query.is_empty() || response.query.is_empty()); }
#[tokio::test]
async fn test_query_with_special_characters() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("C++ and C# are programming languages.", "langs");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("C++ programming").await.expect("Query failed");
assert!(!response.answer.is_empty());
}
#[tokio::test]
async fn test_query_with_unicode() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document(
"Machine learning is used in Tokyo for traffic optimization.",
"japan",
);
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("Tokyo traffic").await.expect("Query failed");
assert!(!response.answer.is_empty());
}
#[tokio::test]
async fn test_very_long_query() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Short content.", "short");
engine.add_document(&doc).await.expect("Failed to add doc");
let long_query = "word ".repeat(1000);
let response = engine.query(&long_query).await.expect("Query failed");
assert!(!response.answer.is_empty());
}
#[tokio::test]
async fn test_no_documents_query() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let response = engine.query("any query").await.expect("Query failed");
assert_eq!(response.sources.len(), 0);
assert_eq!(response.retrieval_stats.chunks_used, 0);
}
#[tokio::test]
async fn test_rag_response_structure() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Complete content for response test.", "response");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("response test").await.expect("Query failed");
assert!(!response.answer.is_empty());
assert_eq!(response.query, "response test");
assert!(response.tokens_used.is_none()); assert!(response.retrieval_stats.retrieval_time_ms >= 0);
}
#[tokio::test]
async fn test_rag_source_structure() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Source structure test content.", "source");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine
.query("source structure")
.await
.expect("Query failed");
for source in &response.sources {
assert!(!source.chunk_id.is_nil());
assert!(!source.text.is_empty());
assert!(source.score >= 0.0);
assert!(source.text.len() <= 200 + 3); }
}
#[tokio::test]
async fn test_rag_stats_structure() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Stats test content.", "stats");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("stats").await.expect("Query failed");
let stats = &response.retrieval_stats;
assert!(stats.chunks_retrieved > 0 || stats.chunks_used == 0);
assert!(stats.retrieval_time_ms < 10000); }
#[tokio::test]
async fn test_rag_response_serialization() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document("Serialization test.", "serial");
engine.add_document(&doc).await.expect("Failed to add doc");
let response = engine.query("serialization").await.expect("Query failed");
let json = serde_json::to_string(&response).expect("Serialization failed");
assert!(json.contains("answer"));
assert!(json.contains("sources"));
assert!(json.contains("retrieval_stats"));
let deserialized: RagResponse =
serde_json::from_str(&json).expect("Deserialization failed");
assert_eq!(response.query, deserialized.query);
}
#[test]
fn test_rag_source_serialization() {
let source = RagSource {
doc_id: Uuid::new_v4(),
chunk_id: Uuid::new_v4(),
text: "Test text".to_string(),
score: 0.95,
section: Some("Introduction".to_string()),
};
let json = serde_json::to_string(&source).expect("Serialization failed");
let deserialized: RagSource = serde_json::from_str(&json).expect("Deserialization failed");
assert_eq!(source.text, deserialized.text);
assert_eq!(source.score, deserialized.score);
assert_eq!(source.section, deserialized.section);
}
#[test]
fn test_rag_retrieval_stats_serialization() {
let stats = RagRetrievalStats {
chunks_retrieved: 5,
chunks_used: 3,
context_tokens: 150,
retrieval_time_ms: 42,
};
let json = serde_json::to_string(&stats).expect("Serialization failed");
let deserialized: RagRetrievalStats =
serde_json::from_str(&json).expect("Deserialization failed");
assert_eq!(stats.chunks_retrieved, deserialized.chunks_retrieved);
assert_eq!(stats.chunks_used, deserialized.chunks_used);
assert_eq!(stats.context_tokens, deserialized.context_tokens);
assert_eq!(stats.retrieval_time_ms, deserialized.retrieval_time_ms);
}
#[tokio::test]
async fn test_concurrent_queries() {
let engine = std::sync::Arc::new(RagEngine::in_memory().expect("Failed to create engine"));
let doc = create_test_document("Concurrent access test document.", "concurrent");
engine.add_document(&doc).await.expect("Failed to add doc");
let mut handles = vec![];
for i in 0..5 {
let engine_clone = engine.clone();
let handle = tokio::spawn(async move {
let query = format!("query {}", i);
engine_clone.query(&query).await
});
handles.push(handle);
}
for handle in handles {
let result = handle.await.expect("Task panicked");
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_retrieve_only() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
let doc = create_test_document(
"Vector databases store embeddings for semantic search.",
"vector-db",
);
engine.add_document(&doc).await.expect("Failed to add doc");
let results = engine
.retrieve("semantic search embeddings", 5)
.await
.expect("Retrieval failed");
assert!(!results.is_empty());
assert!(results[0].text.contains("embeddings"));
for result in &results {
assert!(!result.chunk_id.is_nil());
assert!(!result.text.is_empty());
}
}
#[tokio::test]
async fn test_retrieve_top_k_limit() {
let engine = RagEngine::in_memory().expect("Failed to create engine");
for i in 0..10 {
let doc = create_test_document(
&format!("Document {} about testing retrieval limits.", i),
&format!("doc-{}", i),
);
engine.add_document(&doc).await.expect("Failed to add doc");
}
let results = engine
.retrieve("testing retrieval", 3)
.await
.expect("Retrieval failed");
assert!(results.len() <= 3);
}
#[test]
fn test_engine_builder_pattern() {
let config = RagConfig::quick();
let _engine = RagEngine::in_memory()
.expect("Failed to create engine")
.with_config(config);
}
#[test]
fn test_config_builder_pattern() {
let mut config = RagConfig::default();
config.top_k = 20;
config.min_score = 0.5;
assert_eq!(config.top_k, 20);
assert_eq!(config.min_score, 0.5);
}
}