use std::collections::HashMap;
use std::sync::Arc;
use arc_swap::ArcSwap;
use omnigraph::db::Omnigraph;
use omnigraph::storage::normalize_root_uri;
#[cfg(test)]
use tokio::sync::Mutex;
use crate::identity::GraphKey;
use crate::policy::PolicyEngine;
pub struct GraphHandle {
pub key: GraphKey,
pub uri: String,
pub engine: Arc<Omnigraph>,
pub policy: Option<Arc<PolicyEngine>>,
}
pub struct RegistrySnapshot {
pub graphs: HashMap<GraphKey, Arc<GraphHandle>>,
pub any_per_graph_policy: bool,
}
impl RegistrySnapshot {
pub fn new(graphs: HashMap<GraphKey, Arc<GraphHandle>>) -> Self {
let any_per_graph_policy = graphs.values().any(|h| h.policy.is_some());
Self {
graphs,
any_per_graph_policy,
}
}
}
impl Default for RegistrySnapshot {
fn default() -> Self {
Self::new(HashMap::new())
}
}
pub enum RegistryLookup {
Ready(Arc<GraphHandle>),
Gone,
}
#[derive(Debug, thiserror::Error)]
pub enum InsertError {
#[error("graph '{0}' is already registered")]
DuplicateKey(GraphKey),
#[error("URI '{0}' is already registered as another graph")]
DuplicateUri(String),
#[error("URI '{uri}' is invalid: {message}")]
InvalidUri { uri: String, message: String },
}
pub struct GraphRegistry {
snapshot: ArcSwap<RegistrySnapshot>,
#[cfg(test)]
mutate: Mutex<()>,
}
impl GraphRegistry {
pub fn new() -> Self {
Self {
snapshot: ArcSwap::from_pointee(RegistrySnapshot::default()),
#[cfg(test)]
mutate: Mutex::new(()),
}
}
pub fn from_handles(handles: Vec<Arc<GraphHandle>>) -> Result<Self, InsertError> {
let mut graphs: HashMap<GraphKey, Arc<GraphHandle>> = HashMap::with_capacity(handles.len());
let mut seen_uris: HashMap<String, GraphKey> = HashMap::with_capacity(handles.len());
for handle in handles {
let (canonical_uri, handle) = canonicalize_handle_uri(handle)?;
if graphs.contains_key(&handle.key) {
return Err(InsertError::DuplicateKey(handle.key.clone()));
}
if seen_uris.contains_key(&canonical_uri) {
return Err(InsertError::DuplicateUri(handle.uri.clone()));
}
seen_uris.insert(canonical_uri, handle.key.clone());
graphs.insert(handle.key.clone(), handle);
}
Ok(Self {
snapshot: ArcSwap::from_pointee(RegistrySnapshot::new(graphs)),
#[cfg(test)]
mutate: Mutex::new(()),
})
}
pub fn snapshot_ref(&self) -> arc_swap::Guard<Arc<RegistrySnapshot>> {
self.snapshot.load()
}
pub fn get(&self, key: &GraphKey) -> RegistryLookup {
let snapshot = self.snapshot.load();
match snapshot.graphs.get(key) {
Some(handle) => RegistryLookup::Ready(Arc::clone(handle)),
None => RegistryLookup::Gone,
}
}
pub fn list(&self) -> Vec<Arc<GraphHandle>> {
let snapshot = self.snapshot.load();
snapshot.graphs.values().cloned().collect()
}
pub fn len(&self) -> usize {
self.snapshot.load().graphs.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[cfg(test)]
pub async fn insert(&self, handle: Arc<GraphHandle>) -> Result<(), InsertError> {
let _guard = self.mutate.lock().await;
let current = self.snapshot.load();
let (canonical_uri, handle) = canonicalize_handle_uri(handle)?;
if current.graphs.contains_key(&handle.key) {
return Err(InsertError::DuplicateKey(handle.key.clone()));
}
for existing in current.graphs.values() {
let existing_uri =
normalize_root_uri(&existing.uri).map_err(|err| InsertError::InvalidUri {
uri: existing.uri.clone(),
message: err.to_string(),
})?;
if existing_uri == canonical_uri {
return Err(InsertError::DuplicateUri(handle.uri.clone()));
}
}
let mut new_graphs = current.graphs.clone();
new_graphs.insert(handle.key.clone(), handle);
self.snapshot
.store(Arc::new(RegistrySnapshot::new(new_graphs)));
Ok(())
}
}
fn canonicalize_handle_uri(
handle: Arc<GraphHandle>,
) -> Result<(String, Arc<GraphHandle>), InsertError> {
let canonical_uri = normalize_root_uri(&handle.uri).map_err(|err| InsertError::InvalidUri {
uri: handle.uri.clone(),
message: err.to_string(),
})?;
if canonical_uri == handle.uri {
return Ok((canonical_uri, handle));
}
let canonical_handle = Arc::new(GraphHandle {
key: handle.key.clone(),
uri: canonical_uri.clone(),
engine: Arc::clone(&handle.engine),
policy: handle.policy.clone(),
});
Ok((canonical_uri, canonical_handle))
}
impl Default for GraphRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use std::path::Path;
use tempfile::TempDir;
use super::*;
use crate::graph_id::GraphId;
const TEST_SCHEMA: &str = "node Person { name: String @key }\n";
async fn build_handle(graph_id: &str, dir: &Path) -> Arc<GraphHandle> {
let graph_uri = dir.join(graph_id).to_str().unwrap().to_string();
let engine = Omnigraph::init(&graph_uri, TEST_SCHEMA)
.await
.expect("init engine for registry test");
Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from(graph_id).unwrap()),
uri: graph_uri,
engine: Arc::new(engine),
policy: None,
})
}
#[tokio::test]
async fn new_registry_is_empty() {
let registry = GraphRegistry::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
assert!(registry.list().is_empty());
}
#[tokio::test]
async fn insert_then_get_returns_ready() {
let dir = TempDir::new().unwrap();
let registry = GraphRegistry::new();
let handle = build_handle("alpha", dir.path()).await;
registry.insert(Arc::clone(&handle)).await.unwrap();
match registry.get(&handle.key) {
RegistryLookup::Ready(found) => {
assert!(Arc::ptr_eq(&found, &handle));
}
RegistryLookup::Gone => panic!("expected Ready, got Gone"),
}
}
#[tokio::test]
async fn get_nonexistent_returns_gone() {
let registry = GraphRegistry::new();
let key = GraphKey::cluster(GraphId::try_from("ghost").unwrap());
match registry.get(&key) {
RegistryLookup::Gone => {}
RegistryLookup::Ready(_) => panic!("expected Gone"),
}
}
#[tokio::test]
async fn insert_duplicate_key_returns_error() {
let dir = TempDir::new().unwrap();
let registry = GraphRegistry::new();
let h1 = build_handle("alpha", dir.path()).await;
let dir2 = TempDir::new().unwrap();
let h2 = build_handle("alpha", dir2.path()).await;
registry.insert(h1).await.unwrap();
match registry.insert(h2).await {
Err(InsertError::DuplicateKey(_)) => {}
other => panic!("expected DuplicateKey, got {other:?}"),
}
}
#[tokio::test]
async fn insert_duplicate_uri_returns_error() {
let dir = TempDir::new().unwrap();
let shared_uri = dir.path().join("shared").to_str().unwrap().to_string();
let engine = Omnigraph::init(&shared_uri, TEST_SCHEMA).await.unwrap();
let engine = Arc::new(engine);
let h1 = Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from("alpha").unwrap()),
uri: shared_uri.clone(),
engine: Arc::clone(&engine),
policy: None,
});
let h2 = Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from("beta").unwrap()),
uri: shared_uri,
engine,
policy: None,
});
let registry = GraphRegistry::new();
registry.insert(h1).await.unwrap();
match registry.insert(h2).await {
Err(InsertError::DuplicateUri(_)) => {}
other => panic!("expected DuplicateUri, got {other:?}"),
}
}
#[tokio::test]
async fn list_returns_all_inserted_handles() {
let dir = TempDir::new().unwrap();
let registry = GraphRegistry::new();
for name in ["alpha", "beta", "gamma"] {
let h = build_handle(name, dir.path()).await;
registry.insert(h).await.unwrap();
}
assert_eq!(registry.len(), 3);
let mut ids: Vec<_> = registry
.list()
.into_iter()
.map(|h| h.key.graph_id.as_str().to_string())
.collect();
ids.sort();
assert_eq!(ids, vec!["alpha", "beta", "gamma"]);
}
#[tokio::test]
async fn from_handles_bulk_init_succeeds() {
let dir = TempDir::new().unwrap();
let handles = vec![
build_handle("alpha", dir.path()).await,
build_handle("beta", dir.path()).await,
];
let registry = GraphRegistry::from_handles(handles).unwrap();
assert_eq!(registry.len(), 2);
}
#[tokio::test]
async fn from_handles_rejects_duplicate_keys() {
let dir1 = TempDir::new().unwrap();
let dir2 = TempDir::new().unwrap();
let h1 = build_handle("alpha", dir1.path()).await;
let h2 = build_handle("alpha", dir2.path()).await;
let err = match GraphRegistry::from_handles(vec![h1, h2]) {
Ok(_) => panic!("expected DuplicateKey, got Ok"),
Err(err) => err,
};
assert!(
matches!(err, InsertError::DuplicateKey(_)),
"expected DuplicateKey, got {err}",
);
}
#[tokio::test]
async fn from_handles_rejects_duplicate_uris() {
let dir = TempDir::new().unwrap();
let shared_uri = dir.path().join("shared").to_str().unwrap().to_string();
let engine = Arc::new(Omnigraph::init(&shared_uri, TEST_SCHEMA).await.unwrap());
let h1 = Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from("alpha").unwrap()),
uri: shared_uri.clone(),
engine: Arc::clone(&engine),
policy: None,
});
let h2 = Arc::new(GraphHandle {
key: GraphKey::cluster(GraphId::try_from("beta").unwrap()),
uri: shared_uri,
engine,
policy: None,
});
let err = match GraphRegistry::from_handles(vec![h1, h2]) {
Ok(_) => panic!("expected DuplicateUri, got Ok"),
Err(err) => err,
};
assert!(
matches!(err, InsertError::DuplicateUri(_)),
"expected DuplicateUri, got {err}",
);
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_insert_same_key_exactly_one_succeeds() {
const N: usize = 8;
let registry = Arc::new(GraphRegistry::new());
let mut handles = Vec::with_capacity(N);
let mut dirs = Vec::with_capacity(N);
for _ in 0..N {
let d = TempDir::new().unwrap();
handles.push(build_handle("contested", d.path()).await);
dirs.push(d);
}
let barrier = Arc::new(tokio::sync::Barrier::new(N));
let mut tasks = Vec::with_capacity(N);
for handle in handles {
let registry = Arc::clone(®istry);
let barrier = Arc::clone(&barrier);
tasks.push(tokio::spawn(async move {
barrier.wait().await;
registry.insert(handle).await
}));
}
let mut ok_count = 0usize;
let mut dup_count = 0usize;
for t in tasks {
match t.await.unwrap() {
Ok(()) => ok_count += 1,
Err(InsertError::DuplicateKey(_)) => dup_count += 1,
Err(other) => panic!("unexpected error: {other:?}"),
}
}
assert_eq!(ok_count, 1, "exactly one insert must succeed");
assert_eq!(dup_count, N - 1, "the rest must return DuplicateKey");
assert_eq!(registry.len(), 1);
drop(dirs);
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_insert_distinct_keys_all_succeed() {
const N: usize = 8;
let registry = Arc::new(GraphRegistry::new());
let mut handles = Vec::with_capacity(N);
let mut dirs = Vec::with_capacity(N);
for i in 0..N {
let d = TempDir::new().unwrap();
handles.push(build_handle(&format!("graph-{i}"), d.path()).await);
dirs.push(d);
}
let barrier = Arc::new(tokio::sync::Barrier::new(N));
let mut tasks = Vec::with_capacity(N);
for handle in handles {
let registry = Arc::clone(®istry);
let barrier = Arc::clone(&barrier);
tasks.push(tokio::spawn(async move {
barrier.wait().await;
registry.insert(handle).await
}));
}
for t in tasks {
t.await.unwrap().unwrap();
}
assert_eq!(registry.len(), N);
drop(dirs);
}
#[tokio::test(flavor = "multi_thread")]
async fn concurrent_reads_during_inserts_see_consistent_snapshots() {
let dir = TempDir::new().unwrap();
let registry = Arc::new(GraphRegistry::new());
const N_WRITES: usize = 10;
let writer_registry = Arc::clone(®istry);
let writer_dir = dir.path().to_path_buf();
let writer = tokio::spawn(async move {
for i in 0..N_WRITES {
let h = build_handle(&format!("graph-{i}"), &writer_dir).await;
writer_registry.insert(h).await.unwrap();
}
});
let reader_registry = Arc::clone(®istry);
let reader = tokio::spawn(async move {
for _ in 0..200 {
let snap = reader_registry.list();
assert!(snap.len() <= N_WRITES);
for handle in &snap {
match reader_registry.get(&handle.key) {
RegistryLookup::Ready(found) => {
assert!(Arc::ptr_eq(&found, handle));
}
RegistryLookup::Gone => panic!(
"snapshot listed key {} but get() returned Gone",
handle.key.graph_id
),
}
}
tokio::task::yield_now().await;
}
});
writer.await.unwrap();
reader.await.unwrap();
assert_eq!(registry.len(), N_WRITES);
}
}