use crate::core::traits::*;
use crate::core::{Document, Entity, EntityId, Result, TextChunk};
use std::collections::HashMap;
pub fn test_storage_roundtrip<T>()
where
T: Storage<Entity = Entity, Document = Document, Chunk = TextChunk> + Default,
{
let mut storage = T::default();
let entity = Entity::new(
EntityId::new("test_entity".to_string()),
"Test Entity".to_string(),
"Person".to_string(),
0.9,
);
let stored_id = storage.store_entity(entity.clone()).unwrap();
let retrieved = storage.retrieve_entity(&stored_id).unwrap().unwrap();
assert_eq!(retrieved.name, entity.name);
assert_eq!(retrieved.entity_type, entity.entity_type);
assert_eq!(retrieved.confidence, entity.confidence);
}
pub fn test_storage_nonexistent_key<T>()
where
T: Storage<Entity = Entity, Document = Document, Chunk = TextChunk> + Default,
{
let storage = T::default();
let result = storage.retrieve_entity("nonexistent_id");
match result {
Ok(None) => (), Ok(Some(_)) => panic!("Expected None for nonexistent key, got Some"),
Err(_) => (), }
}
pub fn test_storage_batch_operations<T>()
where
T: Storage<Entity = Entity, Document = Document, Chunk = TextChunk> + Default,
{
let mut storage = T::default();
let entities = vec![
Entity::new(
EntityId::new("entity1".to_string()),
"Entity 1".to_string(),
"Person".to_string(),
0.9,
),
Entity::new(
EntityId::new("entity2".to_string()),
"Entity 2".to_string(),
"Organization".to_string(),
0.8,
),
Entity::new(
EntityId::new("entity3".to_string()),
"Entity 3".to_string(),
"Location".to_string(),
0.7,
),
];
let stored_ids = storage.store_entities_batch(entities.clone()).unwrap();
assert_eq!(stored_ids.len(), 3);
for (stored_id, original_entity) in stored_ids.iter().zip(entities.iter()) {
let retrieved = storage.retrieve_entity(stored_id).unwrap().unwrap();
assert_eq!(retrieved.name, original_entity.name);
}
}
pub fn test_storage_id_consistency<T>()
where
T: Storage<Entity = Entity, Document = Document, Chunk = TextChunk> + Default,
{
let mut storage = T::default();
let entity = Entity::new(
EntityId::new("consistent_id".to_string()),
"Test Entity".to_string(),
"Person".to_string(),
0.9,
);
let id1 = storage.store_entity(entity.clone()).unwrap();
let id2 = storage.store_entity(entity.clone()).unwrap();
println!("First ID: {id1}, Second ID: {id2}");
}
pub fn test_vector_store_basic_operations<T>()
where
T: VectorStore + Default,
{
let mut store = T::default();
let vector1 = vec![1.0, 2.0, 3.0];
let vector2 = vec![4.0, 5.0, 6.0];
let metadata = Some(HashMap::from([("type".to_string(), "test".to_string())]));
store
.add_vector("vec1".to_string(), vector1.clone(), metadata.clone())
.unwrap();
store
.add_vector("vec2".to_string(), vector2.clone(), None)
.unwrap();
assert_eq!(store.len(), 2);
assert!(!store.is_empty());
let results = store.search(&vector1, 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, "vec1");
}
pub fn test_vector_store_similarity_properties<T>()
where
T: VectorStore + Default,
{
let mut store = T::default();
let vector = vec![1.0, 1.0, 1.0];
store
.add_vector("identical1".to_string(), vector.clone(), None)
.unwrap();
store
.add_vector("identical2".to_string(), vector.clone(), None)
.unwrap();
let different_vector = vec![10.0, 10.0, 10.0];
store
.add_vector("different".to_string(), different_vector, None)
.unwrap();
let results = store.search(&vector, 3).unwrap();
assert_eq!(results.len(), 3);
let first_two_ids: std::collections::HashSet<_> =
results.iter().take(2).map(|r| r.id.as_str()).collect();
assert!(first_two_ids.contains("identical1"));
assert!(first_two_ids.contains("identical2"));
}
pub fn test_embedder_basic_functionality<T>()
where
T: Embedder + Default,
{
let embedder = T::default();
if !embedder.is_ready() {
println!("Embedder not ready, skipping test");
return;
}
let text = "This is a test sentence for embedding.";
let embedding = embedder.embed(text).unwrap();
assert!(!embedding.is_empty());
assert_eq!(embedding.len(), embedder.dimension());
let embedding2 = embedder.embed(text).unwrap();
assert_eq!(embedding, embedding2);
}
pub fn test_embedder_batch_consistency<T>()
where
T: Embedder + Default,
{
let embedder = T::default();
if !embedder.is_ready() {
println!("Embedder not ready, skipping test");
return;
}
let texts = [
"First test sentence.",
"Second test sentence.",
"Third test sentence.",
];
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_ref()).collect();
let batch_embeddings = embedder.embed_batch(&text_refs).unwrap();
assert_eq!(batch_embeddings.len(), texts.len());
let individual_embeddings: Result<Vec<_>> =
texts.iter().map(|text| embedder.embed(text)).collect();
let individual_embeddings = individual_embeddings.unwrap();
for (batch_emb, individual_emb) in batch_embeddings.iter().zip(individual_embeddings.iter()) {
assert_eq!(batch_emb, individual_emb);
}
}
pub fn test_entity_extractor_basic_extraction<T>()
where
T: EntityExtractor + Default,
T::Entity: std::fmt::Debug,
{
let extractor = T::default();
let text = "John Smith works at Microsoft Corporation in Seattle.";
let entities = extractor.extract(text).unwrap();
assert!(!entities.is_empty());
println!("Extracted {} entities: {:?}", entities.len(), entities);
}
pub fn test_entity_extractor_confidence<T>()
where
T: EntityExtractor + Default,
T::Entity: std::fmt::Debug,
{
let mut extractor = T::default();
let text = "John Smith works at Microsoft Corporation.";
extractor.set_confidence_threshold(0.9);
let high_confidence_entities = extractor.extract_with_confidence(text).unwrap();
extractor.set_confidence_threshold(0.1);
let low_confidence_entities = extractor.extract_with_confidence(text).unwrap();
assert!(low_confidence_entities.len() >= high_confidence_entities.len());
for (_, confidence) in &high_confidence_entities {
assert!(
*confidence >= 0.9,
"High confidence entity has confidence {confidence}"
);
}
}
pub fn test_language_model_basic_completion<T>()
where
T: LanguageModel + Default,
{
let model = T::default();
if !model.is_available() {
println!("Language model not available, skipping test");
return;
}
let prompt = "Complete this sentence: The capital of France is";
let completion = model.complete(prompt).unwrap();
assert!(!completion.is_empty());
println!("Completion: {completion}");
}
pub fn test_language_model_parameters<T>()
where
T: LanguageModel + Default,
{
let model = T::default();
if !model.is_available() {
println!("Language model not available, skipping test");
return;
}
let prompt = "Count to three:";
let params = GenerationParams {
max_tokens: Some(10),
temperature: Some(0.1),
top_p: Some(0.9),
stop_sequences: Some(vec!["4".to_string()]),
};
let completion = model.complete_with_params(prompt, params).unwrap();
assert!(!completion.is_empty());
assert!(completion.split_whitespace().count() <= 15); }
pub fn test_config_provider_lifecycle<T>()
where
T: ConfigProvider + Default,
T::Config: Clone + PartialEq + std::fmt::Debug,
{
let provider = T::default();
let default_config = provider.default_config();
provider.validate(&default_config).unwrap();
provider.save(&default_config).unwrap();
let loaded_config = provider.load().unwrap();
assert_eq!(default_config, loaded_config);
}
#[macro_export]
macro_rules! test_storage_implementation {
($storage_type:ty) => {
#[test]
fn test_storage_roundtrip() {
$crate::core::test_traits::test_storage_roundtrip::<$storage_type>();
}
#[test]
fn test_storage_nonexistent_key() {
$crate::core::test_traits::test_storage_nonexistent_key::<$storage_type>();
}
#[test]
fn test_storage_batch_operations() {
$crate::core::test_traits::test_storage_batch_operations::<$storage_type>();
}
#[test]
fn test_storage_id_consistency() {
$crate::core::test_traits::test_storage_id_consistency::<$storage_type>();
}
};
}
#[macro_export]
macro_rules! test_vector_store_implementation {
($vector_store_type:ty) => {
#[test]
fn test_vector_store_basic_operations() {
$crate::core::test_traits::test_vector_store_basic_operations::<$vector_store_type>();
}
#[test]
fn test_vector_store_similarity_properties() {
$crate::core::test_traits::test_vector_store_similarity_properties::<$vector_store_type>();
}
};
}
#[macro_export]
macro_rules! test_embedder_implementation {
($embedder_type:ty) => {
#[test]
fn test_embedder_basic_functionality() {
$crate::core::test_traits::test_embedder_basic_functionality::<$embedder_type>();
}
#[test]
fn test_embedder_batch_consistency() {
$crate::core::test_traits::test_embedder_batch_consistency::<$embedder_type>();
}
};
}
#[macro_export]
macro_rules! test_entity_extractor_implementation {
($extractor_type:ty) => {
#[test]
fn test_entity_extractor_basic_extraction() {
$crate::core::test_traits::test_entity_extractor_basic_extraction::<$extractor_type>();
}
#[test]
fn test_entity_extractor_confidence() {
$crate::core::test_traits::test_entity_extractor_confidence::<$extractor_type>();
}
};
}
#[cfg(test)]
mod tests {
#[test]
fn test_trait_testing_framework() {
println!("Trait testing framework initialized successfully");
}
}