use std::collections::HashMap;
use std::collections::hash_map::{Entry, Iter, IterMut};
use nodedb_types::TenantId;
use crate::GraphError;
use crate::csr::CsrIndex;
pub struct ShardedCsrIndex {
partitions: HashMap<TenantId, CsrIndex>,
}
impl ShardedCsrIndex {
pub fn new() -> Self {
Self {
partitions: HashMap::new(),
}
}
pub fn partition(&self, tid: TenantId) -> Option<&CsrIndex> {
self.partitions.get(&tid)
}
pub fn partition_mut(&mut self, tid: TenantId) -> Option<&mut CsrIndex> {
self.partitions.get_mut(&tid)
}
pub fn get_or_create(&mut self, tid: TenantId) -> &mut CsrIndex {
self.partitions.entry(tid).or_default()
}
pub fn drop_partition(&mut self, tid: TenantId) -> bool {
self.partitions.remove(&tid).is_some()
}
pub fn drop_collection(&mut self, _tid: TenantId, _collection: &str) {
}
pub fn contains_partition(&self, tid: TenantId) -> bool {
self.partitions.contains_key(&tid)
}
pub fn partition_count(&self) -> usize {
self.partitions.len()
}
pub fn iter(&self) -> Iter<'_, TenantId, CsrIndex> {
self.partitions.iter()
}
pub fn iter_mut(&mut self) -> IterMut<'_, TenantId, CsrIndex> {
self.partitions.iter_mut()
}
pub fn compact_all(&mut self) -> Result<(), GraphError> {
for (_tid, part) in self.iter_mut() {
part.compact()?;
}
Ok(())
}
pub fn install_partition(&mut self, tid: TenantId, csr: CsrIndex) {
self.partitions.insert(tid, csr);
}
pub fn entry(&mut self, tid: TenantId) -> Entry<'_, TenantId, CsrIndex> {
self.partitions.entry(tid)
}
}
impl Default for ShardedCsrIndex {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tid(n: u64) -> TenantId {
TenantId::new(n)
}
#[test]
fn empty_sharded_has_no_partitions() {
let sharded = ShardedCsrIndex::new();
assert_eq!(sharded.partition_count(), 0);
assert!(!sharded.contains_partition(tid(1)));
assert!(sharded.partition(tid(1)).is_none());
}
#[test]
fn get_or_create_installs_empty_partition() {
let mut sharded = ShardedCsrIndex::new();
let part = sharded.get_or_create(tid(7));
assert_eq!(part.node_count(), 0);
assert!(sharded.contains_partition(tid(7)));
assert_eq!(sharded.partition_count(), 1);
}
#[test]
fn partitions_are_isolated_by_tenant() {
let mut sharded = ShardedCsrIndex::new();
sharded
.get_or_create(tid(1))
.add_edge("alice", "knows", "bob")
.unwrap();
sharded
.get_or_create(tid(2))
.add_edge("alice", "knows", "carol")
.unwrap();
let p1 = sharded.partition(tid(1)).unwrap();
let p2 = sharded.partition(tid(2)).unwrap();
assert!(p1.contains_node("alice"));
assert!(p1.contains_node("bob"));
assert!(!p1.contains_node("carol"));
assert!(p2.contains_node("alice"));
assert!(p2.contains_node("carol"));
assert!(!p2.contains_node("bob"));
}
#[test]
fn node_names_are_unprefixed() {
let mut sharded = ShardedCsrIndex::new();
sharded
.get_or_create(tid(42))
.add_edge("alice", "knows", "bob")
.unwrap();
sharded
.get_or_create(tid(42))
.compact()
.expect("no governor, cannot fail");
let part = sharded.partition(tid(42)).unwrap();
let alice_id = part.node_id("alice").expect("alice must be present");
assert_eq!(part.node_name(alice_id), "alice");
}
#[test]
fn drop_partition_removes_tenant_state() {
let mut sharded = ShardedCsrIndex::new();
sharded
.get_or_create(tid(1))
.add_edge("a", "l", "b")
.unwrap();
assert!(sharded.contains_partition(tid(1)));
assert!(sharded.drop_partition(tid(1)));
assert!(!sharded.contains_partition(tid(1)));
assert_eq!(sharded.partition_count(), 0);
assert!(!sharded.drop_partition(tid(1)));
}
#[test]
fn drop_partition_does_not_touch_other_tenants() {
let mut sharded = ShardedCsrIndex::new();
sharded
.get_or_create(tid(1))
.add_edge("a", "l", "b")
.unwrap();
sharded
.get_or_create(tid(2))
.add_edge("c", "l", "d")
.unwrap();
sharded.drop_partition(tid(1));
assert!(!sharded.contains_partition(tid(1)));
assert!(sharded.contains_partition(tid(2)));
assert!(sharded.partition(tid(2)).unwrap().contains_node("c"));
}
#[test]
fn install_partition_replaces_existing() {
let mut sharded = ShardedCsrIndex::new();
sharded
.get_or_create(tid(1))
.add_edge("old", "l", "value")
.unwrap();
let mut replacement = CsrIndex::new();
replacement.add_edge("new", "l", "value").unwrap();
sharded.install_partition(tid(1), replacement);
let part = sharded.partition(tid(1)).unwrap();
assert!(part.contains_node("new"));
assert!(!part.contains_node("old"));
}
#[test]
fn compact_all_applies_to_every_partition() {
let mut sharded = ShardedCsrIndex::new();
for t in 1..=3 {
sharded
.get_or_create(tid(t))
.add_edge("a", "l", "b")
.unwrap();
}
sharded.compact_all().expect("no governor, cannot fail");
for t in 1..=3 {
let part = sharded.partition(tid(t)).unwrap();
assert_eq!(part.edge_count(), 1);
assert_eq!(part.node_count(), 2);
}
}
}