1use dashmap::DashMap;
7use exo_core::{EntityId, Error, HyperedgeId, SectionId, SheafConsistencyResult};
8use serde::{Deserialize, Serialize};
9use std::collections::HashSet;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
14pub struct Domain {
15 entities: HashSet<EntityId>,
16}
17
18impl Domain {
19 pub fn new(entities: impl IntoIterator<Item = EntityId>) -> Self {
21 Self {
22 entities: entities.into_iter().collect(),
23 }
24 }
25
26 pub fn is_empty(&self) -> bool {
28 self.entities.is_empty()
29 }
30
31 pub fn intersect(&self, other: &Domain) -> Domain {
33 let intersection = self
34 .entities
35 .intersection(&other.entities)
36 .copied()
37 .collect();
38 Domain {
39 entities: intersection,
40 }
41 }
42
43 pub fn contains(&self, entity: &EntityId) -> bool {
45 self.entities.contains(entity)
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct Section {
52 pub id: SectionId,
53 pub domain: Domain,
54 pub data: serde_json::Value,
55}
56
57impl Section {
58 pub fn new(domain: Domain, data: serde_json::Value) -> Self {
60 Self {
61 id: SectionId::new(),
62 domain,
63 data,
64 }
65 }
66}
67
68pub struct SheafStructure {
72 sections: Arc<DashMap<SectionId, Section>>,
74 restriction_maps: Arc<DashMap<String, serde_json::Value>>,
77 hyperedge_sections: Arc<DashMap<HyperedgeId, Vec<SectionId>>>,
79}
80
81impl SheafStructure {
82 pub fn new() -> Self {
84 Self {
85 sections: Arc::new(DashMap::new()),
86 restriction_maps: Arc::new(DashMap::new()),
87 hyperedge_sections: Arc::new(DashMap::new()),
88 }
89 }
90
91 pub fn add_section(&self, section: Section) -> SectionId {
93 let id = section.id;
94 self.sections.insert(id, section);
95 id
96 }
97
98 pub fn get_section(&self, id: &SectionId) -> Option<Section> {
100 self.sections.get(id).map(|entry| entry.clone())
101 }
102
103 pub fn restrict(&self, section: &Section, subdomain: &Domain) -> serde_json::Value {
107 let cache_key = format!("{:?}-{:?}", section.id, subdomain.entities);
109 if let Some(cached) = self.restriction_maps.get(&cache_key) {
110 return cached.clone();
111 }
112
113 let restricted = self.compute_restriction(§ion.data, subdomain);
115
116 self.restriction_maps
118 .insert(cache_key, restricted.clone());
119
120 restricted
121 }
122
123 fn compute_restriction(
125 &self,
126 data: &serde_json::Value,
127 _subdomain: &Domain,
128 ) -> serde_json::Value {
129 data.clone()
132 }
133
134 pub fn update_sections(
136 &mut self,
137 hyperedge_id: HyperedgeId,
138 entities: &[EntityId],
139 ) -> Result<(), Error> {
140 let domain = Domain::new(entities.iter().copied());
142 let section = Section::new(domain, serde_json::json!({}));
143 let section_id = self.add_section(section);
144
145 self.hyperedge_sections
147 .entry(hyperedge_id)
148 .or_insert_with(Vec::new)
149 .push(section_id);
150
151 Ok(())
152 }
153
154 pub fn check_consistency(&self, section_ids: &[SectionId]) -> SheafConsistencyResult {
159 let mut inconsistencies = Vec::new();
160
161 let sections: Vec<_> = section_ids
163 .iter()
164 .filter_map(|id| self.get_section(id))
165 .collect();
166
167 for i in 0..sections.len() {
169 for j in (i + 1)..sections.len() {
170 let section_a = §ions[i];
171 let section_b = §ions[j];
172
173 let overlap = section_a.domain.intersect(§ion_b.domain);
174
175 if overlap.is_empty() {
176 continue;
177 }
178
179 let restricted_a = self.restrict(section_a, &overlap);
181 let restricted_b = self.restrict(section_b, &overlap);
182
183 if !approximately_equal(&restricted_a, &restricted_b, 1e-6) {
185 let discrepancy = compute_discrepancy(&restricted_a, &restricted_b);
186 inconsistencies.push(format!(
187 "Sections {} and {} disagree on overlap (discrepancy: {:.6})",
188 section_a.id.0, section_b.id.0, discrepancy
189 ));
190 }
191 }
192 }
193
194 if inconsistencies.is_empty() {
195 SheafConsistencyResult::Consistent
196 } else {
197 SheafConsistencyResult::Inconsistent(inconsistencies)
198 }
199 }
200
201 pub fn get_hyperedge_sections(&self, hyperedge_id: &HyperedgeId) -> Vec<SectionId> {
203 self.hyperedge_sections
204 .get(hyperedge_id)
205 .map(|entry| entry.clone())
206 .unwrap_or_default()
207 }
208}
209
210impl Default for SheafStructure {
211 fn default() -> Self {
212 Self::new()
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct SheafInconsistency {
219 pub sections: (SectionId, SectionId),
220 pub overlap: Domain,
221 pub discrepancy: f64,
222}
223
224fn approximately_equal(a: &serde_json::Value, b: &serde_json::Value, epsilon: f64) -> bool {
226 match (a, b) {
227 (serde_json::Value::Number(na), serde_json::Value::Number(nb)) => {
228 let a_f64 = na.as_f64().unwrap_or(0.0);
229 let b_f64 = nb.as_f64().unwrap_or(0.0);
230 (a_f64 - b_f64).abs() < epsilon
231 }
232 (serde_json::Value::Array(aa), serde_json::Value::Array(ab)) => {
233 if aa.len() != ab.len() {
234 return false;
235 }
236 aa.iter()
237 .zip(ab.iter())
238 .all(|(x, y)| approximately_equal(x, y, epsilon))
239 }
240 (serde_json::Value::Object(oa), serde_json::Value::Object(ob)) => {
241 if oa.len() != ob.len() {
242 return false;
243 }
244 oa.iter().all(|(k, va)| {
245 ob.get(k)
246 .map(|vb| approximately_equal(va, vb, epsilon))
247 .unwrap_or(false)
248 })
249 }
250 _ => a == b,
251 }
252}
253
254fn compute_discrepancy(a: &serde_json::Value, b: &serde_json::Value) -> f64 {
256 match (a, b) {
257 (serde_json::Value::Number(na), serde_json::Value::Number(nb)) => {
258 let a_f64 = na.as_f64().unwrap_or(0.0);
259 let b_f64 = nb.as_f64().unwrap_or(0.0);
260 (a_f64 - b_f64).abs()
261 }
262 (serde_json::Value::Array(aa), serde_json::Value::Array(ab)) => {
263 let diffs: Vec<f64> = aa
264 .iter()
265 .zip(ab.iter())
266 .map(|(x, y)| compute_discrepancy(x, y))
267 .collect();
268 diffs.iter().sum::<f64>() / diffs.len().max(1) as f64
269 }
270 _ => {
271 if a == b {
272 0.0
273 } else {
274 1.0
275 }
276 }
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_domain_intersection() {
286 let e1 = EntityId::new();
287 let e2 = EntityId::new();
288 let e3 = EntityId::new();
289
290 let d1 = Domain::new(vec![e1, e2]);
291 let d2 = Domain::new(vec![e2, e3]);
292
293 let overlap = d1.intersect(&d2);
294 assert!(!overlap.is_empty());
295 assert!(overlap.contains(&e2));
296 assert!(!overlap.contains(&e1));
297 }
298
299 #[test]
300 fn test_sheaf_consistency() {
301 let sheaf = SheafStructure::new();
302
303 let e1 = EntityId::new();
304 let e2 = EntityId::new();
305
306 let domain1 = Domain::new(vec![e1, e2]);
308 let section1 = Section::new(domain1, serde_json::json!({"value": 42}));
309
310 let domain2 = Domain::new(vec![e2]);
311 let section2 = Section::new(domain2, serde_json::json!({"value": 42}));
312
313 let id1 = sheaf.add_section(section1);
314 let id2 = sheaf.add_section(section2);
315
316 let result = sheaf.check_consistency(&[id1, id2]);
318 assert!(matches!(result, SheafConsistencyResult::Consistent));
319 }
320
321 #[test]
322 fn test_approximately_equal() {
323 let a = serde_json::json!(1.0);
324 let b = serde_json::json!(1.0000001);
325
326 assert!(approximately_equal(&a, &b, 1e-6));
327 assert!(!approximately_equal(&a, &b, 1e-8));
328 }
329}