use crate::VectorStore;
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphAwareConfig {
pub enable_graph_filtering: bool,
pub enable_hierarchical_search: bool,
pub enable_cross_graph_similarity: bool,
pub default_graph: Option<String>,
pub graph_hierarchy: GraphHierarchy,
pub cache_graph_metadata: bool,
}
impl Default for GraphAwareConfig {
fn default() -> Self {
Self {
enable_graph_filtering: true,
enable_hierarchical_search: false,
enable_cross_graph_similarity: false,
default_graph: None,
graph_hierarchy: GraphHierarchy::default(),
cache_graph_metadata: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GraphHierarchy {
pub parent_child: HashMap<String, Vec<String>>,
pub graph_types: HashMap<String, String>,
pub graph_weights: HashMap<String, f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphContext {
pub primary_graph: String,
pub additional_graphs: Vec<String>,
pub scope: GraphSearchScope,
pub context_weights: HashMap<String, f32>,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum GraphSearchScope {
Exact,
IncludeChildren,
IncludeParents,
FullHierarchy,
Related,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphAwareSearchResult {
pub resource: String,
pub score: f32,
pub source_graph: String,
pub context_score: f32,
pub final_score: f32,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResourceGraphInfo {
pub resource: String,
pub graphs: HashSet<String>,
pub primary_graph: Option<String>,
pub last_updated: std::time::SystemTime,
}
pub struct GraphAwareSearch {
config: GraphAwareConfig,
resource_graph_map: HashMap<String, ResourceGraphInfo>,
graph_metadata: HashMap<String, GraphMetadata>,
graph_sizes: HashMap<String, usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GraphMetadata {
pub graph_uri: String,
pub resource_count: usize,
pub avg_internal_similarity: f32,
pub last_modified: std::time::SystemTime,
pub graph_type: Option<String>,
pub quality_score: f32,
}
impl GraphAwareSearch {
pub fn new(config: GraphAwareConfig) -> Self {
Self {
config,
resource_graph_map: HashMap::new(),
graph_metadata: HashMap::new(),
graph_sizes: HashMap::new(),
}
}
pub fn register_resource_graph(&mut self, resource: String, graphs: Vec<String>) {
let graph_set: HashSet<String> = graphs.iter().cloned().collect();
let primary_graph = graphs.first().cloned();
let info = ResourceGraphInfo {
resource: resource.clone(),
graphs: graph_set,
primary_graph,
last_updated: std::time::SystemTime::now(),
};
self.resource_graph_map.insert(resource, info);
for graph in graphs {
*self.graph_sizes.entry(graph).or_insert(0) += 1;
}
}
pub fn search_in_graph(
&self,
vector_store: &VectorStore,
query_text: &str,
graph_context: &GraphContext,
limit: usize,
) -> Result<Vec<GraphAwareSearchResult>> {
let target_graphs = self.resolve_search_graphs(graph_context)?;
let mut all_results = Vec::new();
for graph_uri in &target_graphs {
let graph_results =
self.search_single_graph(vector_store, query_text, graph_uri, limit * 2)?;
all_results.extend(graph_results);
}
let ranked_results = self.rank_results_by_graph_context(all_results, graph_context)?;
Ok(ranked_results.into_iter().take(limit).collect())
}
pub fn search_single_graph(
&self,
vector_store: &VectorStore,
query_text: &str,
graph_uri: &str,
limit: usize,
) -> Result<Vec<GraphAwareSearchResult>> {
let vector_results = vector_store.similarity_search(query_text, limit * 3)?;
let mut graph_results = Vec::new();
for (resource, score) in vector_results {
if let Some(resource_info) = self.resource_graph_map.get(&resource) {
if resource_info.graphs.contains(graph_uri) {
let context_score = self.calculate_context_score(&resource, graph_uri)?;
let final_score = self.combine_scores(score, context_score, graph_uri);
graph_results.push(GraphAwareSearchResult {
resource,
score,
source_graph: graph_uri.to_string(),
context_score,
final_score,
metadata: HashMap::new(),
});
}
}
}
graph_results.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(graph_results.into_iter().take(limit).collect())
}
fn resolve_search_graphs(&self, context: &GraphContext) -> Result<Vec<String>> {
let mut target_graphs = vec![context.primary_graph.clone()];
match context.scope {
GraphSearchScope::Exact => {
}
GraphSearchScope::IncludeChildren => {
if let Some(children) = self
.config
.graph_hierarchy
.parent_child
.get(&context.primary_graph)
{
target_graphs.extend(children.clone());
}
}
GraphSearchScope::IncludeParents => {
for (parent, children) in &self.config.graph_hierarchy.parent_child {
if children.contains(&context.primary_graph) {
target_graphs.push(parent.clone());
}
}
}
GraphSearchScope::FullHierarchy => {
target_graphs.extend(self.get_hierarchy_branch(&context.primary_graph));
}
GraphSearchScope::Related => {
target_graphs.extend(context.additional_graphs.clone());
}
}
target_graphs.extend(context.additional_graphs.clone());
target_graphs.sort();
target_graphs.dedup();
Ok(target_graphs)
}
fn get_hierarchy_branch(&self, graph_uri: &str) -> Vec<String> {
let mut branch_graphs = Vec::new();
self.add_children_recursive(graph_uri, &mut branch_graphs);
self.add_parents_recursive(graph_uri, &mut branch_graphs);
branch_graphs
}
fn add_children_recursive(&self, graph_uri: &str, result: &mut Vec<String>) {
if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph_uri) {
for child in children {
if !result.contains(child) {
result.push(child.clone());
self.add_children_recursive(child, result);
}
}
}
}
fn add_parents_recursive(&self, graph_uri: &str, result: &mut Vec<String>) {
for (parent, children) in &self.config.graph_hierarchy.parent_child {
if children.contains(&graph_uri.to_string()) && !result.contains(parent) {
result.push(parent.clone());
self.add_parents_recursive(parent, result);
}
}
}
fn calculate_context_score(&self, resource: &str, graph_uri: &str) -> Result<f32> {
let mut context_score = 1.0;
if let Some(&weight) = self.config.graph_hierarchy.graph_weights.get(graph_uri) {
context_score *= weight;
}
if let Some(metadata) = self.graph_metadata.get(graph_uri) {
context_score *= metadata.quality_score;
}
if let Some(resource_info) = self.resource_graph_map.get(resource) {
if resource_info.primary_graph.as_ref() == Some(&graph_uri.to_string()) {
context_score *= 1.2; }
}
Ok(context_score.min(1.0)) }
fn combine_scores(&self, similarity_score: f32, context_score: f32, graph_uri: &str) -> f32 {
let similarity_weight = 0.7;
let context_weight = 0.3;
let graph_boost = self
.config
.graph_hierarchy
.graph_weights
.get(graph_uri)
.unwrap_or(&1.0);
(similarity_score * similarity_weight + context_score * context_weight) * graph_boost
}
fn rank_results_by_graph_context(
&self,
mut results: Vec<GraphAwareSearchResult>,
context: &GraphContext,
) -> Result<Vec<GraphAwareSearchResult>> {
for result in &mut results {
if let Some(&weight) = context.context_weights.get(&result.source_graph) {
result.final_score *= weight;
}
if result.source_graph == context.primary_graph {
result.final_score *= 1.1;
}
}
results.sort_by(|a, b| {
b.final_score
.partial_cmp(&a.final_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if self.config.enable_cross_graph_similarity {
results = self.apply_diversity_filtering(results);
}
Ok(results)
}
fn apply_diversity_filtering(
&self,
results: Vec<GraphAwareSearchResult>,
) -> Vec<GraphAwareSearchResult> {
let mut filtered_results = Vec::new();
let mut graph_counts: HashMap<String, usize> = HashMap::new();
let max_per_graph = 3;
for result in results {
let count = graph_counts.entry(result.source_graph.clone()).or_insert(0);
if *count < max_per_graph {
filtered_results.push(result);
*count += 1;
}
}
filtered_results
}
pub fn update_graph_metadata(&mut self, graph_uri: String, metadata: GraphMetadata) {
self.graph_metadata.insert(graph_uri, metadata);
}
pub fn get_graph_stats(&self, graph_uri: &str) -> Option<(usize, Option<&GraphMetadata>)> {
let size = self.graph_sizes.get(graph_uri).cloned();
let metadata = self.graph_metadata.get(graph_uri);
size.map(|s| (s, metadata))
}
pub fn clear_caches(&mut self) {
self.resource_graph_map.clear();
self.graph_metadata.clear();
self.graph_sizes.clear();
}
pub fn resource_in_graph(&self, resource: &str, graph_uri: &str) -> bool {
self.resource_graph_map
.get(resource)
.map(|info| info.graphs.contains(graph_uri))
.unwrap_or(false)
}
pub fn get_resource_graphs(&self, resource: &str) -> Option<&HashSet<String>> {
self.resource_graph_map
.get(resource)
.map(|info| &info.graphs)
}
pub fn cross_graph_similarity(
&self,
vector_store: &VectorStore,
resource1: &str,
graph1: &str,
resource2: &str,
graph2: &str,
) -> Result<f32> {
if !self.config.enable_cross_graph_similarity {
return Err(anyhow!("Cross-graph similarity is disabled"));
}
if !self.resource_in_graph(resource1, graph1) || !self.resource_in_graph(resource2, graph2)
{
return Err(anyhow!("Resources not found in specified graphs"));
}
let base_similarity = vector_store.calculate_similarity(resource1, resource2)?;
let graph_relationship_factor = self.calculate_graph_relationship_factor(graph1, graph2);
Ok(base_similarity * graph_relationship_factor)
}
fn calculate_graph_relationship_factor(&self, graph1: &str, graph2: &str) -> f32 {
if graph1 == graph2 {
return 1.0; }
if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph1) {
if children.contains(&graph2.to_string()) {
return 0.9; }
}
if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph2) {
if children.contains(&graph1.to_string()) {
return 0.9; }
}
if let (Some(type1), Some(type2)) = (
self.config.graph_hierarchy.graph_types.get(graph1),
self.config.graph_hierarchy.graph_types.get(graph2),
) {
if type1 == type2 {
return 0.8; }
}
0.7 }
pub fn set_graph_hierarchy(&mut self, parent_child: HashMap<String, Vec<String>>) {
self.config.graph_hierarchy.parent_child = parent_child;
}
pub fn set_graph_weights(&mut self, weights: HashMap<String, f32>) {
self.config.graph_hierarchy.graph_weights = weights;
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[test]
fn test_graph_context_creation() {
let context = GraphContext {
primary_graph: "http://example.org/graph1".to_string(),
additional_graphs: vec!["http://example.org/graph2".to_string()],
scope: GraphSearchScope::IncludeChildren,
context_weights: HashMap::new(),
};
assert_eq!(context.primary_graph, "http://example.org/graph1");
assert_eq!(context.scope, GraphSearchScope::IncludeChildren);
}
#[test]
fn test_resource_graph_registration() {
let mut search = GraphAwareSearch::new(GraphAwareConfig::default());
search.register_resource_graph(
"http://example.org/resource1".to_string(),
vec!["http://example.org/graph1".to_string()],
);
assert!(
search.resource_in_graph("http://example.org/resource1", "http://example.org/graph1")
);
assert!(
!search.resource_in_graph("http://example.org/resource1", "http://example.org/graph2")
);
}
#[test]
fn test_graph_hierarchy() {
let mut config = GraphAwareConfig::default();
config.graph_hierarchy.parent_child.insert(
"http://example.org/parent".to_string(),
vec![
"http://example.org/child1".to_string(),
"http://example.org/child2".to_string(),
],
);
let search = GraphAwareSearch::new(config);
let branch = search.get_hierarchy_branch("http://example.org/parent");
assert!(branch.contains(&"http://example.org/child1".to_string()));
assert!(branch.contains(&"http://example.org/child2".to_string()));
}
#[test]
fn test_graph_search_scope() -> Result<()> {
let context = GraphContext {
primary_graph: "http://example.org/main".to_string(),
additional_graphs: vec![],
scope: GraphSearchScope::Exact,
context_weights: HashMap::new(),
};
let search = GraphAwareSearch::new(GraphAwareConfig::default());
let graphs = search.resolve_search_graphs(&context)?;
assert_eq!(graphs.len(), 1);
assert_eq!(graphs[0], "http://example.org/main");
Ok(())
}
}