1use crate::VectorStore;
10use anyhow::{anyhow, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct GraphAwareConfig {
17 pub enable_graph_filtering: bool,
19 pub enable_hierarchical_search: bool,
21 pub enable_cross_graph_similarity: bool,
23 pub default_graph: Option<String>,
25 pub graph_hierarchy: GraphHierarchy,
27 pub cache_graph_metadata: bool,
29}
30
31impl Default for GraphAwareConfig {
32 fn default() -> Self {
33 Self {
34 enable_graph_filtering: true,
35 enable_hierarchical_search: false,
36 enable_cross_graph_similarity: false,
37 default_graph: None,
38 graph_hierarchy: GraphHierarchy::default(),
39 cache_graph_metadata: true,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, Default)]
46pub struct GraphHierarchy {
47 pub parent_child: HashMap<String, Vec<String>>,
49 pub graph_types: HashMap<String, String>,
51 pub graph_weights: HashMap<String, f32>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct GraphContext {
58 pub primary_graph: String,
60 pub additional_graphs: Vec<String>,
62 pub scope: GraphSearchScope,
64 pub context_weights: HashMap<String, f32>,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
70pub enum GraphSearchScope {
71 Exact,
73 IncludeChildren,
75 IncludeParents,
77 FullHierarchy,
79 Related,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct GraphAwareSearchResult {
86 pub resource: String,
88 pub score: f32,
90 pub source_graph: String,
92 pub context_score: f32,
94 pub final_score: f32,
96 pub metadata: HashMap<String, String>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ResourceGraphInfo {
103 pub resource: String,
105 pub graphs: HashSet<String>,
107 pub primary_graph: Option<String>,
109 pub last_updated: std::time::SystemTime,
111}
112
113pub struct GraphAwareSearch {
115 config: GraphAwareConfig,
116 resource_graph_map: HashMap<String, ResourceGraphInfo>,
118 graph_metadata: HashMap<String, GraphMetadata>,
120 graph_sizes: HashMap<String, usize>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct GraphMetadata {
127 pub graph_uri: String,
129 pub resource_count: usize,
131 pub avg_internal_similarity: f32,
133 pub last_modified: std::time::SystemTime,
135 pub graph_type: Option<String>,
137 pub quality_score: f32,
139}
140
141impl GraphAwareSearch {
142 pub fn new(config: GraphAwareConfig) -> Self {
143 Self {
144 config,
145 resource_graph_map: HashMap::new(),
146 graph_metadata: HashMap::new(),
147 graph_sizes: HashMap::new(),
148 }
149 }
150
151 pub fn register_resource_graph(&mut self, resource: String, graphs: Vec<String>) {
153 let graph_set: HashSet<String> = graphs.iter().cloned().collect();
154 let primary_graph = graphs.first().cloned();
155
156 let info = ResourceGraphInfo {
157 resource: resource.clone(),
158 graphs: graph_set,
159 primary_graph,
160 last_updated: std::time::SystemTime::now(),
161 };
162
163 self.resource_graph_map.insert(resource, info);
164
165 for graph in graphs {
167 *self.graph_sizes.entry(graph).or_insert(0) += 1;
168 }
169 }
170
171 pub fn search_in_graph(
173 &self,
174 vector_store: &VectorStore,
175 query_text: &str,
176 graph_context: &GraphContext,
177 limit: usize,
178 ) -> Result<Vec<GraphAwareSearchResult>> {
179 let target_graphs = self.resolve_search_graphs(graph_context)?;
181
182 let mut all_results = Vec::new();
184
185 for graph_uri in &target_graphs {
186 let graph_results =
187 self.search_single_graph(vector_store, query_text, graph_uri, limit * 2)?;
188 all_results.extend(graph_results);
189 }
190
191 let ranked_results = self.rank_results_by_graph_context(all_results, graph_context)?;
193
194 Ok(ranked_results.into_iter().take(limit).collect())
196 }
197
198 pub fn search_single_graph(
200 &self,
201 vector_store: &VectorStore,
202 query_text: &str,
203 graph_uri: &str,
204 limit: usize,
205 ) -> Result<Vec<GraphAwareSearchResult>> {
206 let vector_results = vector_store.similarity_search(query_text, limit * 3)?; let mut graph_results = Vec::new();
210
211 for (resource, score) in vector_results {
212 if let Some(resource_info) = self.resource_graph_map.get(&resource) {
214 if resource_info.graphs.contains(graph_uri) {
215 let context_score = self.calculate_context_score(&resource, graph_uri)?;
216 let final_score = self.combine_scores(score, context_score, graph_uri);
217
218 graph_results.push(GraphAwareSearchResult {
219 resource,
220 score,
221 source_graph: graph_uri.to_string(),
222 context_score,
223 final_score,
224 metadata: HashMap::new(),
225 });
226 }
227 }
228 }
229
230 graph_results.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
232
233 Ok(graph_results.into_iter().take(limit).collect())
234 }
235
236 fn resolve_search_graphs(&self, context: &GraphContext) -> Result<Vec<String>> {
238 let mut target_graphs = vec![context.primary_graph.clone()];
239
240 match context.scope {
241 GraphSearchScope::Exact => {
242 }
244 GraphSearchScope::IncludeChildren => {
245 if let Some(children) = self
247 .config
248 .graph_hierarchy
249 .parent_child
250 .get(&context.primary_graph)
251 {
252 target_graphs.extend(children.clone());
253 }
254 }
255 GraphSearchScope::IncludeParents => {
256 for (parent, children) in &self.config.graph_hierarchy.parent_child {
258 if children.contains(&context.primary_graph) {
259 target_graphs.push(parent.clone());
260 }
261 }
262 }
263 GraphSearchScope::FullHierarchy => {
264 target_graphs.extend(self.get_hierarchy_branch(&context.primary_graph));
266 }
267 GraphSearchScope::Related => {
268 target_graphs.extend(context.additional_graphs.clone());
270 }
271 }
272
273 target_graphs.extend(context.additional_graphs.clone());
275
276 target_graphs.sort();
278 target_graphs.dedup();
279 Ok(target_graphs)
280 }
281
282 fn get_hierarchy_branch(&self, graph_uri: &str) -> Vec<String> {
284 let mut branch_graphs = Vec::new();
285
286 self.add_children_recursive(graph_uri, &mut branch_graphs);
288
289 self.add_parents_recursive(graph_uri, &mut branch_graphs);
291
292 branch_graphs
293 }
294
295 fn add_children_recursive(&self, graph_uri: &str, result: &mut Vec<String>) {
297 if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph_uri) {
298 for child in children {
299 if !result.contains(child) {
300 result.push(child.clone());
301 self.add_children_recursive(child, result);
302 }
303 }
304 }
305 }
306
307 fn add_parents_recursive(&self, graph_uri: &str, result: &mut Vec<String>) {
309 for (parent, children) in &self.config.graph_hierarchy.parent_child {
310 if children.contains(&graph_uri.to_string()) && !result.contains(parent) {
311 result.push(parent.clone());
312 self.add_parents_recursive(parent, result);
313 }
314 }
315 }
316
317 fn calculate_context_score(&self, resource: &str, graph_uri: &str) -> Result<f32> {
319 let mut context_score = 1.0;
320
321 if let Some(&weight) = self.config.graph_hierarchy.graph_weights.get(graph_uri) {
323 context_score *= weight;
324 }
325
326 if let Some(metadata) = self.graph_metadata.get(graph_uri) {
328 context_score *= metadata.quality_score;
329 }
330
331 if let Some(resource_info) = self.resource_graph_map.get(resource) {
333 if resource_info.primary_graph.as_ref() == Some(&graph_uri.to_string()) {
334 context_score *= 1.2; }
336 }
337
338 Ok(context_score.min(1.0)) }
340
341 fn combine_scores(&self, similarity_score: f32, context_score: f32, graph_uri: &str) -> f32 {
343 let similarity_weight = 0.7;
345 let context_weight = 0.3;
346
347 let graph_boost = self
349 .config
350 .graph_hierarchy
351 .graph_weights
352 .get(graph_uri)
353 .unwrap_or(&1.0);
354
355 (similarity_score * similarity_weight + context_score * context_weight) * graph_boost
356 }
357
358 fn rank_results_by_graph_context(
360 &self,
361 mut results: Vec<GraphAwareSearchResult>,
362 context: &GraphContext,
363 ) -> Result<Vec<GraphAwareSearchResult>> {
364 for result in &mut results {
366 if let Some(&weight) = context.context_weights.get(&result.source_graph) {
367 result.final_score *= weight;
368 }
369
370 if result.source_graph == context.primary_graph {
372 result.final_score *= 1.1;
373 }
374 }
375
376 results.sort_by(|a, b| b.final_score.partial_cmp(&a.final_score).unwrap());
378
379 if self.config.enable_cross_graph_similarity {
381 results = self.apply_diversity_filtering(results);
382 }
383
384 Ok(results)
385 }
386
387 fn apply_diversity_filtering(
389 &self,
390 results: Vec<GraphAwareSearchResult>,
391 ) -> Vec<GraphAwareSearchResult> {
392 let mut filtered_results = Vec::new();
393 let mut graph_counts: HashMap<String, usize> = HashMap::new();
394 let max_per_graph = 3; for result in results {
397 let count = graph_counts.entry(result.source_graph.clone()).or_insert(0);
398 if *count < max_per_graph {
399 filtered_results.push(result);
400 *count += 1;
401 }
402 }
403
404 filtered_results
405 }
406
407 pub fn update_graph_metadata(&mut self, graph_uri: String, metadata: GraphMetadata) {
409 self.graph_metadata.insert(graph_uri, metadata);
410 }
411
412 pub fn get_graph_stats(&self, graph_uri: &str) -> Option<(usize, Option<&GraphMetadata>)> {
414 let size = self.graph_sizes.get(graph_uri).cloned();
415 let metadata = self.graph_metadata.get(graph_uri);
416 size.map(|s| (s, metadata))
417 }
418
419 pub fn clear_caches(&mut self) {
421 self.resource_graph_map.clear();
422 self.graph_metadata.clear();
423 self.graph_sizes.clear();
424 }
425
426 pub fn resource_in_graph(&self, resource: &str, graph_uri: &str) -> bool {
428 self.resource_graph_map
429 .get(resource)
430 .map(|info| info.graphs.contains(graph_uri))
431 .unwrap_or(false)
432 }
433
434 pub fn get_resource_graphs(&self, resource: &str) -> Option<&HashSet<String>> {
436 self.resource_graph_map
437 .get(resource)
438 .map(|info| &info.graphs)
439 }
440
441 pub fn cross_graph_similarity(
443 &self,
444 vector_store: &VectorStore,
445 resource1: &str,
446 graph1: &str,
447 resource2: &str,
448 graph2: &str,
449 ) -> Result<f32> {
450 if !self.config.enable_cross_graph_similarity {
451 return Err(anyhow!("Cross-graph similarity is disabled"));
452 }
453
454 if !self.resource_in_graph(resource1, graph1) || !self.resource_in_graph(resource2, graph2)
456 {
457 return Err(anyhow!("Resources not found in specified graphs"));
458 }
459
460 let base_similarity = vector_store.calculate_similarity(resource1, resource2)?;
462
463 let graph_relationship_factor = self.calculate_graph_relationship_factor(graph1, graph2);
465
466 Ok(base_similarity * graph_relationship_factor)
467 }
468
469 fn calculate_graph_relationship_factor(&self, graph1: &str, graph2: &str) -> f32 {
471 if graph1 == graph2 {
472 return 1.0; }
474
475 if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph1) {
477 if children.contains(&graph2.to_string()) {
478 return 0.9; }
480 }
481
482 if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph2) {
483 if children.contains(&graph1.to_string()) {
484 return 0.9; }
486 }
487
488 if let (Some(type1), Some(type2)) = (
490 self.config.graph_hierarchy.graph_types.get(graph1),
491 self.config.graph_hierarchy.graph_types.get(graph2),
492 ) {
493 if type1 == type2 {
494 return 0.8; }
496 }
497
498 0.7 }
500
501 pub fn set_graph_hierarchy(&mut self, parent_child: HashMap<String, Vec<String>>) {
503 self.config.graph_hierarchy.parent_child = parent_child;
504 }
505
506 pub fn set_graph_weights(&mut self, weights: HashMap<String, f32>) {
508 self.config.graph_hierarchy.graph_weights = weights;
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_graph_context_creation() {
518 let context = GraphContext {
519 primary_graph: "http://example.org/graph1".to_string(),
520 additional_graphs: vec!["http://example.org/graph2".to_string()],
521 scope: GraphSearchScope::IncludeChildren,
522 context_weights: HashMap::new(),
523 };
524
525 assert_eq!(context.primary_graph, "http://example.org/graph1");
526 assert_eq!(context.scope, GraphSearchScope::IncludeChildren);
527 }
528
529 #[test]
530 fn test_resource_graph_registration() {
531 let mut search = GraphAwareSearch::new(GraphAwareConfig::default());
532
533 search.register_resource_graph(
534 "http://example.org/resource1".to_string(),
535 vec!["http://example.org/graph1".to_string()],
536 );
537
538 assert!(
539 search.resource_in_graph("http://example.org/resource1", "http://example.org/graph1")
540 );
541 assert!(
542 !search.resource_in_graph("http://example.org/resource1", "http://example.org/graph2")
543 );
544 }
545
546 #[test]
547 fn test_graph_hierarchy() {
548 let mut config = GraphAwareConfig::default();
549 config.graph_hierarchy.parent_child.insert(
550 "http://example.org/parent".to_string(),
551 vec![
552 "http://example.org/child1".to_string(),
553 "http://example.org/child2".to_string(),
554 ],
555 );
556
557 let search = GraphAwareSearch::new(config);
558 let branch = search.get_hierarchy_branch("http://example.org/parent");
559
560 assert!(branch.contains(&"http://example.org/child1".to_string()));
561 assert!(branch.contains(&"http://example.org/child2".to_string()));
562 }
563
564 #[test]
565 fn test_graph_search_scope() {
566 let context = GraphContext {
567 primary_graph: "http://example.org/main".to_string(),
568 additional_graphs: vec![],
569 scope: GraphSearchScope::Exact,
570 context_weights: HashMap::new(),
571 };
572
573 let search = GraphAwareSearch::new(GraphAwareConfig::default());
574 let graphs = search.resolve_search_graphs(&context).unwrap();
575
576 assert_eq!(graphs.len(), 1);
577 assert_eq!(graphs[0], "http://example.org/main");
578 }
579}