use crate::core::{ChunkId, Entity, EntityId};
use indexmap::{IndexMap, IndexSet};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BidirectionalIndex {
entity_to_chunks: IndexMap<EntityId, IndexSet<ChunkId>>,
chunk_to_entities: IndexMap<ChunkId, IndexSet<EntityId>>,
mapping_count: usize,
}
impl BidirectionalIndex {
pub fn new() -> Self {
Self {
entity_to_chunks: IndexMap::new(),
chunk_to_entities: IndexMap::new(),
mapping_count: 0,
}
}
pub fn from_entities(entities: &[Entity]) -> Self {
let mut index = Self::new();
for entity in entities {
for mention in &entity.mentions {
index.add_mapping(&entity.id, &mention.chunk_id);
}
}
index
}
pub fn add_mapping(&mut self, entity_id: &EntityId, chunk_id: &ChunkId) {
let chunks = self.entity_to_chunks.entry(entity_id.clone()).or_default();
let was_new = chunks.insert(chunk_id.clone());
let entities = self.chunk_to_entities.entry(chunk_id.clone()).or_default();
entities.insert(entity_id.clone());
if was_new {
self.mapping_count += 1;
}
}
pub fn add_entity_mappings(&mut self, entity_id: &EntityId, chunk_ids: &[ChunkId]) {
for chunk_id in chunk_ids {
self.add_mapping(entity_id, chunk_id);
}
}
pub fn add_chunk_mappings(&mut self, chunk_id: &ChunkId, entity_ids: &[EntityId]) {
for entity_id in entity_ids {
self.add_mapping(entity_id, chunk_id);
}
}
pub fn remove_mapping(&mut self, entity_id: &EntityId, chunk_id: &ChunkId) -> bool {
let mut removed = false;
if let Some(chunks) = self.entity_to_chunks.get_mut(entity_id) {
if chunks.swap_remove(chunk_id) {
removed = true;
if chunks.is_empty() {
self.entity_to_chunks.swap_remove(entity_id);
}
}
}
if let Some(entities) = self.chunk_to_entities.get_mut(chunk_id) {
entities.swap_remove(entity_id);
if entities.is_empty() {
self.chunk_to_entities.swap_remove(chunk_id);
}
}
if removed {
self.mapping_count = self.mapping_count.saturating_sub(1);
}
removed
}
pub fn remove_entity(&mut self, entity_id: &EntityId) -> usize {
let mut removed_count = 0;
if let Some(chunks) = self.entity_to_chunks.swap_remove(entity_id) {
removed_count = chunks.len();
for chunk_id in chunks {
if let Some(entities) = self.chunk_to_entities.get_mut(&chunk_id) {
entities.swap_remove(entity_id);
if entities.is_empty() {
self.chunk_to_entities.swap_remove(&chunk_id);
}
}
}
}
self.mapping_count = self.mapping_count.saturating_sub(removed_count);
removed_count
}
pub fn remove_chunk(&mut self, chunk_id: &ChunkId) -> usize {
let mut removed_count = 0;
if let Some(entities) = self.chunk_to_entities.swap_remove(chunk_id) {
removed_count = entities.len();
for entity_id in entities {
if let Some(chunks) = self.entity_to_chunks.get_mut(&entity_id) {
chunks.swap_remove(chunk_id);
if chunks.is_empty() {
self.entity_to_chunks.swap_remove(&entity_id);
}
}
}
}
self.mapping_count = self.mapping_count.saturating_sub(removed_count);
removed_count
}
pub fn get_chunks_for_entity(&self, entity_id: &EntityId) -> Vec<ChunkId> {
self.entity_to_chunks
.get(entity_id)
.map(|chunks| chunks.iter().cloned().collect())
.unwrap_or_default()
}
pub fn get_entities_for_chunk(&self, chunk_id: &ChunkId) -> Vec<EntityId> {
self.chunk_to_entities
.get(chunk_id)
.map(|entities| entities.iter().cloned().collect())
.unwrap_or_default()
}
pub fn has_mapping(&self, entity_id: &EntityId, chunk_id: &ChunkId) -> bool {
self.entity_to_chunks
.get(entity_id)
.map(|chunks| chunks.contains(chunk_id))
.unwrap_or(false)
}
pub fn get_entity_chunk_count(&self, entity_id: &EntityId) -> usize {
self.entity_to_chunks
.get(entity_id)
.map(|chunks| chunks.len())
.unwrap_or(0)
}
pub fn get_chunk_entity_count(&self, chunk_id: &ChunkId) -> usize {
self.chunk_to_entities
.get(chunk_id)
.map(|entities| entities.len())
.unwrap_or(0)
}
pub fn get_all_entities(&self) -> Vec<EntityId> {
self.entity_to_chunks.keys().cloned().collect()
}
pub fn get_all_chunks(&self) -> Vec<ChunkId> {
self.chunk_to_entities.keys().cloned().collect()
}
pub fn entity_count(&self) -> usize {
self.entity_to_chunks.len()
}
pub fn chunk_count(&self) -> usize {
self.chunk_to_entities.len()
}
pub fn mapping_count(&self) -> usize {
self.mapping_count
}
pub fn is_empty(&self) -> bool {
self.mapping_count == 0
}
pub fn clear(&mut self) {
self.entity_to_chunks.clear();
self.chunk_to_entities.clear();
self.mapping_count = 0;
}
pub fn get_co_occurring_entities(&self, entity_id: &EntityId) -> HashMap<EntityId, usize> {
let mut co_occurrences: HashMap<EntityId, usize> = HashMap::new();
if let Some(chunks) = self.entity_to_chunks.get(entity_id) {
for chunk_id in chunks {
if let Some(entities) = self.chunk_to_entities.get(chunk_id) {
for other_entity_id in entities {
if other_entity_id != entity_id {
*co_occurrences.entry(other_entity_id.clone()).or_insert(0) += 1;
}
}
}
}
}
co_occurrences
}
pub fn get_common_entities(&self, min_chunk_count: usize) -> Vec<(EntityId, usize)> {
let mut common_entities: Vec<_> = self
.entity_to_chunks
.iter()
.filter_map(|(entity_id, chunks)| {
if chunks.len() >= min_chunk_count {
Some((entity_id.clone(), chunks.len()))
} else {
None
}
})
.collect();
common_entities.sort_by_key(|item| std::cmp::Reverse(item.1));
common_entities
}
pub fn get_dense_chunks(&self, min_entity_count: usize) -> Vec<(ChunkId, usize)> {
let mut dense_chunks: Vec<_> = self
.chunk_to_entities
.iter()
.filter_map(|(chunk_id, entities)| {
if entities.len() >= min_entity_count {
Some((chunk_id.clone(), entities.len()))
} else {
None
}
})
.collect();
dense_chunks.sort_by_key(|item| std::cmp::Reverse(item.1));
dense_chunks
}
pub fn merge(&mut self, other: &BidirectionalIndex) {
for (entity_id, chunks) in &other.entity_to_chunks {
for chunk_id in chunks {
self.add_mapping(entity_id, chunk_id);
}
}
}
pub fn get_statistics(&self) -> IndexStatistics {
let avg_chunks_per_entity = if self.entity_count() > 0 {
self.mapping_count as f64 / self.entity_count() as f64
} else {
0.0
};
let avg_entities_per_chunk = if self.chunk_count() > 0 {
self.mapping_count as f64 / self.chunk_count() as f64
} else {
0.0
};
IndexStatistics {
total_entities: self.entity_count(),
total_chunks: self.chunk_count(),
total_mappings: self.mapping_count(),
avg_chunks_per_entity,
avg_entities_per_chunk,
}
}
}
impl Default for BidirectionalIndex {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStatistics {
pub total_entities: usize,
pub total_chunks: usize,
pub total_mappings: usize,
pub avg_chunks_per_entity: f64,
pub avg_entities_per_chunk: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{ChunkId, EntityId, EntityMention};
#[test]
fn test_basic_operations() {
let mut index = BidirectionalIndex::new();
let entity1 = EntityId::new("entity_1".to_string());
let entity2 = EntityId::new("entity_2".to_string());
let chunk1 = ChunkId::new("chunk_1".to_string());
let chunk2 = ChunkId::new("chunk_2".to_string());
index.add_mapping(&entity1, &chunk1);
index.add_mapping(&entity1, &chunk2);
index.add_mapping(&entity2, &chunk1);
let chunks = index.get_chunks_for_entity(&entity1);
assert_eq!(chunks.len(), 2);
assert!(chunks.contains(&chunk1));
assert!(chunks.contains(&chunk2));
let entities = index.get_entities_for_chunk(&chunk1);
assert_eq!(entities.len(), 2);
assert!(entities.contains(&entity1));
assert!(entities.contains(&entity2));
assert_eq!(index.entity_count(), 2);
assert_eq!(index.chunk_count(), 2);
assert_eq!(index.mapping_count(), 3);
}
#[test]
fn test_idempotent_add() {
let mut index = BidirectionalIndex::new();
let entity = EntityId::new("entity_1".to_string());
let chunk = ChunkId::new("chunk_1".to_string());
index.add_mapping(&entity, &chunk);
index.add_mapping(&entity, &chunk);
index.add_mapping(&entity, &chunk);
assert_eq!(index.mapping_count(), 1);
assert_eq!(index.get_chunks_for_entity(&entity).len(), 1);
}
#[test]
fn test_removal() {
let mut index = BidirectionalIndex::new();
let entity1 = EntityId::new("entity_1".to_string());
let entity2 = EntityId::new("entity_2".to_string());
let chunk1 = ChunkId::new("chunk_1".to_string());
let chunk2 = ChunkId::new("chunk_2".to_string());
index.add_mapping(&entity1, &chunk1);
index.add_mapping(&entity1, &chunk2);
index.add_mapping(&entity2, &chunk1);
assert!(index.remove_mapping(&entity1, &chunk1));
assert_eq!(index.mapping_count(), 2);
let removed = index.remove_entity(&entity1);
assert_eq!(removed, 1);
assert_eq!(index.mapping_count(), 1);
assert_eq!(index.entity_count(), 1);
assert_eq!(index.chunk_count(), 1);
}
#[test]
fn test_from_entities() {
let entity1 = Entity::new(
EntityId::new("entity_1".to_string()),
"Entity 1".to_string(),
"PERSON".to_string(),
0.9,
)
.with_mentions(vec![
EntityMention {
chunk_id: ChunkId::new("chunk_1".to_string()),
start_offset: 0,
end_offset: 8,
confidence: 0.9,
},
EntityMention {
chunk_id: ChunkId::new("chunk_2".to_string()),
start_offset: 10,
end_offset: 18,
confidence: 0.9,
},
]);
let index = BidirectionalIndex::from_entities(&[entity1]);
assert_eq!(index.entity_count(), 1);
assert_eq!(index.chunk_count(), 2);
assert_eq!(index.mapping_count(), 2);
}
#[test]
fn test_co_occurrence() {
let mut index = BidirectionalIndex::new();
let entity1 = EntityId::new("entity_1".to_string());
let entity2 = EntityId::new("entity_2".to_string());
let entity3 = EntityId::new("entity_3".to_string());
let chunk1 = ChunkId::new("chunk_1".to_string());
let chunk2 = ChunkId::new("chunk_2".to_string());
index.add_mapping(&entity1, &chunk1);
index.add_mapping(&entity1, &chunk2);
index.add_mapping(&entity2, &chunk1);
index.add_mapping(&entity2, &chunk2);
index.add_mapping(&entity3, &chunk1);
let co_occurrences = index.get_co_occurring_entities(&entity1);
assert_eq!(co_occurrences.get(&entity2), Some(&2)); assert_eq!(co_occurrences.get(&entity3), Some(&1)); }
#[test]
fn test_common_entities() {
let mut index = BidirectionalIndex::new();
let entity1 = EntityId::new("entity_1".to_string());
let entity2 = EntityId::new("entity_2".to_string());
let chunk1 = ChunkId::new("chunk_1".to_string());
let chunk2 = ChunkId::new("chunk_2".to_string());
let chunk3 = ChunkId::new("chunk_3".to_string());
index.add_mapping(&entity1, &chunk1);
index.add_mapping(&entity1, &chunk2);
index.add_mapping(&entity1, &chunk3);
index.add_mapping(&entity2, &chunk1);
let common = index.get_common_entities(2);
assert_eq!(common.len(), 1);
assert_eq!(common[0].0, entity1);
assert_eq!(common[0].1, 3);
}
#[test]
fn test_merge() {
let mut index1 = BidirectionalIndex::new();
let mut index2 = BidirectionalIndex::new();
let entity1 = EntityId::new("entity_1".to_string());
let entity2 = EntityId::new("entity_2".to_string());
let chunk1 = ChunkId::new("chunk_1".to_string());
let chunk2 = ChunkId::new("chunk_2".to_string());
index1.add_mapping(&entity1, &chunk1);
index2.add_mapping(&entity2, &chunk2);
index1.merge(&index2);
assert_eq!(index1.entity_count(), 2);
assert_eq!(index1.chunk_count(), 2);
assert_eq!(index1.mapping_count(), 2);
}
#[test]
fn test_statistics() {
let mut index = BidirectionalIndex::new();
let entity1 = EntityId::new("entity_1".to_string());
let entity2 = EntityId::new("entity_2".to_string());
let chunk1 = ChunkId::new("chunk_1".to_string());
let chunk2 = ChunkId::new("chunk_2".to_string());
index.add_mapping(&entity1, &chunk1);
index.add_mapping(&entity1, &chunk2);
index.add_mapping(&entity2, &chunk1);
let stats = index.get_statistics();
assert_eq!(stats.total_entities, 2);
assert_eq!(stats.total_chunks, 2);
assert_eq!(stats.total_mappings, 3);
assert_eq!(stats.avg_chunks_per_entity, 1.5);
assert_eq!(stats.avg_entities_per_chunk, 1.5);
}
}