use oxirs_chat::rag::{RAGConfig, RAGSystem, SimpleEmbeddingModel};
use oxirs_core::{ConcreteStore, Literal, NamedNode, Triple};
use oxirs_vec::{
index::AdvancedVectorIndex,
index::{DistanceMetric, IndexConfig, IndexType},
};
use std::collections::HashMap;
use std::sync::Arc;
#[tokio::test]
#[ignore] async fn test_vector_index_creation() {
let dimension = 128;
let config = IndexConfig {
index_type: IndexType::Hnsw,
max_connections: 16,
ef_construction: 200,
ef_search: 100,
distance_metric: DistanceMetric::Cosine,
..Default::default()
};
let mut vector_index = AdvancedVectorIndex::new(config);
assert_eq!(vector_index.len(), 0);
assert!(vector_index.is_empty());
let subject = NamedNode::new("http://example.org/person/alice").unwrap();
let predicate = NamedNode::new("http://xmlns.com/foaf/0.1/name").unwrap();
let object = Literal::new_simple_literal("Alice Smith");
let triple = Triple::new(subject, predicate, object);
let vector = vec![0.1; dimension];
let metadata = HashMap::new();
let result = vector_index.add("test_id".to_string(), vector.clone(), triple, metadata);
assert!(result.is_ok());
assert_eq!(vector_index.len(), 1);
assert!(!vector_index.is_empty());
let search_results = vector_index.search(&vector, 1).unwrap();
assert_eq!(search_results.len(), 1);
assert!(search_results[0].score > 0.9);
println!("✅ Vector index creation and search test passed!");
}
#[tokio::test]
async fn test_embedding_model() {
let dimension = 64;
let embedding_model = SimpleEmbeddingModel::new(dimension);
let texts = vec![
"Alice works for ACME Corporation".to_string(),
"Bob is a software engineer".to_string(),
"The quick brown fox jumps over the lazy dog".to_string(),
];
let mut embeddings = Vec::new();
for text in &texts {
let embedding = embedding_model.embed(text).unwrap();
embeddings.push(embedding);
}
assert_eq!(embeddings.len(), 3);
for embedding in &embeddings {
assert_eq!(embedding.len(), dimension);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 0.1,
"Embedding should be approximately normalized, got norm: {norm}"
);
}
assert_ne!(embeddings[0], embeddings[1]);
assert_ne!(embeddings[1], embeddings[2]);
println!("✅ Embedding model test passed!");
}
#[tokio::test]
#[ignore] async fn test_vector_search_similarity() {
let dimension = 128;
let config = IndexConfig {
index_type: IndexType::Hnsw,
max_connections: 16,
ef_construction: 200,
ef_search: 100,
distance_metric: DistanceMetric::Cosine,
..Default::default()
};
let mut vector_index = AdvancedVectorIndex::new(config);
let triples_and_texts = [
(
"Alice works for ACME",
create_test_triple("alice", "worksFor", "acme"),
),
(
"Bob is a software engineer",
create_test_triple("bob", "hasJob", "engineer"),
),
(
"ACME is a technology company",
create_test_triple("acme", "type", "company"),
),
(
"Alice develops software",
create_test_triple("alice", "activity", "development"),
),
];
let embedding_model = SimpleEmbeddingModel::new(dimension);
for (i, (text, triple)) in triples_and_texts.iter().enumerate() {
let embedding = embedding_model.embed(text).unwrap();
let metadata = HashMap::new();
vector_index
.add(
format!("id_{i}"),
embedding.clone(),
triple.clone(),
metadata,
)
.unwrap();
}
let query = "employment job work";
let query_embedding = embedding_model.embed(query).unwrap();
let results = vector_index.search(&query_embedding, 3).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 3);
for i in 1..results.len() {
assert!(results[i - 1].score >= results[i].score);
}
println!("✅ Vector search similarity test passed!");
}
#[tokio::test]
async fn test_rag_system_with_vector_index() {
let store = ConcreteStore::new().expect("Failed to create store");
let test_triples = vec![
create_test_triple("alice", "worksFor", "acme"),
create_test_triple("bob", "worksFor", "techcorp"),
create_test_triple("alice", "hasSkill", "programming"),
create_test_triple("acme", "type", "company"),
create_test_triple("techcorp", "type", "company"),
];
for triple in &test_triples {
let quad = oxirs_core::Quad::new(
triple.subject().clone(),
triple.predicate().clone(),
triple.object().clone(),
oxirs_core::GraphName::DefaultGraph,
);
ConcreteStore::insert_quad(&store, quad).expect("Failed to insert quad");
}
let store_arc = Arc::new(store);
let config = RAGConfig::default();
let rag_system = RAGSystem::with_vector_index(config, store_arc, 64).await;
match rag_system {
Ok(mut system) => {
println!("✅ RAG system with vector index created successfully!");
let query = "Who works for companies?";
match system.retrieve(query).await {
Ok(knowledge) => {
let triple_count = knowledge
.retrieved_triples
.as_ref()
.map(|t| t.len())
.unwrap_or(0);
println!("✅ Knowledge retrieval succeeded, found {triple_count} triples");
}
Err(e) => {
println!("⚠️ Knowledge retrieval failed: {e}");
}
}
}
Err(e) => {
println!("⚠️ RAG system creation failed: {e}");
}
}
}
#[tokio::test]
#[ignore] async fn test_cosine_similarity() {
let _dimension = 4;
let config = IndexConfig {
index_type: IndexType::Hnsw,
max_connections: 16,
ef_construction: 200,
ef_search: 100,
distance_metric: DistanceMetric::Cosine,
..Default::default()
};
let mut index = AdvancedVectorIndex::new(config);
let vec1 = vec![1.0, 0.0, 0.0, 0.0];
let vec2 = vec![0.0, 1.0, 0.0, 0.0];
let vec3 = vec![1.0, 0.0, 0.0, 0.0];
let triple = create_test_triple("test", "test", "test");
let metadata = HashMap::new();
index
.add(
"id1".to_string(),
vec1.clone(),
triple.clone(),
metadata.clone(),
)
.unwrap();
index
.add(
"id2".to_string(),
vec2.clone(),
triple.clone(),
metadata.clone(),
)
.unwrap();
index
.add("id3".to_string(), vec3.clone(), triple.clone(), metadata)
.unwrap();
let results = index.search(&vec1, 3).unwrap();
assert_eq!(results.len(), 3);
assert!(results[0].score > 0.9);
let vec2_result = results.iter().find(|r| {
r.score < 0.1
});
assert!(vec2_result.is_some());
println!("✅ Cosine similarity test passed!");
}
fn create_test_triple(subject: &str, predicate: &str, object: &str) -> Triple {
let s = NamedNode::new(format!("http://example.org/{subject}")).unwrap();
let p = NamedNode::new(format!("http://example.org/{predicate}")).unwrap();
let o = if object.chars().all(|c| c.is_alphabetic()) {
oxirs_core::model::Object::from(
NamedNode::new(format!("http://example.org/{object}")).unwrap(),
)
} else {
oxirs_core::model::Object::from(Literal::new_simple_literal(object))
};
Triple::new(s, p, o)
}