pub mod extractor;
pub mod memory;
pub mod traits;
pub mod traversal;
pub mod types;
pub use extractor::{
MockEntityExtractor, MockRelationshipExtractor, PatternEntityExtractor,
PatternRelationshipExtractor,
};
pub use memory::InMemoryGraphStore;
pub use traits::{EntityExtractor, Graph, GraphStore, RelationshipExtractor};
pub use traversal::{bfs_traverse, find_entities_within_hops, find_shortest_path};
pub use types::{
Direction, EntityId, EntityType, GraphEntity, GraphPath, GraphQuery, GraphRelationship,
HybridSearchResult, RelationshipType,
};
use async_trait::async_trait;
use crate::error::GraphError;
use crate::types::Document;
pub struct GraphLayer<E: EntityExtractor, R: RelationshipExtractor, S: GraphStore> {
entity_extractor: E,
relationship_extractor: R,
store: S,
}
impl<E: EntityExtractor, R: RelationshipExtractor, S: GraphStore> GraphLayer<E, R, S> {
#[must_use]
pub fn new(entity_extractor: E, relationship_extractor: R, store: S) -> Self {
Self {
entity_extractor,
relationship_extractor,
store,
}
}
#[must_use]
pub fn entity_extractor(&self) -> &E {
&self.entity_extractor
}
#[must_use]
pub fn relationship_extractor(&self) -> &R {
&self.relationship_extractor
}
#[must_use]
pub fn store(&self) -> &S {
&self.store
}
pub fn store_mut(&mut self) -> &mut S {
&mut self.store
}
}
#[async_trait]
impl<E: EntityExtractor, R: RelationshipExtractor, S: GraphStore> Graph for GraphLayer<E, R, S> {
async fn index_document(&mut self, document: &Document) -> Result<(), GraphError> {
let mut entities = self
.entity_extractor
.extract_entities(&document.content)
.await?;
for entity in &mut entities {
entity.source_doc_id = Some(document.id.clone());
}
let entity_ids = self.store.add_entities(entities.clone()).await?;
let stored_entities: Vec<GraphEntity> = {
let mut result = Vec::with_capacity(entity_ids.len());
for id in &entity_ids {
if let Some(entity) = self.store.get_entity(id).await? {
result.push(entity);
}
}
result
};
let relationships = self
.relationship_extractor
.extract_relationships(&document.content, &stored_entities)
.await?;
self.store.add_relationships(relationships).await?;
Ok(())
}
async fn index_documents(&mut self, documents: &[Document]) -> Result<(), GraphError> {
for document in documents {
self.index_document(document).await?;
}
Ok(())
}
async fn query(&self, query: &GraphQuery) -> Result<Vec<GraphPath>, GraphError> {
self.store.traverse(query).await
}
async fn find_related(
&self,
entity_name: &str,
max_hops: usize,
) -> Result<Vec<GraphPath>, GraphError> {
let entities = self.store.find_entities_by_name(entity_name).await?;
if entities.is_empty() {
return Ok(Vec::new());
}
let start_ids: Vec<EntityId> = entities.into_iter().map(|e| e.id).collect();
let query = GraphQuery::new(start_ids).with_max_hops(max_hops);
self.store.traverse(&query).await
}
async fn entity_count(&self) -> usize {
self.store.entity_count().await
}
async fn relationship_count(&self) -> usize {
self.store.relationship_count().await
}
async fn clear(&mut self) -> Result<(), GraphError> {
self.store.clear().await
}
}
#[derive(Default)]
pub struct GraphLayerBuilder<E, R, S> {
entity_extractor: Option<E>,
relationship_extractor: Option<R>,
store: Option<S>,
}
impl<E: EntityExtractor, R: RelationshipExtractor, S: GraphStore> GraphLayerBuilder<E, R, S> {
#[must_use]
pub fn new() -> Self {
Self {
entity_extractor: None,
relationship_extractor: None,
store: None,
}
}
#[must_use]
pub fn with_entity_extractor(mut self, extractor: E) -> Self {
self.entity_extractor = Some(extractor);
self
}
#[must_use]
pub fn with_relationship_extractor(mut self, extractor: R) -> Self {
self.relationship_extractor = Some(extractor);
self
}
#[must_use]
pub fn with_store(mut self, store: S) -> Self {
self.store = Some(store);
self
}
pub fn build(self) -> Result<GraphLayer<E, R, S>, GraphError> {
let entity_extractor = self.entity_extractor.ok_or_else(|| {
GraphError::ConfigurationError("Entity extractor not configured".to_string())
})?;
let relationship_extractor = self.relationship_extractor.ok_or_else(|| {
GraphError::ConfigurationError("Relationship extractor not configured".to_string())
})?;
let store = self.store.ok_or_else(|| {
GraphError::ConfigurationError("Graph store not configured".to_string())
})?;
Ok(GraphLayer::new(
entity_extractor,
relationship_extractor,
store,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_graph_layer()
-> GraphLayer<MockEntityExtractor, MockRelationshipExtractor, InMemoryGraphStore> {
GraphLayer::new(
MockEntityExtractor::new(),
MockRelationshipExtractor::new(),
InMemoryGraphStore::new(),
)
}
#[tokio::test]
async fn test_graph_layer_creation() {
let layer = create_test_graph_layer();
assert_eq!(layer.entity_count().await, 0);
assert_eq!(layer.relationship_count().await, 0);
}
#[tokio::test]
async fn test_graph_layer_with_mock_extractors() {
let entities = vec![
GraphEntity::new("Rust", EntityType::Technology).with_id("rust"),
GraphEntity::new("LLVM", EntityType::Technology).with_id("llvm"),
];
let relationships = vec![GraphRelationship::new(
"rust",
"llvm",
RelationshipType::Uses,
)];
let mut layer = GraphLayer::new(
MockEntityExtractor::with_entities(entities),
MockRelationshipExtractor::with_relationships(relationships),
InMemoryGraphStore::new(),
);
let doc = Document::new("Rust uses LLVM for code generation.");
layer.index_document(&doc).await.unwrap();
assert_eq!(layer.entity_count().await, 2);
assert_eq!(layer.relationship_count().await, 1);
}
#[tokio::test]
async fn test_graph_layer_query() {
let entities = vec![
GraphEntity::new("A", EntityType::Concept).with_id("a"),
GraphEntity::new("B", EntityType::Concept).with_id("b"),
];
let relationships = vec![GraphRelationship::new(
"a",
"b",
RelationshipType::RelatedTo,
)];
let mut layer = GraphLayer::new(
MockEntityExtractor::with_entities(entities),
MockRelationshipExtractor::with_relationships(relationships),
InMemoryGraphStore::new(),
);
let doc = Document::new("A relates to B.");
layer.index_document(&doc).await.unwrap();
let query = GraphQuery::new(vec!["a".to_string()]).with_max_hops(1);
let paths = layer.query(&query).await.unwrap();
assert!(!paths.is_empty());
}
#[tokio::test]
async fn test_graph_layer_find_related() {
let entities = vec![
GraphEntity::new("Rust", EntityType::Technology).with_id("rust"),
GraphEntity::new("Cargo", EntityType::Technology).with_id("cargo"),
];
let relationships = vec![GraphRelationship::new(
"rust",
"cargo",
RelationshipType::Uses,
)];
let mut layer = GraphLayer::new(
MockEntityExtractor::with_entities(entities),
MockRelationshipExtractor::with_relationships(relationships),
InMemoryGraphStore::new(),
);
let doc = Document::new("Rust uses Cargo.");
layer.index_document(&doc).await.unwrap();
let paths = layer.find_related("Rust", 2).await.unwrap();
assert!(!paths.is_empty());
}
#[tokio::test]
async fn test_graph_layer_clear() {
let mut layer = create_test_graph_layer();
layer
.store_mut()
.add_entity(GraphEntity::new("Test", EntityType::Concept))
.await
.unwrap();
assert_eq!(layer.entity_count().await, 1);
layer.clear().await.unwrap();
assert_eq!(layer.entity_count().await, 0);
}
#[tokio::test]
async fn test_graph_layer_builder() {
let layer = GraphLayerBuilder::new()
.with_entity_extractor(MockEntityExtractor::new())
.with_relationship_extractor(MockRelationshipExtractor::new())
.with_store(InMemoryGraphStore::new())
.build()
.unwrap();
assert_eq!(layer.entity_count().await, 0);
}
#[tokio::test]
async fn test_graph_layer_builder_missing_component() {
let result = GraphLayerBuilder::<
MockEntityExtractor,
MockRelationshipExtractor,
InMemoryGraphStore,
>::new()
.with_entity_extractor(MockEntityExtractor::new())
.build();
assert!(result.is_err());
}
#[tokio::test]
async fn test_index_multiple_documents() {
let mut extractor = MockEntityExtractor::new();
extractor.add_entity(GraphEntity::new("Entity1", EntityType::Concept).with_id("e1"));
extractor.add_entity(GraphEntity::new("Entity2", EntityType::Concept).with_id("e2"));
let mut layer = GraphLayer::new(
extractor,
MockRelationshipExtractor::new(),
InMemoryGraphStore::new(),
);
let docs = vec![
Document::new("First document"),
Document::new("Second document"),
];
layer.index_documents(&docs).await.unwrap();
assert_eq!(layer.entity_count().await, 2);
}
#[tokio::test]
async fn test_pattern_based_extraction() {
let mut layer = GraphLayer::new(
PatternEntityExtractor::new(),
PatternRelationshipExtractor::new(),
InMemoryGraphStore::new(),
);
let doc = Document::new("Rust is a programming language. Rust uses LLVM for compilation.");
layer.index_document(&doc).await.unwrap();
assert!(layer.entity_count().await > 0);
}
}