mod common;
use lc::config::Config;
use lc::vector_db::VectorDatabase;
use std::collections::HashMap;
#[cfg(test)]
mod similar_search_tests {
use super::*;
fn setup_test_database_with_vectors(test_name: &str) -> VectorDatabase {
let db_name = format!("similar_test_{}_{}", test_name, std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let test_data = vec![
(
"Machine learning is a subset of artificial intelligence",
vec![0.8, 0.6, 0.2, 0.1, 0.0],
),
(
"Deep learning uses neural networks with multiple layers",
vec![0.7, 0.7, 0.3, 0.1, 0.0],
),
(
"Artificial intelligence enables computers to think",
vec![0.9, 0.5, 0.1, 0.1, 0.0],
),
(
"Python is a popular programming language",
vec![0.1, 0.2, 0.8, 0.7, 0.1],
),
(
"JavaScript runs in web browsers",
vec![0.0, 0.1, 0.9, 0.6, 0.2],
),
(
"Cooking pasta requires boiling water",
vec![0.0, 0.0, 0.1, 0.1, 0.9],
),
];
for (text, vector) in test_data {
db.add_vector(text, &vector, model, provider).unwrap();
}
db
}
#[test]
fn test_basic_similarity_search() {
let db = setup_test_database_with_vectors("basic");
let ai_query = vec![0.85, 0.55, 0.15, 0.1, 0.0];
let result = db.find_similar(&ai_query, 3);
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 3);
assert!(similar[0].1 >= similar[1].1);
assert!(similar[1].1 >= similar[2].1);
assert!(
similar[0].0.text.to_lowercase().contains("artificial")
|| similar[0].0.text.to_lowercase().contains("machine")
|| similar[0].0.text.to_lowercase().contains("deep")
);
}
#[test]
fn test_similarity_search_with_limit() {
let db = setup_test_database_with_vectors("with_limit");
let query = vec![0.5, 0.5, 0.5, 0.5, 0.5];
for limit in 1..=6 {
let result = db.find_similar(&query, limit);
assert!(result.is_ok());
let similar = result.unwrap();
let expected_len = std::cmp::min(limit, 6); assert_eq!(similar.len(), expected_len);
}
}
#[test]
fn test_similarity_search_with_zero_limit() {
let db = setup_test_database_with_vectors("zero_limit");
let query = vec![0.5, 0.5, 0.5, 0.5, 0.5];
let result = db.find_similar(&query, 0);
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 0);
}
#[test]
fn test_similarity_search_exact_match() {
let db = setup_test_database_with_vectors("exact_match");
let exact_query = vec![0.8, 0.6, 0.2, 0.1, 0.0];
let result = db.find_similar(&exact_query, 1);
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 1);
assert!(similar[0].1 > 0.95);
assert!(similar[0].0.text.contains("Machine learning"));
}
#[test]
fn test_similarity_search_ordering() {
let db = setup_test_database_with_vectors("ordering");
let query = vec![0.8, 0.6, 0.2, 0.1, 0.0]; let result = db.find_similar(&query, 6);
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 6);
for i in 1..similar.len() {
assert!(
similar[i - 1].1 >= similar[i].1,
"Similarity scores not in descending order: {} >= {}",
similar[i - 1].1,
similar[i].1
);
}
let top_results = &similar[0..3];
let ai_keywords = [
"machine",
"learning",
"artificial",
"intelligence",
"deep",
"neural",
];
for (entry, _) in top_results {
let text_lower = entry.text.to_lowercase();
let has_ai_keyword = ai_keywords
.iter()
.any(|&keyword| text_lower.contains(keyword));
assert!(
has_ai_keyword,
"Top result should contain AI-related keywords: {}",
entry.text
);
}
}
#[test]
fn test_similarity_search_empty_database() {
let db_name = format!("empty_similar_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let query = vec![1.0, 0.0, 0.0];
let result = db.find_similar(&query, 5);
assert!(result.is_ok());
let similar = result.unwrap();
assert!(similar.is_empty());
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similarity_search_single_vector() {
let db_name = format!("single_similar_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let vector = vec![0.5, 0.5, 0.5];
db.add_vector("Single vector", &vector, model, provider)
.unwrap();
let query = vec![0.6, 0.4, 0.5];
let result = db.find_similar(&query, 5);
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 1);
assert_eq!(similar[0].0.text, "Single vector");
VectorDatabase::delete_database(&db_name).unwrap();
}
}
#[cfg(test)]
mod similar_model_resolution_tests {
use super::*;
fn create_test_config_for_similarity() -> Config {
let mut config = Config {
providers: HashMap::new(),
default_provider: Some("openai".to_string()),
default_model: Some("text-embedding-3-small".to_string()),
aliases: HashMap::new(),
system_prompt: None,
templates: HashMap::new(),
max_tokens: None,
temperature: None,
stream: None,
};
config.providers.insert(
"openai".to_string(),
lc::config::ProviderConfig {
endpoint: "https://api.openai.com/v1".to_string(),
api_key: Some("sk-test123".to_string()),
models: vec![
"text-embedding-3-small".to_string(),
"text-embedding-3-large".to_string(),
],
models_path: "/v1/models".to_string(),
chat_path: "/v1/chat/completions".to_string(),
headers: HashMap::new(),
token_url: None,
cached_token: None,
auth_type: None,
vars: HashMap::new(),
images_path: Some("/images/generations".to_string()),
embeddings_path: Some("/embeddings".to_string()),
chat_templates: None,
images_templates: None,
embeddings_templates: None,
models_templates: None,
audio_path: None,
speech_path: None,
audio_templates: None,
speech_templates: None,
},
);
config.providers.insert(
"cohere".to_string(),
lc::config::ProviderConfig {
endpoint: "https://api.cohere.ai/v1".to_string(),
api_key: Some("cohere-test-key".to_string()),
models: vec!["embed-english-v3.0".to_string()],
models_path: "/v1/models".to_string(),
chat_path: "/v1/chat".to_string(),
headers: HashMap::new(),
token_url: None,
cached_token: None,
auth_type: None,
vars: HashMap::new(),
images_path: Some("/images/generations".to_string()),
embeddings_path: Some("/embeddings".to_string()),
chat_templates: None,
images_templates: None,
embeddings_templates: None,
models_templates: None,
audio_path: None,
speech_path: None,
audio_templates: None,
speech_templates: None,
},
);
config
}
#[test]
fn test_similar_model_resolution_from_database() {
let config = create_test_config_for_similarity();
let db_name = format!("model_resolution_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let stored_model = "text-embedding-3-large";
let stored_provider = "openai";
let vector = vec![0.1, 0.2, 0.3];
db.add_vector("Test text", &vector, stored_model, stored_provider)
.unwrap();
let model_info = db.get_model_info().unwrap();
assert!(model_info.is_some());
let (db_model, db_provider) = model_info.unwrap();
assert_eq!(db_model, stored_model);
assert_eq!(db_provider, stored_provider);
let result =
lc::utils::resolve_model_and_provider(&config, Some(db_provider), Some(db_model));
assert!(result.is_ok());
let (resolved_provider, resolved_model) = result.unwrap();
assert_eq!(resolved_provider, stored_provider);
assert_eq!(resolved_model, stored_model);
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similar_model_override() {
let config = create_test_config_for_similarity();
let db_name = format!("model_override_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
db.add_vector(
"Test",
&vec![0.1, 0.2, 0.3],
"text-embedding-3-small",
"openai",
)
.unwrap();
let result = lc::utils::resolve_model_and_provider(
&config,
Some("cohere".to_string()),
Some("embed-english-v3.0".to_string()),
);
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider, "cohere");
assert_eq!(model, "embed-english-v3.0");
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similar_with_provider_model_format() {
let config = create_test_config_for_similarity();
let result = lc::utils::resolve_model_and_provider(
&config,
None,
Some("cohere:embed-english-v3.0".to_string()),
);
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider, "cohere");
assert_eq!(model, "embed-english-v3.0");
}
}
#[cfg(test)]
mod similar_parameter_validation_tests {
use super::*;
#[test]
fn test_similarity_limit_validation() {
let db_name = format!("limit_validation_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
for i in 0..5 {
let vector = vec![i as f64, 0.0, 0.0];
db.add_vector(&format!("Text {}", i), &vector, model, provider)
.unwrap();
}
let query = vec![0.0, 0.0, 0.0];
let test_limits = vec![0, 1, 3, 5, 10, 100];
for limit in test_limits {
let result = db.find_similar(&query, limit);
assert!(result.is_ok(), "Failed with limit: {}", limit);
let similar = result.unwrap();
let expected_len = std::cmp::min(limit, 5); assert_eq!(
similar.len(),
expected_len,
"Wrong result count for limit: {}",
limit
);
}
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similarity_query_vector_validation() {
let db_name = format!("query_validation_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let stored_vector = vec![1.0, 0.0, 0.0];
db.add_vector("Test vector", &stored_vector, model, provider)
.unwrap();
let test_queries = vec![
vec![1.0, 0.0, 0.0], vec![0.5, 0.5, 0.5], vec![1.0, 0.0], vec![1.0, 0.0, 0.0, 0.0], vec![], ];
for (i, query) in test_queries.iter().enumerate() {
let result = db.find_similar(query, 1);
match result {
Ok(similar) => {
assert!(similar.len() <= 1, "Query {} returned too many results", i);
}
Err(_) => {
}
}
}
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similarity_with_special_values() {
let db_name = format!("special_values_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let normal_vector = vec![0.5, 0.5, 0.5];
db.add_vector("Normal vector", &normal_vector, model, provider)
.unwrap();
let special_queries = vec![
vec![0.0, 0.0, 0.0], vec![1.0, 1.0, 1.0], vec![-1.0, -1.0, -1.0], vec![f64::MAX, 0.0, 0.0], vec![f64::MIN, 0.0, 0.0], ];
for query in special_queries {
let result = db.find_similar(&query, 1);
match result {
Ok(similar) => {
assert!(similar.len() <= 1);
if !similar.is_empty() {
assert!(
similar[0].1.is_finite(),
"Similarity score should be finite for query: {:?}",
query
);
}
}
Err(_) => {
}
}
}
VectorDatabase::delete_database(&db_name).unwrap();
}
}
#[cfg(test)]
mod similar_error_handling_tests {
use super::*;
#[test]
fn test_similar_with_nonexistent_database() {
let db_name = format!("nonexistent_similar_db_{}", std::process::id());
let result = VectorDatabase::new(&db_name);
match result {
Ok(db) => {
let query = vec![1.0, 0.0, 0.0];
let similar_result = db.find_similar(&query, 5);
assert!(similar_result.is_ok());
let similar = similar_result.unwrap();
assert!(similar.is_empty());
VectorDatabase::delete_database(&db_name).unwrap();
}
Err(_) => {
}
}
}
#[test]
fn test_similar_with_corrupted_data() {
let db_name = format!("corruption_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
db.add_vector("Normal text", &vec![0.1, 0.2, 0.3], model, provider)
.unwrap();
let query = vec![0.1, 0.2, 0.3];
let result = db.find_similar(&query, 1);
assert!(result.is_ok());
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similar_with_invalid_model_info() {
let config = Config {
providers: HashMap::new(),
default_provider: None,
default_model: None,
aliases: HashMap::new(),
system_prompt: None,
templates: HashMap::new(),
max_tokens: None,
temperature: None,
stream: None,
};
let result = lc::utils::resolve_model_and_provider(&config, None, None);
assert!(result.is_err());
let result = lc::utils::resolve_model_and_provider(
&config,
Some("invalid".to_string()),
Some("model".to_string()),
);
assert!(result.is_err());
}
}
#[cfg(test)]
mod similar_performance_tests {
use super::*;
#[test]
fn test_similarity_search_performance() {
let db_name = format!("performance_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
let vector_count = 100;
for i in 0..vector_count {
let vector: Vec<f64> = (0..10).map(|j| (i * 10 + j) as f64 * 0.01).collect();
db.add_vector(&format!("Vector {}", i), &vector, model, provider)
.unwrap();
}
let query: Vec<f64> = (0..10).map(|i| i as f64 * 0.01).collect();
let start = std::time::Instant::now();
let result = db.find_similar(&query, 10);
let duration = start.elapsed();
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 10);
assert!(
duration.as_secs() < 1,
"Search took too long: {:?}",
duration
);
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similarity_search_with_large_vectors() {
let db_name = format!("large_vector_performance_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-large";
let provider = "openai";
let dimension = 1536; let vector_count = 10;
for i in 0..vector_count {
let vector: Vec<f64> = (0..dimension)
.map(|j| ((i * dimension + j) as f64) * 0.0001)
.collect();
db.add_vector(&format!("Large vector {}", i), &vector, model, provider)
.unwrap();
}
let query: Vec<f64> = (0..dimension).map(|i| (i as f64) * 0.0001).collect();
let result = db.find_similar(&query, 5);
assert!(result.is_ok());
let similar = result.unwrap();
assert_eq!(similar.len(), 5);
for (entry, _) in &similar {
assert_eq!(entry.vector.len(), dimension);
}
VectorDatabase::delete_database(&db_name).unwrap();
}
}
#[cfg(test)]
mod similar_integration_tests {
use super::*;
#[test]
fn test_complete_similarity_workflow() {
let _config = Config {
providers: HashMap::new(),
default_provider: Some("openai".to_string()),
default_model: Some("text-embedding-3-small".to_string()),
aliases: HashMap::new(),
system_prompt: None,
templates: HashMap::new(),
max_tokens: None,
temperature: None,
stream: None,
};
let db_name = format!("similarity_workflow_test_{}", std::process::id());
let model = "text-embedding-3-small";
let provider = "openai";
let db = VectorDatabase::new(&db_name).unwrap();
let test_vectors = vec![
(
"Artificial intelligence research",
vec![0.9, 0.1, 0.0, 0.0, 0.0],
),
("Machine learning algorithms", vec![0.8, 0.2, 0.0, 0.0, 0.0]),
(
"Web development with JavaScript",
vec![0.0, 0.0, 0.9, 0.1, 0.0],
),
("Database design principles", vec![0.0, 0.0, 0.1, 0.9, 0.0]),
("Cooking Italian cuisine", vec![0.0, 0.0, 0.0, 0.0, 1.0]),
];
for (text, vector) in &test_vectors {
db.add_vector(text, vector, model, provider).unwrap();
}
let model_info = db.get_model_info().unwrap().unwrap();
assert_eq!(model_info.0, model);
assert_eq!(model_info.1, provider);
let ai_query = vec![0.85, 0.15, 0.0, 0.0, 0.0];
let similar = db.find_similar(&ai_query, 3).unwrap();
assert_eq!(similar.len(), 3);
assert!(
similar[0].0.text.to_lowercase().contains("artificial")
|| similar[0].0.text.to_lowercase().contains("intelligence")
);
assert!(
similar[1].0.text.to_lowercase().contains("machine")
|| similar[1].0.text.to_lowercase().contains("learning")
);
let web_query = vec![0.0, 0.0, 0.8, 0.2, 0.0];
let web_similar = db.find_similar(&web_query, 2).unwrap();
assert_eq!(web_similar.len(), 2);
assert!(
web_similar[0].0.text.to_lowercase().contains("web")
|| web_similar[0].0.text.to_lowercase().contains("javascript")
);
let all_similar = db.find_similar(&ai_query, 10).unwrap();
assert_eq!(all_similar.len(), 5);
for i in 1..all_similar.len() {
assert!(all_similar[i - 1].1 >= all_similar[i].1);
}
VectorDatabase::delete_database(&db_name).unwrap();
}
#[test]
fn test_similarity_with_model_consistency() {
let db_name = format!("model_consistency_test_{}", std::process::id());
let db = VectorDatabase::new(&db_name).unwrap();
let model = "text-embedding-3-small";
let provider = "openai";
db.add_vector("First text", &vec![0.1, 0.2, 0.3], model, provider)
.unwrap();
db.add_vector("Second text", &vec![0.2, 0.3, 0.4], model, provider)
.unwrap();
let model_info = db.get_model_info().unwrap().unwrap();
assert_eq!(model_info.0, model);
assert_eq!(model_info.1, provider);
let query = vec![0.15, 0.25, 0.35];
let similar = db.find_similar(&query, 2).unwrap();
assert_eq!(similar.len(), 2);
for (entry, _) in &similar {
assert_eq!(entry.model, model);
assert_eq!(entry.provider, provider);
}
VectorDatabase::delete_database(&db_name).unwrap();
}
}