1use crate::core::{ChunkId, Entity, EntityId};
44use indexmap::{IndexMap, IndexSet};
45use serde::{Deserialize, Serialize};
46use std::collections::HashMap;
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct BidirectionalIndex {
55 entity_to_chunks: IndexMap<EntityId, IndexSet<ChunkId>>,
57
58 chunk_to_entities: IndexMap<ChunkId, IndexSet<EntityId>>,
60
61 mapping_count: usize,
63}
64
65impl BidirectionalIndex {
66 pub fn new() -> Self {
68 Self {
69 entity_to_chunks: IndexMap::new(),
70 chunk_to_entities: IndexMap::new(),
71 mapping_count: 0,
72 }
73 }
74
75 pub fn from_entities(entities: &[Entity]) -> Self {
79 let mut index = Self::new();
80
81 for entity in entities {
82 for mention in &entity.mentions {
83 index.add_mapping(&entity.id, &mention.chunk_id);
84 }
85 }
86
87 index
88 }
89
90 pub fn add_mapping(&mut self, entity_id: &EntityId, chunk_id: &ChunkId) {
94 let chunks = self.entity_to_chunks.entry(entity_id.clone()).or_default();
96 let was_new = chunks.insert(chunk_id.clone());
97
98 let entities = self.chunk_to_entities.entry(chunk_id.clone()).or_default();
100 entities.insert(entity_id.clone());
101
102 if was_new {
104 self.mapping_count += 1;
105 }
106 }
107
108 pub fn add_entity_mappings(&mut self, entity_id: &EntityId, chunk_ids: &[ChunkId]) {
110 for chunk_id in chunk_ids {
111 self.add_mapping(entity_id, chunk_id);
112 }
113 }
114
115 pub fn add_chunk_mappings(&mut self, chunk_id: &ChunkId, entity_ids: &[EntityId]) {
117 for entity_id in entity_ids {
118 self.add_mapping(entity_id, chunk_id);
119 }
120 }
121
122 pub fn remove_mapping(&mut self, entity_id: &EntityId, chunk_id: &ChunkId) -> bool {
126 let mut removed = false;
127
128 if let Some(chunks) = self.entity_to_chunks.get_mut(entity_id) {
130 if chunks.swap_remove(chunk_id) {
131 removed = true;
132
133 if chunks.is_empty() {
135 self.entity_to_chunks.swap_remove(entity_id);
136 }
137 }
138 }
139
140 if let Some(entities) = self.chunk_to_entities.get_mut(chunk_id) {
142 entities.swap_remove(entity_id);
143
144 if entities.is_empty() {
146 self.chunk_to_entities.swap_remove(chunk_id);
147 }
148 }
149
150 if removed {
151 self.mapping_count = self.mapping_count.saturating_sub(1);
152 }
153
154 removed
155 }
156
157 pub fn remove_entity(&mut self, entity_id: &EntityId) -> usize {
161 let mut removed_count = 0;
162
163 if let Some(chunks) = self.entity_to_chunks.swap_remove(entity_id) {
164 removed_count = chunks.len();
165
166 for chunk_id in chunks {
168 if let Some(entities) = self.chunk_to_entities.get_mut(&chunk_id) {
169 entities.swap_remove(entity_id);
170
171 if entities.is_empty() {
172 self.chunk_to_entities.swap_remove(&chunk_id);
173 }
174 }
175 }
176 }
177
178 self.mapping_count = self.mapping_count.saturating_sub(removed_count);
179 removed_count
180 }
181
182 pub fn remove_chunk(&mut self, chunk_id: &ChunkId) -> usize {
186 let mut removed_count = 0;
187
188 if let Some(entities) = self.chunk_to_entities.swap_remove(chunk_id) {
189 removed_count = entities.len();
190
191 for entity_id in entities {
193 if let Some(chunks) = self.entity_to_chunks.get_mut(&entity_id) {
194 chunks.swap_remove(chunk_id);
195
196 if chunks.is_empty() {
197 self.entity_to_chunks.swap_remove(&entity_id);
198 }
199 }
200 }
201 }
202
203 self.mapping_count = self.mapping_count.saturating_sub(removed_count);
204 removed_count
205 }
206
207 pub fn get_chunks_for_entity(&self, entity_id: &EntityId) -> Vec<ChunkId> {
211 self.entity_to_chunks
212 .get(entity_id)
213 .map(|chunks| chunks.iter().cloned().collect())
214 .unwrap_or_default()
215 }
216
217 pub fn get_entities_for_chunk(&self, chunk_id: &ChunkId) -> Vec<EntityId> {
221 self.chunk_to_entities
222 .get(chunk_id)
223 .map(|entities| entities.iter().cloned().collect())
224 .unwrap_or_default()
225 }
226
227 pub fn has_mapping(&self, entity_id: &EntityId, chunk_id: &ChunkId) -> bool {
229 self.entity_to_chunks
230 .get(entity_id)
231 .map(|chunks| chunks.contains(chunk_id))
232 .unwrap_or(false)
233 }
234
235 pub fn get_entity_chunk_count(&self, entity_id: &EntityId) -> usize {
237 self.entity_to_chunks
238 .get(entity_id)
239 .map(|chunks| chunks.len())
240 .unwrap_or(0)
241 }
242
243 pub fn get_chunk_entity_count(&self, chunk_id: &ChunkId) -> usize {
245 self.chunk_to_entities
246 .get(chunk_id)
247 .map(|entities| entities.len())
248 .unwrap_or(0)
249 }
250
251 pub fn get_all_entities(&self) -> Vec<EntityId> {
253 self.entity_to_chunks.keys().cloned().collect()
254 }
255
256 pub fn get_all_chunks(&self) -> Vec<ChunkId> {
258 self.chunk_to_entities.keys().cloned().collect()
259 }
260
261 pub fn entity_count(&self) -> usize {
263 self.entity_to_chunks.len()
264 }
265
266 pub fn chunk_count(&self) -> usize {
268 self.chunk_to_entities.len()
269 }
270
271 pub fn mapping_count(&self) -> usize {
273 self.mapping_count
274 }
275
276 pub fn is_empty(&self) -> bool {
278 self.mapping_count == 0
279 }
280
281 pub fn clear(&mut self) {
283 self.entity_to_chunks.clear();
284 self.chunk_to_entities.clear();
285 self.mapping_count = 0;
286 }
287
288 pub fn get_co_occurring_entities(&self, entity_id: &EntityId) -> HashMap<EntityId, usize> {
292 let mut co_occurrences: HashMap<EntityId, usize> = HashMap::new();
293
294 if let Some(chunks) = self.entity_to_chunks.get(entity_id) {
296 for chunk_id in chunks {
298 if let Some(entities) = self.chunk_to_entities.get(chunk_id) {
299 for other_entity_id in entities {
300 if other_entity_id != entity_id {
302 *co_occurrences.entry(other_entity_id.clone()).or_insert(0) += 1;
303 }
304 }
305 }
306 }
307 }
308
309 co_occurrences
310 }
311
312 pub fn get_common_entities(&self, min_chunk_count: usize) -> Vec<(EntityId, usize)> {
316 let mut common_entities: Vec<_> = self
317 .entity_to_chunks
318 .iter()
319 .filter_map(|(entity_id, chunks)| {
320 if chunks.len() >= min_chunk_count {
321 Some((entity_id.clone(), chunks.len()))
322 } else {
323 None
324 }
325 })
326 .collect();
327
328 common_entities.sort_by(|a, b| b.1.cmp(&a.1));
330
331 common_entities
332 }
333
334 pub fn get_dense_chunks(&self, min_entity_count: usize) -> Vec<(ChunkId, usize)> {
338 let mut dense_chunks: Vec<_> = self
339 .chunk_to_entities
340 .iter()
341 .filter_map(|(chunk_id, entities)| {
342 if entities.len() >= min_entity_count {
343 Some((chunk_id.clone(), entities.len()))
344 } else {
345 None
346 }
347 })
348 .collect();
349
350 dense_chunks.sort_by(|a, b| b.1.cmp(&a.1));
352
353 dense_chunks
354 }
355
356 pub fn merge(&mut self, other: &BidirectionalIndex) {
360 for (entity_id, chunks) in &other.entity_to_chunks {
361 for chunk_id in chunks {
362 self.add_mapping(entity_id, chunk_id);
363 }
364 }
365 }
366
367 pub fn get_statistics(&self) -> IndexStatistics {
369 let avg_chunks_per_entity = if self.entity_count() > 0 {
370 self.mapping_count as f64 / self.entity_count() as f64
371 } else {
372 0.0
373 };
374
375 let avg_entities_per_chunk = if self.chunk_count() > 0 {
376 self.mapping_count as f64 / self.chunk_count() as f64
377 } else {
378 0.0
379 };
380
381 IndexStatistics {
382 total_entities: self.entity_count(),
383 total_chunks: self.chunk_count(),
384 total_mappings: self.mapping_count(),
385 avg_chunks_per_entity,
386 avg_entities_per_chunk,
387 }
388 }
389}
390
391impl Default for BidirectionalIndex {
392 fn default() -> Self {
393 Self::new()
394 }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct IndexStatistics {
400 pub total_entities: usize,
402
403 pub total_chunks: usize,
405
406 pub total_mappings: usize,
408
409 pub avg_chunks_per_entity: f64,
411
412 pub avg_entities_per_chunk: f64,
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::core::{ChunkId, DocumentId, EntityId, EntityMention, TextChunk};
420
421 #[test]
422 fn test_basic_operations() {
423 let mut index = BidirectionalIndex::new();
424
425 let entity1 = EntityId::new("entity_1".to_string());
426 let entity2 = EntityId::new("entity_2".to_string());
427 let chunk1 = ChunkId::new("chunk_1".to_string());
428 let chunk2 = ChunkId::new("chunk_2".to_string());
429
430 index.add_mapping(&entity1, &chunk1);
432 index.add_mapping(&entity1, &chunk2);
433 index.add_mapping(&entity2, &chunk1);
434
435 let chunks = index.get_chunks_for_entity(&entity1);
437 assert_eq!(chunks.len(), 2);
438 assert!(chunks.contains(&chunk1));
439 assert!(chunks.contains(&chunk2));
440
441 let entities = index.get_entities_for_chunk(&chunk1);
443 assert_eq!(entities.len(), 2);
444 assert!(entities.contains(&entity1));
445 assert!(entities.contains(&entity2));
446
447 assert_eq!(index.entity_count(), 2);
449 assert_eq!(index.chunk_count(), 2);
450 assert_eq!(index.mapping_count(), 3);
451 }
452
453 #[test]
454 fn test_idempotent_add() {
455 let mut index = BidirectionalIndex::new();
456
457 let entity = EntityId::new("entity_1".to_string());
458 let chunk = ChunkId::new("chunk_1".to_string());
459
460 index.add_mapping(&entity, &chunk);
461 index.add_mapping(&entity, &chunk);
462 index.add_mapping(&entity, &chunk);
463
464 assert_eq!(index.mapping_count(), 1);
465 assert_eq!(index.get_chunks_for_entity(&entity).len(), 1);
466 }
467
468 #[test]
469 fn test_removal() {
470 let mut index = BidirectionalIndex::new();
471
472 let entity1 = EntityId::new("entity_1".to_string());
473 let entity2 = EntityId::new("entity_2".to_string());
474 let chunk1 = ChunkId::new("chunk_1".to_string());
475 let chunk2 = ChunkId::new("chunk_2".to_string());
476
477 index.add_mapping(&entity1, &chunk1);
478 index.add_mapping(&entity1, &chunk2);
479 index.add_mapping(&entity2, &chunk1);
480
481 assert!(index.remove_mapping(&entity1, &chunk1));
483 assert_eq!(index.mapping_count(), 2);
484
485 let removed = index.remove_entity(&entity1);
487 assert_eq!(removed, 1);
488 assert_eq!(index.mapping_count(), 1);
489
490 assert_eq!(index.entity_count(), 1);
492 assert_eq!(index.chunk_count(), 1);
493 }
494
495 #[test]
496 fn test_from_entities() {
497 let entity1 = Entity::new(
498 EntityId::new("entity_1".to_string()),
499 "Entity 1".to_string(),
500 "PERSON".to_string(),
501 0.9,
502 )
503 .with_mentions(vec![
504 EntityMention {
505 chunk_id: ChunkId::new("chunk_1".to_string()),
506 start_offset: 0,
507 end_offset: 8,
508 confidence: 0.9,
509 },
510 EntityMention {
511 chunk_id: ChunkId::new("chunk_2".to_string()),
512 start_offset: 10,
513 end_offset: 18,
514 confidence: 0.9,
515 },
516 ]);
517
518 let index = BidirectionalIndex::from_entities(&[entity1]);
519
520 assert_eq!(index.entity_count(), 1);
521 assert_eq!(index.chunk_count(), 2);
522 assert_eq!(index.mapping_count(), 2);
523 }
524
525 #[test]
526 fn test_co_occurrence() {
527 let mut index = BidirectionalIndex::new();
528
529 let entity1 = EntityId::new("entity_1".to_string());
530 let entity2 = EntityId::new("entity_2".to_string());
531 let entity3 = EntityId::new("entity_3".to_string());
532 let chunk1 = ChunkId::new("chunk_1".to_string());
533 let chunk2 = ChunkId::new("chunk_2".to_string());
534
535 index.add_mapping(&entity1, &chunk1);
537 index.add_mapping(&entity1, &chunk2);
538 index.add_mapping(&entity2, &chunk1);
539 index.add_mapping(&entity2, &chunk2);
540
541 index.add_mapping(&entity3, &chunk1);
543
544 let co_occurrences = index.get_co_occurring_entities(&entity1);
545 assert_eq!(co_occurrences.get(&entity2), Some(&2)); assert_eq!(co_occurrences.get(&entity3), Some(&1)); }
548
549 #[test]
550 fn test_common_entities() {
551 let mut index = BidirectionalIndex::new();
552
553 let entity1 = EntityId::new("entity_1".to_string());
554 let entity2 = EntityId::new("entity_2".to_string());
555 let chunk1 = ChunkId::new("chunk_1".to_string());
556 let chunk2 = ChunkId::new("chunk_2".to_string());
557 let chunk3 = ChunkId::new("chunk_3".to_string());
558
559 index.add_mapping(&entity1, &chunk1);
561 index.add_mapping(&entity1, &chunk2);
562 index.add_mapping(&entity1, &chunk3);
563
564 index.add_mapping(&entity2, &chunk1);
566
567 let common = index.get_common_entities(2);
568 assert_eq!(common.len(), 1);
569 assert_eq!(common[0].0, entity1);
570 assert_eq!(common[0].1, 3);
571 }
572
573 #[test]
574 fn test_merge() {
575 let mut index1 = BidirectionalIndex::new();
576 let mut index2 = BidirectionalIndex::new();
577
578 let entity1 = EntityId::new("entity_1".to_string());
579 let entity2 = EntityId::new("entity_2".to_string());
580 let chunk1 = ChunkId::new("chunk_1".to_string());
581 let chunk2 = ChunkId::new("chunk_2".to_string());
582
583 index1.add_mapping(&entity1, &chunk1);
584 index2.add_mapping(&entity2, &chunk2);
585
586 index1.merge(&index2);
587
588 assert_eq!(index1.entity_count(), 2);
589 assert_eq!(index1.chunk_count(), 2);
590 assert_eq!(index1.mapping_count(), 2);
591 }
592
593 #[test]
594 fn test_statistics() {
595 let mut index = BidirectionalIndex::new();
596
597 let entity1 = EntityId::new("entity_1".to_string());
598 let entity2 = EntityId::new("entity_2".to_string());
599 let chunk1 = ChunkId::new("chunk_1".to_string());
600 let chunk2 = ChunkId::new("chunk_2".to_string());
601
602 index.add_mapping(&entity1, &chunk1);
603 index.add_mapping(&entity1, &chunk2);
604 index.add_mapping(&entity2, &chunk1);
605
606 let stats = index.get_statistics();
607 assert_eq!(stats.total_entities, 2);
608 assert_eq!(stats.total_chunks, 2);
609 assert_eq!(stats.total_mappings, 3);
610 assert_eq!(stats.avg_chunks_per_entity, 1.5);
611 assert_eq!(stats.avg_entities_per_chunk, 1.5);
612 }
613}