use super::*;
pub struct GraphTraversal {
store: Arc<dyn Store>,
}
impl GraphTraversal {
pub fn new(store: Arc<dyn Store>) -> Self {
Self { store }
}
pub async fn perform_graph_search(
&self,
_query: &str,
entities: &[ExtractedEntity],
max_depth: usize,
) -> Result<Vec<RagSearchResult>> {
let mut results = Vec::new();
let mut visited_entities = HashSet::new();
for entity in entities {
if let Some(ref iri) = entity.iri {
if visited_entities.contains(iri) {
continue;
}
visited_entities.insert(iri.clone());
let entity_triples = self.find_entity_triples(iri, max_depth).await?;
for triple in entity_triples {
results.push(RagSearchResult {
triple,
score: entity.confidence * 0.8, search_type: SearchType::GraphTraversal,
});
}
let expanded_results = self.expand_entity_context(iri, entity.confidence).await?;
results.extend(expanded_results);
}
}
self.deduplicate_and_rank_graph_results(results)
}
async fn find_entity_triples(&self, entity_iri: &str, max_depth: usize) -> Result<Vec<Triple>> {
let mut triples = Vec::new();
let mut visited = HashSet::new();
let mut current_entities = vec![entity_iri.to_string()];
for _depth in 0..max_depth {
let next_entities = Vec::new();
for entity in ¤t_entities {
if visited.contains(entity) {
continue;
}
visited.insert(entity.clone());
if let Ok(subject_triples) = self.find_triples_with_subject(entity).await {
triples.extend(subject_triples);
}
if let Ok(object_triples) = self.find_triples_with_object(entity).await {
triples.extend(object_triples);
}
}
current_entities = next_entities;
if current_entities.is_empty() {
break;
}
}
Ok(triples)
}
async fn expand_entity_context(
&self,
entity_iri: &str,
base_confidence: f32,
) -> Result<Vec<RagSearchResult>> {
let mut expanded_results = Vec::new();
let type_triples = self.find_entity_types(entity_iri).await?;
for triple in type_triples {
expanded_results.push(RagSearchResult {
triple,
score: base_confidence * 0.9, search_type: SearchType::GraphTraversal,
});
}
let same_type_entities = self.find_same_type_entities(entity_iri, 5).await?;
for triple in same_type_entities {
expanded_results.push(RagSearchResult {
triple,
score: base_confidence * 0.6, search_type: SearchType::GraphTraversal,
});
}
let property_context = self.find_property_context(entity_iri).await?;
for triple in property_context {
expanded_results.push(RagSearchResult {
triple,
score: base_confidence * 0.7,
search_type: SearchType::GraphTraversal,
});
}
Ok(expanded_results)
}
async fn find_entity_types(&self, entity_iri: &str) -> Result<Vec<Triple>> {
let type_predicate = "http://www.w3.org/1999/02/22-rdf-syntax-ns#type";
let mut type_triples = Vec::new();
if let Ok(subject_triples) = self.find_triples_with_subject(entity_iri).await {
for triple in subject_triples {
if triple.predicate().to_string().contains(type_predicate) {
type_triples.push(triple);
}
}
}
Ok(type_triples)
}
async fn find_same_type_entities(&self, entity_iri: &str, limit: usize) -> Result<Vec<Triple>> {
let mut same_type_triples = Vec::new();
let entity_types = self.find_entity_types(entity_iri).await?;
if entity_types.is_empty() {
return Ok(same_type_triples);
}
for type_triple in entity_types.iter().take(2) {
if let Ok(entities_of_type) = self
.find_entities_of_type(&type_triple.object().to_string(), limit)
.await
{
same_type_triples.extend(entities_of_type);
}
}
Ok(same_type_triples)
}
async fn find_entities_of_type(&self, type_iri: &str, limit: usize) -> Result<Vec<Triple>> {
let _ = (type_iri, limit);
Ok(Vec::new())
}
async fn find_property_context(&self, entity_iri: &str) -> Result<Vec<Triple>> {
let _ = entity_iri;
Ok(Vec::new())
}
async fn find_triples_with_subject(&self, entity_iri: &str) -> Result<Vec<Triple>> {
let _ = entity_iri;
Ok(Vec::new())
}
async fn find_triples_with_object(&self, entity_iri: &str) -> Result<Vec<Triple>> {
let _ = entity_iri;
Ok(Vec::new())
}
fn deduplicate_and_rank_graph_results(
&self,
mut results: Vec<RagSearchResult>,
) -> Result<Vec<RagSearchResult>> {
let mut seen = HashSet::new();
results.retain(|result| {
let key = format!(
"{}:{}:{}",
result.triple.subject(),
result.triple.predicate(),
result.triple.object()
);
seen.insert(key)
});
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
}
#[derive(Debug, Clone)]
pub struct ExtractedEntity {
pub text: String,
pub entity_type: EntityType,
pub iri: Option<String>,
pub confidence: f32,
pub aliases: Vec<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum EntityType {
Person,
Organization,
Location,
Concept,
Event,
Other,
}
#[derive(Debug, Clone)]
pub struct ExtractedRelationship {
pub subject: String,
pub predicate: String,
pub object: String,
pub confidence: f32,
pub relation_type: RelationType,
}
#[derive(Debug, Clone, PartialEq)]
pub enum RelationType {
CausalRelation,
TemporalRelation,
SpatialRelation,
ConceptualRelation,
Other,
}