use dashmap::DashMap;
use exo_core::{EntityId, Error, HyperedgeId, SectionId, SheafConsistencyResult};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Domain {
entities: HashSet<EntityId>,
}
impl Domain {
pub fn new(entities: impl IntoIterator<Item = EntityId>) -> Self {
Self {
entities: entities.into_iter().collect(),
}
}
pub fn is_empty(&self) -> bool {
self.entities.is_empty()
}
pub fn intersect(&self, other: &Domain) -> Domain {
let intersection = self
.entities
.intersection(&other.entities)
.copied()
.collect();
Domain {
entities: intersection,
}
}
pub fn contains(&self, entity: &EntityId) -> bool {
self.entities.contains(entity)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Section {
pub id: SectionId,
pub domain: Domain,
pub data: serde_json::Value,
}
impl Section {
pub fn new(domain: Domain, data: serde_json::Value) -> Self {
Self {
id: SectionId::new(),
domain,
data,
}
}
}
pub struct SheafStructure {
sections: Arc<DashMap<SectionId, Section>>,
restriction_maps: Arc<DashMap<String, serde_json::Value>>,
hyperedge_sections: Arc<DashMap<HyperedgeId, Vec<SectionId>>>,
}
impl SheafStructure {
pub fn new() -> Self {
Self {
sections: Arc::new(DashMap::new()),
restriction_maps: Arc::new(DashMap::new()),
hyperedge_sections: Arc::new(DashMap::new()),
}
}
pub fn add_section(&self, section: Section) -> SectionId {
let id = section.id;
self.sections.insert(id, section);
id
}
pub fn get_section(&self, id: &SectionId) -> Option<Section> {
self.sections.get(id).map(|entry| entry.clone())
}
pub fn restrict(&self, section: &Section, subdomain: &Domain) -> serde_json::Value {
let cache_key = format!("{:?}-{:?}", section.id, subdomain.entities);
if let Some(cached) = self.restriction_maps.get(&cache_key) {
return cached.clone();
}
let restricted = self.compute_restriction(§ion.data, subdomain);
self.restriction_maps.insert(cache_key, restricted.clone());
restricted
}
fn compute_restriction(
&self,
data: &serde_json::Value,
_subdomain: &Domain,
) -> serde_json::Value {
data.clone()
}
pub fn update_sections(
&mut self,
hyperedge_id: HyperedgeId,
entities: &[EntityId],
) -> Result<(), Error> {
let domain = Domain::new(entities.iter().copied());
let section = Section::new(domain, serde_json::json!({}));
let section_id = self.add_section(section);
self.hyperedge_sections
.entry(hyperedge_id)
.or_insert_with(Vec::new)
.push(section_id);
Ok(())
}
pub fn check_consistency(&self, section_ids: &[SectionId]) -> SheafConsistencyResult {
let mut inconsistencies = Vec::new();
let sections: Vec<_> = section_ids
.iter()
.filter_map(|id| self.get_section(id))
.collect();
for i in 0..sections.len() {
for j in (i + 1)..sections.len() {
let section_a = §ions[i];
let section_b = §ions[j];
let overlap = section_a.domain.intersect(§ion_b.domain);
if overlap.is_empty() {
continue;
}
let restricted_a = self.restrict(section_a, &overlap);
let restricted_b = self.restrict(section_b, &overlap);
if !approximately_equal(&restricted_a, &restricted_b, 1e-6) {
let discrepancy = compute_discrepancy(&restricted_a, &restricted_b);
inconsistencies.push(format!(
"Sections {} and {} disagree on overlap (discrepancy: {:.6})",
section_a.id.0, section_b.id.0, discrepancy
));
}
}
}
if inconsistencies.is_empty() {
SheafConsistencyResult::Consistent
} else {
SheafConsistencyResult::Inconsistent(inconsistencies)
}
}
pub fn get_hyperedge_sections(&self, hyperedge_id: &HyperedgeId) -> Vec<SectionId> {
self.hyperedge_sections
.get(hyperedge_id)
.map(|entry| entry.clone())
.unwrap_or_default()
}
}
impl Default for SheafStructure {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SheafInconsistency {
pub sections: (SectionId, SectionId),
pub overlap: Domain,
pub discrepancy: f64,
}
fn approximately_equal(a: &serde_json::Value, b: &serde_json::Value, epsilon: f64) -> bool {
match (a, b) {
(serde_json::Value::Number(na), serde_json::Value::Number(nb)) => {
let a_f64 = na.as_f64().unwrap_or(0.0);
let b_f64 = nb.as_f64().unwrap_or(0.0);
(a_f64 - b_f64).abs() < epsilon
}
(serde_json::Value::Array(aa), serde_json::Value::Array(ab)) => {
if aa.len() != ab.len() {
return false;
}
aa.iter()
.zip(ab.iter())
.all(|(x, y)| approximately_equal(x, y, epsilon))
}
(serde_json::Value::Object(oa), serde_json::Value::Object(ob)) => {
if oa.len() != ob.len() {
return false;
}
oa.iter().all(|(k, va)| {
ob.get(k)
.map(|vb| approximately_equal(va, vb, epsilon))
.unwrap_or(false)
})
}
_ => a == b,
}
}
fn compute_discrepancy(a: &serde_json::Value, b: &serde_json::Value) -> f64 {
match (a, b) {
(serde_json::Value::Number(na), serde_json::Value::Number(nb)) => {
let a_f64 = na.as_f64().unwrap_or(0.0);
let b_f64 = nb.as_f64().unwrap_or(0.0);
(a_f64 - b_f64).abs()
}
(serde_json::Value::Array(aa), serde_json::Value::Array(ab)) => {
let diffs: Vec<f64> = aa
.iter()
.zip(ab.iter())
.map(|(x, y)| compute_discrepancy(x, y))
.collect();
diffs.iter().sum::<f64>() / diffs.len().max(1) as f64
}
_ => {
if a == b {
0.0
} else {
1.0
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_domain_intersection() {
let e1 = EntityId::new();
let e2 = EntityId::new();
let e3 = EntityId::new();
let d1 = Domain::new(vec![e1, e2]);
let d2 = Domain::new(vec![e2, e3]);
let overlap = d1.intersect(&d2);
assert!(!overlap.is_empty());
assert!(overlap.contains(&e2));
assert!(!overlap.contains(&e1));
}
#[test]
fn test_sheaf_consistency() {
let sheaf = SheafStructure::new();
let e1 = EntityId::new();
let e2 = EntityId::new();
let domain1 = Domain::new(vec![e1, e2]);
let section1 = Section::new(domain1, serde_json::json!({"value": 42}));
let domain2 = Domain::new(vec![e2]);
let section2 = Section::new(domain2, serde_json::json!({"value": 42}));
let id1 = sheaf.add_section(section1);
let id2 = sheaf.add_section(section2);
let result = sheaf.check_consistency(&[id1, id2]);
assert!(matches!(result, SheafConsistencyResult::Consistent));
}
#[test]
fn test_approximately_equal() {
let a = serde_json::json!(1.0);
let b = serde_json::json!(1.0000001);
assert!(approximately_equal(&a, &b, 1e-6));
assert!(!approximately_equal(&a, &b, 1e-8));
}
}