use symbi_runtime::context::manager::{
ContextManager, ContextManagerConfig, StandardContextManager,
};
use symbi_runtime::context::types::{Knowledge, KnowledgeFact, KnowledgeId, KnowledgeSource};
use symbi_runtime::rag::engine::{RAGEngine, StandardRAGEngine};
use symbi_runtime::rag::types::*;
use symbi_runtime::types::AgentId;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
async fn create_test_context_manager() -> Arc<dyn ContextManager> {
let config = ContextManagerConfig {
enable_vector_db: false, enable_persistence: false, ..Default::default()
};
let manager = StandardContextManager::new(config, "test-agent")
.await
.expect("Failed to create context manager");
Arc::new(manager)
}
fn create_test_rag_request(agent_id: AgentId, query: &str) -> RAGRequest {
RAGRequest {
agent_id,
query: query.to_string(),
preferences: QueryPreferences {
response_length: ResponseLength::Standard,
include_citations: true,
preferred_sources: vec!["technical".to_string()],
response_format: ResponseFormat::Text,
language: "en".to_string(),
},
constraints: QueryConstraints {
max_documents: 5,
time_limit: Duration::from_millis(500), security_level: AccessLevel::Public,
allowed_sources: vec!["public".to_string()],
excluded_sources: vec![],
},
}
}
async fn populate_test_knowledge(
context_manager: &Arc<dyn ContextManager>,
agent_id: AgentId,
) -> Result<(), Box<dyn std::error::Error>> {
let facts = vec![
KnowledgeFact {
id: KnowledgeId::new(),
subject: "machine learning".to_string(),
predicate: "is".to_string(),
object: "a subset of artificial intelligence that focuses on algorithms that learn from data".to_string(),
confidence: 0.9,
source: KnowledgeSource::UserProvided,
created_at: SystemTime::now(),
verified: true,
},
KnowledgeFact {
id: KnowledgeId::new(),
subject: "neural networks".to_string(),
predicate: "are".to_string(),
object: "computing systems inspired by biological neural networks".to_string(),
confidence: 0.85,
source: KnowledgeSource::UserProvided,
created_at: SystemTime::now(),
verified: true,
},
KnowledgeFact {
id: KnowledgeId::new(),
subject: "deep learning".to_string(),
predicate: "uses".to_string(),
object: "multiple layers of neural networks to model complex patterns".to_string(),
confidence: 0.88,
source: KnowledgeSource::UserProvided,
created_at: SystemTime::now(),
verified: true,
},
];
for fact in facts {
context_manager
.add_knowledge(agent_id, Knowledge::Fact(fact))
.await?;
}
Ok(())
}
#[tokio::test]
async fn test_rag_engine_initialization() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let config = RAGConfig {
embedding_model: EmbeddingModelConfig {
model_name: "mock-model".to_string(),
model_type: EmbeddingModelType::Local,
dimension: 384,
max_tokens: 512,
batch_size: 32,
},
retrieval_config: RetrievalConfig {
max_documents: 10,
similarity_threshold: 0.7,
context_window: 2048,
enable_hybrid_search: true,
reranking_enabled: true,
},
ranking_config: RankingConfig {
ranking_algorithm: RankingAlgorithm::Hybrid,
relevance_weight: 0.4,
recency_weight: 0.2,
authority_weight: 0.2,
diversity_weight: 0.2,
},
generation_config: GenerationConfig {
max_response_length: 1000,
temperature: 0.7,
top_p: 0.9,
enable_citations: true,
response_format: ResponseFormat::Text,
},
validation_config: ValidationConfig {
enable_policy_check: true,
enable_content_filter: true,
enable_fact_check: false, confidence_threshold: 0.7,
},
};
let result = rag_engine.initialize(config).await;
assert!(result.is_ok(), "RAG engine initialization should succeed");
}
#[tokio::test]
async fn test_rag_pipeline_performance() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager.clone());
let agent_id = AgentId::new();
populate_test_knowledge(&context_manager, agent_id)
.await
.unwrap();
let request = create_test_rag_request(agent_id, "What is machine learning?");
let start_time = std::time::Instant::now();
let result = rag_engine.process_query(request).await;
let processing_time = start_time.elapsed();
assert!(result.is_ok(), "RAG query processing should succeed");
let response = result.unwrap();
assert!(
processing_time < Duration::from_millis(500),
"Processing time should be under 500ms, got: {:?}",
processing_time
);
assert!(
!response.response.content.is_empty(),
"Response content should not be empty"
);
assert!(
response.confidence_score > 0.0,
"Confidence score should be positive"
);
assert!(
!response.sources_used.is_empty(),
"Should have sources used"
);
assert!(
!response.follow_up_suggestions.is_empty(),
"Should have follow-up suggestions"
);
}
#[tokio::test]
async fn test_rag_query_analysis_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let test_queries = vec![
("How do neural networks work?", QueryIntent::Factual),
(
"What is the difference between supervised and unsupervised learning?",
QueryIntent::Factual,
),
("Explain deep learning concepts", QueryIntent::Factual),
(
"Analyze the performance of this ML model",
QueryIntent::Analytical,
),
(
"Create a new classification algorithm",
QueryIntent::Creative,
),
(
"Fix this training error in my model",
QueryIntent::Troubleshooting,
),
];
for (query, expected_intent) in test_queries {
let result = rag_engine.analyze_query(query, None).await;
assert!(
result.is_ok(),
"Query analysis should succeed for: {}",
query
);
let analyzed = result.unwrap();
assert_eq!(
analyzed.intent, expected_intent,
"Intent classification failed for: {}",
query
);
assert!(
!analyzed.keywords.is_empty(),
"Keywords should be extracted for: {}",
query
);
assert!(
!analyzed.embeddings.is_empty(),
"Embeddings should be generated for: {}",
query
);
assert_eq!(
analyzed.original_query, query,
"Original query should be preserved"
);
}
}
#[tokio::test]
async fn test_rag_document_retrieval_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager.clone());
let agent_id = AgentId::new();
populate_test_knowledge(&context_manager, agent_id)
.await
.unwrap();
let analyzed_query = AnalyzedQuery {
original_query: "machine learning algorithms".to_string(),
expanded_terms: vec![
"machine".to_string(),
"learning".to_string(),
"algorithms".to_string(),
"ML".to_string(),
],
intent: QueryIntent::Factual,
entities: vec![Entity {
text: "machine learning".to_string(),
entity_type: EntityType::Technology,
confidence: 0.9,
}],
keywords: vec![
"machine".to_string(),
"learning".to_string(),
"algorithms".to_string(),
],
embeddings: vec![0.1; 384], context_keywords: vec!["machine".to_string(), "learning".to_string()],
};
let result = rag_engine.retrieve_documents(&analyzed_query).await;
assert!(result.is_ok(), "Document retrieval should succeed");
let documents = result.unwrap();
assert!(
!documents.is_empty(),
"Should retrieve at least one document"
);
for doc in &documents {
assert!(!doc.title.is_empty(), "Document should have a title");
assert!(!doc.content.is_empty(), "Document should have content");
assert!(
!doc.embeddings.is_empty(),
"Document should have embeddings"
);
assert!(
doc.metadata.created_at <= SystemTime::now(),
"Document creation time should be valid"
);
}
}
#[tokio::test]
async fn test_rag_ranking_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let analyzed_query = AnalyzedQuery {
original_query: "neural networks".to_string(),
expanded_terms: vec![
"neural".to_string(),
"networks".to_string(),
"deep".to_string(),
],
intent: QueryIntent::Factual,
entities: vec![],
keywords: vec!["neural".to_string(), "networks".to_string()],
embeddings: vec![0.2; 384],
context_keywords: vec!["neural".to_string(), "networks".to_string()],
};
let documents = rag_engine
.retrieve_documents(&analyzed_query)
.await
.unwrap();
let result = rag_engine.rank_documents(documents, &analyzed_query).await;
assert!(result.is_ok(), "Document ranking should succeed");
let ranked_docs = result.unwrap();
assert!(!ranked_docs.is_empty(), "Should have ranked documents");
for i in 1..ranked_docs.len() {
assert!(
ranked_docs[i - 1].relevance_score >= ranked_docs[i].relevance_score,
"Documents should be sorted by relevance score"
);
}
for doc in &ranked_docs {
assert!(
doc.ranking_factors.semantic_similarity >= 0.0
&& doc.ranking_factors.semantic_similarity <= 1.0,
"Semantic similarity should be between 0 and 1"
);
assert!(
doc.ranking_factors.keyword_match >= 0.0 && doc.ranking_factors.keyword_match <= 1.0,
"Keyword match should be between 0 and 1"
);
assert!(
doc.ranking_factors.recency_score >= 0.0 && doc.ranking_factors.recency_score <= 1.0,
"Recency score should be between 0 and 1"
);
assert!(
doc.ranking_factors.authority_score >= 0.0
&& doc.ranking_factors.authority_score <= 1.0,
"Authority score should be between 0 and 1"
);
}
}
#[tokio::test]
async fn test_rag_context_augmentation_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let analyzed_query = AnalyzedQuery {
original_query: "deep learning".to_string(),
expanded_terms: vec![
"deep".to_string(),
"learning".to_string(),
"neural".to_string(),
],
intent: QueryIntent::Factual,
entities: vec![],
keywords: vec!["deep".to_string(), "learning".to_string()],
embeddings: vec![0.3; 384],
context_keywords: vec!["deep".to_string(), "learning".to_string()],
};
let documents = rag_engine
.retrieve_documents(&analyzed_query)
.await
.unwrap();
let ranked_docs = rag_engine
.rank_documents(documents, &analyzed_query)
.await
.unwrap();
let result = rag_engine
.augment_context(&analyzed_query, ranked_docs.clone())
.await;
assert!(result.is_ok(), "Context augmentation should succeed");
let augmented = result.unwrap();
assert_eq!(augmented.original_query, analyzed_query.original_query);
assert!(
!augmented.context_summary.is_empty(),
"Context summary should not be empty"
);
assert_eq!(
augmented.citations.len(),
ranked_docs.len(),
"Should have citations for all documents"
);
assert_eq!(
augmented.retrieved_documents.len(),
ranked_docs.len(),
"Should preserve all ranked documents"
);
for citation in &augmented.citations {
assert!(!citation.title.is_empty(), "Citation should have a title");
assert!(
citation.relevance_score >= 0.0,
"Citation should have a valid relevance score"
);
}
}
#[tokio::test]
async fn test_rag_response_generation_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let analyzed_query = AnalyzedQuery {
original_query: "What are neural networks?".to_string(),
expanded_terms: vec!["neural".to_string(), "networks".to_string()],
intent: QueryIntent::Factual,
entities: vec![],
keywords: vec!["neural".to_string(), "networks".to_string()],
embeddings: vec![0.4; 384],
context_keywords: vec!["neural".to_string(), "networks".to_string()],
};
let documents = rag_engine
.retrieve_documents(&analyzed_query)
.await
.unwrap();
let ranked_docs = rag_engine
.rank_documents(documents, &analyzed_query)
.await
.unwrap();
let augmented_context = rag_engine
.augment_context(&analyzed_query, ranked_docs)
.await
.unwrap();
let result = rag_engine
.generate_response(augmented_context.clone())
.await;
assert!(result.is_ok(), "Response generation should succeed");
let response = result.unwrap();
assert!(
!response.content.is_empty(),
"Generated response should have content"
);
assert!(
response.confidence > 0.0 && response.confidence <= 1.0,
"Confidence should be between 0 and 1"
);
assert_eq!(
response.citations.len(),
augmented_context.citations.len(),
"Should preserve citations"
);
assert!(
response.metadata.tokens_used > 0,
"Should report token usage"
);
assert!(
response.metadata.sources_consulted > 0,
"Should report sources consulted"
);
assert_eq!(
response.validation_status,
ValidationStatus::Pending,
"Initial validation status should be pending"
);
}
#[tokio::test]
async fn test_rag_response_validation_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let test_response = GeneratedResponse {
content: "Neural networks are computing systems inspired by biological neural networks that constitute animal brains.".to_string(),
confidence: 0.85,
citations: vec![
Citation {
document_id: DocumentId::new(),
title: "Introduction to Neural Networks".to_string(),
author: Some("AI Expert".to_string()),
url: Some("https://example.com/neural-networks".to_string()),
relevance_score: 0.9,
}
],
metadata: ResponseMetadata {
generation_time: Duration::from_millis(150),
tokens_used: 75,
sources_consulted: 1,
model_version: "test-v1.0".to_string(),
},
validation_status: ValidationStatus::Pending,
};
let agent_id = AgentId::new();
let result = rag_engine.validate_response(&test_response, agent_id).await;
assert!(result.is_ok(), "Response validation should succeed");
let validation = result.unwrap();
assert!(validation.is_valid, "Response should be valid");
assert!(
validation.policy_violations.is_empty(),
"Should have no policy violations"
);
assert!(
validation.content_issues.is_empty(),
"Should have no content issues"
);
assert!(
validation.confidence_score > 0.0,
"Validation confidence should be positive"
);
}
#[tokio::test]
async fn test_rag_end_to_end_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager.clone());
let agent_id = AgentId::new();
populate_test_knowledge(&context_manager, agent_id)
.await
.unwrap();
let test_queries = vec![
"What is machine learning?",
"How do neural networks work?",
"Explain deep learning",
];
for query in test_queries {
let request = create_test_rag_request(agent_id, query);
let start_time = std::time::Instant::now();
let result = rag_engine.process_query(request).await;
let processing_time = start_time.elapsed();
assert!(
result.is_ok(),
"End-to-end RAG processing should succeed for query: {}",
query
);
let response = result.unwrap();
assert!(
processing_time < Duration::from_millis(500),
"Processing should be under 500ms for query: {}, got: {:?}",
query,
processing_time
);
assert!(
!response.response.content.is_empty(),
"Response should have content for query: {}",
query
);
assert!(
response.confidence_score > 0.0,
"Should have positive confidence for query: {}",
query
);
assert!(
response.processing_time > Duration::from_millis(0),
"Should report processing time for query: {}",
query
);
if !response.sources_used.is_empty() {
for citation in &response.sources_used {
assert!(
!citation.title.is_empty(),
"Citation should have title for query: {}",
query
);
assert!(
citation.relevance_score >= 0.0,
"Citation should have valid relevance score for query: {}",
query
);
}
}
assert!(
!response.follow_up_suggestions.is_empty(),
"Should have follow-up suggestions for query: {}",
query
);
}
}
#[tokio::test]
async fn test_rag_stats_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let result = rag_engine.get_stats().await;
assert!(result.is_ok(), "Getting RAG stats should succeed");
let stats = result.unwrap();
assert!(
stats.total_documents == stats.total_documents,
"Total documents should be accessible"
);
assert!(
stats.total_queries == stats.total_queries,
"Total queries should be accessible"
);
assert!(
stats.cache_hit_rate >= 0.0 && stats.cache_hit_rate <= 1.0,
"Cache hit rate should be between 0 and 1"
);
assert!(
stats.validation_pass_rate >= 0.0 && stats.validation_pass_rate <= 1.0,
"Validation pass rate should be between 0 and 1"
);
}
#[tokio::test]
async fn test_rag_error_handling_integration() {
let context_manager = create_test_context_manager().await;
let rag_engine = StandardRAGEngine::new(context_manager);
let request = RAGRequest {
agent_id: AgentId::new(),
query: "Test query".to_string(),
preferences: QueryPreferences {
response_length: ResponseLength::Standard,
include_citations: true,
preferred_sources: vec![],
response_format: ResponseFormat::Text,
language: "en".to_string(),
},
constraints: QueryConstraints {
max_documents: 5,
time_limit: Duration::from_nanos(1), security_level: AccessLevel::Public,
allowed_sources: vec![],
excluded_sources: vec![],
},
};
let result = rag_engine.process_query(request).await;
if result.is_err() {
match result.unwrap_err() {
RAGError::Timeout(_) => {
}
other => {
println!("Got error (acceptable): {:?}", other);
}
}
} else {
println!("Query succeeded despite short timeout (acceptable for mock)");
}
}