use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct NodeType(pub String);
impl NodeType {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl<S: Into<String>> From<S> for NodeType {
fn from(s: S) -> Self {
Self(s.into())
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct EdgeType {
pub src_type: NodeType,
pub relation: String,
pub dst_type: NodeType,
}
impl EdgeType {
pub fn new(
src_type: impl Into<NodeType>,
relation: impl Into<String>,
dst_type: impl Into<NodeType>,
) -> Self {
Self {
src_type: src_type.into(),
relation: relation.into(),
dst_type: dst_type.into(),
}
}
pub fn reverse(&self) -> Self {
Self {
src_type: self.dst_type.clone(),
relation: format!("rev_{}", self.relation),
dst_type: self.src_type.clone(),
}
}
}
pub type TypedNodeIndex = usize;
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EdgeStore {
src: Vec<TypedNodeIndex>,
dst: Vec<TypedNodeIndex>,
#[serde(skip)]
fwd_adj: HashMap<TypedNodeIndex, Vec<TypedNodeIndex>>,
#[serde(skip)]
rev_adj: HashMap<TypedNodeIndex, Vec<TypedNodeIndex>>,
}
impl EdgeStore {
pub fn new() -> Self {
Self::default()
}
pub fn from_edges(src: Vec<TypedNodeIndex>, dst: Vec<TypedNodeIndex>) -> Self {
debug_assert_eq!(src.len(), dst.len());
let mut store = Self {
src,
dst,
fwd_adj: HashMap::new(),
rev_adj: HashMap::new(),
};
store.rebuild_adj();
store
}
pub fn num_edges(&self) -> usize {
self.src.len()
}
pub fn len(&self) -> usize {
self.src.len()
}
pub fn is_empty(&self) -> bool {
self.src.is_empty()
}
pub fn src(&self) -> &[TypedNodeIndex] {
&self.src
}
pub fn dst(&self) -> &[TypedNodeIndex] {
&self.dst
}
pub fn edge_index(&self) -> (&[TypedNodeIndex], &[TypedNodeIndex]) {
(&self.src, &self.dst)
}
pub fn add_edge(&mut self, src: TypedNodeIndex, dst: TypedNodeIndex) {
self.src.push(src);
self.dst.push(dst);
self.fwd_adj.entry(src).or_default().push(dst);
self.rev_adj.entry(dst).or_default().push(src);
}
pub fn iter(&self) -> impl Iterator<Item = (TypedNodeIndex, TypedNodeIndex)> + '_ {
self.src.iter().copied().zip(self.dst.iter().copied())
}
pub fn neighbors(&self, src: TypedNodeIndex) -> &[TypedNodeIndex] {
self.fwd_adj.get(&src).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn incoming(&self, dst: TypedNodeIndex) -> &[TypedNodeIndex] {
self.rev_adj.get(&dst).map(|v| v.as_slice()).unwrap_or(&[])
}
pub fn rebuild_adj(&mut self) {
self.fwd_adj.clear();
self.rev_adj.clear();
for (&s, &d) in self.src.iter().zip(self.dst.iter()) {
self.fwd_adj.entry(s).or_default().push(d);
self.rev_adj.entry(d).or_default().push(s);
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct NodeStore {
pub ids: Vec<String>,
id_to_idx: HashMap<String, TypedNodeIndex>,
}
impl NodeStore {
pub fn new() -> Self {
Self::default()
}
pub fn num_nodes(&self) -> usize {
self.ids.len()
}
pub fn add_node(&mut self, id: impl Into<String>) -> TypedNodeIndex {
let id = id.into();
if let Some(&idx) = self.id_to_idx.get(&id) {
return idx;
}
let idx = self.ids.len();
self.id_to_idx.insert(id.clone(), idx);
self.ids.push(id);
idx
}
pub fn get_index(&self, id: &str) -> Option<TypedNodeIndex> {
self.id_to_idx.get(id).copied()
}
pub fn get_id(&self, idx: TypedNodeIndex) -> Option<&str> {
self.ids.get(idx).map(|s| s.as_str())
}
pub fn contains(&self, id: &str) -> bool {
self.id_to_idx.contains_key(id)
}
}
#[derive(Serialize, Deserialize)]
struct HeteroGraphSerde {
node_stores: HashMap<NodeType, NodeStore>,
edge_stores: Vec<(EdgeType, EdgeStore)>,
}
#[derive(Debug, Clone, Default)]
pub struct HeteroGraph {
node_stores: HashMap<NodeType, NodeStore>,
edge_stores: HashMap<EdgeType, EdgeStore>,
}
impl serde::Serialize for HeteroGraph {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let proxy = HeteroGraphSerde {
node_stores: self.node_stores.clone(),
edge_stores: self
.edge_stores
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect(),
};
proxy.serialize(serializer)
}
}
impl<'de> serde::Deserialize<'de> for HeteroGraph {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = HeteroGraphSerde::deserialize(deserializer)?;
let mut hg = Self {
node_stores: raw.node_stores,
edge_stores: raw.edge_stores.into_iter().collect(),
};
hg.rebuild_adjacency();
Ok(hg)
}
}
impl HeteroGraph {
pub fn new() -> Self {
Self::default()
}
pub fn with_capacity(node_types: usize, edge_types: usize) -> Self {
Self {
node_stores: HashMap::with_capacity(node_types),
edge_stores: HashMap::with_capacity(edge_types),
}
}
pub fn num_node_types(&self) -> usize {
self.node_stores.len()
}
pub fn num_edge_types(&self) -> usize {
self.edge_stores.len()
}
pub fn node_types(&self) -> impl Iterator<Item = &NodeType> {
self.node_stores.keys()
}
pub fn edge_types(&self) -> impl Iterator<Item = &EdgeType> {
self.edge_stores.keys()
}
pub fn add_node(&mut self, node_type: NodeType, id: impl Into<String>) -> TypedNodeIndex {
self.node_stores.entry(node_type).or_default().add_node(id)
}
pub fn add_edge(&mut self, edge_type: &EdgeType, src_id: &str, dst_id: &str) {
let src_idx = self.add_node(edge_type.src_type.clone(), src_id);
let dst_idx = self.add_node(edge_type.dst_type.clone(), dst_id);
self.edge_stores
.entry(edge_type.clone())
.or_default()
.add_edge(src_idx, dst_idx);
}
pub fn add_edge_bidirectional(&mut self, edge_type: &EdgeType, src_id: &str, dst_id: &str) {
self.add_edge(edge_type, src_id, dst_id);
self.add_edge(&edge_type.reverse(), dst_id, src_id);
}
pub fn node_store(&self, node_type: &NodeType) -> Option<&NodeStore> {
self.node_stores.get(node_type)
}
pub fn edge_store(&self, edge_type: &EdgeType) -> Option<&EdgeStore> {
self.edge_stores.get(edge_type)
}
pub fn node_store_mut(&mut self, node_type: &NodeType) -> Option<&mut NodeStore> {
self.node_stores.get_mut(node_type)
}
pub fn edge_store_mut(&mut self, edge_type: &EdgeType) -> Option<&mut EdgeStore> {
self.edge_stores.get_mut(edge_type)
}
pub fn num_nodes(&self, node_type: &NodeType) -> usize {
self.node_stores
.get(node_type)
.map(|s| s.num_nodes())
.unwrap_or(0)
}
pub fn num_edges(&self, edge_type: &EdgeType) -> usize {
self.edge_stores
.get(edge_type)
.map(|s| s.num_edges())
.unwrap_or(0)
}
pub fn total_nodes(&self) -> usize {
self.node_stores.values().map(|s| s.num_nodes()).sum()
}
pub fn total_edges(&self) -> usize {
self.edge_stores.values().map(|s| s.num_edges()).sum()
}
pub fn get_node_index(&self, node_type: &NodeType, id: &str) -> Option<TypedNodeIndex> {
self.node_stores.get(node_type)?.get_index(id)
}
pub fn get_node_id(&self, node_type: &NodeType, idx: TypedNodeIndex) -> Option<&str> {
self.node_stores.get(node_type)?.get_id(idx)
}
pub fn neighbors(&self, edge_type: &EdgeType, src_idx: TypedNodeIndex) -> Vec<TypedNodeIndex> {
self.edge_stores
.get(edge_type)
.map(|store| store.neighbors(src_idx).to_vec())
.unwrap_or_default()
}
pub fn incoming_neighbors(
&self,
edge_type: &EdgeType,
dst_idx: TypedNodeIndex,
) -> Vec<TypedNodeIndex> {
self.edge_stores
.get(edge_type)
.map(|store| store.incoming(dst_idx).to_vec())
.unwrap_or_default()
}
pub fn neighbors_by_id<'a>(&'a self, edge_type: &EdgeType, src_id: &str) -> Vec<&'a str> {
let src_idx = match self.get_node_index(&edge_type.src_type, src_id) {
Some(idx) => idx,
None => return Vec::new(),
};
let dst_store = match self.node_stores.get(&edge_type.dst_type) {
Some(s) => s,
None => return Vec::new(),
};
self.neighbors(edge_type, src_idx)
.into_iter()
.filter_map(|idx| dst_store.get_id(idx))
.collect()
}
pub fn out_degree(&self, edge_type: &EdgeType, node_idx: TypedNodeIndex) -> usize {
self.edge_stores
.get(edge_type)
.map(|store| store.neighbors(node_idx).len())
.unwrap_or(0)
}
pub fn in_degree(&self, edge_type: &EdgeType, node_idx: TypedNodeIndex) -> usize {
self.edge_stores
.get(edge_type)
.map(|store| store.incoming(node_idx).len())
.unwrap_or(0)
}
pub fn rebuild_adjacency(&mut self) {
for store in self.edge_stores.values_mut() {
store.rebuild_adj();
}
}
pub fn metapath_neighbors(
&self,
_start_type: &NodeType,
start_idx: TypedNodeIndex,
metapath: &[EdgeType],
) -> HashSet<TypedNodeIndex> {
let mut current: HashSet<TypedNodeIndex> = [start_idx].into_iter().collect();
for edge_type in metapath {
let mut next = HashSet::new();
for &idx in ¤t {
for neighbor in self.neighbors(edge_type, idx) {
next.insert(neighbor);
}
}
current = next;
}
current
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeteroGraphStats {
pub num_node_types: usize,
pub num_edge_types: usize,
pub total_nodes: usize,
pub total_edges: usize,
pub nodes_by_type: HashMap<String, usize>,
pub edges_by_type: HashMap<String, usize>,
}
impl HeteroGraph {
pub fn stats(&self) -> HeteroGraphStats {
HeteroGraphStats {
num_node_types: self.num_node_types(),
num_edge_types: self.num_edge_types(),
total_nodes: self.total_nodes(),
total_edges: self.total_edges(),
nodes_by_type: self
.node_stores
.iter()
.map(|(t, s)| (t.0.clone(), s.num_nodes()))
.collect(),
edges_by_type: self
.edge_stores
.iter()
.map(|(t, s)| {
(
format!("{}->{}:{}", t.src_type.0, t.dst_type.0, t.relation),
s.num_edges(),
)
})
.collect(),
}
}
}
impl HeteroGraph {
pub fn to_knowledge_graph(&self) -> crate::KnowledgeGraph {
let mut kg = crate::KnowledgeGraph::new();
for (edge_type, edge_store) in &self.edge_stores {
let src_store = match self.node_stores.get(&edge_type.src_type) {
Some(s) => s,
None => continue,
};
let dst_store = match self.node_stores.get(&edge_type.dst_type) {
Some(s) => s,
None => continue,
};
for (&s, &d) in edge_store.src().iter().zip(edge_store.dst().iter()) {
if let (Some(subj), Some(obj)) = (src_store.get_id(s), dst_store.get_id(d)) {
kg.add_triple(crate::Triple::new(subj, &*edge_type.relation, obj));
}
}
}
kg
}
}
impl HeteroGraph {
pub fn from_knowledge_graph(kg: &crate::KnowledgeGraph) -> Self {
Self::from(kg)
}
}
impl From<&crate::KnowledgeGraph> for HeteroGraph {
fn from(kg: &crate::KnowledgeGraph) -> Self {
let mut hg = HeteroGraph::new();
let entity_type = NodeType::new("entity");
for triple in kg.triples() {
let edge_type = EdgeType::new(
entity_type.clone(),
triple.predicate().as_str(),
entity_type.clone(),
);
hg.add_edge(
&edge_type,
triple.subject().as_str(),
triple.object().as_str(),
);
}
hg
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hetero_graph_basic() {
let mut hg = HeteroGraph::new();
let user = NodeType::new("user");
let item = NodeType::new("item");
hg.add_node(user.clone(), "alice");
hg.add_node(user.clone(), "bob");
hg.add_node(item.clone(), "book1");
hg.add_node(item.clone(), "book2");
assert_eq!(hg.num_node_types(), 2);
assert_eq!(hg.num_nodes(&user), 2);
assert_eq!(hg.num_nodes(&item), 2);
}
#[test]
fn test_hetero_graph_edges() {
let mut hg = HeteroGraph::new();
let buys = EdgeType::new("user", "buys", "item");
hg.add_edge(&buys, "alice", "book1");
hg.add_edge(&buys, "bob", "book1");
hg.add_edge(&buys, "alice", "book2");
assert_eq!(hg.num_edge_types(), 1);
assert_eq!(hg.num_edges(&buys), 3);
let alice_idx = hg.get_node_index(&NodeType::new("user"), "alice").unwrap();
let neighbors = hg.neighbors(&buys, alice_idx);
assert_eq!(neighbors.len(), 2);
}
#[test]
fn test_hetero_graph_bidirectional() {
let mut hg = HeteroGraph::new();
let follows = EdgeType::new("user", "follows", "user");
hg.add_edge_bidirectional(&follows, "alice", "bob");
assert_eq!(hg.num_edge_types(), 2); assert_eq!(hg.total_edges(), 2);
}
#[test]
fn test_metapath() {
let mut hg = HeteroGraph::new();
let writes = EdgeType::new("author", "writes", "paper");
let cites = EdgeType::new("paper", "cites", "paper");
hg.add_edge(&writes, "alice", "paper1");
hg.add_edge(&writes, "bob", "paper2");
hg.add_edge_bidirectional(&cites, "paper2", "paper1");
let alice_idx = hg
.get_node_index(&NodeType::new("author"), "alice")
.unwrap();
let metapath = vec![writes.clone(), cites.reverse()];
let reachable = hg.metapath_neighbors(&NodeType::new("author"), alice_idx, &metapath);
assert_eq!(reachable.len(), 1);
}
#[test]
fn test_from_knowledge_graph() {
let mut kg = crate::KnowledgeGraph::new();
kg.add_triple(crate::Triple::new("Alice", "knows", "Bob"));
kg.add_triple(crate::Triple::new("Bob", "works_at", "Acme"));
let hg = HeteroGraph::from(&kg);
assert_eq!(hg.num_node_types(), 1); assert_eq!(hg.num_edge_types(), 2); assert_eq!(hg.total_nodes(), 3);
assert_eq!(hg.total_edges(), 2);
}
#[test]
fn test_adjacency_index_neighbors() {
let mut hg = HeteroGraph::new();
let buys = EdgeType::new("user", "buys", "item");
hg.add_edge(&buys, "alice", "book1");
hg.add_edge(&buys, "alice", "book2");
hg.add_edge(&buys, "bob", "book1");
let alice_idx = hg.get_node_index(&NodeType::new("user"), "alice").unwrap();
let bob_idx = hg.get_node_index(&NodeType::new("user"), "bob").unwrap();
let book1_idx = hg.get_node_index(&NodeType::new("item"), "book1").unwrap();
let alice_neighbors = hg.neighbors(&buys, alice_idx);
assert_eq!(alice_neighbors.len(), 2);
let bob_neighbors = hg.neighbors(&buys, bob_idx);
assert_eq!(bob_neighbors.len(), 1);
let book1_incoming = hg.incoming_neighbors(&buys, book1_idx);
assert_eq!(book1_incoming.len(), 2);
}
#[test]
fn test_neighbors_by_id() {
let mut hg = HeteroGraph::new();
let buys = EdgeType::new("user", "buys", "item");
hg.add_edge(&buys, "alice", "book1");
hg.add_edge(&buys, "alice", "book2");
let mut neighbors = hg.neighbors_by_id(&buys, "alice");
neighbors.sort();
assert_eq!(neighbors, vec!["book1", "book2"]);
assert!(hg.neighbors_by_id(&buys, "nobody").is_empty());
}
#[test]
fn test_degree_methods() {
let mut hg = HeteroGraph::new();
let buys = EdgeType::new("user", "buys", "item");
hg.add_edge(&buys, "alice", "book1");
hg.add_edge(&buys, "alice", "book2");
hg.add_edge(&buys, "bob", "book1");
let alice_idx = hg.get_node_index(&NodeType::new("user"), "alice").unwrap();
let book1_idx = hg.get_node_index(&NodeType::new("item"), "book1").unwrap();
assert_eq!(hg.out_degree(&buys, alice_idx), 2);
assert_eq!(hg.in_degree(&buys, book1_idx), 2);
let fake = EdgeType::new("a", "b", "c");
assert_eq!(hg.out_degree(&fake, 0), 0);
}
#[test]
fn test_rebuild_adjacency() {
let mut hg = HeteroGraph::new();
let buys = EdgeType::new("user", "buys", "item");
hg.add_edge(&buys, "alice", "book1");
hg.add_edge(&buys, "alice", "book2");
for store in hg.edge_stores.values_mut() {
store.fwd_adj.clear();
store.rev_adj.clear();
}
let alice_idx = hg.get_node_index(&NodeType::new("user"), "alice").unwrap();
assert!(hg.neighbors(&buys, alice_idx).is_empty());
hg.rebuild_adjacency();
let neighbors = hg.neighbors(&buys, alice_idx);
assert_eq!(neighbors.len(), 2);
}
#[test]
fn test_to_knowledge_graph() {
let mut hg = HeteroGraph::new();
let buys = EdgeType::new("user", "buys", "item");
let follows = EdgeType::new("user", "follows", "user");
hg.add_edge(&buys, "alice", "book1");
hg.add_edge(&follows, "alice", "bob");
let kg = hg.to_knowledge_graph();
assert_eq!(kg.triple_count(), 2);
assert_eq!(kg.entity_count(), 3); }
#[test]
fn test_to_knowledge_graph_roundtrip() {
let mut kg = crate::KnowledgeGraph::new();
kg.add_triple(crate::Triple::new("Alice", "knows", "Bob"));
kg.add_triple(crate::Triple::new("Bob", "works_at", "Acme"));
let hg = HeteroGraph::from(&kg);
let kg2 = hg.to_knowledge_graph();
assert_eq!(kg2.entity_count(), kg.entity_count());
assert_eq!(kg2.triple_count(), kg.triple_count());
}
#[test]
fn test_edge_store_from_edges_builds_adj() {
let store = EdgeStore::from_edges(vec![0, 0, 1], vec![1, 2, 2]);
assert_eq!(store.neighbors(0), &[1, 2]);
assert_eq!(store.neighbors(1), &[2]);
assert_eq!(store.incoming(2), &[0, 1]);
}
#[test]
fn test_serde_roundtrip_queries_work() {
let mut hg = HeteroGraph::new();
let buys = EdgeType::new("user", "buys", "item");
let follows = EdgeType::new("user", "follows", "user");
hg.add_edge(&buys, "alice", "book1");
hg.add_edge(&buys, "alice", "book2");
hg.add_edge(&buys, "bob", "book1");
hg.add_edge(&follows, "alice", "bob");
let alice_buys_pre = {
let mut v = hg.neighbors_by_id(&buys, "alice");
v.sort();
v
};
let bob_buys_pre = hg.neighbors_by_id(&buys, "bob");
let alice_follows_pre = hg.neighbors_by_id(&follows, "alice");
let book1_incoming_pre = {
let idx = hg.get_node_index(&NodeType::new("item"), "book1").unwrap();
let mut v = hg.incoming_neighbors(&buys, idx);
v.sort();
v
};
let json = serde_json::to_string(&hg).expect("serialize");
let recovered: HeteroGraph = serde_json::from_str(&json).expect("deserialize");
assert_eq!(recovered.total_nodes(), hg.total_nodes());
assert_eq!(recovered.total_edges(), hg.total_edges());
assert_eq!(recovered.num_node_types(), hg.num_node_types());
assert_eq!(recovered.num_edge_types(), hg.num_edge_types());
let mut alice_buys_post = recovered.neighbors_by_id(&buys, "alice");
alice_buys_post.sort();
assert_eq!(alice_buys_post, alice_buys_pre);
assert_eq!(recovered.neighbors_by_id(&buys, "bob"), bob_buys_pre);
assert_eq!(
recovered.neighbors_by_id(&follows, "alice"),
alice_follows_pre
);
let book1_idx = recovered
.get_node_index(&NodeType::new("item"), "book1")
.unwrap();
let mut book1_incoming_post = recovered.incoming_neighbors(&buys, book1_idx);
book1_incoming_post.sort();
assert_eq!(book1_incoming_post, book1_incoming_pre);
}
}