#[cfg(test)]
mod rag_tests {
use super::super::*;
use crate::context::manager::{ContextManager, ContextManagerConfig, StandardContextManager};
use crate::types::AgentId;
use std::sync::Arc;
use std::time::Duration;
use tokio;
async fn create_test_context_manager() -> Arc<dyn ContextManager> {
let config = ContextManagerConfig::default();
let manager = StandardContextManager::new(config, "test-agent")
.await
.unwrap();
Arc::new(manager)
}
fn create_test_rag_request() -> RAGRequest {
RAGRequest {
agent_id: AgentId::new(),
query: "What is machine learning?".to_string(),
preferences: QueryPreferences {
response_length: ResponseLength::Standard,
include_citations: true,
preferred_sources: vec!["academic".to_string()],
response_format: ResponseFormat::Text,
language: "en".to_string(),
},
constraints: QueryConstraints {
max_documents: 10,
time_limit: Duration::from_secs(30),
security_level: AccessLevel::Public,
allowed_sources: vec!["public".to_string()],
excluded_sources: vec![],
},
}
}
#[tokio::test]
async fn test_rag_engine_creation() {
let context_manager = create_test_context_manager().await;
let _rag_engine = StandardRAGEngine::new(context_manager);
}
#[tokio::test]
async fn test_query_analysis() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let query = "How to implement machine learning algorithms?";
let result = rag_engine.analyze_query(query, None).await;
assert!(result.is_ok());
let analyzed = result.unwrap();
assert_eq!(analyzed.original_query, query);
assert!(!analyzed.keywords.is_empty());
assert_eq!(analyzed.intent, QueryIntent::Procedural);
assert!(!analyzed.embeddings.is_empty());
}
#[tokio::test]
async fn test_intent_classification() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let test_cases = vec![
("What is artificial intelligence?", QueryIntent::Factual),
("How to train a neural network?", QueryIntent::Procedural),
(
"Compare supervised vs unsupervised learning",
QueryIntent::Analytical,
), (
"Analyze the performance of this model",
QueryIntent::Analytical,
),
(
"Create a new classification algorithm",
QueryIntent::Creative,
),
("Fix this training error", QueryIntent::Troubleshooting),
];
for (query, expected_intent) in test_cases {
let result = rag_engine.analyze_query(query, None).await.unwrap();
assert_eq!(
result.intent, expected_intent,
"Failed for query: {}",
query
);
}
}
#[tokio::test]
async fn test_document_retrieval() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let analyzed_query = AnalyzedQuery {
original_query: "machine learning".to_string(),
expanded_terms: vec![
"machine".to_string(),
"learning".to_string(),
"ML".to_string(),
],
intent: QueryIntent::Factual,
entities: vec![],
keywords: vec!["machine".to_string(), "learning".to_string()],
embeddings: vec![0.1, 0.2, 0.3],
context_keywords: vec!["machine".to_string(), "learning".to_string()],
};
let result = rag_engine.retrieve_documents(&analyzed_query).await;
assert!(result.is_ok());
let documents = result.unwrap();
assert!(!documents.is_empty());
assert!(documents.len() <= 10); }
#[tokio::test]
async fn test_document_ranking() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let analyzed_query = AnalyzedQuery {
original_query: "machine learning".to_string(),
expanded_terms: vec!["machine".to_string(), "learning".to_string()],
intent: QueryIntent::Factual,
entities: vec![],
keywords: vec!["machine".to_string(), "learning".to_string()],
embeddings: vec![0.1, 0.2, 0.3],
context_keywords: vec!["machine".to_string(), "learning".to_string()],
};
let documents = rag_engine
.retrieve_documents(&analyzed_query)
.await
.unwrap();
let result = rag_engine.rank_documents(documents, &analyzed_query).await;
assert!(result.is_ok());
let ranked_docs = result.unwrap();
assert!(!ranked_docs.is_empty());
for i in 1..ranked_docs.len() {
assert!(ranked_docs[i - 1].relevance_score >= ranked_docs[i].relevance_score);
}
for doc in &ranked_docs {
assert!(doc.ranking_factors.semantic_similarity >= 0.0);
assert!(doc.ranking_factors.keyword_match >= 0.0);
assert!(doc.ranking_factors.recency_score >= 0.0);
assert!(doc.ranking_factors.authority_score >= 0.0);
}
}
#[tokio::test]
async fn test_context_augmentation() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let analyzed_query = AnalyzedQuery {
original_query: "machine learning".to_string(),
expanded_terms: vec!["machine".to_string(), "learning".to_string()],
intent: QueryIntent::Factual,
entities: vec![],
keywords: vec!["machine".to_string(), "learning".to_string()],
embeddings: vec![0.1, 0.2, 0.3],
context_keywords: vec!["machine".to_string(), "learning".to_string()],
};
let documents = rag_engine
.retrieve_documents(&analyzed_query)
.await
.unwrap();
let ranked_docs = rag_engine
.rank_documents(documents, &analyzed_query)
.await
.unwrap();
let result = rag_engine
.augment_context(&analyzed_query, ranked_docs)
.await;
assert!(result.is_ok());
let augmented = result.unwrap();
assert_eq!(augmented.original_query, analyzed_query.original_query);
assert!(!augmented.context_summary.is_empty());
assert!(!augmented.citations.is_empty());
assert_eq!(
augmented.retrieved_documents.len(),
augmented.citations.len()
);
}
#[tokio::test]
async fn test_response_generation() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let analyzed_query = AnalyzedQuery {
original_query: "machine learning".to_string(),
expanded_terms: vec!["machine".to_string(), "learning".to_string()],
intent: QueryIntent::Factual,
entities: vec![],
keywords: vec!["machine".to_string(), "learning".to_string()],
embeddings: vec![0.1, 0.2, 0.3],
context_keywords: vec!["machine".to_string(), "learning".to_string()],
};
let documents = rag_engine
.retrieve_documents(&analyzed_query)
.await
.unwrap();
let ranked_docs = rag_engine
.rank_documents(documents, &analyzed_query)
.await
.unwrap();
let augmented_context = rag_engine
.augment_context(&analyzed_query, ranked_docs)
.await
.unwrap();
let result = rag_engine.generate_response(augmented_context).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(!response.content.is_empty());
assert!(response.confidence > 0.0);
assert!(response.confidence <= 1.0);
assert!(!response.citations.is_empty());
assert_eq!(response.validation_status, ValidationStatus::Pending);
}
#[tokio::test]
async fn test_response_validation() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let response = GeneratedResponse {
content: "This is a test response about machine learning.".to_string(),
confidence: 0.8,
citations: vec![],
metadata: ResponseMetadata {
generation_time: Duration::from_millis(100),
tokens_used: 50,
sources_consulted: 2,
model_version: "test-v1.0".to_string(),
},
validation_status: ValidationStatus::Pending,
};
let result = rag_engine
.validate_response(&response, AgentId::new())
.await;
assert!(result.is_ok());
let validation = result.unwrap();
assert!(validation.is_valid);
assert!(validation.policy_violations.is_empty());
assert!(validation.content_issues.is_empty());
assert!(validation.confidence_score > 0.0);
}
#[tokio::test]
async fn test_full_rag_pipeline() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let request = create_test_rag_request();
let result = rag_engine.process_query(request).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(!response.response.content.is_empty());
assert!(response.confidence_score > 0.0);
assert!(response.processing_time > Duration::from_millis(0));
assert!(!response.follow_up_suggestions.is_empty());
}
#[tokio::test]
async fn test_keyword_extraction() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let text = "Machine learning algorithms for natural language processing";
let keywords = rag_engine.extract_keywords(text);
assert!(keywords.contains(&"machine".to_string()));
assert!(keywords.contains(&"learning".to_string()));
assert!(keywords.contains(&"algorithms".to_string()));
assert!(keywords.contains(&"natural".to_string()));
assert!(keywords.contains(&"language".to_string()));
assert!(keywords.contains(&"processing".to_string()));
}
#[tokio::test]
async fn test_entity_extraction() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let text = "OpenAI developed GPT-3 with 175 billion parameters";
let entities = rag_engine.extract_entities(text);
assert!(entities
.iter()
.any(|e| e.text == "OpenAI" && matches!(e.entity_type, EntityType::Concept)));
assert!(entities
.iter()
.any(|e| e.text == "175" && matches!(e.entity_type, EntityType::Number)));
}
#[tokio::test]
async fn test_semantic_similarity() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let vec1 = vec![1.0, 0.0, 0.0];
let vec2 = vec![1.0, 0.0, 0.0];
let vec3 = vec![0.0, 1.0, 0.0];
let sim1 = rag_engine.calculate_semantic_similarity(&vec1, &vec2);
assert!((sim1 - 1.0).abs() < 0.001);
let sim2 = rag_engine.calculate_semantic_similarity(&vec1, &vec3);
assert!((sim2 - 0.0).abs() < 0.001);
}
#[tokio::test]
async fn test_mock_embeddings_generation() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let text = "test text for embeddings";
let embeddings = rag_engine.generate_mock_embeddings(text);
assert_eq!(embeddings.len(), 384);
let norm: f32 = embeddings.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn test_rag_stats() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let result = rag_engine.get_stats().await;
assert!(result.is_ok());
let stats = result.unwrap();
assert_eq!(stats.total_documents, 0);
assert_eq!(stats.total_queries, 0);
assert_eq!(stats.cache_hit_rate, 0.0);
assert_eq!(stats.validation_pass_rate, 0.0);
}
}