use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::{EntityId, RelationType, Triple};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HyperTriple {
pub core: Triple,
pub qualifiers: HashMap<RelationType, EntityId>,
}
impl HyperTriple {
pub fn new(core: Triple) -> Self {
Self {
core,
qualifiers: HashMap::new(),
}
}
pub fn from_parts(
subject: impl Into<EntityId>,
predicate: impl Into<RelationType>,
object: impl Into<EntityId>,
) -> Self {
Self::new(Triple::new(subject, predicate, object))
}
pub fn with_qualifier(
mut self,
key: impl Into<RelationType>,
value: impl Into<EntityId>,
) -> Self {
self.qualifiers.insert(key.into(), value.into());
self
}
pub fn with_qualifiers(
mut self,
qualifiers: impl IntoIterator<Item = (impl Into<RelationType>, impl Into<EntityId>)>,
) -> Self {
for (k, v) in qualifiers {
self.qualifiers.insert(k.into(), v.into());
}
self
}
pub fn qualifier(&self, key: &str) -> Option<&EntityId> {
self.qualifiers.get(&RelationType::from(key))
}
pub fn arity(&self) -> usize {
2 + self.qualifiers.len()
}
pub fn entities(&self) -> impl Iterator<Item = &EntityId> {
std::iter::once(self.core.subject())
.chain(std::iter::once(self.core.object()))
.chain(self.qualifiers.values())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct RoleBinding {
pub role: String,
pub entity: EntityId,
}
impl RoleBinding {
pub fn new(role: impl Into<String>, entity: impl Into<EntityId>) -> Self {
Self {
role: role.into(),
entity: entity.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HyperEdge {
pub relation: RelationType,
pub bindings: Vec<RoleBinding>,
pub confidence: Option<f32>,
}
impl HyperEdge {
pub fn new(relation: impl Into<RelationType>) -> Self {
Self {
relation: relation.into(),
bindings: Vec::new(),
confidence: None,
}
}
pub fn with_binding(mut self, role: impl Into<String>, entity: impl Into<EntityId>) -> Self {
self.bindings.push(RoleBinding::new(role, entity));
self
}
pub fn with_bindings(
mut self,
bindings: impl IntoIterator<Item = (impl Into<String>, impl Into<EntityId>)>,
) -> Self {
for (role, entity) in bindings {
self.bindings.push(RoleBinding::new(role, entity));
}
self
}
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = Some(confidence.clamp(0.0, 1.0));
self
}
pub fn arity(&self) -> usize {
self.bindings.len()
}
pub fn entity_at(&self, position: usize) -> Option<&EntityId> {
self.bindings.get(position).map(|b| &b.entity)
}
pub fn entity_by_role(&self, role: &str) -> Option<&EntityId> {
self.bindings
.iter()
.find(|b| b.role == role)
.map(|b| &b.entity)
}
pub fn entities(&self) -> impl Iterator<Item = &EntityId> {
self.bindings.iter().map(|b| &b.entity)
}
pub fn roles(&self) -> impl Iterator<Item = &str> {
self.bindings.iter().map(|b| b.role.as_str())
}
pub fn reify(&self, intermediate_id: impl Into<EntityId>) -> Vec<Triple> {
let intermediate = intermediate_id.into();
let mut triples = Vec::with_capacity(self.bindings.len() + 1);
triples.push(Triple::new(
intermediate.clone(),
"rdf:type",
self.relation.as_str(),
));
for binding in &self.bindings {
triples.push(Triple::new(
intermediate.clone(),
binding.role.clone(),
binding.entity.clone(),
));
}
triples
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct HyperGraph {
pub triples: Vec<Triple>,
pub hyper_triples: Vec<HyperTriple>,
pub hyperedges: Vec<HyperEdge>,
}
impl HyperGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_triple(&mut self, triple: Triple) {
self.triples.push(triple);
}
pub fn add_hyper_triple(&mut self, hyper_triple: HyperTriple) {
self.hyper_triples.push(hyper_triple);
}
pub fn add_hyperedge(&mut self, hyperedge: HyperEdge) {
self.hyperedges.push(hyperedge);
}
pub fn fact_count(&self) -> usize {
self.triples.len() + self.hyper_triples.len() + self.hyperedges.len()
}
pub fn entities(&self) -> std::collections::HashSet<&EntityId> {
let mut entities = std::collections::HashSet::new();
for t in &self.triples {
entities.insert(t.subject());
entities.insert(t.object());
}
for ht in &self.hyper_triples {
for e in ht.entities() {
entities.insert(e);
}
}
for he in &self.hyperedges {
for e in he.entities() {
entities.insert(e);
}
}
entities
}
pub fn to_knowledge_graph(&self) -> crate::KnowledgeGraph {
let mut kg = crate::KnowledgeGraph::new();
for triple in self.to_reified_triples() {
kg.add_triple(triple);
}
kg
}
pub fn find_by_qualifier(&self, key: &str, value: &str) -> Vec<&HyperTriple> {
self.hyper_triples
.iter()
.filter(|ht| {
ht.qualifiers
.iter()
.any(|(k, v)| k.as_str() == key && v.as_str() == value)
})
.collect()
}
pub fn find_by_entity(&self, entity: &str) -> Vec<&HyperTriple> {
self.hyper_triples
.iter()
.filter(|ht| {
ht.core.subject().as_str() == entity
|| ht.core.object().as_str() == entity
|| ht.qualifiers.values().any(|v| v.as_str() == entity)
})
.collect()
}
pub fn hyper_triples_for_subject(&self, subject: &str) -> Vec<&HyperTriple> {
self.hyper_triples
.iter()
.filter(|ht| ht.core.subject().as_str() == subject)
.collect()
}
pub fn to_reified_triples(&self) -> Vec<Triple> {
let mut result = self.triples.clone();
for (i, ht) in self.hyper_triples.iter().enumerate() {
result.push(ht.core.clone());
let statement_id = format!("_:stmt_{}", i);
for (key, value) in &ht.qualifiers {
result.push(Triple::new(
statement_id.clone(),
key.clone(),
value.clone(),
));
}
}
for (i, he) in self.hyperedges.iter().enumerate() {
let node_id = format!("_:hyperedge_{}", i);
result.extend(he.reify(node_id));
}
result
}
}
impl From<&crate::KnowledgeGraph> for HyperGraph {
fn from(kg: &crate::KnowledgeGraph) -> Self {
let mut hg = HyperGraph::new();
for triple in kg.triples() {
hg.add_triple(triple.clone());
}
hg
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hyper_triple_creation() {
let ht = HyperTriple::from_parts("Einstein", "educated_at", "ETH Zurich")
.with_qualifier("degree", "PhD")
.with_qualifier("year", "1905");
assert_eq!(ht.arity(), 4); assert_eq!(ht.qualifiers.len(), 2);
}
#[test]
fn test_hyper_edge_creation() {
let he = HyperEdge::new("purchase")
.with_binding("buyer", "Alice")
.with_binding("seller", "Amazon")
.with_binding("item", "Rust Book")
.with_binding("price", "$50");
assert_eq!(he.arity(), 4);
assert_eq!(he.entity_by_role("buyer"), Some(&EntityId::from("Alice")));
assert_eq!(he.entity_at(0), Some(&EntityId::from("Alice")));
}
#[test]
fn test_hyper_edge_reification() {
let he = HyperEdge::new("award")
.with_binding("recipient", "Einstein")
.with_binding("prize", "Nobel")
.with_binding("year", "1921");
let reified = he.reify("_:award_1");
assert_eq!(reified.len(), 4);
assert_eq!(reified[0].predicate().as_str(), "rdf:type");
}
#[test]
fn test_hyper_graph_mixed() {
let mut hg = HyperGraph::new();
hg.add_triple(Triple::new("Einstein", "born_in", "Ulm"));
hg.add_hyper_triple(
HyperTriple::from_parts("Einstein", "won", "Nobel Prize")
.with_qualifier("year", "1921"),
);
hg.add_hyperedge(
HyperEdge::new("collaboration")
.with_binding("scientist_1", "Einstein")
.with_binding("scientist_2", "Bohr")
.with_binding("topic", "Quantum Mechanics"),
);
assert_eq!(hg.fact_count(), 3);
let entities = hg.entities();
assert!(entities.contains(&EntityId::from("Einstein")));
assert!(entities.contains(&EntityId::from("Bohr")));
}
#[test]
fn test_from_knowledge_graph() {
let mut kg = crate::KnowledgeGraph::new();
kg.add_triple(Triple::new("Einstein", "born_in", "Ulm"));
kg.add_triple(Triple::new("Einstein", "won", "Nobel Prize"));
kg.add_triple(Triple::new("Ulm", "located_in", "Germany"));
let hg = HyperGraph::from(&kg);
assert_eq!(hg.triples.len(), 3);
assert!(hg.hyper_triples.is_empty());
assert!(hg.hyperedges.is_empty());
}
#[test]
fn test_to_knowledge_graph() {
let mut hg = HyperGraph::new();
hg.add_triple(Triple::new("Einstein", "born_in", "Ulm"));
hg.add_hyper_triple(
HyperTriple::from_parts("Einstein", "won", "Nobel Prize")
.with_qualifier("year", "1921"),
);
let kg = hg.to_knowledge_graph();
assert_eq!(kg.triple_count(), 3);
}
#[test]
fn test_kg_roundtrip_plain_triples() {
let mut kg = crate::KnowledgeGraph::new();
kg.add_triple(Triple::new("A", "r1", "B"));
kg.add_triple(Triple::new("B", "r2", "C"));
kg.add_triple(Triple::new("C", "r3", "A"));
let hg = HyperGraph::from(&kg);
let kg2 = hg.to_knowledge_graph();
assert_eq!(kg.triple_count(), kg2.triple_count());
}
#[test]
fn test_find_by_qualifier() {
let mut hg = HyperGraph::new();
hg.add_hyper_triple(
HyperTriple::from_parts("Einstein", "won", "Nobel Prize")
.with_qualifier("year", "1921"),
);
hg.add_hyper_triple(
HyperTriple::from_parts("Curie", "won", "Nobel Prize").with_qualifier("year", "1903"),
);
hg.add_hyper_triple(
HyperTriple::from_parts("Bohr", "won", "Nobel Prize").with_qualifier("year", "1922"),
);
let results = hg.find_by_qualifier("year", "1921");
assert_eq!(results.len(), 1);
assert_eq!(results[0].core.subject().as_str(), "Einstein");
let empty = hg.find_by_qualifier("year", "2000");
assert!(empty.is_empty());
}
#[test]
fn test_find_by_entity() {
let mut hg = HyperGraph::new();
hg.add_hyper_triple(
HyperTriple::from_parts("Einstein", "won", "Nobel Prize")
.with_qualifier("year", "1921"),
);
hg.add_hyper_triple(
HyperTriple::from_parts("Curie", "won", "Nobel Prize")
.with_qualifier("field", "Chemistry"),
);
let results = hg.find_by_entity("Einstein");
assert_eq!(results.len(), 1);
let results = hg.find_by_entity("Nobel Prize");
assert_eq!(results.len(), 2);
let results = hg.find_by_entity("Chemistry");
assert_eq!(results.len(), 1);
assert_eq!(results[0].core.subject().as_str(), "Curie");
}
#[test]
fn test_hyper_triples_for_subject() {
let mut hg = HyperGraph::new();
hg.add_hyper_triple(
HyperTriple::from_parts("Einstein", "won", "Nobel Prize")
.with_qualifier("year", "1921"),
);
hg.add_hyper_triple(
HyperTriple::from_parts("Einstein", "educated_at", "ETH Zurich")
.with_qualifier("degree", "PhD"),
);
hg.add_hyper_triple(
HyperTriple::from_parts("Curie", "won", "Nobel Prize").with_qualifier("year", "1903"),
);
let results = hg.hyper_triples_for_subject("Einstein");
assert_eq!(results.len(), 2);
let results = hg.hyper_triples_for_subject("Curie");
assert_eq!(results.len(), 1);
let results = hg.hyper_triples_for_subject("Bohr");
assert!(results.is_empty());
}
}