Skip to main content

graphrag_core/entity/
bidirectional_index.rs

1//! Bidirectional Entity-Chunk Index
2//!
3//! This module provides efficient bidirectional lookups between entities and chunks,
4//! essential for LazyGraphRAG and E2GraphRAG query refinement and concept expansion.
5//!
6//! ## Key Features
7//!
8//! - **Fast lookups**: O(1) access in both directions
9//! - **Many-to-many relationships**: One entity can appear in multiple chunks, one chunk can contain multiple entities
10//! - **Incremental updates**: Add/remove mappings without rebuilding the entire index
11//! - **Memory efficient**: Uses IndexMap for predictable iteration order
12//!
13//! ## Use Cases
14//!
15//! 1. **Query Expansion**: Given entities in a query, find all relevant chunks
16//! 2. **Context Retrieval**: Given a chunk, find all related entities
17//! 3. **Concept Graph Building**: Track concept co-occurrence across chunks
18//! 4. **Iterative Deepening**: Expand search by traversing entity-chunk relationships
19//!
20//! ## Example
21//!
22//! ```rust
23//! use graphrag_core::entity::bidirectional_index::BidirectionalIndex;
24//! use graphrag_core::core::{EntityId, ChunkId};
25//!
26//! let mut index = BidirectionalIndex::new();
27//!
28//! let entity_id = EntityId::new("entity_1".to_string());
29//! let chunk_id = ChunkId::new("chunk_1".to_string());
30//!
31//! // Add mapping
32//! index.add_mapping(&entity_id, &chunk_id);
33//!
34//! // Query by entity
35//! let chunks = index.get_chunks_for_entity(&entity_id);
36//! assert_eq!(chunks.len(), 1);
37//!
38//! // Query by chunk
39//! let entities = index.get_entities_for_chunk(&chunk_id);
40//! assert_eq!(entities.len(), 1);
41//! ```
42
43use crate::core::{ChunkId, Entity, EntityId};
44use indexmap::{IndexMap, IndexSet};
45use serde::{Deserialize, Serialize};
46use std::collections::HashMap;
47
48/// Bidirectional index for fast entity-chunk lookups
49///
50/// This structure maintains two indexes:
51/// 1. Entity → Chunks: Given an entity, find all chunks it appears in
52/// 2. Chunk → Entities: Given a chunk, find all entities it contains
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct BidirectionalIndex {
55    /// Maps entity IDs to the chunks they appear in
56    entity_to_chunks: IndexMap<EntityId, IndexSet<ChunkId>>,
57
58    /// Maps chunk IDs to the entities they contain
59    chunk_to_entities: IndexMap<ChunkId, IndexSet<EntityId>>,
60
61    /// Total number of entity-chunk mappings
62    mapping_count: usize,
63}
64
65impl BidirectionalIndex {
66    /// Create a new empty bidirectional index
67    pub fn new() -> Self {
68        Self {
69            entity_to_chunks: IndexMap::new(),
70            chunk_to_entities: IndexMap::new(),
71            mapping_count: 0,
72        }
73    }
74
75    /// Create a bidirectional index from a collection of entities
76    ///
77    /// This is useful for building the index from extracted entities
78    pub fn from_entities(entities: &[Entity]) -> Self {
79        let mut index = Self::new();
80
81        for entity in entities {
82            for mention in &entity.mentions {
83                index.add_mapping(&entity.id, &mention.chunk_id);
84            }
85        }
86
87        index
88    }
89
90    /// Add a mapping between an entity and a chunk
91    ///
92    /// This is idempotent - adding the same mapping multiple times has no effect
93    pub fn add_mapping(&mut self, entity_id: &EntityId, chunk_id: &ChunkId) {
94        // Add to entity → chunks index
95        let chunks = self.entity_to_chunks.entry(entity_id.clone()).or_default();
96        let was_new = chunks.insert(chunk_id.clone());
97
98        // Add to chunk → entities index
99        let entities = self.chunk_to_entities.entry(chunk_id.clone()).or_default();
100        entities.insert(entity_id.clone());
101
102        // Update mapping count only if it was a new mapping
103        if was_new {
104            self.mapping_count += 1;
105        }
106    }
107
108    /// Add multiple mappings for an entity
109    pub fn add_entity_mappings(&mut self, entity_id: &EntityId, chunk_ids: &[ChunkId]) {
110        for chunk_id in chunk_ids {
111            self.add_mapping(entity_id, chunk_id);
112        }
113    }
114
115    /// Add multiple mappings for a chunk
116    pub fn add_chunk_mappings(&mut self, chunk_id: &ChunkId, entity_ids: &[EntityId]) {
117        for entity_id in entity_ids {
118            self.add_mapping(entity_id, chunk_id);
119        }
120    }
121
122    /// Remove a specific mapping between an entity and a chunk
123    ///
124    /// Returns true if the mapping existed and was removed
125    pub fn remove_mapping(&mut self, entity_id: &EntityId, chunk_id: &ChunkId) -> bool {
126        let mut removed = false;
127
128        // Remove from entity → chunks index
129        if let Some(chunks) = self.entity_to_chunks.get_mut(entity_id) {
130            if chunks.swap_remove(chunk_id) {
131                removed = true;
132
133                // Clean up empty entries
134                if chunks.is_empty() {
135                    self.entity_to_chunks.swap_remove(entity_id);
136                }
137            }
138        }
139
140        // Remove from chunk → entities index
141        if let Some(entities) = self.chunk_to_entities.get_mut(chunk_id) {
142            entities.swap_remove(entity_id);
143
144            // Clean up empty entries
145            if entities.is_empty() {
146                self.chunk_to_entities.swap_remove(chunk_id);
147            }
148        }
149
150        if removed {
151            self.mapping_count = self.mapping_count.saturating_sub(1);
152        }
153
154        removed
155    }
156
157    /// Remove all mappings for an entity
158    ///
159    /// Returns the number of mappings removed
160    pub fn remove_entity(&mut self, entity_id: &EntityId) -> usize {
161        let mut removed_count = 0;
162
163        if let Some(chunks) = self.entity_to_chunks.swap_remove(entity_id) {
164            removed_count = chunks.len();
165
166            // Remove from chunk → entities index
167            for chunk_id in chunks {
168                if let Some(entities) = self.chunk_to_entities.get_mut(&chunk_id) {
169                    entities.swap_remove(entity_id);
170
171                    if entities.is_empty() {
172                        self.chunk_to_entities.swap_remove(&chunk_id);
173                    }
174                }
175            }
176        }
177
178        self.mapping_count = self.mapping_count.saturating_sub(removed_count);
179        removed_count
180    }
181
182    /// Remove all mappings for a chunk
183    ///
184    /// Returns the number of mappings removed
185    pub fn remove_chunk(&mut self, chunk_id: &ChunkId) -> usize {
186        let mut removed_count = 0;
187
188        if let Some(entities) = self.chunk_to_entities.swap_remove(chunk_id) {
189            removed_count = entities.len();
190
191            // Remove from entity → chunks index
192            for entity_id in entities {
193                if let Some(chunks) = self.entity_to_chunks.get_mut(&entity_id) {
194                    chunks.swap_remove(chunk_id);
195
196                    if chunks.is_empty() {
197                        self.entity_to_chunks.swap_remove(&entity_id);
198                    }
199                }
200            }
201        }
202
203        self.mapping_count = self.mapping_count.saturating_sub(removed_count);
204        removed_count
205    }
206
207    /// Get all chunks that contain a specific entity
208    ///
209    /// Returns an empty vector if the entity is not found
210    pub fn get_chunks_for_entity(&self, entity_id: &EntityId) -> Vec<ChunkId> {
211        self.entity_to_chunks
212            .get(entity_id)
213            .map(|chunks| chunks.iter().cloned().collect())
214            .unwrap_or_default()
215    }
216
217    /// Get all entities in a specific chunk
218    ///
219    /// Returns an empty vector if the chunk is not found
220    pub fn get_entities_for_chunk(&self, chunk_id: &ChunkId) -> Vec<EntityId> {
221        self.chunk_to_entities
222            .get(chunk_id)
223            .map(|entities| entities.iter().cloned().collect())
224            .unwrap_or_default()
225    }
226
227    /// Check if a specific entity-chunk mapping exists
228    pub fn has_mapping(&self, entity_id: &EntityId, chunk_id: &ChunkId) -> bool {
229        self.entity_to_chunks
230            .get(entity_id)
231            .map(|chunks| chunks.contains(chunk_id))
232            .unwrap_or(false)
233    }
234
235    /// Get the number of chunks an entity appears in
236    pub fn get_entity_chunk_count(&self, entity_id: &EntityId) -> usize {
237        self.entity_to_chunks
238            .get(entity_id)
239            .map(|chunks| chunks.len())
240            .unwrap_or(0)
241    }
242
243    /// Get the number of entities in a chunk
244    pub fn get_chunk_entity_count(&self, chunk_id: &ChunkId) -> usize {
245        self.chunk_to_entities
246            .get(chunk_id)
247            .map(|entities| entities.len())
248            .unwrap_or(0)
249    }
250
251    /// Get all entity IDs in the index
252    pub fn get_all_entities(&self) -> Vec<EntityId> {
253        self.entity_to_chunks.keys().cloned().collect()
254    }
255
256    /// Get all chunk IDs in the index
257    pub fn get_all_chunks(&self) -> Vec<ChunkId> {
258        self.chunk_to_entities.keys().cloned().collect()
259    }
260
261    /// Get the total number of unique entities
262    pub fn entity_count(&self) -> usize {
263        self.entity_to_chunks.len()
264    }
265
266    /// Get the total number of unique chunks
267    pub fn chunk_count(&self) -> usize {
268        self.chunk_to_entities.len()
269    }
270
271    /// Get the total number of entity-chunk mappings
272    pub fn mapping_count(&self) -> usize {
273        self.mapping_count
274    }
275
276    /// Check if the index is empty
277    pub fn is_empty(&self) -> bool {
278        self.mapping_count == 0
279    }
280
281    /// Clear all mappings from the index
282    pub fn clear(&mut self) {
283        self.entity_to_chunks.clear();
284        self.chunk_to_entities.clear();
285        self.mapping_count = 0;
286    }
287
288    /// Get co-occurring entities for a given entity
289    ///
290    /// Returns entities that appear in the same chunks, along with their co-occurrence count
291    pub fn get_co_occurring_entities(&self, entity_id: &EntityId) -> HashMap<EntityId, usize> {
292        let mut co_occurrences: HashMap<EntityId, usize> = HashMap::new();
293
294        // Get all chunks this entity appears in
295        if let Some(chunks) = self.entity_to_chunks.get(entity_id) {
296            // For each chunk, get all entities in that chunk
297            for chunk_id in chunks {
298                if let Some(entities) = self.chunk_to_entities.get(chunk_id) {
299                    for other_entity_id in entities {
300                        // Skip the entity itself
301                        if other_entity_id != entity_id {
302                            *co_occurrences.entry(other_entity_id.clone()).or_insert(0) += 1;
303                        }
304                    }
305                }
306            }
307        }
308
309        co_occurrences
310    }
311
312    /// Get entities that appear in multiple chunks (common entities)
313    ///
314    /// Returns entities sorted by the number of chunks they appear in (descending)
315    pub fn get_common_entities(&self, min_chunk_count: usize) -> Vec<(EntityId, usize)> {
316        let mut common_entities: Vec<_> = self
317            .entity_to_chunks
318            .iter()
319            .filter_map(|(entity_id, chunks)| {
320                if chunks.len() >= min_chunk_count {
321                    Some((entity_id.clone(), chunks.len()))
322                } else {
323                    None
324                }
325            })
326            .collect();
327
328        // Sort by chunk count descending
329        common_entities.sort_by(|a, b| b.1.cmp(&a.1));
330
331        common_entities
332    }
333
334    /// Get chunks that contain multiple entities (dense chunks)
335    ///
336    /// Returns chunks sorted by the number of entities they contain (descending)
337    pub fn get_dense_chunks(&self, min_entity_count: usize) -> Vec<(ChunkId, usize)> {
338        let mut dense_chunks: Vec<_> = self
339            .chunk_to_entities
340            .iter()
341            .filter_map(|(chunk_id, entities)| {
342                if entities.len() >= min_entity_count {
343                    Some((chunk_id.clone(), entities.len()))
344                } else {
345                    None
346                }
347            })
348            .collect();
349
350        // Sort by entity count descending
351        dense_chunks.sort_by(|a, b| b.1.cmp(&a.1));
352
353        dense_chunks
354    }
355
356    /// Merge another index into this one
357    ///
358    /// Useful for combining indices from multiple documents
359    pub fn merge(&mut self, other: &BidirectionalIndex) {
360        for (entity_id, chunks) in &other.entity_to_chunks {
361            for chunk_id in chunks {
362                self.add_mapping(entity_id, chunk_id);
363            }
364        }
365    }
366
367    /// Get statistics about the index
368    pub fn get_statistics(&self) -> IndexStatistics {
369        let avg_chunks_per_entity = if self.entity_count() > 0 {
370            self.mapping_count as f64 / self.entity_count() as f64
371        } else {
372            0.0
373        };
374
375        let avg_entities_per_chunk = if self.chunk_count() > 0 {
376            self.mapping_count as f64 / self.chunk_count() as f64
377        } else {
378            0.0
379        };
380
381        IndexStatistics {
382            total_entities: self.entity_count(),
383            total_chunks: self.chunk_count(),
384            total_mappings: self.mapping_count(),
385            avg_chunks_per_entity,
386            avg_entities_per_chunk,
387        }
388    }
389}
390
391impl Default for BidirectionalIndex {
392    fn default() -> Self {
393        Self::new()
394    }
395}
396
397/// Statistics about the bidirectional index
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct IndexStatistics {
400    /// Total number of unique entities
401    pub total_entities: usize,
402
403    /// Total number of unique chunks
404    pub total_chunks: usize,
405
406    /// Total number of entity-chunk mappings
407    pub total_mappings: usize,
408
409    /// Average number of chunks per entity
410    pub avg_chunks_per_entity: f64,
411
412    /// Average number of entities per chunk
413    pub avg_entities_per_chunk: f64,
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::core::{ChunkId, DocumentId, EntityId, EntityMention, TextChunk};
420
421    #[test]
422    fn test_basic_operations() {
423        let mut index = BidirectionalIndex::new();
424
425        let entity1 = EntityId::new("entity_1".to_string());
426        let entity2 = EntityId::new("entity_2".to_string());
427        let chunk1 = ChunkId::new("chunk_1".to_string());
428        let chunk2 = ChunkId::new("chunk_2".to_string());
429
430        // Add mappings
431        index.add_mapping(&entity1, &chunk1);
432        index.add_mapping(&entity1, &chunk2);
433        index.add_mapping(&entity2, &chunk1);
434
435        // Test entity → chunks lookup
436        let chunks = index.get_chunks_for_entity(&entity1);
437        assert_eq!(chunks.len(), 2);
438        assert!(chunks.contains(&chunk1));
439        assert!(chunks.contains(&chunk2));
440
441        // Test chunk → entities lookup
442        let entities = index.get_entities_for_chunk(&chunk1);
443        assert_eq!(entities.len(), 2);
444        assert!(entities.contains(&entity1));
445        assert!(entities.contains(&entity2));
446
447        // Test counts
448        assert_eq!(index.entity_count(), 2);
449        assert_eq!(index.chunk_count(), 2);
450        assert_eq!(index.mapping_count(), 3);
451    }
452
453    #[test]
454    fn test_idempotent_add() {
455        let mut index = BidirectionalIndex::new();
456
457        let entity = EntityId::new("entity_1".to_string());
458        let chunk = ChunkId::new("chunk_1".to_string());
459
460        index.add_mapping(&entity, &chunk);
461        index.add_mapping(&entity, &chunk);
462        index.add_mapping(&entity, &chunk);
463
464        assert_eq!(index.mapping_count(), 1);
465        assert_eq!(index.get_chunks_for_entity(&entity).len(), 1);
466    }
467
468    #[test]
469    fn test_removal() {
470        let mut index = BidirectionalIndex::new();
471
472        let entity1 = EntityId::new("entity_1".to_string());
473        let entity2 = EntityId::new("entity_2".to_string());
474        let chunk1 = ChunkId::new("chunk_1".to_string());
475        let chunk2 = ChunkId::new("chunk_2".to_string());
476
477        index.add_mapping(&entity1, &chunk1);
478        index.add_mapping(&entity1, &chunk2);
479        index.add_mapping(&entity2, &chunk1);
480
481        // Remove specific mapping
482        assert!(index.remove_mapping(&entity1, &chunk1));
483        assert_eq!(index.mapping_count(), 2);
484
485        // Remove entity
486        let removed = index.remove_entity(&entity1);
487        assert_eq!(removed, 1);
488        assert_eq!(index.mapping_count(), 1);
489
490        // Only entity2 → chunk1 should remain
491        assert_eq!(index.entity_count(), 1);
492        assert_eq!(index.chunk_count(), 1);
493    }
494
495    #[test]
496    fn test_from_entities() {
497        let entity1 = Entity::new(
498            EntityId::new("entity_1".to_string()),
499            "Entity 1".to_string(),
500            "PERSON".to_string(),
501            0.9,
502        )
503        .with_mentions(vec![
504            EntityMention {
505                chunk_id: ChunkId::new("chunk_1".to_string()),
506                start_offset: 0,
507                end_offset: 8,
508                confidence: 0.9,
509            },
510            EntityMention {
511                chunk_id: ChunkId::new("chunk_2".to_string()),
512                start_offset: 10,
513                end_offset: 18,
514                confidence: 0.9,
515            },
516        ]);
517
518        let index = BidirectionalIndex::from_entities(&[entity1]);
519
520        assert_eq!(index.entity_count(), 1);
521        assert_eq!(index.chunk_count(), 2);
522        assert_eq!(index.mapping_count(), 2);
523    }
524
525    #[test]
526    fn test_co_occurrence() {
527        let mut index = BidirectionalIndex::new();
528
529        let entity1 = EntityId::new("entity_1".to_string());
530        let entity2 = EntityId::new("entity_2".to_string());
531        let entity3 = EntityId::new("entity_3".to_string());
532        let chunk1 = ChunkId::new("chunk_1".to_string());
533        let chunk2 = ChunkId::new("chunk_2".to_string());
534
535        // entity1 and entity2 co-occur in both chunks
536        index.add_mapping(&entity1, &chunk1);
537        index.add_mapping(&entity1, &chunk2);
538        index.add_mapping(&entity2, &chunk1);
539        index.add_mapping(&entity2, &chunk2);
540
541        // entity3 co-occurs with entity1 only in chunk1
542        index.add_mapping(&entity3, &chunk1);
543
544        let co_occurrences = index.get_co_occurring_entities(&entity1);
545        assert_eq!(co_occurrences.get(&entity2), Some(&2)); // co-occurs in 2 chunks
546        assert_eq!(co_occurrences.get(&entity3), Some(&1)); // co-occurs in 1 chunk
547    }
548
549    #[test]
550    fn test_common_entities() {
551        let mut index = BidirectionalIndex::new();
552
553        let entity1 = EntityId::new("entity_1".to_string());
554        let entity2 = EntityId::new("entity_2".to_string());
555        let chunk1 = ChunkId::new("chunk_1".to_string());
556        let chunk2 = ChunkId::new("chunk_2".to_string());
557        let chunk3 = ChunkId::new("chunk_3".to_string());
558
559        // entity1 appears in 3 chunks
560        index.add_mapping(&entity1, &chunk1);
561        index.add_mapping(&entity1, &chunk2);
562        index.add_mapping(&entity1, &chunk3);
563
564        // entity2 appears in 1 chunk
565        index.add_mapping(&entity2, &chunk1);
566
567        let common = index.get_common_entities(2);
568        assert_eq!(common.len(), 1);
569        assert_eq!(common[0].0, entity1);
570        assert_eq!(common[0].1, 3);
571    }
572
573    #[test]
574    fn test_merge() {
575        let mut index1 = BidirectionalIndex::new();
576        let mut index2 = BidirectionalIndex::new();
577
578        let entity1 = EntityId::new("entity_1".to_string());
579        let entity2 = EntityId::new("entity_2".to_string());
580        let chunk1 = ChunkId::new("chunk_1".to_string());
581        let chunk2 = ChunkId::new("chunk_2".to_string());
582
583        index1.add_mapping(&entity1, &chunk1);
584        index2.add_mapping(&entity2, &chunk2);
585
586        index1.merge(&index2);
587
588        assert_eq!(index1.entity_count(), 2);
589        assert_eq!(index1.chunk_count(), 2);
590        assert_eq!(index1.mapping_count(), 2);
591    }
592
593    #[test]
594    fn test_statistics() {
595        let mut index = BidirectionalIndex::new();
596
597        let entity1 = EntityId::new("entity_1".to_string());
598        let entity2 = EntityId::new("entity_2".to_string());
599        let chunk1 = ChunkId::new("chunk_1".to_string());
600        let chunk2 = ChunkId::new("chunk_2".to_string());
601
602        index.add_mapping(&entity1, &chunk1);
603        index.add_mapping(&entity1, &chunk2);
604        index.add_mapping(&entity2, &chunk1);
605
606        let stats = index.get_statistics();
607        assert_eq!(stats.total_entities, 2);
608        assert_eq!(stats.total_chunks, 2);
609        assert_eq!(stats.total_mappings, 3);
610        assert_eq!(stats.avg_chunks_per_entity, 1.5);
611        assert_eq!(stats.avg_entities_per_chunk, 1.5);
612    }
613}