1use super::{algorithms::GraphAlgorithms, KnowledgeGraph};
6use crate::RragResult;
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[async_trait]
13pub trait QueryExpander: Send + Sync {
14 async fn expand_query(
16 &self,
17 query: &str,
18 options: &ExpansionOptions,
19 ) -> RragResult<ExpansionResult>;
20
21 async fn expand_terms(
23 &self,
24 terms: &[String],
25 options: &ExpansionOptions,
26 ) -> RragResult<Vec<String>>;
27
28 async fn find_related_entities(
30 &self,
31 entities: &[String],
32 options: &ExpansionOptions,
33 ) -> RragResult<Vec<String>>;
34
35 async fn get_suggestions(&self, query: &str, max_suggestions: usize)
37 -> RragResult<Vec<String>>;
38}
39
40pub struct GraphQueryExpander {
42 graph: KnowledgeGraph,
44
45 config: ExpansionConfig,
47
48 expansion_cache: tokio::sync::RwLock<HashMap<String, Vec<String>>>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ExpansionConfig {
55 pub max_expansion_depth: usize,
57
58 pub max_expansion_terms: usize,
60
61 pub min_similarity_threshold: f32,
63
64 pub strategy_weights: HashMap<ExpansionStrategy, f32>,
66
67 pub enable_semantic_expansion: bool,
69
70 pub enable_structural_expansion: bool,
72
73 pub enable_statistical_expansion: bool,
75
76 pub enable_caching: bool,
78
79 pub stop_words: HashSet<String>,
81}
82
83#[derive(Debug, Clone)]
85pub struct ExpansionOptions {
86 pub strategies: Vec<ExpansionStrategy>,
88
89 pub max_terms: Option<usize>,
91
92 pub min_confidence: f32,
94
95 pub focus_entities: Vec<String>,
97
98 pub context: Option<String>,
100
101 pub include_original: bool,
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub enum ExpansionStrategy {
108 Semantic,
110
111 Hierarchical,
113
114 Similarity,
116
117 CoOccurrence,
119
120 Synonym,
122
123 EntityType,
125
126 PathBased,
128
129 PageRank,
131
132 Custom(String),
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ExpansionResult {
139 pub original_query: String,
141
142 pub expanded_terms: Vec<ExpandedTerm>,
144
145 pub stats: ExpansionStats,
147
148 pub strategies_used: Vec<ExpansionStrategy>,
150
151 pub confidence: f32,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ExpandedTerm {
158 pub term: String,
160
161 pub strategy: ExpansionStrategy,
163
164 pub confidence: f32,
166
167 pub weight: f32,
169
170 pub source_entities: Vec<String>,
172
173 pub relationship: Option<String>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ExpansionStats {
180 pub entities_found: usize,
182
183 pub terms_per_strategy: HashMap<String, usize>,
185
186 pub expansion_time_ms: u64,
188
189 pub nodes_examined: usize,
191
192 pub edges_examined: usize,
194}
195
196impl Default for ExpansionConfig {
197 fn default() -> Self {
198 let mut strategy_weights = HashMap::new();
199 strategy_weights.insert(ExpansionStrategy::Semantic, 1.0);
200 strategy_weights.insert(ExpansionStrategy::Hierarchical, 0.8);
201 strategy_weights.insert(ExpansionStrategy::Similarity, 0.7);
202 strategy_weights.insert(ExpansionStrategy::CoOccurrence, 0.6);
203 strategy_weights.insert(ExpansionStrategy::EntityType, 0.5);
204 strategy_weights.insert(ExpansionStrategy::PathBased, 0.4);
205
206 let mut stop_words = HashSet::new();
207 stop_words.extend(
208 vec![
209 "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with",
210 "by", "from", "up", "about", "into", "through", "during", "before", "after",
211 "above", "below", "between", "among", "this", "that",
212 ]
213 .into_iter()
214 .map(|s| s.to_string()),
215 );
216
217 Self {
218 max_expansion_depth: 2,
219 max_expansion_terms: 20,
220 min_similarity_threshold: 0.3,
221 strategy_weights,
222 enable_semantic_expansion: true,
223 enable_structural_expansion: true,
224 enable_statistical_expansion: true,
225 enable_caching: true,
226 stop_words,
227 }
228 }
229}
230
231impl Default for ExpansionOptions {
232 fn default() -> Self {
233 Self {
234 strategies: vec![
235 ExpansionStrategy::Semantic,
236 ExpansionStrategy::Similarity,
237 ExpansionStrategy::CoOccurrence,
238 ],
239 max_terms: Some(10),
240 min_confidence: 0.3,
241 focus_entities: Vec::new(),
242 context: None,
243 include_original: true,
244 }
245 }
246}
247
248impl GraphQueryExpander {
249 pub fn new(graph: KnowledgeGraph, config: ExpansionConfig) -> Self {
251 Self {
252 graph,
253 config,
254 expansion_cache: tokio::sync::RwLock::new(HashMap::new()),
255 }
256 }
257
258 pub async fn update_graph(&mut self, graph: KnowledgeGraph) {
260 self.graph = graph;
261 if self.config.enable_caching {
263 self.expansion_cache.write().await.clear();
264 }
265 }
266
267 async fn extract_query_entities(&self, query: &str) -> Vec<String> {
269 let mut entities = Vec::new();
270 let query_lower = query.to_lowercase();
271
272 for (_, node) in &self.graph.nodes {
274 let label_lower = node.label.to_lowercase();
275 if query_lower.contains(&label_lower) && !self.config.stop_words.contains(&label_lower)
276 {
277 entities.push(node.id.clone());
278 }
279 }
280
281 entities
282 }
283
284 async fn semantic_expansion(
286 &self,
287 entity_ids: &[String],
288 options: &ExpansionOptions,
289 ) -> RragResult<Vec<ExpandedTerm>> {
290 let mut expanded_terms = Vec::new();
291 let strategy_weight = self
292 .config
293 .strategy_weights
294 .get(&ExpansionStrategy::Semantic)
295 .copied()
296 .unwrap_or(1.0);
297
298 for entity_id in entity_ids {
299 if let Some(_entity_node) = self.graph.get_node(entity_id) {
300 let semantic_edges: Vec<_> = self
302 .graph
303 .edges
304 .values()
305 .filter(|edge| {
306 (edge.source_id == *entity_id || edge.target_id == *entity_id)
307 && matches!(edge.edge_type, super::EdgeType::Semantic(_))
308 })
309 .collect();
310
311 for edge in semantic_edges {
312 let related_node_id = if edge.source_id == *entity_id {
313 &edge.target_id
314 } else {
315 &edge.source_id
316 };
317
318 if let Some(related_node) = self.graph.get_node(related_node_id) {
319 let confidence = edge.confidence * strategy_weight;
320 if confidence >= options.min_confidence {
321 let expanded_term = ExpandedTerm {
322 term: related_node.label.clone(),
323 strategy: ExpansionStrategy::Semantic,
324 confidence,
325 weight: edge.weight * strategy_weight,
326 source_entities: vec![entity_id.clone()],
327 relationship: Some(edge.label.clone()),
328 };
329 expanded_terms.push(expanded_term);
330 }
331 }
332 }
333 }
334 }
335
336 Ok(expanded_terms)
337 }
338
339 async fn hierarchical_expansion(
341 &self,
342 entity_ids: &[String],
343 options: &ExpansionOptions,
344 ) -> RragResult<Vec<ExpandedTerm>> {
345 let mut expanded_terms = Vec::new();
346 let strategy_weight = self
347 .config
348 .strategy_weights
349 .get(&ExpansionStrategy::Hierarchical)
350 .copied()
351 .unwrap_or(0.8);
352
353 for entity_id in entity_ids {
354 let hierarchical_edges: Vec<_> = self
356 .graph
357 .edges
358 .values()
359 .filter(|edge| {
360 (edge.source_id == *entity_id || edge.target_id == *entity_id)
361 && matches!(edge.edge_type, super::EdgeType::Hierarchical)
362 })
363 .collect();
364
365 for edge in hierarchical_edges {
366 let related_node_id = if edge.source_id == *entity_id {
367 &edge.target_id
368 } else {
369 &edge.source_id
370 };
371
372 if let Some(related_node) = self.graph.get_node(related_node_id) {
373 let confidence = edge.confidence * strategy_weight;
374 if confidence >= options.min_confidence {
375 let expanded_term = ExpandedTerm {
376 term: related_node.label.clone(),
377 strategy: ExpansionStrategy::Hierarchical,
378 confidence,
379 weight: edge.weight * strategy_weight,
380 source_entities: vec![entity_id.clone()],
381 relationship: Some(if edge.source_id == *entity_id {
382 "parent".to_string()
383 } else {
384 "child".to_string()
385 }),
386 };
387 expanded_terms.push(expanded_term);
388 }
389 }
390 }
391 }
392
393 Ok(expanded_terms)
394 }
395
396 async fn similarity_expansion(
398 &self,
399 entity_ids: &[String],
400 options: &ExpansionOptions,
401 ) -> RragResult<Vec<ExpandedTerm>> {
402 let mut expanded_terms = Vec::new();
403 let strategy_weight = self
404 .config
405 .strategy_weights
406 .get(&ExpansionStrategy::Similarity)
407 .copied()
408 .unwrap_or(0.7);
409
410 for entity_id in entity_ids {
411 if let Some(entity_node) = self.graph.get_node(entity_id) {
412 if let Some(entity_embedding) = &entity_node.embedding {
414 for (other_id, other_node) in &self.graph.nodes {
415 if other_id == entity_id {
416 continue;
417 }
418
419 if let Some(other_embedding) = &other_node.embedding {
420 if let Ok(similarity) =
421 entity_embedding.cosine_similarity(other_embedding)
422 {
423 if similarity >= self.config.min_similarity_threshold {
424 let confidence = similarity * strategy_weight;
425 if confidence >= options.min_confidence {
426 let expanded_term = ExpandedTerm {
427 term: other_node.label.clone(),
428 strategy: ExpansionStrategy::Similarity,
429 confidence,
430 weight: similarity * strategy_weight,
431 source_entities: vec![entity_id.clone()],
432 relationship: Some(format!(
433 "similarity:{:.2}",
434 similarity
435 )),
436 };
437 expanded_terms.push(expanded_term);
438 }
439 }
440 }
441 }
442 }
443 }
444
445 let similarity_edges: Vec<_> = self
447 .graph
448 .edges
449 .values()
450 .filter(|edge| {
451 (edge.source_id == *entity_id || edge.target_id == *entity_id)
452 && matches!(edge.edge_type, super::EdgeType::Similar)
453 })
454 .collect();
455
456 for edge in similarity_edges {
457 let related_node_id = if edge.source_id == *entity_id {
458 &edge.target_id
459 } else {
460 &edge.source_id
461 };
462
463 if let Some(related_node) = self.graph.get_node(related_node_id) {
464 let confidence = edge.confidence * strategy_weight;
465 if confidence >= options.min_confidence {
466 let expanded_term = ExpandedTerm {
467 term: related_node.label.clone(),
468 strategy: ExpansionStrategy::Similarity,
469 confidence,
470 weight: edge.weight * strategy_weight,
471 source_entities: vec![entity_id.clone()],
472 relationship: Some("explicit_similarity".to_string()),
473 };
474 expanded_terms.push(expanded_term);
475 }
476 }
477 }
478 }
479 }
480
481 Ok(expanded_terms)
482 }
483
484 async fn cooccurrence_expansion(
486 &self,
487 entity_ids: &[String],
488 options: &ExpansionOptions,
489 ) -> RragResult<Vec<ExpandedTerm>> {
490 let mut expanded_terms = Vec::new();
491 let strategy_weight = self
492 .config
493 .strategy_weights
494 .get(&ExpansionStrategy::CoOccurrence)
495 .copied()
496 .unwrap_or(0.6);
497
498 for entity_id in entity_ids {
499 let cooccurrence_edges: Vec<_> = self
501 .graph
502 .edges
503 .values()
504 .filter(|edge| {
505 (edge.source_id == *entity_id || edge.target_id == *entity_id)
506 && matches!(edge.edge_type, super::EdgeType::CoOccurs)
507 })
508 .collect();
509
510 for edge in cooccurrence_edges {
511 let related_node_id = if edge.source_id == *entity_id {
512 &edge.target_id
513 } else {
514 &edge.source_id
515 };
516
517 if let Some(related_node) = self.graph.get_node(related_node_id) {
518 let confidence = edge.confidence * strategy_weight;
519 if confidence >= options.min_confidence {
520 let expanded_term = ExpandedTerm {
521 term: related_node.label.clone(),
522 strategy: ExpansionStrategy::CoOccurrence,
523 confidence,
524 weight: edge.weight * strategy_weight,
525 source_entities: vec![entity_id.clone()],
526 relationship: Some("co_occurrence".to_string()),
527 };
528 expanded_terms.push(expanded_term);
529 }
530 }
531 }
532 }
533
534 Ok(expanded_terms)
535 }
536
537 async fn entity_type_expansion(
539 &self,
540 entity_ids: &[String],
541 options: &ExpansionOptions,
542 ) -> RragResult<Vec<ExpandedTerm>> {
543 let mut expanded_terms = Vec::new();
544 let strategy_weight = self
545 .config
546 .strategy_weights
547 .get(&ExpansionStrategy::EntityType)
548 .copied()
549 .unwrap_or(0.5);
550
551 let mut entities_by_type: HashMap<String, Vec<String>> = HashMap::new();
553 for entity_id in entity_ids {
554 if let Some(entity_node) = self.graph.get_node(entity_id) {
555 let type_key = match &entity_node.node_type {
556 super::NodeType::Entity(entity_type) => entity_type.clone(),
557 super::NodeType::Concept => "Concept".to_string(),
558 super::NodeType::Document => "Document".to_string(),
559 super::NodeType::DocumentChunk => "DocumentChunk".to_string(),
560 super::NodeType::Keyword => "Keyword".to_string(),
561 super::NodeType::Custom(custom) => custom.clone(),
562 };
563
564 entities_by_type
565 .entry(type_key)
566 .or_default()
567 .push(entity_id.clone());
568 }
569 }
570
571 for (entity_type, type_entities) in entities_by_type {
573 let similar_type_nodes: Vec<_> = self
574 .graph
575 .nodes
576 .values()
577 .filter(|node| {
578 let node_type_key = match &node.node_type {
579 super::NodeType::Entity(et) => et.clone(),
580 super::NodeType::Concept => "Concept".to_string(),
581 super::NodeType::Document => "Document".to_string(),
582 super::NodeType::DocumentChunk => "DocumentChunk".to_string(),
583 super::NodeType::Keyword => "Keyword".to_string(),
584 super::NodeType::Custom(custom) => custom.clone(),
585 };
586 node_type_key == entity_type && !type_entities.contains(&node.id)
587 })
588 .collect();
589
590 for node in similar_type_nodes.into_iter().take(5) {
591 let confidence = strategy_weight * 0.5; if confidence >= options.min_confidence {
594 let expanded_term = ExpandedTerm {
595 term: node.label.clone(),
596 strategy: ExpansionStrategy::EntityType,
597 confidence,
598 weight: strategy_weight * 0.5,
599 source_entities: type_entities.clone(),
600 relationship: Some(format!("same_type:{}", entity_type)),
601 };
602 expanded_terms.push(expanded_term);
603 }
604 }
605 }
606
607 Ok(expanded_terms)
608 }
609
610 async fn path_based_expansion(
612 &self,
613 entity_ids: &[String],
614 options: &ExpansionOptions,
615 ) -> RragResult<Vec<ExpandedTerm>> {
616 let mut expanded_terms = Vec::new();
617 let strategy_weight = self
618 .config
619 .strategy_weights
620 .get(&ExpansionStrategy::PathBased)
621 .copied()
622 .unwrap_or(0.4);
623
624 for entity_id in entity_ids {
626 let traversal_config = super::algorithms::TraversalConfig {
627 max_depth: self.config.max_expansion_depth,
628 max_nodes: 50, ..Default::default()
630 };
631
632 if let Ok(visited_nodes) =
633 GraphAlgorithms::bfs_search(&self.graph, entity_id, &traversal_config)
634 {
635 for visited_node_id in visited_nodes.iter().skip(1) {
636 if let Some(visited_node) = self.graph.get_node(visited_node_id) {
638 let distance = visited_nodes
640 .iter()
641 .position(|id| id == visited_node_id)
642 .unwrap_or(0);
643 let distance_factor = 1.0 / (distance as f32 + 1.0);
644 let confidence = strategy_weight * distance_factor;
645
646 if confidence >= options.min_confidence {
647 let expanded_term = ExpandedTerm {
648 term: visited_node.label.clone(),
649 strategy: ExpansionStrategy::PathBased,
650 confidence,
651 weight: confidence,
652 source_entities: vec![entity_id.clone()],
653 relationship: Some(format!("path_distance:{}", distance)),
654 };
655 expanded_terms.push(expanded_term);
656 }
657 }
658 }
659 }
660 }
661
662 Ok(expanded_terms)
663 }
664
665 fn apply_focus_boosting(&self, terms: &mut [ExpandedTerm], focus_entities: &[String]) {
667 if focus_entities.is_empty() {
668 return;
669 }
670
671 for term in terms {
672 let is_related = term
674 .source_entities
675 .iter()
676 .any(|source| focus_entities.contains(source));
677
678 if is_related {
679 term.confidence *= 1.5;
680 term.weight *= 1.5;
681 }
682 }
683 }
684
685 fn deduplicate_and_rank(&self, terms: &mut Vec<ExpandedTerm>, max_terms: Option<usize>) {
687 let mut seen_terms: HashMap<String, usize> = HashMap::new();
689 let mut unique_terms: Vec<ExpandedTerm> = Vec::new();
690
691 for term in terms.drain(..) {
692 match seen_terms.get(&term.term) {
693 Some(&existing_index) => {
694 if term.confidence > unique_terms[existing_index].confidence {
695 unique_terms[existing_index] = term;
696 }
697 }
698 None => {
699 seen_terms.insert(term.term.clone(), unique_terms.len());
700 unique_terms.push(term);
701 }
702 }
703 }
704
705 unique_terms.sort_by(|a, b| {
707 b.weight
708 .partial_cmp(&a.weight)
709 .unwrap_or(std::cmp::Ordering::Equal)
710 .then_with(|| {
711 b.confidence
712 .partial_cmp(&a.confidence)
713 .unwrap_or(std::cmp::Ordering::Equal)
714 })
715 });
716
717 if let Some(limit) = max_terms {
719 unique_terms.truncate(limit);
720 }
721
722 *terms = unique_terms;
723 }
724}
725
726#[async_trait]
727impl QueryExpander for GraphQueryExpander {
728 async fn expand_query(
729 &self,
730 query: &str,
731 options: &ExpansionOptions,
732 ) -> RragResult<ExpansionResult> {
733 let start_time = std::time::Instant::now();
734
735 if self.config.enable_caching {
737 let cache_key = format!("{}:{:?}", query, options.strategies);
738 if let Some(cached_terms) = self.expansion_cache.read().await.get(&cache_key) {
739 let result = ExpansionResult {
740 original_query: query.to_string(),
741 expanded_terms: cached_terms
742 .iter()
743 .map(|term| ExpandedTerm {
744 term: term.clone(),
745 strategy: ExpansionStrategy::Semantic, confidence: 0.8,
747 weight: 0.8,
748 source_entities: Vec::new(),
749 relationship: None,
750 })
751 .collect(),
752 stats: ExpansionStats {
753 entities_found: 0,
754 terms_per_strategy: HashMap::new(),
755 expansion_time_ms: start_time.elapsed().as_millis() as u64,
756 nodes_examined: 0,
757 edges_examined: 0,
758 },
759 strategies_used: options.strategies.clone(),
760 confidence: 0.8,
761 };
762 return Ok(result);
763 }
764 }
765
766 let entity_ids = self.extract_query_entities(query).await;
768 let mut expanded_terms = Vec::new();
769 let mut terms_per_strategy = HashMap::new();
770 let mut nodes_examined = 0;
771 let mut edges_examined = 0;
772
773 for strategy in &options.strategies {
775 let strategy_terms = match strategy {
776 ExpansionStrategy::Semantic if self.config.enable_semantic_expansion => {
777 self.semantic_expansion(&entity_ids, options).await?
778 }
779 ExpansionStrategy::Hierarchical if self.config.enable_structural_expansion => {
780 self.hierarchical_expansion(&entity_ids, options).await?
781 }
782 ExpansionStrategy::Similarity => {
783 self.similarity_expansion(&entity_ids, options).await?
784 }
785 ExpansionStrategy::CoOccurrence if self.config.enable_statistical_expansion => {
786 self.cooccurrence_expansion(&entity_ids, options).await?
787 }
788 ExpansionStrategy::EntityType => {
789 self.entity_type_expansion(&entity_ids, options).await?
790 }
791 ExpansionStrategy::PathBased if self.config.enable_structural_expansion => {
792 self.path_based_expansion(&entity_ids, options).await?
793 }
794 _ => Vec::new(), };
796
797 terms_per_strategy.insert(strategy.to_string(), strategy_terms.len());
798 expanded_terms.extend(strategy_terms);
799
800 nodes_examined += entity_ids.len();
802 edges_examined += entity_ids.len() * 5; }
804
805 self.apply_focus_boosting(&mut expanded_terms, &options.focus_entities);
807
808 self.deduplicate_and_rank(&mut expanded_terms, options.max_terms);
810
811 if options.include_original {
813 let original_terms: Vec<_> = query
814 .split_whitespace()
815 .filter(|term| !self.config.stop_words.contains(&term.to_lowercase()))
816 .map(|term| ExpandedTerm {
817 term: term.to_string(),
818 strategy: ExpansionStrategy::Custom("original".to_string()),
819 confidence: 1.0,
820 weight: 1.0,
821 source_entities: Vec::new(),
822 relationship: Some("original_query".to_string()),
823 })
824 .collect();
825
826 expanded_terms.splice(0..0, original_terms);
827 }
828
829 let confidence = if !expanded_terms.is_empty() {
831 expanded_terms.iter().map(|t| t.confidence).sum::<f32>() / expanded_terms.len() as f32
832 } else {
833 0.0
834 };
835
836 if self.config.enable_caching {
838 let cache_key = format!("{}:{:?}", query, options.strategies);
839 let cache_terms: Vec<_> = expanded_terms.iter().map(|t| t.term.clone()).collect();
840 self.expansion_cache
841 .write()
842 .await
843 .insert(cache_key, cache_terms);
844 }
845
846 let expansion_time_ms = start_time.elapsed().as_millis() as u64;
847
848 Ok(ExpansionResult {
849 original_query: query.to_string(),
850 expanded_terms,
851 stats: ExpansionStats {
852 entities_found: entity_ids.len(),
853 terms_per_strategy,
854 expansion_time_ms,
855 nodes_examined,
856 edges_examined,
857 },
858 strategies_used: options.strategies.clone(),
859 confidence,
860 })
861 }
862
863 async fn expand_terms(
864 &self,
865 terms: &[String],
866 options: &ExpansionOptions,
867 ) -> RragResult<Vec<String>> {
868 let combined_query = terms.join(" ");
869 let expansion_result = self.expand_query(&combined_query, options).await?;
870 Ok(expansion_result
871 .expanded_terms
872 .into_iter()
873 .map(|t| t.term)
874 .collect())
875 }
876
877 async fn find_related_entities(
878 &self,
879 entities: &[String],
880 options: &ExpansionOptions,
881 ) -> RragResult<Vec<String>> {
882 let entity_ids: Vec<_> = entities
884 .iter()
885 .filter_map(|entity_name| {
886 self.graph
887 .nodes
888 .values()
889 .find(|node| node.label.eq_ignore_ascii_case(entity_name))
890 .map(|node| node.id.clone())
891 })
892 .collect();
893
894 if entity_ids.is_empty() {
895 return Ok(Vec::new());
896 }
897
898 let expanded_terms = self.semantic_expansion(&entity_ids, options).await?;
900 Ok(expanded_terms.into_iter().map(|t| t.term).collect())
901 }
902
903 async fn get_suggestions(
904 &self,
905 query: &str,
906 max_suggestions: usize,
907 ) -> RragResult<Vec<String>> {
908 let options = ExpansionOptions {
909 strategies: vec![ExpansionStrategy::Semantic, ExpansionStrategy::Similarity],
910 max_terms: Some(max_suggestions),
911 min_confidence: 0.2, ..Default::default()
913 };
914
915 let expansion_result = self.expand_query(query, &options).await?;
916 Ok(expansion_result
917 .expanded_terms
918 .into_iter()
919 .map(|t| t.term)
920 .collect())
921 }
922}
923
924impl std::fmt::Display for ExpansionStrategy {
925 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
926 match self {
927 ExpansionStrategy::Semantic => write!(f, "semantic"),
928 ExpansionStrategy::Hierarchical => write!(f, "hierarchical"),
929 ExpansionStrategy::Similarity => write!(f, "similarity"),
930 ExpansionStrategy::CoOccurrence => write!(f, "co_occurrence"),
931 ExpansionStrategy::Synonym => write!(f, "synonym"),
932 ExpansionStrategy::EntityType => write!(f, "entity_type"),
933 ExpansionStrategy::PathBased => write!(f, "path_based"),
934 ExpansionStrategy::PageRank => write!(f, "pagerank"),
935 ExpansionStrategy::Custom(name) => write!(f, "custom_{}", name),
936 }
937 }
938}
939
940#[cfg(test)]
941mod tests {
942 use super::*;
943 use crate::graph_retrieval::{EdgeType, GraphEdge, GraphNode, NodeType};
944
945 fn create_test_graph() -> KnowledgeGraph {
946 let mut graph = KnowledgeGraph::new();
947
948 let node1 = GraphNode::new("machine learning", NodeType::Concept);
950 let node2 = GraphNode::new("artificial intelligence", NodeType::Concept);
951 let node3 = GraphNode::new("deep learning", NodeType::Concept);
952 let node4 = GraphNode::new("neural networks", NodeType::Concept);
953
954 let node1_id = node1.id.clone();
955 let node2_id = node2.id.clone();
956 let node3_id = node3.id.clone();
957 let node4_id = node4.id.clone();
958
959 graph.add_node(node1).unwrap();
960 graph.add_node(node2).unwrap();
961 graph.add_node(node3).unwrap();
962 graph.add_node(node4).unwrap();
963
964 graph
966 .add_edge(
967 GraphEdge::new(
968 node3_id.clone(),
969 node1_id.clone(),
970 "is_a",
971 EdgeType::Semantic("is_a".to_string()),
972 )
973 .with_confidence(0.9)
974 .with_weight(0.9),
975 )
976 .unwrap();
977
978 graph
979 .add_edge(
980 GraphEdge::new(
981 node1_id.clone(),
982 node2_id.clone(),
983 "part_of",
984 EdgeType::Semantic("part_of".to_string()),
985 )
986 .with_confidence(0.8)
987 .with_weight(0.8),
988 )
989 .unwrap();
990
991 graph
992 .add_edge(
993 GraphEdge::new(
994 node4_id.clone(),
995 node3_id.clone(),
996 "used_in",
997 EdgeType::Semantic("used_in".to_string()),
998 )
999 .with_confidence(0.7)
1000 .with_weight(0.7),
1001 )
1002 .unwrap();
1003
1004 graph
1005 }
1006
1007 #[tokio::test]
1008 async fn test_query_expansion() {
1009 let graph = create_test_graph();
1010 let config = ExpansionConfig::default();
1011 let expander = GraphQueryExpander::new(graph, config);
1012
1013 let options = ExpansionOptions {
1014 strategies: vec![ExpansionStrategy::Semantic],
1015 max_terms: Some(5),
1016 min_confidence: 0.3,
1017 ..Default::default()
1018 };
1019
1020 let result = expander
1021 .expand_query("machine learning", &options)
1022 .await
1023 .unwrap();
1024
1025 assert!(!result.expanded_terms.is_empty());
1026 assert!(result.stats.entities_found > 0);
1027 assert!(result.confidence > 0.0);
1028 }
1029
1030 #[tokio::test]
1031 async fn test_semantic_expansion() {
1032 let graph = create_test_graph();
1033 let config = ExpansionConfig::default();
1034 let expander = GraphQueryExpander::new(graph.clone(), config);
1035
1036 let ml_node_id = graph
1038 .nodes
1039 .values()
1040 .find(|node| node.label == "machine learning")
1041 .unwrap()
1042 .id
1043 .clone();
1044
1045 let options = ExpansionOptions::default();
1046 let expanded_terms = expander
1047 .semantic_expansion(&[ml_node_id], &options)
1048 .await
1049 .unwrap();
1050
1051 assert!(!expanded_terms.is_empty());
1053
1054 let term_texts: Vec<_> = expanded_terms.iter().map(|t| &t.term).collect();
1056 assert!(
1057 term_texts.contains(&&"artificial intelligence".to_string())
1058 || term_texts.contains(&&"deep learning".to_string())
1059 );
1060 }
1061
1062 #[tokio::test]
1063 async fn test_term_expansion() {
1064 let graph = create_test_graph();
1065 let config = ExpansionConfig::default();
1066 let expander = GraphQueryExpander::new(graph, config);
1067
1068 let options = ExpansionOptions::default();
1069 let expanded_terms = expander
1070 .expand_terms(&["machine learning".to_string()], &options)
1071 .await
1072 .unwrap();
1073
1074 assert!(!expanded_terms.is_empty());
1075
1076 let has_ai_terms = expanded_terms.iter().any(|term| {
1078 term.contains("artificial") || term.contains("deep") || term.contains("neural")
1079 });
1080 assert!(has_ai_terms);
1081 }
1082
1083 #[tokio::test]
1084 async fn test_get_suggestions() {
1085 let graph = create_test_graph();
1086 let config = ExpansionConfig::default();
1087 let expander = GraphQueryExpander::new(graph, config);
1088
1089 let suggestions = expander.get_suggestions("machine", 3).await.unwrap();
1090
1091 assert!(!suggestions.is_empty());
1093 assert!(suggestions.len() <= 3);
1094 }
1095}