1use crate::VectorStore;
10use anyhow::{anyhow, Result};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct GraphAwareConfig {
17 pub enable_graph_filtering: bool,
19 pub enable_hierarchical_search: bool,
21 pub enable_cross_graph_similarity: bool,
23 pub default_graph: Option<String>,
25 pub graph_hierarchy: GraphHierarchy,
27 pub cache_graph_metadata: bool,
29}
30
31impl Default for GraphAwareConfig {
32 fn default() -> Self {
33 Self {
34 enable_graph_filtering: true,
35 enable_hierarchical_search: false,
36 enable_cross_graph_similarity: false,
37 default_graph: None,
38 graph_hierarchy: GraphHierarchy::default(),
39 cache_graph_metadata: true,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, Default)]
46pub struct GraphHierarchy {
47 pub parent_child: HashMap<String, Vec<String>>,
49 pub graph_types: HashMap<String, String>,
51 pub graph_weights: HashMap<String, f32>,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct GraphContext {
58 pub primary_graph: String,
60 pub additional_graphs: Vec<String>,
62 pub scope: GraphSearchScope,
64 pub context_weights: HashMap<String, f32>,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
70pub enum GraphSearchScope {
71 Exact,
73 IncludeChildren,
75 IncludeParents,
77 FullHierarchy,
79 Related,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct GraphAwareSearchResult {
86 pub resource: String,
88 pub score: f32,
90 pub source_graph: String,
92 pub context_score: f32,
94 pub final_score: f32,
96 pub metadata: HashMap<String, String>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ResourceGraphInfo {
103 pub resource: String,
105 pub graphs: HashSet<String>,
107 pub primary_graph: Option<String>,
109 pub last_updated: std::time::SystemTime,
111}
112
113pub struct GraphAwareSearch {
115 config: GraphAwareConfig,
116 resource_graph_map: HashMap<String, ResourceGraphInfo>,
118 graph_metadata: HashMap<String, GraphMetadata>,
120 graph_sizes: HashMap<String, usize>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct GraphMetadata {
127 pub graph_uri: String,
129 pub resource_count: usize,
131 pub avg_internal_similarity: f32,
133 pub last_modified: std::time::SystemTime,
135 pub graph_type: Option<String>,
137 pub quality_score: f32,
139}
140
141impl GraphAwareSearch {
142 pub fn new(config: GraphAwareConfig) -> Self {
143 Self {
144 config,
145 resource_graph_map: HashMap::new(),
146 graph_metadata: HashMap::new(),
147 graph_sizes: HashMap::new(),
148 }
149 }
150
151 pub fn register_resource_graph(&mut self, resource: String, graphs: Vec<String>) {
153 let graph_set: HashSet<String> = graphs.iter().cloned().collect();
154 let primary_graph = graphs.first().cloned();
155
156 let info = ResourceGraphInfo {
157 resource: resource.clone(),
158 graphs: graph_set,
159 primary_graph,
160 last_updated: std::time::SystemTime::now(),
161 };
162
163 self.resource_graph_map.insert(resource, info);
164
165 for graph in graphs {
167 *self.graph_sizes.entry(graph).or_insert(0) += 1;
168 }
169 }
170
171 pub fn search_in_graph(
173 &self,
174 vector_store: &VectorStore,
175 query_text: &str,
176 graph_context: &GraphContext,
177 limit: usize,
178 ) -> Result<Vec<GraphAwareSearchResult>> {
179 let target_graphs = self.resolve_search_graphs(graph_context)?;
181
182 let mut all_results = Vec::new();
184
185 for graph_uri in &target_graphs {
186 let graph_results =
187 self.search_single_graph(vector_store, query_text, graph_uri, limit * 2)?;
188 all_results.extend(graph_results);
189 }
190
191 let ranked_results = self.rank_results_by_graph_context(all_results, graph_context)?;
193
194 Ok(ranked_results.into_iter().take(limit).collect())
196 }
197
198 pub fn search_single_graph(
200 &self,
201 vector_store: &VectorStore,
202 query_text: &str,
203 graph_uri: &str,
204 limit: usize,
205 ) -> Result<Vec<GraphAwareSearchResult>> {
206 let vector_results = vector_store.similarity_search(query_text, limit * 3)?; let mut graph_results = Vec::new();
210
211 for (resource, score) in vector_results {
212 if let Some(resource_info) = self.resource_graph_map.get(&resource) {
214 if resource_info.graphs.contains(graph_uri) {
215 let context_score = self.calculate_context_score(&resource, graph_uri)?;
216 let final_score = self.combine_scores(score, context_score, graph_uri);
217
218 graph_results.push(GraphAwareSearchResult {
219 resource,
220 score,
221 source_graph: graph_uri.to_string(),
222 context_score,
223 final_score,
224 metadata: HashMap::new(),
225 });
226 }
227 }
228 }
229
230 graph_results.sort_by(|a, b| {
232 b.final_score
233 .partial_cmp(&a.final_score)
234 .unwrap_or(std::cmp::Ordering::Equal)
235 });
236
237 Ok(graph_results.into_iter().take(limit).collect())
238 }
239
240 fn resolve_search_graphs(&self, context: &GraphContext) -> Result<Vec<String>> {
242 let mut target_graphs = vec![context.primary_graph.clone()];
243
244 match context.scope {
245 GraphSearchScope::Exact => {
246 }
248 GraphSearchScope::IncludeChildren => {
249 if let Some(children) = self
251 .config
252 .graph_hierarchy
253 .parent_child
254 .get(&context.primary_graph)
255 {
256 target_graphs.extend(children.clone());
257 }
258 }
259 GraphSearchScope::IncludeParents => {
260 for (parent, children) in &self.config.graph_hierarchy.parent_child {
262 if children.contains(&context.primary_graph) {
263 target_graphs.push(parent.clone());
264 }
265 }
266 }
267 GraphSearchScope::FullHierarchy => {
268 target_graphs.extend(self.get_hierarchy_branch(&context.primary_graph));
270 }
271 GraphSearchScope::Related => {
272 target_graphs.extend(context.additional_graphs.clone());
274 }
275 }
276
277 target_graphs.extend(context.additional_graphs.clone());
279
280 target_graphs.sort();
282 target_graphs.dedup();
283 Ok(target_graphs)
284 }
285
286 fn get_hierarchy_branch(&self, graph_uri: &str) -> Vec<String> {
288 let mut branch_graphs = Vec::new();
289
290 self.add_children_recursive(graph_uri, &mut branch_graphs);
292
293 self.add_parents_recursive(graph_uri, &mut branch_graphs);
295
296 branch_graphs
297 }
298
299 fn add_children_recursive(&self, graph_uri: &str, result: &mut Vec<String>) {
301 if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph_uri) {
302 for child in children {
303 if !result.contains(child) {
304 result.push(child.clone());
305 self.add_children_recursive(child, result);
306 }
307 }
308 }
309 }
310
311 fn add_parents_recursive(&self, graph_uri: &str, result: &mut Vec<String>) {
313 for (parent, children) in &self.config.graph_hierarchy.parent_child {
314 if children.contains(&graph_uri.to_string()) && !result.contains(parent) {
315 result.push(parent.clone());
316 self.add_parents_recursive(parent, result);
317 }
318 }
319 }
320
321 fn calculate_context_score(&self, resource: &str, graph_uri: &str) -> Result<f32> {
323 let mut context_score = 1.0;
324
325 if let Some(&weight) = self.config.graph_hierarchy.graph_weights.get(graph_uri) {
327 context_score *= weight;
328 }
329
330 if let Some(metadata) = self.graph_metadata.get(graph_uri) {
332 context_score *= metadata.quality_score;
333 }
334
335 if let Some(resource_info) = self.resource_graph_map.get(resource) {
337 if resource_info.primary_graph.as_ref() == Some(&graph_uri.to_string()) {
338 context_score *= 1.2; }
340 }
341
342 Ok(context_score.min(1.0)) }
344
345 fn combine_scores(&self, similarity_score: f32, context_score: f32, graph_uri: &str) -> f32 {
347 let similarity_weight = 0.7;
349 let context_weight = 0.3;
350
351 let graph_boost = self
353 .config
354 .graph_hierarchy
355 .graph_weights
356 .get(graph_uri)
357 .unwrap_or(&1.0);
358
359 (similarity_score * similarity_weight + context_score * context_weight) * graph_boost
360 }
361
362 fn rank_results_by_graph_context(
364 &self,
365 mut results: Vec<GraphAwareSearchResult>,
366 context: &GraphContext,
367 ) -> Result<Vec<GraphAwareSearchResult>> {
368 for result in &mut results {
370 if let Some(&weight) = context.context_weights.get(&result.source_graph) {
371 result.final_score *= weight;
372 }
373
374 if result.source_graph == context.primary_graph {
376 result.final_score *= 1.1;
377 }
378 }
379
380 results.sort_by(|a, b| {
382 b.final_score
383 .partial_cmp(&a.final_score)
384 .unwrap_or(std::cmp::Ordering::Equal)
385 });
386
387 if self.config.enable_cross_graph_similarity {
389 results = self.apply_diversity_filtering(results);
390 }
391
392 Ok(results)
393 }
394
395 fn apply_diversity_filtering(
397 &self,
398 results: Vec<GraphAwareSearchResult>,
399 ) -> Vec<GraphAwareSearchResult> {
400 let mut filtered_results = Vec::new();
401 let mut graph_counts: HashMap<String, usize> = HashMap::new();
402 let max_per_graph = 3; for result in results {
405 let count = graph_counts.entry(result.source_graph.clone()).or_insert(0);
406 if *count < max_per_graph {
407 filtered_results.push(result);
408 *count += 1;
409 }
410 }
411
412 filtered_results
413 }
414
415 pub fn update_graph_metadata(&mut self, graph_uri: String, metadata: GraphMetadata) {
417 self.graph_metadata.insert(graph_uri, metadata);
418 }
419
420 pub fn get_graph_stats(&self, graph_uri: &str) -> Option<(usize, Option<&GraphMetadata>)> {
422 let size = self.graph_sizes.get(graph_uri).cloned();
423 let metadata = self.graph_metadata.get(graph_uri);
424 size.map(|s| (s, metadata))
425 }
426
427 pub fn clear_caches(&mut self) {
429 self.resource_graph_map.clear();
430 self.graph_metadata.clear();
431 self.graph_sizes.clear();
432 }
433
434 pub fn resource_in_graph(&self, resource: &str, graph_uri: &str) -> bool {
436 self.resource_graph_map
437 .get(resource)
438 .map(|info| info.graphs.contains(graph_uri))
439 .unwrap_or(false)
440 }
441
442 pub fn get_resource_graphs(&self, resource: &str) -> Option<&HashSet<String>> {
444 self.resource_graph_map
445 .get(resource)
446 .map(|info| &info.graphs)
447 }
448
449 pub fn cross_graph_similarity(
451 &self,
452 vector_store: &VectorStore,
453 resource1: &str,
454 graph1: &str,
455 resource2: &str,
456 graph2: &str,
457 ) -> Result<f32> {
458 if !self.config.enable_cross_graph_similarity {
459 return Err(anyhow!("Cross-graph similarity is disabled"));
460 }
461
462 if !self.resource_in_graph(resource1, graph1) || !self.resource_in_graph(resource2, graph2)
464 {
465 return Err(anyhow!("Resources not found in specified graphs"));
466 }
467
468 let base_similarity = vector_store.calculate_similarity(resource1, resource2)?;
470
471 let graph_relationship_factor = self.calculate_graph_relationship_factor(graph1, graph2);
473
474 Ok(base_similarity * graph_relationship_factor)
475 }
476
477 fn calculate_graph_relationship_factor(&self, graph1: &str, graph2: &str) -> f32 {
479 if graph1 == graph2 {
480 return 1.0; }
482
483 if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph1) {
485 if children.contains(&graph2.to_string()) {
486 return 0.9; }
488 }
489
490 if let Some(children) = self.config.graph_hierarchy.parent_child.get(graph2) {
491 if children.contains(&graph1.to_string()) {
492 return 0.9; }
494 }
495
496 if let (Some(type1), Some(type2)) = (
498 self.config.graph_hierarchy.graph_types.get(graph1),
499 self.config.graph_hierarchy.graph_types.get(graph2),
500 ) {
501 if type1 == type2 {
502 return 0.8; }
504 }
505
506 0.7 }
508
509 pub fn set_graph_hierarchy(&mut self, parent_child: HashMap<String, Vec<String>>) {
511 self.config.graph_hierarchy.parent_child = parent_child;
512 }
513
514 pub fn set_graph_weights(&mut self, weights: HashMap<String, f32>) {
516 self.config.graph_hierarchy.graph_weights = weights;
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523
524 #[test]
525 fn test_graph_context_creation() {
526 let context = GraphContext {
527 primary_graph: "http://example.org/graph1".to_string(),
528 additional_graphs: vec!["http://example.org/graph2".to_string()],
529 scope: GraphSearchScope::IncludeChildren,
530 context_weights: HashMap::new(),
531 };
532
533 assert_eq!(context.primary_graph, "http://example.org/graph1");
534 assert_eq!(context.scope, GraphSearchScope::IncludeChildren);
535 }
536
537 #[test]
538 fn test_resource_graph_registration() {
539 let mut search = GraphAwareSearch::new(GraphAwareConfig::default());
540
541 search.register_resource_graph(
542 "http://example.org/resource1".to_string(),
543 vec!["http://example.org/graph1".to_string()],
544 );
545
546 assert!(
547 search.resource_in_graph("http://example.org/resource1", "http://example.org/graph1")
548 );
549 assert!(
550 !search.resource_in_graph("http://example.org/resource1", "http://example.org/graph2")
551 );
552 }
553
554 #[test]
555 fn test_graph_hierarchy() {
556 let mut config = GraphAwareConfig::default();
557 config.graph_hierarchy.parent_child.insert(
558 "http://example.org/parent".to_string(),
559 vec![
560 "http://example.org/child1".to_string(),
561 "http://example.org/child2".to_string(),
562 ],
563 );
564
565 let search = GraphAwareSearch::new(config);
566 let branch = search.get_hierarchy_branch("http://example.org/parent");
567
568 assert!(branch.contains(&"http://example.org/child1".to_string()));
569 assert!(branch.contains(&"http://example.org/child2".to_string()));
570 }
571
572 #[test]
573 fn test_graph_search_scope() {
574 let context = GraphContext {
575 primary_graph: "http://example.org/main".to_string(),
576 additional_graphs: vec![],
577 scope: GraphSearchScope::Exact,
578 context_weights: HashMap::new(),
579 };
580
581 let search = GraphAwareSearch::new(GraphAwareConfig::default());
582 let graphs = search.resolve_search_graphs(&context).unwrap();
583
584 assert_eq!(graphs.len(), 1);
585 assert_eq!(graphs[0], "http://example.org/main");
586 }
587}