exo_hypergraph/
hyperedge.rs1use dashmap::DashMap;
7use exo_core::{EntityId, HyperedgeId, Relation, RelationType, SubstrateTime};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Hyperedge {
14 pub id: HyperedgeId,
16 pub entities: Vec<EntityId>,
18 pub relation: Relation,
20 pub weight: f32,
22 pub created_at: SubstrateTime,
24}
25
26impl Hyperedge {
27 pub fn new(entities: Vec<EntityId>, relation: Relation) -> Self {
29 Self {
30 id: HyperedgeId::new(),
31 entities,
32 relation,
33 weight: 1.0,
34 created_at: SubstrateTime::now(),
35 }
36 }
37
38 pub fn arity(&self) -> usize {
40 self.entities.len()
41 }
42
43 pub fn contains_entity(&self, entity: &EntityId) -> bool {
45 self.entities.contains(entity)
46 }
47}
48
49pub struct HyperedgeIndex {
53 edges: Arc<DashMap<HyperedgeId, Hyperedge>>,
55 entity_index: Arc<DashMap<EntityId, Vec<HyperedgeId>>>,
57 relation_index: Arc<DashMap<RelationType, Vec<HyperedgeId>>>,
59}
60
61impl HyperedgeIndex {
62 pub fn new() -> Self {
64 Self {
65 edges: Arc::new(DashMap::new()),
66 entity_index: Arc::new(DashMap::new()),
67 relation_index: Arc::new(DashMap::new()),
68 }
69 }
70
71 pub fn insert(&self, entities: &[EntityId], relation: &Relation) -> HyperedgeId {
75 let hyperedge = Hyperedge::new(entities.to_vec(), relation.clone());
76 let hyperedge_id = hyperedge.id;
77
78 self.edges.insert(hyperedge_id, hyperedge);
80
81 for entity in entities {
83 self.entity_index
84 .entry(*entity)
85 .or_insert_with(Vec::new)
86 .push(hyperedge_id);
87 }
88
89 self.relation_index
91 .entry(relation.relation_type.clone())
92 .or_insert_with(Vec::new)
93 .push(hyperedge_id);
94
95 hyperedge_id
96 }
97
98 pub fn get(&self, id: &HyperedgeId) -> Option<Hyperedge> {
100 self.edges.get(id).map(|entry| entry.clone())
101 }
102
103 pub fn get_by_entity(&self, entity: &EntityId) -> Vec<HyperedgeId> {
105 self.entity_index
106 .get(entity)
107 .map(|entry| entry.clone())
108 .unwrap_or_default()
109 }
110
111 pub fn get_by_relation(&self, relation_type: &RelationType) -> Vec<HyperedgeId> {
113 self.relation_index
114 .get(relation_type)
115 .map(|entry| entry.clone())
116 .unwrap_or_default()
117 }
118
119 pub fn len(&self) -> usize {
121 self.edges.len()
122 }
123
124 pub fn is_empty(&self) -> bool {
126 self.edges.is_empty()
127 }
128
129 pub fn max_size(&self) -> usize {
131 self.edges
132 .iter()
133 .map(|entry| entry.value().arity())
134 .max()
135 .unwrap_or(0)
136 }
137
138 pub fn remove(&self, id: &HyperedgeId) -> Option<Hyperedge> {
140 if let Some((_, hyperedge)) = self.edges.remove(id) {
141 for entity in &hyperedge.entities {
143 if let Some(mut entry) = self.entity_index.get_mut(entity) {
144 entry.retain(|he_id| he_id != id);
145 }
146 }
147
148 if let Some(mut entry) = self.relation_index.get_mut(&hyperedge.relation.relation_type)
150 {
151 entry.retain(|he_id| he_id != id);
152 }
153
154 Some(hyperedge)
155 } else {
156 None
157 }
158 }
159
160 pub fn all(&self) -> Vec<Hyperedge> {
162 self.edges.iter().map(|entry| entry.clone()).collect()
163 }
164
165 pub fn find_connecting(&self, entities: &[EntityId]) -> Vec<HyperedgeId> {
169 if entities.is_empty() {
170 return Vec::new();
171 }
172
173 let mut candidates = self.get_by_entity(&entities[0]);
175
176 candidates.retain(|he_id| {
178 if let Some(he) = self.get(he_id) {
179 entities.iter().all(|e| he.contains_entity(e))
180 } else {
181 false
182 }
183 });
184
185 candidates
186 }
187}
188
189impl Default for HyperedgeIndex {
190 fn default() -> Self {
191 Self::new()
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use exo_core::RelationType;
199
200 #[test]
201 fn test_hyperedge_creation() {
202 let entities = vec![EntityId::new(), EntityId::new(), EntityId::new()];
203 let relation = Relation {
204 relation_type: RelationType::new("test"),
205 properties: serde_json::json!({}),
206 };
207
208 let he = Hyperedge::new(entities.clone(), relation);
209
210 assert_eq!(he.arity(), 3);
211 assert!(he.contains_entity(&entities[0]));
212 assert_eq!(he.weight, 1.0);
213 }
214
215 #[test]
216 fn test_hyperedge_index() {
217 let index = HyperedgeIndex::new();
218
219 let e1 = EntityId::new();
220 let e2 = EntityId::new();
221 let e3 = EntityId::new();
222
223 let relation = Relation {
224 relation_type: RelationType::new("test"),
225 properties: serde_json::json!({}),
226 };
227
228 let he_id = index.insert(&[e1, e2, e3], &relation);
230
231 assert!(index.get(&he_id).is_some());
233 assert_eq!(index.get_by_entity(&e1).len(), 1);
234 assert_eq!(index.get_by_entity(&e2).len(), 1);
235 assert_eq!(index.len(), 1);
236 }
237
238 #[test]
239 fn test_find_connecting() {
240 let index = HyperedgeIndex::new();
241
242 let e1 = EntityId::new();
243 let e2 = EntityId::new();
244 let e3 = EntityId::new();
245 let e4 = EntityId::new();
246
247 let relation = Relation {
248 relation_type: RelationType::new("test"),
249 properties: serde_json::json!({}),
250 };
251
252 index.insert(&[e1, e2], &relation);
254 let he2 = index.insert(&[e1, e2, e3], &relation);
255 index.insert(&[e1, e4], &relation);
256
257 let connecting = index.find_connecting(&[e1, e2, e3]);
259 assert_eq!(connecting.len(), 1);
260 assert_eq!(connecting[0], he2);
261 }
262}