1use dashmap::DashMap;
7use exo_core::{EntityId, Error, HyperedgeId, SectionId, SheafConsistencyResult};
8use serde::{Deserialize, Serialize};
9use std::collections::HashSet;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Domain {
15 entities: HashSet<EntityId>,
16}
17
18impl Domain {
19 pub fn new(entities: impl IntoIterator<Item = EntityId>) -> Self {
21 Self {
22 entities: entities.into_iter().collect(),
23 }
24 }
25
26 pub fn is_empty(&self) -> bool {
28 self.entities.is_empty()
29 }
30
31 pub fn intersect(&self, other: &Domain) -> Domain {
33 let intersection = self
34 .entities
35 .intersection(&other.entities)
36 .copied()
37 .collect();
38 Domain {
39 entities: intersection,
40 }
41 }
42
43 pub fn contains(&self, entity: &EntityId) -> bool {
45 self.entities.contains(entity)
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct Section {
52 pub id: SectionId,
53 pub domain: Domain,
54 pub data: serde_json::Value,
55}
56
57impl Section {
58 pub fn new(domain: Domain, data: serde_json::Value) -> Self {
60 Self {
61 id: SectionId::new(),
62 domain,
63 data,
64 }
65 }
66}
67
68pub struct SheafStructure {
72 sections: Arc<DashMap<SectionId, Section>>,
74 restriction_maps: Arc<DashMap<String, serde_json::Value>>,
77 hyperedge_sections: Arc<DashMap<HyperedgeId, Vec<SectionId>>>,
79}
80
81impl SheafStructure {
82 pub fn new() -> Self {
84 Self {
85 sections: Arc::new(DashMap::new()),
86 restriction_maps: Arc::new(DashMap::new()),
87 hyperedge_sections: Arc::new(DashMap::new()),
88 }
89 }
90
91 pub fn add_section(&self, section: Section) -> SectionId {
93 let id = section.id;
94 self.sections.insert(id, section);
95 id
96 }
97
98 pub fn get_section(&self, id: &SectionId) -> Option<Section> {
100 self.sections.get(id).map(|entry| entry.clone())
101 }
102
103 pub fn restrict(&self, section: &Section, subdomain: &Domain) -> serde_json::Value {
107 let cache_key = format!("{:?}-{:?}", section.id, subdomain.entities);
109 if let Some(cached) = self.restriction_maps.get(&cache_key) {
110 return cached.clone();
111 }
112
113 let restricted = self.compute_restriction(§ion.data, subdomain);
115
116 self.restriction_maps.insert(cache_key, restricted.clone());
118
119 restricted
120 }
121
122 fn compute_restriction(
124 &self,
125 data: &serde_json::Value,
126 _subdomain: &Domain,
127 ) -> serde_json::Value {
128 data.clone()
131 }
132
133 pub fn update_sections(
135 &mut self,
136 hyperedge_id: HyperedgeId,
137 entities: &[EntityId],
138 ) -> Result<(), Error> {
139 let domain = Domain::new(entities.iter().copied());
141 let section = Section::new(domain, serde_json::json!({}));
142 let section_id = self.add_section(section);
143
144 self.hyperedge_sections
146 .entry(hyperedge_id)
147 .or_insert_with(Vec::new)
148 .push(section_id);
149
150 Ok(())
151 }
152
153 pub fn check_consistency(&self, section_ids: &[SectionId]) -> SheafConsistencyResult {
158 let mut inconsistencies = Vec::new();
159
160 let sections: Vec<_> = section_ids
162 .iter()
163 .filter_map(|id| self.get_section(id))
164 .collect();
165
166 for i in 0..sections.len() {
168 for j in (i + 1)..sections.len() {
169 let section_a = §ions[i];
170 let section_b = §ions[j];
171
172 let overlap = section_a.domain.intersect(§ion_b.domain);
173
174 if overlap.is_empty() {
175 continue;
176 }
177
178 let restricted_a = self.restrict(section_a, &overlap);
180 let restricted_b = self.restrict(section_b, &overlap);
181
182 if !approximately_equal(&restricted_a, &restricted_b, 1e-6) {
184 let discrepancy = compute_discrepancy(&restricted_a, &restricted_b);
185 inconsistencies.push(format!(
186 "Sections {} and {} disagree on overlap (discrepancy: {:.6})",
187 section_a.id.0, section_b.id.0, discrepancy
188 ));
189 }
190 }
191 }
192
193 if inconsistencies.is_empty() {
194 SheafConsistencyResult::Consistent
195 } else {
196 SheafConsistencyResult::Inconsistent(inconsistencies)
197 }
198 }
199
200 pub fn get_hyperedge_sections(&self, hyperedge_id: &HyperedgeId) -> Vec<SectionId> {
202 self.hyperedge_sections
203 .get(hyperedge_id)
204 .map(|entry| entry.clone())
205 .unwrap_or_default()
206 }
207}
208
209impl Default for SheafStructure {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct SheafInconsistency {
218 pub sections: (SectionId, SectionId),
219 pub overlap: Domain,
220 pub discrepancy: f64,
221}
222
223fn approximately_equal(a: &serde_json::Value, b: &serde_json::Value, epsilon: f64) -> bool {
225 match (a, b) {
226 (serde_json::Value::Number(na), serde_json::Value::Number(nb)) => {
227 let a_f64 = na.as_f64().unwrap_or(0.0);
228 let b_f64 = nb.as_f64().unwrap_or(0.0);
229 (a_f64 - b_f64).abs() < epsilon
230 }
231 (serde_json::Value::Array(aa), serde_json::Value::Array(ab)) => {
232 if aa.len() != ab.len() {
233 return false;
234 }
235 aa.iter()
236 .zip(ab.iter())
237 .all(|(x, y)| approximately_equal(x, y, epsilon))
238 }
239 (serde_json::Value::Object(oa), serde_json::Value::Object(ob)) => {
240 if oa.len() != ob.len() {
241 return false;
242 }
243 oa.iter().all(|(k, va)| {
244 ob.get(k)
245 .map(|vb| approximately_equal(va, vb, epsilon))
246 .unwrap_or(false)
247 })
248 }
249 _ => a == b,
250 }
251}
252
253fn compute_discrepancy(a: &serde_json::Value, b: &serde_json::Value) -> f64 {
255 match (a, b) {
256 (serde_json::Value::Number(na), serde_json::Value::Number(nb)) => {
257 let a_f64 = na.as_f64().unwrap_or(0.0);
258 let b_f64 = nb.as_f64().unwrap_or(0.0);
259 (a_f64 - b_f64).abs()
260 }
261 (serde_json::Value::Array(aa), serde_json::Value::Array(ab)) => {
262 let diffs: Vec<f64> = aa
263 .iter()
264 .zip(ab.iter())
265 .map(|(x, y)| compute_discrepancy(x, y))
266 .collect();
267 diffs.iter().sum::<f64>() / diffs.len().max(1) as f64
268 }
269 _ => {
270 if a == b {
271 0.0
272 } else {
273 1.0
274 }
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_domain_intersection() {
285 let e1 = EntityId::new();
286 let e2 = EntityId::new();
287 let e3 = EntityId::new();
288
289 let d1 = Domain::new(vec![e1, e2]);
290 let d2 = Domain::new(vec![e2, e3]);
291
292 let overlap = d1.intersect(&d2);
293 assert!(!overlap.is_empty());
294 assert!(overlap.contains(&e2));
295 assert!(!overlap.contains(&e1));
296 }
297
298 #[test]
299 fn test_sheaf_consistency() {
300 let sheaf = SheafStructure::new();
301
302 let e1 = EntityId::new();
303 let e2 = EntityId::new();
304
305 let domain1 = Domain::new(vec![e1, e2]);
307 let section1 = Section::new(domain1, serde_json::json!({"value": 42}));
308
309 let domain2 = Domain::new(vec![e2]);
310 let section2 = Section::new(domain2, serde_json::json!({"value": 42}));
311
312 let id1 = sheaf.add_section(section1);
313 let id2 = sheaf.add_section(section2);
314
315 let result = sheaf.check_consistency(&[id1, id2]);
317 assert!(matches!(result, SheafConsistencyResult::Consistent));
318 }
319
320 #[test]
321 fn test_approximately_equal() {
322 let a = serde_json::json!(1.0);
323 let b = serde_json::json!(1.0000001);
324
325 assert!(approximately_equal(&a, &b, 1e-6));
326 assert!(!approximately_equal(&a, &b, 1e-8));
327 }
328}