rexis_rag/graph_retrieval/
query_expansion.rs

1//! # Query Expansion Using Graph Structure
2//!
3//! Leverage knowledge graph structure to expand and enhance queries for improved retrieval.
4
5use super::{algorithms::GraphAlgorithms, KnowledgeGraph};
6use crate::RragResult;
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11/// Query expander trait for different expansion strategies
12#[async_trait]
13pub trait QueryExpander: Send + Sync {
14    /// Expand a text query using the knowledge graph
15    async fn expand_query(
16        &self,
17        query: &str,
18        options: &ExpansionOptions,
19    ) -> RragResult<ExpansionResult>;
20
21    /// Expand query terms using graph structure
22    async fn expand_terms(
23        &self,
24        terms: &[String],
25        options: &ExpansionOptions,
26    ) -> RragResult<Vec<String>>;
27
28    /// Find related entities for query expansion
29    async fn find_related_entities(
30        &self,
31        entities: &[String],
32        options: &ExpansionOptions,
33    ) -> RragResult<Vec<String>>;
34
35    /// Get expansion suggestions for a query
36    async fn get_suggestions(&self, query: &str, max_suggestions: usize)
37        -> RragResult<Vec<String>>;
38}
39
40/// Graph-based query expander
41pub struct GraphQueryExpander {
42    /// Knowledge graph
43    graph: KnowledgeGraph,
44
45    /// Expansion configuration
46    config: ExpansionConfig,
47
48    /// Pre-computed expansion cache
49    expansion_cache: tokio::sync::RwLock<HashMap<String, Vec<String>>>,
50}
51
52/// Query expansion configuration
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct ExpansionConfig {
55    /// Maximum expansion depth in the graph
56    pub max_expansion_depth: usize,
57
58    /// Maximum number of expanded terms
59    pub max_expansion_terms: usize,
60
61    /// Minimum similarity threshold for term expansion
62    pub min_similarity_threshold: f32,
63
64    /// Weights for different expansion strategies
65    pub strategy_weights: HashMap<ExpansionStrategy, f32>,
66
67    /// Enable semantic expansion using embeddings
68    pub enable_semantic_expansion: bool,
69
70    /// Enable structural expansion using graph topology
71    pub enable_structural_expansion: bool,
72
73    /// Enable statistical expansion using co-occurrence
74    pub enable_statistical_expansion: bool,
75
76    /// Cache expansion results
77    pub enable_caching: bool,
78
79    /// Stop words to avoid in expansion
80    pub stop_words: HashSet<String>,
81}
82
83/// Expansion options for individual queries
84#[derive(Debug, Clone)]
85pub struct ExpansionOptions {
86    /// Specific expansion strategies to use
87    pub strategies: Vec<ExpansionStrategy>,
88
89    /// Maximum number of terms to add
90    pub max_terms: Option<usize>,
91
92    /// Minimum confidence for expanded terms
93    pub min_confidence: f32,
94
95    /// Focus entities (boost terms related to these)
96    pub focus_entities: Vec<String>,
97
98    /// Context for expansion (document domain, etc.)
99    pub context: Option<String>,
100
101    /// Whether to include original query terms
102    pub include_original: bool,
103}
104
105/// Query expansion strategies
106#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub enum ExpansionStrategy {
108    /// Semantic expansion using entity relationships
109    Semantic,
110
111    /// Hierarchical expansion (parent/child concepts)
112    Hierarchical,
113
114    /// Similarity-based expansion
115    Similarity,
116
117    /// Co-occurrence based expansion
118    CoOccurrence,
119
120    /// Synonym expansion
121    Synonym,
122
123    /// Entity type expansion
124    EntityType,
125
126    /// Path-based expansion (following graph paths)
127    PathBased,
128
129    /// PageRank-based expansion (importance-weighted)
130    PageRank,
131
132    /// Custom expansion strategy
133    Custom(String),
134}
135
136/// Expansion result
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct ExpansionResult {
139    /// Original query
140    pub original_query: String,
141
142    /// Expanded query terms
143    pub expanded_terms: Vec<ExpandedTerm>,
144
145    /// Expansion statistics
146    pub stats: ExpansionStats,
147
148    /// Used expansion strategies
149    pub strategies_used: Vec<ExpansionStrategy>,
150
151    /// Expansion confidence score
152    pub confidence: f32,
153}
154
155/// Expanded term with metadata
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct ExpandedTerm {
158    /// The expanded term
159    pub term: String,
160
161    /// Expansion strategy that generated this term
162    pub strategy: ExpansionStrategy,
163
164    /// Confidence score
165    pub confidence: f32,
166
167    /// Weight/importance score
168    pub weight: f32,
169
170    /// Source entities that led to this expansion
171    pub source_entities: Vec<String>,
172
173    /// Semantic relationship to original query
174    pub relationship: Option<String>,
175}
176
177/// Expansion statistics
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct ExpansionStats {
180    /// Number of entities found in original query
181    pub entities_found: usize,
182
183    /// Number of terms added per strategy
184    pub terms_per_strategy: HashMap<String, usize>,
185
186    /// Expansion time in milliseconds
187    pub expansion_time_ms: u64,
188
189    /// Graph nodes examined
190    pub nodes_examined: usize,
191
192    /// Graph edges examined  
193    pub edges_examined: usize,
194}
195
196impl Default for ExpansionConfig {
197    fn default() -> Self {
198        let mut strategy_weights = HashMap::new();
199        strategy_weights.insert(ExpansionStrategy::Semantic, 1.0);
200        strategy_weights.insert(ExpansionStrategy::Hierarchical, 0.8);
201        strategy_weights.insert(ExpansionStrategy::Similarity, 0.7);
202        strategy_weights.insert(ExpansionStrategy::CoOccurrence, 0.6);
203        strategy_weights.insert(ExpansionStrategy::EntityType, 0.5);
204        strategy_weights.insert(ExpansionStrategy::PathBased, 0.4);
205
206        let mut stop_words = HashSet::new();
207        stop_words.extend(
208            vec![
209                "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with",
210                "by", "from", "up", "about", "into", "through", "during", "before", "after",
211                "above", "below", "between", "among", "this", "that",
212            ]
213            .into_iter()
214            .map(|s| s.to_string()),
215        );
216
217        Self {
218            max_expansion_depth: 2,
219            max_expansion_terms: 20,
220            min_similarity_threshold: 0.3,
221            strategy_weights,
222            enable_semantic_expansion: true,
223            enable_structural_expansion: true,
224            enable_statistical_expansion: true,
225            enable_caching: true,
226            stop_words,
227        }
228    }
229}
230
231impl Default for ExpansionOptions {
232    fn default() -> Self {
233        Self {
234            strategies: vec![
235                ExpansionStrategy::Semantic,
236                ExpansionStrategy::Similarity,
237                ExpansionStrategy::CoOccurrence,
238            ],
239            max_terms: Some(10),
240            min_confidence: 0.3,
241            focus_entities: Vec::new(),
242            context: None,
243            include_original: true,
244        }
245    }
246}
247
248impl GraphQueryExpander {
249    /// Create a new graph query expander
250    pub fn new(graph: KnowledgeGraph, config: ExpansionConfig) -> Self {
251        Self {
252            graph,
253            config,
254            expansion_cache: tokio::sync::RwLock::new(HashMap::new()),
255        }
256    }
257
258    /// Update the knowledge graph
259    pub async fn update_graph(&mut self, graph: KnowledgeGraph) {
260        self.graph = graph;
261        // Clear cache when graph changes
262        if self.config.enable_caching {
263            self.expansion_cache.write().await.clear();
264        }
265    }
266
267    /// Extract entities from query text using graph nodes
268    async fn extract_query_entities(&self, query: &str) -> Vec<String> {
269        let mut entities = Vec::new();
270        let query_lower = query.to_lowercase();
271
272        // Find matching node labels in the query
273        for (_, node) in &self.graph.nodes {
274            let label_lower = node.label.to_lowercase();
275            if query_lower.contains(&label_lower) && !self.config.stop_words.contains(&label_lower)
276            {
277                entities.push(node.id.clone());
278            }
279        }
280
281        entities
282    }
283
284    /// Expand using semantic relationships
285    async fn semantic_expansion(
286        &self,
287        entity_ids: &[String],
288        options: &ExpansionOptions,
289    ) -> RragResult<Vec<ExpandedTerm>> {
290        let mut expanded_terms = Vec::new();
291        let strategy_weight = self
292            .config
293            .strategy_weights
294            .get(&ExpansionStrategy::Semantic)
295            .copied()
296            .unwrap_or(1.0);
297
298        for entity_id in entity_ids {
299            if let Some(_entity_node) = self.graph.get_node(entity_id) {
300                // Find semantic relationships
301                let semantic_edges: Vec<_> = self
302                    .graph
303                    .edges
304                    .values()
305                    .filter(|edge| {
306                        (edge.source_id == *entity_id || edge.target_id == *entity_id)
307                            && matches!(edge.edge_type, super::EdgeType::Semantic(_))
308                    })
309                    .collect();
310
311                for edge in semantic_edges {
312                    let related_node_id = if edge.source_id == *entity_id {
313                        &edge.target_id
314                    } else {
315                        &edge.source_id
316                    };
317
318                    if let Some(related_node) = self.graph.get_node(related_node_id) {
319                        let confidence = edge.confidence * strategy_weight;
320                        if confidence >= options.min_confidence {
321                            let expanded_term = ExpandedTerm {
322                                term: related_node.label.clone(),
323                                strategy: ExpansionStrategy::Semantic,
324                                confidence,
325                                weight: edge.weight * strategy_weight,
326                                source_entities: vec![entity_id.clone()],
327                                relationship: Some(edge.label.clone()),
328                            };
329                            expanded_terms.push(expanded_term);
330                        }
331                    }
332                }
333            }
334        }
335
336        Ok(expanded_terms)
337    }
338
339    /// Expand using hierarchical relationships
340    async fn hierarchical_expansion(
341        &self,
342        entity_ids: &[String],
343        options: &ExpansionOptions,
344    ) -> RragResult<Vec<ExpandedTerm>> {
345        let mut expanded_terms = Vec::new();
346        let strategy_weight = self
347            .config
348            .strategy_weights
349            .get(&ExpansionStrategy::Hierarchical)
350            .copied()
351            .unwrap_or(0.8);
352
353        for entity_id in entity_ids {
354            // Find hierarchical edges (parent/child relationships)
355            let hierarchical_edges: Vec<_> = self
356                .graph
357                .edges
358                .values()
359                .filter(|edge| {
360                    (edge.source_id == *entity_id || edge.target_id == *entity_id)
361                        && matches!(edge.edge_type, super::EdgeType::Hierarchical)
362                })
363                .collect();
364
365            for edge in hierarchical_edges {
366                let related_node_id = if edge.source_id == *entity_id {
367                    &edge.target_id
368                } else {
369                    &edge.source_id
370                };
371
372                if let Some(related_node) = self.graph.get_node(related_node_id) {
373                    let confidence = edge.confidence * strategy_weight;
374                    if confidence >= options.min_confidence {
375                        let expanded_term = ExpandedTerm {
376                            term: related_node.label.clone(),
377                            strategy: ExpansionStrategy::Hierarchical,
378                            confidence,
379                            weight: edge.weight * strategy_weight,
380                            source_entities: vec![entity_id.clone()],
381                            relationship: Some(if edge.source_id == *entity_id {
382                                "parent".to_string()
383                            } else {
384                                "child".to_string()
385                            }),
386                        };
387                        expanded_terms.push(expanded_term);
388                    }
389                }
390            }
391        }
392
393        Ok(expanded_terms)
394    }
395
396    /// Expand using similarity relationships
397    async fn similarity_expansion(
398        &self,
399        entity_ids: &[String],
400        options: &ExpansionOptions,
401    ) -> RragResult<Vec<ExpandedTerm>> {
402        let mut expanded_terms = Vec::new();
403        let strategy_weight = self
404            .config
405            .strategy_weights
406            .get(&ExpansionStrategy::Similarity)
407            .copied()
408            .unwrap_or(0.7);
409
410        for entity_id in entity_ids {
411            if let Some(entity_node) = self.graph.get_node(entity_id) {
412                // If entity has embedding, find similar nodes by embedding similarity
413                if let Some(entity_embedding) = &entity_node.embedding {
414                    for (other_id, other_node) in &self.graph.nodes {
415                        if other_id == entity_id {
416                            continue;
417                        }
418
419                        if let Some(other_embedding) = &other_node.embedding {
420                            if let Ok(similarity) =
421                                entity_embedding.cosine_similarity(other_embedding)
422                            {
423                                if similarity >= self.config.min_similarity_threshold {
424                                    let confidence = similarity * strategy_weight;
425                                    if confidence >= options.min_confidence {
426                                        let expanded_term = ExpandedTerm {
427                                            term: other_node.label.clone(),
428                                            strategy: ExpansionStrategy::Similarity,
429                                            confidence,
430                                            weight: similarity * strategy_weight,
431                                            source_entities: vec![entity_id.clone()],
432                                            relationship: Some(format!(
433                                                "similarity:{:.2}",
434                                                similarity
435                                            )),
436                                        };
437                                        expanded_terms.push(expanded_term);
438                                    }
439                                }
440                            }
441                        }
442                    }
443                }
444
445                // Also find explicit similarity edges
446                let similarity_edges: Vec<_> = self
447                    .graph
448                    .edges
449                    .values()
450                    .filter(|edge| {
451                        (edge.source_id == *entity_id || edge.target_id == *entity_id)
452                            && matches!(edge.edge_type, super::EdgeType::Similar)
453                    })
454                    .collect();
455
456                for edge in similarity_edges {
457                    let related_node_id = if edge.source_id == *entity_id {
458                        &edge.target_id
459                    } else {
460                        &edge.source_id
461                    };
462
463                    if let Some(related_node) = self.graph.get_node(related_node_id) {
464                        let confidence = edge.confidence * strategy_weight;
465                        if confidence >= options.min_confidence {
466                            let expanded_term = ExpandedTerm {
467                                term: related_node.label.clone(),
468                                strategy: ExpansionStrategy::Similarity,
469                                confidence,
470                                weight: edge.weight * strategy_weight,
471                                source_entities: vec![entity_id.clone()],
472                                relationship: Some("explicit_similarity".to_string()),
473                            };
474                            expanded_terms.push(expanded_term);
475                        }
476                    }
477                }
478            }
479        }
480
481        Ok(expanded_terms)
482    }
483
484    /// Expand using co-occurrence relationships
485    async fn cooccurrence_expansion(
486        &self,
487        entity_ids: &[String],
488        options: &ExpansionOptions,
489    ) -> RragResult<Vec<ExpandedTerm>> {
490        let mut expanded_terms = Vec::new();
491        let strategy_weight = self
492            .config
493            .strategy_weights
494            .get(&ExpansionStrategy::CoOccurrence)
495            .copied()
496            .unwrap_or(0.6);
497
498        for entity_id in entity_ids {
499            // Find co-occurrence edges
500            let cooccurrence_edges: Vec<_> = self
501                .graph
502                .edges
503                .values()
504                .filter(|edge| {
505                    (edge.source_id == *entity_id || edge.target_id == *entity_id)
506                        && matches!(edge.edge_type, super::EdgeType::CoOccurs)
507                })
508                .collect();
509
510            for edge in cooccurrence_edges {
511                let related_node_id = if edge.source_id == *entity_id {
512                    &edge.target_id
513                } else {
514                    &edge.source_id
515                };
516
517                if let Some(related_node) = self.graph.get_node(related_node_id) {
518                    let confidence = edge.confidence * strategy_weight;
519                    if confidence >= options.min_confidence {
520                        let expanded_term = ExpandedTerm {
521                            term: related_node.label.clone(),
522                            strategy: ExpansionStrategy::CoOccurrence,
523                            confidence,
524                            weight: edge.weight * strategy_weight,
525                            source_entities: vec![entity_id.clone()],
526                            relationship: Some("co_occurrence".to_string()),
527                        };
528                        expanded_terms.push(expanded_term);
529                    }
530                }
531            }
532        }
533
534        Ok(expanded_terms)
535    }
536
537    /// Expand using entity type relationships
538    async fn entity_type_expansion(
539        &self,
540        entity_ids: &[String],
541        options: &ExpansionOptions,
542    ) -> RragResult<Vec<ExpandedTerm>> {
543        let mut expanded_terms = Vec::new();
544        let strategy_weight = self
545            .config
546            .strategy_weights
547            .get(&ExpansionStrategy::EntityType)
548            .copied()
549            .unwrap_or(0.5);
550
551        // Group entities by type
552        let mut entities_by_type: HashMap<String, Vec<String>> = HashMap::new();
553        for entity_id in entity_ids {
554            if let Some(entity_node) = self.graph.get_node(entity_id) {
555                let type_key = match &entity_node.node_type {
556                    super::NodeType::Entity(entity_type) => entity_type.clone(),
557                    super::NodeType::Concept => "Concept".to_string(),
558                    super::NodeType::Document => "Document".to_string(),
559                    super::NodeType::DocumentChunk => "DocumentChunk".to_string(),
560                    super::NodeType::Keyword => "Keyword".to_string(),
561                    super::NodeType::Custom(custom) => custom.clone(),
562                };
563
564                entities_by_type
565                    .entry(type_key)
566                    .or_default()
567                    .push(entity_id.clone());
568            }
569        }
570
571        // For each type, find other entities of the same type
572        for (entity_type, type_entities) in entities_by_type {
573            let similar_type_nodes: Vec<_> = self
574                .graph
575                .nodes
576                .values()
577                .filter(|node| {
578                    let node_type_key = match &node.node_type {
579                        super::NodeType::Entity(et) => et.clone(),
580                        super::NodeType::Concept => "Concept".to_string(),
581                        super::NodeType::Document => "Document".to_string(),
582                        super::NodeType::DocumentChunk => "DocumentChunk".to_string(),
583                        super::NodeType::Keyword => "Keyword".to_string(),
584                        super::NodeType::Custom(custom) => custom.clone(),
585                    };
586                    node_type_key == entity_type && !type_entities.contains(&node.id)
587                })
588                .collect();
589
590            for node in similar_type_nodes.into_iter().take(5) {
591                // Limit to avoid too many results
592                let confidence = strategy_weight * 0.5; // Lower confidence for type-based expansion
593                if confidence >= options.min_confidence {
594                    let expanded_term = ExpandedTerm {
595                        term: node.label.clone(),
596                        strategy: ExpansionStrategy::EntityType,
597                        confidence,
598                        weight: strategy_weight * 0.5,
599                        source_entities: type_entities.clone(),
600                        relationship: Some(format!("same_type:{}", entity_type)),
601                    };
602                    expanded_terms.push(expanded_term);
603                }
604            }
605        }
606
607        Ok(expanded_terms)
608    }
609
610    /// Expand using graph paths
611    async fn path_based_expansion(
612        &self,
613        entity_ids: &[String],
614        options: &ExpansionOptions,
615    ) -> RragResult<Vec<ExpandedTerm>> {
616        let mut expanded_terms = Vec::new();
617        let strategy_weight = self
618            .config
619            .strategy_weights
620            .get(&ExpansionStrategy::PathBased)
621            .copied()
622            .unwrap_or(0.4);
623
624        // Use BFS to find nodes within expansion depth
625        for entity_id in entity_ids {
626            let traversal_config = super::algorithms::TraversalConfig {
627                max_depth: self.config.max_expansion_depth,
628                max_nodes: 50, // Limit to avoid performance issues
629                ..Default::default()
630            };
631
632            if let Ok(visited_nodes) =
633                GraphAlgorithms::bfs_search(&self.graph, entity_id, &traversal_config)
634            {
635                for visited_node_id in visited_nodes.iter().skip(1) {
636                    // Skip the source node
637                    if let Some(visited_node) = self.graph.get_node(visited_node_id) {
638                        // Calculate confidence based on distance from source
639                        let distance = visited_nodes
640                            .iter()
641                            .position(|id| id == visited_node_id)
642                            .unwrap_or(0);
643                        let distance_factor = 1.0 / (distance as f32 + 1.0);
644                        let confidence = strategy_weight * distance_factor;
645
646                        if confidence >= options.min_confidence {
647                            let expanded_term = ExpandedTerm {
648                                term: visited_node.label.clone(),
649                                strategy: ExpansionStrategy::PathBased,
650                                confidence,
651                                weight: confidence,
652                                source_entities: vec![entity_id.clone()],
653                                relationship: Some(format!("path_distance:{}", distance)),
654                            };
655                            expanded_terms.push(expanded_term);
656                        }
657                    }
658                }
659            }
660        }
661
662        Ok(expanded_terms)
663    }
664
665    /// Apply focus entity boosting
666    fn apply_focus_boosting(&self, terms: &mut [ExpandedTerm], focus_entities: &[String]) {
667        if focus_entities.is_empty() {
668            return;
669        }
670
671        for term in terms {
672            // Boost terms that are related to focus entities
673            let is_related = term
674                .source_entities
675                .iter()
676                .any(|source| focus_entities.contains(source));
677
678            if is_related {
679                term.confidence *= 1.5;
680                term.weight *= 1.5;
681            }
682        }
683    }
684
685    /// Deduplicate and rank expanded terms
686    fn deduplicate_and_rank(&self, terms: &mut Vec<ExpandedTerm>, max_terms: Option<usize>) {
687        // Remove duplicates by term text, keeping the one with highest confidence
688        let mut seen_terms: HashMap<String, usize> = HashMap::new();
689        let mut unique_terms: Vec<ExpandedTerm> = Vec::new();
690
691        for term in terms.drain(..) {
692            match seen_terms.get(&term.term) {
693                Some(&existing_index) => {
694                    if term.confidence > unique_terms[existing_index].confidence {
695                        unique_terms[existing_index] = term;
696                    }
697                }
698                None => {
699                    seen_terms.insert(term.term.clone(), unique_terms.len());
700                    unique_terms.push(term);
701                }
702            }
703        }
704
705        // Sort by weight (descending) then by confidence
706        unique_terms.sort_by(|a, b| {
707            b.weight
708                .partial_cmp(&a.weight)
709                .unwrap_or(std::cmp::Ordering::Equal)
710                .then_with(|| {
711                    b.confidence
712                        .partial_cmp(&a.confidence)
713                        .unwrap_or(std::cmp::Ordering::Equal)
714                })
715        });
716
717        // Limit results
718        if let Some(limit) = max_terms {
719            unique_terms.truncate(limit);
720        }
721
722        *terms = unique_terms;
723    }
724}
725
726#[async_trait]
727impl QueryExpander for GraphQueryExpander {
728    async fn expand_query(
729        &self,
730        query: &str,
731        options: &ExpansionOptions,
732    ) -> RragResult<ExpansionResult> {
733        let start_time = std::time::Instant::now();
734
735        // Check cache first
736        if self.config.enable_caching {
737            let cache_key = format!("{}:{:?}", query, options.strategies);
738            if let Some(cached_terms) = self.expansion_cache.read().await.get(&cache_key) {
739                let result = ExpansionResult {
740                    original_query: query.to_string(),
741                    expanded_terms: cached_terms
742                        .iter()
743                        .map(|term| ExpandedTerm {
744                            term: term.clone(),
745                            strategy: ExpansionStrategy::Semantic, // Default for cached results
746                            confidence: 0.8,
747                            weight: 0.8,
748                            source_entities: Vec::new(),
749                            relationship: None,
750                        })
751                        .collect(),
752                    stats: ExpansionStats {
753                        entities_found: 0,
754                        terms_per_strategy: HashMap::new(),
755                        expansion_time_ms: start_time.elapsed().as_millis() as u64,
756                        nodes_examined: 0,
757                        edges_examined: 0,
758                    },
759                    strategies_used: options.strategies.clone(),
760                    confidence: 0.8,
761                };
762                return Ok(result);
763            }
764        }
765
766        // Extract entities from query
767        let entity_ids = self.extract_query_entities(query).await;
768        let mut expanded_terms = Vec::new();
769        let mut terms_per_strategy = HashMap::new();
770        let mut nodes_examined = 0;
771        let mut edges_examined = 0;
772
773        // Apply expansion strategies
774        for strategy in &options.strategies {
775            let strategy_terms = match strategy {
776                ExpansionStrategy::Semantic if self.config.enable_semantic_expansion => {
777                    self.semantic_expansion(&entity_ids, options).await?
778                }
779                ExpansionStrategy::Hierarchical if self.config.enable_structural_expansion => {
780                    self.hierarchical_expansion(&entity_ids, options).await?
781                }
782                ExpansionStrategy::Similarity => {
783                    self.similarity_expansion(&entity_ids, options).await?
784                }
785                ExpansionStrategy::CoOccurrence if self.config.enable_statistical_expansion => {
786                    self.cooccurrence_expansion(&entity_ids, options).await?
787                }
788                ExpansionStrategy::EntityType => {
789                    self.entity_type_expansion(&entity_ids, options).await?
790                }
791                ExpansionStrategy::PathBased if self.config.enable_structural_expansion => {
792                    self.path_based_expansion(&entity_ids, options).await?
793                }
794                _ => Vec::new(), // Strategy not enabled or supported
795            };
796
797            terms_per_strategy.insert(strategy.to_string(), strategy_terms.len());
798            expanded_terms.extend(strategy_terms);
799
800            // Update examination counters (simplified)
801            nodes_examined += entity_ids.len();
802            edges_examined += entity_ids.len() * 5; // Rough estimate
803        }
804
805        // Apply focus entity boosting if specified
806        self.apply_focus_boosting(&mut expanded_terms, &options.focus_entities);
807
808        // Deduplicate and rank
809        self.deduplicate_and_rank(&mut expanded_terms, options.max_terms);
810
811        // Add original query terms if requested
812        if options.include_original {
813            let original_terms: Vec<_> = query
814                .split_whitespace()
815                .filter(|term| !self.config.stop_words.contains(&term.to_lowercase()))
816                .map(|term| ExpandedTerm {
817                    term: term.to_string(),
818                    strategy: ExpansionStrategy::Custom("original".to_string()),
819                    confidence: 1.0,
820                    weight: 1.0,
821                    source_entities: Vec::new(),
822                    relationship: Some("original_query".to_string()),
823                })
824                .collect();
825
826            expanded_terms.splice(0..0, original_terms);
827        }
828
829        // Calculate overall confidence
830        let confidence = if !expanded_terms.is_empty() {
831            expanded_terms.iter().map(|t| t.confidence).sum::<f32>() / expanded_terms.len() as f32
832        } else {
833            0.0
834        };
835
836        // Cache results if enabled
837        if self.config.enable_caching {
838            let cache_key = format!("{}:{:?}", query, options.strategies);
839            let cache_terms: Vec<_> = expanded_terms.iter().map(|t| t.term.clone()).collect();
840            self.expansion_cache
841                .write()
842                .await
843                .insert(cache_key, cache_terms);
844        }
845
846        let expansion_time_ms = start_time.elapsed().as_millis() as u64;
847
848        Ok(ExpansionResult {
849            original_query: query.to_string(),
850            expanded_terms,
851            stats: ExpansionStats {
852                entities_found: entity_ids.len(),
853                terms_per_strategy,
854                expansion_time_ms,
855                nodes_examined,
856                edges_examined,
857            },
858            strategies_used: options.strategies.clone(),
859            confidence,
860        })
861    }
862
863    async fn expand_terms(
864        &self,
865        terms: &[String],
866        options: &ExpansionOptions,
867    ) -> RragResult<Vec<String>> {
868        let combined_query = terms.join(" ");
869        let expansion_result = self.expand_query(&combined_query, options).await?;
870        Ok(expansion_result
871            .expanded_terms
872            .into_iter()
873            .map(|t| t.term)
874            .collect())
875    }
876
877    async fn find_related_entities(
878        &self,
879        entities: &[String],
880        options: &ExpansionOptions,
881    ) -> RragResult<Vec<String>> {
882        // Find entity IDs matching the given entity names
883        let entity_ids: Vec<_> = entities
884            .iter()
885            .filter_map(|entity_name| {
886                self.graph
887                    .nodes
888                    .values()
889                    .find(|node| node.label.eq_ignore_ascii_case(entity_name))
890                    .map(|node| node.id.clone())
891            })
892            .collect();
893
894        if entity_ids.is_empty() {
895            return Ok(Vec::new());
896        }
897
898        // Use semantic expansion to find related entities
899        let expanded_terms = self.semantic_expansion(&entity_ids, options).await?;
900        Ok(expanded_terms.into_iter().map(|t| t.term).collect())
901    }
902
903    async fn get_suggestions(
904        &self,
905        query: &str,
906        max_suggestions: usize,
907    ) -> RragResult<Vec<String>> {
908        let options = ExpansionOptions {
909            strategies: vec![ExpansionStrategy::Semantic, ExpansionStrategy::Similarity],
910            max_terms: Some(max_suggestions),
911            min_confidence: 0.2, // Lower threshold for suggestions
912            ..Default::default()
913        };
914
915        let expansion_result = self.expand_query(query, &options).await?;
916        Ok(expansion_result
917            .expanded_terms
918            .into_iter()
919            .map(|t| t.term)
920            .collect())
921    }
922}
923
924impl std::fmt::Display for ExpansionStrategy {
925    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
926        match self {
927            ExpansionStrategy::Semantic => write!(f, "semantic"),
928            ExpansionStrategy::Hierarchical => write!(f, "hierarchical"),
929            ExpansionStrategy::Similarity => write!(f, "similarity"),
930            ExpansionStrategy::CoOccurrence => write!(f, "co_occurrence"),
931            ExpansionStrategy::Synonym => write!(f, "synonym"),
932            ExpansionStrategy::EntityType => write!(f, "entity_type"),
933            ExpansionStrategy::PathBased => write!(f, "path_based"),
934            ExpansionStrategy::PageRank => write!(f, "pagerank"),
935            ExpansionStrategy::Custom(name) => write!(f, "custom_{}", name),
936        }
937    }
938}
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943    use crate::graph_retrieval::{EdgeType, GraphEdge, GraphNode, NodeType};
944
945    fn create_test_graph() -> KnowledgeGraph {
946        let mut graph = KnowledgeGraph::new();
947
948        // Add nodes
949        let node1 = GraphNode::new("machine learning", NodeType::Concept);
950        let node2 = GraphNode::new("artificial intelligence", NodeType::Concept);
951        let node3 = GraphNode::new("deep learning", NodeType::Concept);
952        let node4 = GraphNode::new("neural networks", NodeType::Concept);
953
954        let node1_id = node1.id.clone();
955        let node2_id = node2.id.clone();
956        let node3_id = node3.id.clone();
957        let node4_id = node4.id.clone();
958
959        graph.add_node(node1).unwrap();
960        graph.add_node(node2).unwrap();
961        graph.add_node(node3).unwrap();
962        graph.add_node(node4).unwrap();
963
964        // Add semantic relationships
965        graph
966            .add_edge(
967                GraphEdge::new(
968                    node3_id.clone(),
969                    node1_id.clone(),
970                    "is_a",
971                    EdgeType::Semantic("is_a".to_string()),
972                )
973                .with_confidence(0.9)
974                .with_weight(0.9),
975            )
976            .unwrap();
977
978        graph
979            .add_edge(
980                GraphEdge::new(
981                    node1_id.clone(),
982                    node2_id.clone(),
983                    "part_of",
984                    EdgeType::Semantic("part_of".to_string()),
985                )
986                .with_confidence(0.8)
987                .with_weight(0.8),
988            )
989            .unwrap();
990
991        graph
992            .add_edge(
993                GraphEdge::new(
994                    node4_id.clone(),
995                    node3_id.clone(),
996                    "used_in",
997                    EdgeType::Semantic("used_in".to_string()),
998                )
999                .with_confidence(0.7)
1000                .with_weight(0.7),
1001            )
1002            .unwrap();
1003
1004        graph
1005    }
1006
1007    #[tokio::test]
1008    async fn test_query_expansion() {
1009        let graph = create_test_graph();
1010        let config = ExpansionConfig::default();
1011        let expander = GraphQueryExpander::new(graph, config);
1012
1013        let options = ExpansionOptions {
1014            strategies: vec![ExpansionStrategy::Semantic],
1015            max_terms: Some(5),
1016            min_confidence: 0.3,
1017            ..Default::default()
1018        };
1019
1020        let result = expander
1021            .expand_query("machine learning", &options)
1022            .await
1023            .unwrap();
1024
1025        assert!(!result.expanded_terms.is_empty());
1026        assert!(result.stats.entities_found > 0);
1027        assert!(result.confidence > 0.0);
1028    }
1029
1030    #[tokio::test]
1031    async fn test_semantic_expansion() {
1032        let graph = create_test_graph();
1033        let config = ExpansionConfig::default();
1034        let expander = GraphQueryExpander::new(graph.clone(), config);
1035
1036        // Find the machine learning node ID
1037        let ml_node_id = graph
1038            .nodes
1039            .values()
1040            .find(|node| node.label == "machine learning")
1041            .unwrap()
1042            .id
1043            .clone();
1044
1045        let options = ExpansionOptions::default();
1046        let expanded_terms = expander
1047            .semantic_expansion(&[ml_node_id], &options)
1048            .await
1049            .unwrap();
1050
1051        // Should find related terms through semantic relationships
1052        assert!(!expanded_terms.is_empty());
1053
1054        // Check that we found "artificial intelligence" and "deep learning"
1055        let term_texts: Vec<_> = expanded_terms.iter().map(|t| &t.term).collect();
1056        assert!(
1057            term_texts.contains(&&"artificial intelligence".to_string())
1058                || term_texts.contains(&&"deep learning".to_string())
1059        );
1060    }
1061
1062    #[tokio::test]
1063    async fn test_term_expansion() {
1064        let graph = create_test_graph();
1065        let config = ExpansionConfig::default();
1066        let expander = GraphQueryExpander::new(graph, config);
1067
1068        let options = ExpansionOptions::default();
1069        let expanded_terms = expander
1070            .expand_terms(&["machine learning".to_string()], &options)
1071            .await
1072            .unwrap();
1073
1074        assert!(!expanded_terms.is_empty());
1075
1076        // Should include related AI terms
1077        let has_ai_terms = expanded_terms.iter().any(|term| {
1078            term.contains("artificial") || term.contains("deep") || term.contains("neural")
1079        });
1080        assert!(has_ai_terms);
1081    }
1082
1083    #[tokio::test]
1084    async fn test_get_suggestions() {
1085        let graph = create_test_graph();
1086        let config = ExpansionConfig::default();
1087        let expander = GraphQueryExpander::new(graph, config);
1088
1089        let suggestions = expander.get_suggestions("machine", 3).await.unwrap();
1090
1091        // Should return some suggestions for the partial query
1092        assert!(!suggestions.is_empty());
1093        assert!(suggestions.len() <= 3);
1094    }
1095}