use std::collections::{HashMap, HashSet};
use anyhow::{Context, Result};
use crate::memory::graph::types::{EntityOccurrenceIndex, GraphEdge};
const DEFAULT_LIMIT: usize = 100;
pub fn co_occurring_entities(
index: &dyn EntityOccurrenceIndex,
subject_entity: &str,
limit: Option<usize>,
) -> Result<Vec<GraphEdge>> {
let cap = limit.unwrap_or(DEFAULT_LIMIT);
let mut shared: HashMap<String, HashSet<String>> = HashMap::new();
let subject_nodes = index
.nodes_for_entity(subject_entity)
.with_context(|| format!("nodes_for_entity({subject_entity})"))?;
for node_id in subject_nodes {
let entities = index
.entities_on_node(&node_id)
.with_context(|| format!("entities_on_node({node_id})"))?;
for object in entities {
if object == subject_entity {
continue;
}
shared.entry(object).or_default().insert(node_id.clone());
}
}
let mut edges: Vec<GraphEdge> = shared
.into_iter()
.map(|(object, nodes)| GraphEdge {
subject: subject_entity.to_string(),
object,
weight: nodes.len().min(u32::MAX as usize) as u32,
})
.collect();
edges.sort_by(|a, b| {
b.weight
.cmp(&a.weight)
.then_with(|| a.object.cmp(&b.object))
});
edges.truncate(cap);
Ok(edges)
}
pub fn neighbors(
index: &dyn EntityOccurrenceIndex,
subject_entity: &str,
limit: Option<usize>,
) -> Result<Vec<String>> {
Ok(co_occurring_entities(index, subject_entity, limit)?
.into_iter()
.map(|e| e.object)
.collect())
}
pub fn group_by_weight(edges: Vec<GraphEdge>) -> HashMap<u32, Vec<String>> {
let mut out: HashMap<u32, Vec<String>> = HashMap::new();
for e in edges {
out.entry(e.weight).or_default().push(e.object);
}
out
}
#[cfg(test)]
#[path = "query_tests.rs"]
mod tests;