use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use crate::cache::InvalidationEvent;
use crate::commit::TenantId;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AdjacencyRow {
pub rid: String,
pub weight: f64,
}
type AdjacencyList = Vec<AdjacencyRow>;
pub trait EntityAdjacencyProvider: Send + Sync {
fn top_adjacent(&self, tenant_id: TenantId, entity_id: &str, limit: usize)
-> Vec<AdjacencyRow>;
fn top_adjacent_union(
&self,
tenant_id: TenantId,
entity_ids: &[&str],
per_entity_limit: usize,
) -> Vec<AdjacencyRow> {
let mut seen = std::collections::HashSet::new();
let mut out = Vec::new();
for eid in entity_ids {
for row in self.top_adjacent(tenant_id, eid, per_entity_limit) {
if seen.insert(row.rid.clone()) {
out.push(row);
}
}
}
out
}
}
pub struct EntityAdjacencyIndex {
inner: Arc<RwLock<HashMap<TenantId, HashMap<String, AdjacencyList>>>>,
}
impl EntityAdjacencyIndex {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn upsert_edge(&self, tenant_id: TenantId, src: &str, dst: &str, weight: f64) {
let mut map = self.inner.write();
let tenant_map = map.entry(tenant_id).or_default();
Self::upsert_directed(tenant_map, src, dst, weight);
Self::upsert_directed(tenant_map, dst, src, weight);
}
pub fn delete_edge(&self, tenant_id: TenantId, src: &str, dst: &str) {
let mut map = self.inner.write();
if let Some(tenant_map) = map.get_mut(&tenant_id) {
Self::delete_directed(tenant_map, src, dst);
Self::delete_directed(tenant_map, dst, src);
}
}
fn upsert_directed(
tenant_map: &mut HashMap<String, AdjacencyList>,
from: &str,
to: &str,
weight: f64,
) {
let list = tenant_map.entry(from.to_string()).or_default();
if let Some(existing) = list.iter_mut().find(|r| r.rid == to) {
existing.weight = weight;
} else {
list.push(AdjacencyRow {
rid: to.to_string(),
weight,
});
}
list.sort_by(|a, b| {
b.weight
.partial_cmp(&a.weight)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.rid.cmp(&b.rid))
});
}
fn delete_directed(tenant_map: &mut HashMap<String, AdjacencyList>, from: &str, to: &str) {
if let Some(list) = tenant_map.get_mut(from) {
list.retain(|r| r.rid != to);
if list.is_empty() {
tenant_map.remove(from);
}
}
}
pub fn clear_tenant(&self, tenant_id: TenantId) {
self.inner.write().remove(&tenant_id);
}
pub fn total_entries(&self) -> usize {
self.inner
.read()
.values()
.map(|m| m.values().map(|v| v.len()).sum::<usize>())
.sum()
}
pub fn tenant_count(&self) -> usize {
self.inner.read().len()
}
pub fn entity_count(&self, tenant_id: TenantId) -> usize {
self.inner
.read()
.get(&tenant_id)
.map(|m| m.len())
.unwrap_or(0)
}
}
impl Default for EntityAdjacencyIndex {
fn default() -> Self {
Self::new()
}
}
impl Clone for EntityAdjacencyIndex {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl EntityAdjacencyProvider for EntityAdjacencyIndex {
fn top_adjacent(
&self,
tenant_id: TenantId,
entity_id: &str,
limit: usize,
) -> Vec<AdjacencyRow> {
let map = self.inner.read();
let Some(tenant_map) = map.get(&tenant_id) else {
return Vec::new();
};
let Some(list) = tenant_map.get(entity_id) else {
return Vec::new();
};
list.iter().take(limit).cloned().collect()
}
}
pub fn spawn_invalidation_bus_subscriber(
index: EntityAdjacencyIndex,
bus: &crate::cache::InvalidationBus,
) -> tokio::task::JoinHandle<()> {
let mut rx = bus.subscribe();
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(InvalidationEvent::EdgeChanged {
tenant_id,
src,
dst,
}) => {
index.delete_edge(tenant_id, &src, &dst);
index.upsert_edge(tenant_id, &src, &dst, 1.0);
}
Ok(InvalidationEvent::TenantConfigChanged { .. }) => {
}
Ok(InvalidationEvent::Tombstoned { .. })
| Ok(InvalidationEvent::Updated { .. }) => {
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => {
let map_clone = Arc::clone(&index.inner);
map_clone.write().clear();
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
}
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_index_returns_no_neighbors() {
let idx = EntityAdjacencyIndex::new();
assert!(idx.top_adjacent(TenantId::new(1), "alice", 10).is_empty());
assert_eq!(idx.tenant_count(), 0);
assert_eq!(idx.total_entries(), 0);
}
#[test]
fn upsert_then_lookup_returns_both_directions() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.7);
let alice = idx.top_adjacent(TenantId::new(1), "alice", 10);
assert_eq!(alice.len(), 1);
assert_eq!(alice[0].rid, "bob");
assert_eq!(alice[0].weight, 0.7);
let bob = idx.top_adjacent(TenantId::new(1), "bob", 10);
assert_eq!(bob.len(), 1);
assert_eq!(bob[0].rid, "alice");
}
#[test]
fn list_is_sorted_by_weight_desc() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "low", 0.1);
idx.upsert_edge(TenantId::new(1), "alice", "high", 0.9);
idx.upsert_edge(TenantId::new(1), "alice", "mid", 0.5);
let res = idx.top_adjacent(TenantId::new(1), "alice", 10);
assert_eq!(res.len(), 3);
assert_eq!(res[0].rid, "high");
assert_eq!(res[1].rid, "mid");
assert_eq!(res[2].rid, "low");
}
#[test]
fn limit_truncates_prefix() {
let idx = EntityAdjacencyIndex::new();
for i in 0..10 {
idx.upsert_edge(TenantId::new(1), "alice", &format!("n{i}"), i as f64);
}
let res = idx.top_adjacent(TenantId::new(1), "alice", 3);
assert_eq!(res.len(), 3);
assert_eq!(res[0].rid, "n9");
assert_eq!(res[1].rid, "n8");
assert_eq!(res[2].rid, "n7");
}
#[test]
fn upsert_replaces_weight_not_duplicates() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.3);
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.8);
let res = idx.top_adjacent(TenantId::new(1), "alice", 10);
assert_eq!(res.len(), 1);
assert_eq!(res[0].weight, 0.8);
}
#[test]
fn delete_edge_removes_both_directions() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.5);
idx.delete_edge(TenantId::new(1), "alice", "bob");
assert!(idx.top_adjacent(TenantId::new(1), "alice", 10).is_empty());
assert!(idx.top_adjacent(TenantId::new(1), "bob", 10).is_empty());
}
#[test]
fn delete_one_edge_keeps_others() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.5);
idx.upsert_edge(TenantId::new(1), "alice", "carol", 0.7);
idx.delete_edge(TenantId::new(1), "alice", "bob");
let alice = idx.top_adjacent(TenantId::new(1), "alice", 10);
assert_eq!(alice.len(), 1);
assert_eq!(alice[0].rid, "carol");
}
#[test]
fn delete_nonexistent_edge_is_noop() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.5);
idx.delete_edge(TenantId::new(1), "alice", "ghost"); idx.delete_edge(TenantId::new(2), "alice", "bob"); assert_eq!(idx.top_adjacent(TenantId::new(1), "alice", 10).len(), 1);
}
#[test]
fn per_tenant_isolation() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.5);
idx.upsert_edge(TenantId::new(2), "alice", "carol", 0.9);
let t1 = idx.top_adjacent(TenantId::new(1), "alice", 10);
let t2 = idx.top_adjacent(TenantId::new(2), "alice", 10);
assert_eq!(t1.len(), 1);
assert_eq!(t1[0].rid, "bob");
assert_eq!(t2.len(), 1);
assert_eq!(t2[0].rid, "carol");
}
#[test]
fn clear_tenant_drops_only_that_tenant() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "a", "b", 0.5);
idx.upsert_edge(TenantId::new(2), "a", "b", 0.5);
idx.clear_tenant(TenantId::new(1));
assert!(idx.top_adjacent(TenantId::new(1), "a", 10).is_empty());
assert_eq!(idx.top_adjacent(TenantId::new(2), "a", 10).len(), 1);
}
#[test]
fn metrics_track_counts() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "a", "b", 0.5);
idx.upsert_edge(TenantId::new(1), "a", "c", 0.5);
idx.upsert_edge(TenantId::new(2), "x", "y", 0.5);
assert_eq!(idx.total_entries(), 6);
assert_eq!(idx.tenant_count(), 2);
assert_eq!(idx.entity_count(TenantId::new(1)), 3);
assert_eq!(idx.entity_count(TenantId::new(2)), 2);
}
#[test]
fn top_adjacent_union_dedupes_across_seeds() {
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.5);
idx.upsert_edge(TenantId::new(1), "alice", "charlie", 0.7);
idx.upsert_edge(TenantId::new(1), "bob", "charlie", 0.3);
let union = idx.top_adjacent_union(TenantId::new(1), &["alice", "bob"], 10);
let rids: Vec<&str> = union.iter().map(|r| r.rid.as_str()).collect();
assert!(rids.contains(&"alice"));
assert!(rids.contains(&"bob"));
assert!(rids.contains(&"charlie"));
let mut seen = std::collections::HashSet::new();
for r in &rids {
assert!(seen.insert(r), "duplicate in union: {r}");
}
}
#[test]
fn dyn_dispatch_works() {
let idx: Arc<dyn EntityAdjacencyProvider> = Arc::new(EntityAdjacencyIndex::new());
let r = idx.top_adjacent(TenantId::new(1), "ghost", 10);
assert!(r.is_empty());
}
#[tokio::test]
async fn invalidation_bus_subscriber_applies_edge_changed() {
let bus = crate::cache::InvalidationBus::new();
let idx = EntityAdjacencyIndex::new();
let handle = spawn_invalidation_bus_subscriber(idx.clone(), &bus);
bus.publish(InvalidationEvent::EdgeChanged {
tenant_id: TenantId::new(1),
src: "alice".into(),
dst: "bob".into(),
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let r = idx.top_adjacent(TenantId::new(1), "alice", 10);
assert_eq!(r.len(), 1);
assert_eq!(r[0].rid, "bob");
handle.abort();
}
#[tokio::test]
async fn invalidation_bus_subscriber_ignores_unrelated_events() {
let bus = crate::cache::InvalidationBus::new();
let idx = EntityAdjacencyIndex::new();
idx.upsert_edge(TenantId::new(1), "alice", "bob", 0.5);
let handle = spawn_invalidation_bus_subscriber(idx.clone(), &bus);
bus.publish(InvalidationEvent::Tombstoned {
tenant_id: TenantId::new(1),
rid: "some_memory".into(),
});
bus.publish(InvalidationEvent::Updated {
tenant_id: TenantId::new(1),
rid: "some_memory".into(),
});
bus.publish(InvalidationEvent::TenantConfigChanged {
tenant_id: TenantId::new(1),
key: "some_key".into(),
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let r = idx.top_adjacent(TenantId::new(1), "alice", 10);
assert_eq!(r.len(), 1);
assert_eq!(r[0].rid, "bob");
handle.abort();
}
}