mod common;
use lc::config::Config;
use lc::provider::EmbeddingRequest;
use lc::vector_db::VectorDatabase;
use std::collections::HashMap;
#[cfg(test)]
mod rag_context_retrieval_tests {
use super::*;
fn setup_rag_test_database(db_name: &str) -> VectorDatabase {
let _ = VectorDatabase::delete_database(db_name);
let db = VectorDatabase::new(db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let knowledge_base = vec![
("Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed.",
vec![0.8, 0.6, 0.2, 0.1, 0.0, 0.1]),
("Deep learning is a subset of machine learning that uses neural networks with multiple layers to model and understand complex patterns in data.",
vec![0.7, 0.7, 0.3, 0.1, 0.0, 0.1]),
("Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes that process information.",
vec![0.6, 0.8, 0.4, 0.1, 0.0, 0.1]),
("Natural language processing (NLP) is a branch of AI that helps computers understand, interpret and manipulate human language.",
vec![0.5, 0.4, 0.8, 0.2, 0.0, 0.1]),
("Computer vision is a field of AI that trains computers to interpret and understand visual information from the world.",
vec![0.4, 0.3, 0.2, 0.8, 0.2, 0.1]),
("Python is a high-level programming language known for its simplicity and readability. It's widely used in data science and AI.",
vec![0.2, 0.1, 0.3, 0.1, 0.8, 0.5]),
];
for (text, vector) in knowledge_base {
db.add_vector(text, &vector, model, provider).unwrap();
}
db
}
#[test]
fn test_rag_context_retrieval_basic() {
let db_name = "rag_context_basic";
let db = setup_rag_test_database(db_name);
let ml_query = vec![0.75, 0.65, 0.25, 0.1, 0.0, 0.1];
let similar = db.find_similar(&ml_query, 3).unwrap();
assert_eq!(similar.len(), 3);
let context_texts: Vec<&str> = similar
.iter()
.map(|(entry, _)| entry.text.as_str())
.collect();
assert!(context_texts
.iter()
.any(|text| text.contains("machine learning") || text.contains("Machine learning")));
assert!(similar[0].1 >= similar[1].1);
assert!(similar[1].1 >= similar[2].1);
VectorDatabase::delete_database(db_name).unwrap();
}
#[test]
fn test_rag_context_filtering_by_similarity() {
let db_name = "rag_context_filtering";
let db = setup_rag_test_database(db_name);
let specific_query = vec![0.8, 0.6, 0.2, 0.1, 0.0, 0.1]; let similar = db.find_similar(&specific_query, 6).unwrap();
assert_eq!(similar.len(), 6);
let high_similarity_results: Vec<_> = similar
.iter()
.filter(|(_, similarity)| *similarity > 0.3)
.collect();
assert!(!high_similarity_results.is_empty());
assert!(
similar[0].1 > 0.8,
"First result similarity: {}",
similar[0].1
);
VectorDatabase::delete_database(db_name).unwrap();
}
#[test]
fn test_rag_context_with_different_topics() {
let db_name = "rag_context_topics";
let db = setup_rag_test_database(db_name);
let test_queries = vec![
(vec![0.7, 0.7, 0.3, 0.1, 0.0, 0.1], "deep learning"), (vec![0.5, 0.4, 0.8, 0.2, 0.0, 0.1], "natural language"), (vec![0.4, 0.3, 0.2, 0.8, 0.2, 0.1], "computer vision"), (vec![0.2, 0.1, 0.3, 0.1, 0.8, 0.5], "python"), ];
for (query_vector, expected_topic) in test_queries {
let similar = db.find_similar(&query_vector, 2).unwrap();
assert!(!similar.is_empty());
let top_result = &similar[0].0.text.to_lowercase();
assert!(
top_result.contains(expected_topic),
"Query for '{}' didn't return relevant content. Got: '{}'",
expected_topic,
similar[0].0.text
);
}
VectorDatabase::delete_database(db_name).unwrap();
}
#[test]
fn test_rag_context_empty_database() {
let empty_db = VectorDatabase::new("empty_rag_test").unwrap();
let query = vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
let similar = empty_db.find_similar(&query, 3).unwrap();
assert!(similar.is_empty());
VectorDatabase::delete_database("empty_rag_test").unwrap();
}
#[test]
fn test_rag_context_formatting() {
let db_name = "rag_context_formatting";
let db = setup_rag_test_database(db_name);
let query = vec![0.8, 0.6, 0.2, 0.1, 0.0, 0.1];
let similar = db.find_similar(&query, 3).unwrap();
let mut context = String::new();
let mut included_count = 0;
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
included_count += 1;
}
}
assert!(!context.is_empty());
assert!(included_count > 0);
assert!(context.contains("- "));
assert!(context.ends_with('\n'));
assert!(
context.to_lowercase().contains("machine learning")
|| context.to_lowercase().contains("artificial intelligence")
);
VectorDatabase::delete_database(db_name).unwrap();
}
}
#[cfg(test)]
mod rag_model_consistency_tests {
use super::*;
#[test]
fn test_rag_model_info_retrieval() {
let db = VectorDatabase::new("rag_model_test").unwrap();
let stored_model = "text-embedding-3-small";
let stored_provider = "openai";
let vector = vec![0.1, 0.2, 0.3, 0.4, 0.5];
db.add_vector("Test content", &vector, stored_model, stored_provider)
.unwrap();
let model_info = db.get_model_info().unwrap();
assert!(model_info.is_some());
let (db_model, db_provider) = model_info.unwrap();
assert_eq!(db_model, stored_model);
assert_eq!(db_provider, stored_provider);
VectorDatabase::delete_database("rag_model_test").unwrap();
}
#[test]
fn test_rag_embedding_request_creation() {
let model = "text-embedding-3-small";
let query = "What is machine learning?";
let embedding_request = EmbeddingRequest {
model: model.to_string(),
input: query.to_string(),
encoding_format: Some("float".to_string()),
};
assert_eq!(embedding_request.model, model);
assert_eq!(embedding_request.input, query);
assert_eq!(embedding_request.encoding_format, Some("float".to_string()));
}
#[test]
fn test_rag_dimension_consistency_check() {
let db = VectorDatabase::new("rag_dimension_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let stored_vector = vec![0.1; 1536]; db.add_vector("Stored content", &stored_vector, model, provider)
.unwrap();
let matching_query = vec![0.2; 1536];
let result = db.find_similar(&matching_query, 1);
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 1);
let mismatched_query = vec![0.2; 1024];
let result = db.find_similar(&mismatched_query, 1);
let _ = result;
VectorDatabase::delete_database("rag_dimension_test").unwrap();
}
#[test]
fn test_rag_provider_client_separation() {
let mut config = Config {
providers: HashMap::new(),
default_provider: Some("venice".to_string()), default_model: Some("llama-3.3-70b".to_string()),
aliases: HashMap::new(),
system_prompt: None,
templates: HashMap::new(),
max_tokens: None,
temperature: None,
stream: None,
};
config.providers.insert(
"venice".to_string(),
lc::config::ProviderConfig {
endpoint: "https://api.venice.ai/api/v1".to_string(),
api_key: Some("venice-key".to_string()),
models: vec!["llama-3.3-70b".to_string()],
models_path: "/models".to_string(),
chat_path: "/chat/completions".to_string(),
headers: HashMap::new(),
token_url: None,
cached_token: None,
auth_type: None,
vars: std::collections::HashMap::new(),
images_path: Some("/images/generations".to_string()),
embeddings_path: Some("/embeddings".to_string()),
chat_templates: None,
images_templates: None,
embeddings_templates: None,
models_templates: None,
audio_path: None,
speech_path: None,
audio_templates: None,
speech_templates: None,
},
);
config.providers.insert(
"openai".to_string(),
lc::config::ProviderConfig {
endpoint: "https://api.openai.com/v1".to_string(),
api_key: Some("openai-key".to_string()),
models: vec!["text-embedding-3-small".to_string()],
models_path: "/v1/models".to_string(),
chat_path: "/v1/chat/completions".to_string(),
headers: HashMap::new(),
token_url: None,
cached_token: None,
auth_type: None,
vars: std::collections::HashMap::new(),
images_path: Some("/images/generations".to_string()),
embeddings_path: Some("/embeddings".to_string()),
chat_templates: None,
images_templates: None,
embeddings_templates: None,
models_templates: None,
audio_path: None,
speech_path: None,
audio_templates: None,
speech_templates: None,
},
);
let chat_result = lc::utils::resolve_model_and_provider(&config, None, None);
assert!(chat_result.is_ok());
let (chat_provider, chat_model) = chat_result.unwrap();
assert_eq!(chat_provider, "venice");
assert_eq!(chat_model, "llama-3.3-70b");
let embed_result = lc::utils::resolve_model_and_provider(
&config,
Some("openai".to_string()),
Some("text-embedding-3-small".to_string()),
);
assert!(embed_result.is_ok());
let (embed_provider, embed_model) = embed_result.unwrap();
assert_eq!(embed_provider, "openai");
assert_eq!(embed_model, "text-embedding-3-small");
}
}
#[cfg(test)]
mod rag_integration_tests {
use super::*;
#[test]
fn test_rag_workflow_simulation() {
let db_name = "rag_workflow_test";
let embedding_model = "text-embedding-3-small";
let embedding_provider = "openai";
let _chat_model = "gpt-4o-mini";
let _chat_provider = "openai";
let db = VectorDatabase::new(db_name).unwrap();
let knowledge_entries = vec![
(
"Rust is a systems programming language focused on safety and performance.",
vec![0.8, 0.2, 0.1, 0.0],
),
(
"Python is popular for data science and machine learning applications.",
vec![0.2, 0.8, 0.1, 0.0],
),
(
"JavaScript is the language of the web, running in browsers and servers.",
vec![0.1, 0.2, 0.8, 0.0],
),
(
"Machine learning models require large datasets for training.",
vec![0.1, 0.9, 0.0, 0.1],
),
];
for (text, vector) in &knowledge_entries {
db.add_vector(text, vector, embedding_model, embedding_provider)
.unwrap();
}
let user_query = "Tell me about Python programming";
let query_vector = vec![0.25, 0.75, 0.15, 0.05];
let similar = db.find_similar(&query_vector, 3).unwrap();
assert!(!similar.is_empty());
let mut context = String::new();
let mut relevant_count = 0;
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
relevant_count += 1;
}
}
assert!(!context.is_empty());
assert!(relevant_count > 0);
assert!(context.to_lowercase().contains("python"));
let enhanced_prompt = format!(
"Context from knowledge base:\n{}\n\nUser question: {}",
context, user_query
);
assert!(enhanced_prompt.contains("Context from knowledge base:"));
assert!(enhanced_prompt.contains("User question:"));
assert!(enhanced_prompt.contains(user_query));
assert!(enhanced_prompt.to_lowercase().contains("python"));
let model_info = db.get_model_info().unwrap().unwrap();
assert_eq!(model_info.0, embedding_model);
assert_eq!(model_info.1, embedding_provider);
VectorDatabase::delete_database(db_name).unwrap();
}
#[test]
fn test_rag_with_multiple_relevant_contexts() {
let db = VectorDatabase::new("multi_context_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let related_content = vec![
(
"Deep learning is a subset of machine learning using neural networks.",
vec![0.9, 0.8, 0.1, 0.0],
),
(
"Neural networks consist of layers of interconnected nodes.",
vec![0.8, 0.9, 0.1, 0.0],
),
(
"Backpropagation is the algorithm used to train neural networks.",
vec![0.7, 0.8, 0.2, 0.0],
),
(
"Convolutional neural networks are used for image processing.",
vec![0.6, 0.7, 0.3, 0.1],
),
(
"Recurrent neural networks handle sequential data.",
vec![0.5, 0.6, 0.4, 0.1],
),
];
for (text, vector) in &related_content {
db.add_vector(text, vector, model, provider).unwrap();
}
let nn_query = vec![0.8, 0.85, 0.15, 0.05];
let similar = db.find_similar(&nn_query, 5).unwrap();
assert_eq!(similar.len(), 5);
let high_similarity_count = similar.iter().filter(|(_, sim)| *sim > 0.5).count();
assert!(
high_similarity_count >= 3,
"Should have multiple highly relevant results"
);
let mut context = String::new();
for (entry, similarity) in &similar {
if *similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
}
}
let context_lower = context.to_lowercase();
assert!(context_lower.contains("neural"));
assert!(
context_lower.contains("deep learning") || context_lower.contains("backpropagation")
);
VectorDatabase::delete_database("multi_context_test").unwrap();
}
#[test]
fn test_rag_with_no_relevant_context() {
let db = VectorDatabase::new("no_context_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let unrelated_content = vec![
(
"Cooking pasta requires boiling water and salt.",
vec![0.0, 0.0, 0.0, 1.0],
),
(
"Gardening tips for growing tomatoes in summer.",
vec![0.0, 0.0, 0.1, 0.9],
),
(
"Travel guide to visiting Paris museums.",
vec![0.0, 0.1, 0.0, 0.8],
),
];
for (text, vector) in &unrelated_content {
db.add_vector(text, vector, model, provider).unwrap();
}
let programming_query = vec![1.0, 0.0, 0.0, 0.0];
let similar = db.find_similar(&programming_query, 3).unwrap();
assert_eq!(similar.len(), 3);
for (_, similarity) in &similar {
assert!(
*similarity < 0.5,
"Similarity should be low for unrelated content: {}",
similarity
);
}
let mut context = String::new();
let mut included_count = 0;
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
included_count += 1;
}
}
assert!(
included_count <= 1,
"Should include few or no results due to low similarity"
);
VectorDatabase::delete_database("no_context_test").unwrap();
}
#[test]
fn test_rag_context_length_management() {
let db = VectorDatabase::new("context_length_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let content_entries = vec![
("Short text.", vec![0.8, 0.2, 0.0, 0.0]),
("This is a medium-length text that contains more information about the topic and provides additional context.", vec![0.7, 0.3, 0.0, 0.0]),
("This is a very long text entry that contains extensive information about the topic, providing detailed explanations, examples, and comprehensive coverage of various aspects. It includes multiple sentences and covers different subtopics within the main subject area. This type of content might be typical of documentation, articles, or detailed explanations that would be stored in a knowledge base for retrieval-augmented generation systems.", vec![0.6, 0.4, 0.0, 0.0]),
];
for (text, vector) in &content_entries {
db.add_vector(text, vector, model, provider).unwrap();
}
let query = vec![0.75, 0.25, 0.0, 0.0];
let similar = db.find_similar(&query, 3).unwrap();
let mut context = String::new();
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
}
}
assert!(!context.is_empty());
assert!(context.contains("Short text"));
assert!(context.contains("medium-length"));
assert!(context.contains("very long text"));
VectorDatabase::delete_database("context_length_test").unwrap();
}
}
#[cfg(test)]
mod rag_error_handling_tests {
use super::*;
#[test]
fn test_rag_with_nonexistent_database() {
let result = VectorDatabase::new("nonexistent_rag_db");
match result {
Ok(db) => {
let query = vec![0.5, 0.5, 0.5];
let similar = db.find_similar(&query, 3).unwrap();
assert!(similar.is_empty());
let mut context = String::new();
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
}
}
assert!(context.is_empty());
VectorDatabase::delete_database("nonexistent_rag_db").unwrap();
}
Err(_) => {
}
}
}
#[test]
fn test_rag_with_empty_database() {
let db = VectorDatabase::new("empty_rag_db").unwrap();
let query = vec![0.5, 0.5, 0.5];
let similar = db.find_similar(&query, 3).unwrap();
assert!(similar.is_empty());
let mut context = String::new();
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
}
}
assert!(context.is_empty());
VectorDatabase::delete_database("empty_rag_db").unwrap();
}
#[test]
fn test_rag_with_invalid_model_info() {
let db = VectorDatabase::new("invalid_model_rag_test").unwrap();
let model_info = db.get_model_info().unwrap();
assert!(model_info.is_none());
VectorDatabase::delete_database("invalid_model_rag_test").unwrap();
}
#[test]
fn test_rag_with_dimension_mismatch() {
let db = VectorDatabase::new("dimension_mismatch_rag_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let stored_vector = vec![0.1, 0.2, 0.3, 0.4, 0.5];
db.add_vector("Test content", &stored_vector, model, provider)
.unwrap();
let mismatched_query = vec![0.1, 0.2, 0.3]; let result = db.find_similar(&mismatched_query, 1);
match result {
Ok(similar) => {
for (_, similarity) in similar {
assert!(similarity.is_finite());
}
}
Err(_) => {
}
}
VectorDatabase::delete_database("dimension_mismatch_rag_test").unwrap();
}
#[test]
fn test_rag_context_formatting_edge_cases() {
let db = VectorDatabase::new("formatting_edge_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let edge_cases = vec![
("", vec![0.8, 0.2, 0.0]), ("Single word", vec![0.7, 0.3, 0.0]), ("Text\nwith\nnewlines", vec![0.6, 0.4, 0.0]), ("Text with special chars: !@#$%^&*()", vec![0.5, 0.5, 0.0]), ];
for (text, vector) in &edge_cases {
db.add_vector(text, vector, model, provider).unwrap();
}
let query = vec![0.75, 0.25, 0.0];
let similar = db.find_similar(&query, 4).unwrap();
let mut context = String::new();
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
}
}
if !context.is_empty() {
assert!(context.contains("- "));
}
VectorDatabase::delete_database("formatting_edge_test").unwrap();
}
}
#[cfg(test)]
mod rag_performance_tests {
use super::*;
#[test]
fn test_rag_retrieval_performance() {
let db = VectorDatabase::new("rag_performance_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let entry_count = 100;
for i in 0..entry_count {
let vector: Vec<f64> = (0..10).map(|j| ((i * 10 + j) as f64) * 0.01).collect();
let text = format!(
"Knowledge entry {} with information about topic {}",
i,
i % 10
);
db.add_vector(&text, &vector, model, provider).unwrap();
}
let query: Vec<f64> = (0..10).map(|i| (i as f64) * 0.01).collect();
let start = std::time::Instant::now();
let similar = db.find_similar(&query, 5).unwrap();
let retrieval_duration = start.elapsed();
assert_eq!(similar.len(), 5);
let start = std::time::Instant::now();
let mut context = String::new();
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
}
}
let formatting_duration = start.elapsed();
assert!(
retrieval_duration.as_millis() < 100,
"Retrieval too slow: {:?}",
retrieval_duration
);
assert!(
formatting_duration.as_millis() < 10,
"Formatting too slow: {:?}",
formatting_duration
);
VectorDatabase::delete_database("rag_performance_test").unwrap();
}
#[test]
fn test_rag_with_large_context() {
let db = VectorDatabase::new("large_context_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let base_vector = vec![0.8, 0.2, 0.0, 0.0, 0.0];
for i in 0..20 {
let mut vector = base_vector.clone();
vector[0] += (i as f64) * 0.001;
let text = format!("This is knowledge entry number {} containing detailed information about the main topic. It includes comprehensive explanations and examples that would be useful for answering user questions.", i);
db.add_vector(&text, &vector, model, provider).unwrap();
}
let query = vec![0.8, 0.2, 0.0, 0.0, 0.0];
let similar = db.find_similar(&query, 20).unwrap();
let mut context = String::new();
let mut included_count = 0;
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
included_count += 1;
}
}
assert!(
included_count >= 10,
"Should include many relevant entries: {}",
included_count
);
assert!(context.len() > 1000, "Context should be substantial");
assert!(context.len() < 10000, "Context should not be excessive");
VectorDatabase::delete_database("large_context_test").unwrap();
}
#[test]
fn test_rag_similarity_threshold_performance() {
let db = VectorDatabase::new("threshold_performance_test").unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let mixed_content = vec![
("Machine learning algorithms", vec![0.9, 0.1, 0.0, 0.0]),
("Deep learning networks", vec![0.85, 0.15, 0.0, 0.0]),
("Neural network training", vec![0.8, 0.2, 0.0, 0.0]),
("Data science methods", vec![0.6, 0.4, 0.0, 0.0]),
("Statistical analysis", vec![0.5, 0.5, 0.0, 0.0]),
("Computer algorithms", vec![0.4, 0.6, 0.0, 0.0]),
("Cooking recipes", vec![0.1, 0.0, 0.9, 0.0]),
("Travel destinations", vec![0.0, 0.1, 0.0, 0.9]),
("Sports statistics", vec![0.0, 0.0, 0.1, 0.9]),
];
for (text, vector) in &mixed_content {
db.add_vector(text, vector, model, provider).unwrap();
}
let query = vec![0.9, 0.1, 0.0, 0.0];
let similar = db.find_similar(&query, 9).unwrap();
let mut high_relevance_count = 0;
let mut moderate_relevance_count = 0;
let mut low_relevance_count = 0;
for (_, similarity) in &similar {
if *similarity > 0.8 {
high_relevance_count += 1;
} else if *similarity > 0.3 {
moderate_relevance_count += 1;
} else {
low_relevance_count += 1;
}
}
assert!(
high_relevance_count >= 2,
"Should have high relevance results"
);
assert!(
low_relevance_count >= 2,
"Should have low relevance results"
);
let mut context = String::new();
let mut included_count = 0;
for (entry, similarity) in similar {
if similarity > 0.3 {
context.push_str(&format!("- {}\n", entry.text));
included_count += 1;
}
}
assert_eq!(
included_count,
high_relevance_count + moderate_relevance_count
);
assert!(
context.to_lowercase().contains("machine learning")
|| context.to_lowercase().contains("neural")
);
assert!(
!context.to_lowercase().contains("cooking")
&& !context.to_lowercase().contains("travel")
);
VectorDatabase::delete_database("threshold_performance_test").unwrap();
}
}