use dashmap::DashMap;
use exo_core::{EntityId, HyperedgeId, Relation, RelationType, SubstrateTime};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hyperedge {
pub id: HyperedgeId,
pub entities: Vec<EntityId>,
pub relation: Relation,
pub weight: f32,
pub created_at: SubstrateTime,
}
impl Hyperedge {
pub fn new(entities: Vec<EntityId>, relation: Relation) -> Self {
Self {
id: HyperedgeId::new(),
entities,
relation,
weight: 1.0,
created_at: SubstrateTime::now(),
}
}
pub fn arity(&self) -> usize {
self.entities.len()
}
pub fn contains_entity(&self, entity: &EntityId) -> bool {
self.entities.contains(entity)
}
}
pub struct HyperedgeIndex {
edges: Arc<DashMap<HyperedgeId, Hyperedge>>,
entity_index: Arc<DashMap<EntityId, Vec<HyperedgeId>>>,
relation_index: Arc<DashMap<RelationType, Vec<HyperedgeId>>>,
}
impl HyperedgeIndex {
pub fn new() -> Self {
Self {
edges: Arc::new(DashMap::new()),
entity_index: Arc::new(DashMap::new()),
relation_index: Arc::new(DashMap::new()),
}
}
pub fn insert(&self, entities: &[EntityId], relation: &Relation) -> HyperedgeId {
let hyperedge = Hyperedge::new(entities.to_vec(), relation.clone());
let hyperedge_id = hyperedge.id;
self.edges.insert(hyperedge_id, hyperedge);
for entity in entities {
self.entity_index
.entry(*entity)
.or_insert_with(Vec::new)
.push(hyperedge_id);
}
self.relation_index
.entry(relation.relation_type.clone())
.or_insert_with(Vec::new)
.push(hyperedge_id);
hyperedge_id
}
pub fn get(&self, id: &HyperedgeId) -> Option<Hyperedge> {
self.edges.get(id).map(|entry| entry.clone())
}
pub fn get_by_entity(&self, entity: &EntityId) -> Vec<HyperedgeId> {
self.entity_index
.get(entity)
.map(|entry| entry.clone())
.unwrap_or_default()
}
pub fn get_by_relation(&self, relation_type: &RelationType) -> Vec<HyperedgeId> {
self.relation_index
.get(relation_type)
.map(|entry| entry.clone())
.unwrap_or_default()
}
pub fn len(&self) -> usize {
self.edges.len()
}
pub fn is_empty(&self) -> bool {
self.edges.is_empty()
}
pub fn max_size(&self) -> usize {
self.edges
.iter()
.map(|entry| entry.value().arity())
.max()
.unwrap_or(0)
}
pub fn remove(&self, id: &HyperedgeId) -> Option<Hyperedge> {
if let Some((_, hyperedge)) = self.edges.remove(id) {
for entity in &hyperedge.entities {
if let Some(mut entry) = self.entity_index.get_mut(entity) {
entry.retain(|he_id| he_id != id);
}
}
if let Some(mut entry) = self
.relation_index
.get_mut(&hyperedge.relation.relation_type)
{
entry.retain(|he_id| he_id != id);
}
Some(hyperedge)
} else {
None
}
}
pub fn all(&self) -> Vec<Hyperedge> {
self.edges.iter().map(|entry| entry.clone()).collect()
}
pub fn find_connecting(&self, entities: &[EntityId]) -> Vec<HyperedgeId> {
if entities.is_empty() {
return Vec::new();
}
let mut candidates = self.get_by_entity(&entities[0]);
candidates.retain(|he_id| {
if let Some(he) = self.get(he_id) {
entities.iter().all(|e| he.contains_entity(e))
} else {
false
}
});
candidates
}
}
impl Default for HyperedgeIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use exo_core::RelationType;
#[test]
fn test_hyperedge_creation() {
let entities = vec![EntityId::new(), EntityId::new(), EntityId::new()];
let relation = Relation {
relation_type: RelationType::new("test"),
properties: serde_json::json!({}),
};
let he = Hyperedge::new(entities.clone(), relation);
assert_eq!(he.arity(), 3);
assert!(he.contains_entity(&entities[0]));
assert_eq!(he.weight, 1.0);
}
#[test]
fn test_hyperedge_index() {
let index = HyperedgeIndex::new();
let e1 = EntityId::new();
let e2 = EntityId::new();
let e3 = EntityId::new();
let relation = Relation {
relation_type: RelationType::new("test"),
properties: serde_json::json!({}),
};
let he_id = index.insert(&[e1, e2, e3], &relation);
assert!(index.get(&he_id).is_some());
assert_eq!(index.get_by_entity(&e1).len(), 1);
assert_eq!(index.get_by_entity(&e2).len(), 1);
assert_eq!(index.len(), 1);
}
#[test]
fn test_find_connecting() {
let index = HyperedgeIndex::new();
let e1 = EntityId::new();
let e2 = EntityId::new();
let e3 = EntityId::new();
let e4 = EntityId::new();
let relation = Relation {
relation_type: RelationType::new("test"),
properties: serde_json::json!({}),
};
index.insert(&[e1, e2], &relation);
let he2 = index.insert(&[e1, e2, e3], &relation);
index.insert(&[e1, e4], &relation);
let connecting = index.find_connecting(&[e1, e2, e3]);
assert_eq!(connecting.len(), 1);
assert_eq!(connecting[0], he2);
}
}