1use std::collections::HashMap;
22use std::sync::Arc;
23
24use arc_swap::ArcSwap;
25use omnigraph::db::Omnigraph;
26use omnigraph::storage::normalize_root_uri;
27#[cfg(test)]
28use tokio::sync::Mutex;
29
30use crate::identity::GraphKey;
31use crate::policy::PolicyEngine;
32
33pub struct GraphHandle {
37 pub key: GraphKey,
39 pub uri: String,
43 pub engine: Arc<Omnigraph>,
46 pub policy: Option<Arc<PolicyEngine>>,
50}
51
52pub struct RegistrySnapshot {
60 pub graphs: HashMap<GraphKey, Arc<GraphHandle>>,
61 pub any_per_graph_policy: bool,
67}
68
69impl RegistrySnapshot {
70 pub fn new(graphs: HashMap<GraphKey, Arc<GraphHandle>>) -> Self {
74 let any_per_graph_policy = graphs.values().any(|h| h.policy.is_some());
75 Self {
76 graphs,
77 any_per_graph_policy,
78 }
79 }
80}
81
82impl Default for RegistrySnapshot {
83 fn default() -> Self {
84 Self::new(HashMap::new())
85 }
86}
87
88pub enum RegistryLookup {
90 Ready(Arc<GraphHandle>),
92 Gone,
95}
96
97#[derive(Debug, thiserror::Error)]
99pub enum InsertError {
100 #[error("graph '{0}' is already registered")]
102 DuplicateKey(GraphKey),
103 #[error("URI '{0}' is already registered as another graph")]
107 DuplicateUri(String),
108 #[error("URI '{uri}' is invalid: {message}")]
110 InvalidUri { uri: String, message: String },
111}
112
113pub struct GraphRegistry {
114 snapshot: ArcSwap<RegistrySnapshot>,
115 #[cfg(test)]
120 mutate: Mutex<()>,
121}
122
123impl GraphRegistry {
124 pub fn new() -> Self {
126 Self {
127 snapshot: ArcSwap::from_pointee(RegistrySnapshot::default()),
128 #[cfg(test)]
129 mutate: Mutex::new(()),
130 }
131 }
132
133 pub fn from_handles(handles: Vec<Arc<GraphHandle>>) -> Result<Self, InsertError> {
136 let mut graphs: HashMap<GraphKey, Arc<GraphHandle>> = HashMap::with_capacity(handles.len());
137 let mut seen_uris: HashMap<String, GraphKey> = HashMap::with_capacity(handles.len());
138 for handle in handles {
139 let (canonical_uri, handle) = canonicalize_handle_uri(handle)?;
140 if graphs.contains_key(&handle.key) {
141 return Err(InsertError::DuplicateKey(handle.key.clone()));
142 }
143 if seen_uris.contains_key(&canonical_uri) {
144 return Err(InsertError::DuplicateUri(handle.uri.clone()));
145 }
146 seen_uris.insert(canonical_uri, handle.key.clone());
147 graphs.insert(handle.key.clone(), handle);
148 }
149 Ok(Self {
150 snapshot: ArcSwap::from_pointee(RegistrySnapshot::new(graphs)),
151 #[cfg(test)]
152 mutate: Mutex::new(()),
153 })
154 }
155
156 pub fn snapshot_ref(&self) -> arc_swap::Guard<Arc<RegistrySnapshot>> {
161 self.snapshot.load()
162 }
163
164 pub fn get(&self, key: &GraphKey) -> RegistryLookup {
167 let snapshot = self.snapshot.load();
168 match snapshot.graphs.get(key) {
169 Some(handle) => RegistryLookup::Ready(Arc::clone(handle)),
170 None => RegistryLookup::Gone,
171 }
172 }
173
174 pub fn list(&self) -> Vec<Arc<GraphHandle>> {
179 let snapshot = self.snapshot.load();
180 snapshot.graphs.values().cloned().collect()
181 }
182
183 pub fn len(&self) -> usize {
185 self.snapshot.load().graphs.len()
186 }
187
188 pub fn is_empty(&self) -> bool {
189 self.len() == 0
190 }
191
192 #[cfg(test)]
208 pub async fn insert(&self, handle: Arc<GraphHandle>) -> Result<(), InsertError> {
209 let _guard = self.mutate.lock().await;
210 let current = self.snapshot.load();
211 let (canonical_uri, handle) = canonicalize_handle_uri(handle)?;
212 if current.graphs.contains_key(&handle.key) {
213 return Err(InsertError::DuplicateKey(handle.key.clone()));
214 }
215 for existing in current.graphs.values() {
216 let existing_uri =
217 normalize_root_uri(&existing.uri).map_err(|err| InsertError::InvalidUri {
218 uri: existing.uri.clone(),
219 message: err.to_string(),
220 })?;
221 if existing_uri == canonical_uri {
222 return Err(InsertError::DuplicateUri(handle.uri.clone()));
223 }
224 }
225 let mut new_graphs = current.graphs.clone();
226 new_graphs.insert(handle.key.clone(), handle);
227 self.snapshot
228 .store(Arc::new(RegistrySnapshot::new(new_graphs)));
229 Ok(())
230 }
231}
232
233fn canonicalize_handle_uri(
234 handle: Arc<GraphHandle>,
235) -> Result<(String, Arc<GraphHandle>), InsertError> {
236 let canonical_uri = normalize_root_uri(&handle.uri).map_err(|err| InsertError::InvalidUri {
237 uri: handle.uri.clone(),
238 message: err.to_string(),
239 })?;
240 if canonical_uri == handle.uri {
241 return Ok((canonical_uri, handle));
242 }
243 let canonical_handle = Arc::new(GraphHandle {
244 key: handle.key.clone(),
245 uri: canonical_uri.clone(),
246 engine: Arc::clone(&handle.engine),
247 policy: handle.policy.clone(),
248 });
249 Ok((canonical_uri, canonical_handle))
250}
251
252impl Default for GraphRegistry {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use std::path::Path;
261
262 use tempfile::TempDir;
263
264 use super::*;
265 use crate::graph_id::GraphId;
266
267 const TEST_SCHEMA: &str = "node Person { name: String @key }\n";
268
269 async fn build_handle(graph_id: &str, dir: &Path) -> Arc<GraphHandle> {
270 let graph_uri = dir.join(graph_id).to_str().unwrap().to_string();
271 let engine = Omnigraph::init(&graph_uri, TEST_SCHEMA)
272 .await
273 .expect("init engine for registry test");
274 Arc::new(GraphHandle {
275 key: GraphKey::cluster(GraphId::try_from(graph_id).unwrap()),
276 uri: graph_uri,
277 engine: Arc::new(engine),
278 policy: None,
279 })
280 }
281
282 #[tokio::test]
283 async fn new_registry_is_empty() {
284 let registry = GraphRegistry::new();
285 assert!(registry.is_empty());
286 assert_eq!(registry.len(), 0);
287 assert!(registry.list().is_empty());
288 }
289
290 #[tokio::test]
291 async fn insert_then_get_returns_ready() {
292 let dir = TempDir::new().unwrap();
293 let registry = GraphRegistry::new();
294 let handle = build_handle("alpha", dir.path()).await;
295 registry.insert(Arc::clone(&handle)).await.unwrap();
296
297 match registry.get(&handle.key) {
298 RegistryLookup::Ready(found) => {
299 assert!(Arc::ptr_eq(&found, &handle));
300 }
301 RegistryLookup::Gone => panic!("expected Ready, got Gone"),
302 }
303 }
304
305 #[tokio::test]
306 async fn get_nonexistent_returns_gone() {
307 let registry = GraphRegistry::new();
308 let key = GraphKey::cluster(GraphId::try_from("ghost").unwrap());
309 match registry.get(&key) {
310 RegistryLookup::Gone => {}
311 RegistryLookup::Ready(_) => panic!("expected Gone"),
312 }
313 }
314
315 #[tokio::test]
316 async fn insert_duplicate_key_returns_error() {
317 let dir = TempDir::new().unwrap();
318 let registry = GraphRegistry::new();
319 let h1 = build_handle("alpha", dir.path()).await;
320 let dir2 = TempDir::new().unwrap();
322 let h2 = build_handle("alpha", dir2.path()).await;
323 registry.insert(h1).await.unwrap();
324
325 match registry.insert(h2).await {
326 Err(InsertError::DuplicateKey(_)) => {}
327 other => panic!("expected DuplicateKey, got {other:?}"),
328 }
329 }
330
331 #[tokio::test]
332 async fn insert_duplicate_uri_returns_error() {
333 let dir = TempDir::new().unwrap();
334 let shared_uri = dir.path().join("shared").to_str().unwrap().to_string();
336 let engine = Omnigraph::init(&shared_uri, TEST_SCHEMA).await.unwrap();
337 let engine = Arc::new(engine);
338 let h1 = Arc::new(GraphHandle {
339 key: GraphKey::cluster(GraphId::try_from("alpha").unwrap()),
340 uri: shared_uri.clone(),
341 engine: Arc::clone(&engine),
342 policy: None,
343 });
344 let h2 = Arc::new(GraphHandle {
345 key: GraphKey::cluster(GraphId::try_from("beta").unwrap()),
346 uri: shared_uri,
347 engine,
348 policy: None,
349 });
350
351 let registry = GraphRegistry::new();
352 registry.insert(h1).await.unwrap();
353 match registry.insert(h2).await {
354 Err(InsertError::DuplicateUri(_)) => {}
355 other => panic!("expected DuplicateUri, got {other:?}"),
356 }
357 }
358
359 #[tokio::test]
360 async fn list_returns_all_inserted_handles() {
361 let dir = TempDir::new().unwrap();
362 let registry = GraphRegistry::new();
363 for name in ["alpha", "beta", "gamma"] {
364 let h = build_handle(name, dir.path()).await;
365 registry.insert(h).await.unwrap();
366 }
367 assert_eq!(registry.len(), 3);
368 let mut ids: Vec<_> = registry
369 .list()
370 .into_iter()
371 .map(|h| h.key.graph_id.as_str().to_string())
372 .collect();
373 ids.sort();
374 assert_eq!(ids, vec!["alpha", "beta", "gamma"]);
375 }
376
377 #[tokio::test]
378 async fn from_handles_bulk_init_succeeds() {
379 let dir = TempDir::new().unwrap();
380 let handles = vec![
381 build_handle("alpha", dir.path()).await,
382 build_handle("beta", dir.path()).await,
383 ];
384 let registry = GraphRegistry::from_handles(handles).unwrap();
385 assert_eq!(registry.len(), 2);
386 }
387
388 #[tokio::test]
389 async fn from_handles_rejects_duplicate_keys() {
390 let dir1 = TempDir::new().unwrap();
391 let dir2 = TempDir::new().unwrap();
392 let h1 = build_handle("alpha", dir1.path()).await;
393 let h2 = build_handle("alpha", dir2.path()).await;
394 let err = match GraphRegistry::from_handles(vec![h1, h2]) {
395 Ok(_) => panic!("expected DuplicateKey, got Ok"),
396 Err(err) => err,
397 };
398 assert!(
399 matches!(err, InsertError::DuplicateKey(_)),
400 "expected DuplicateKey, got {err}",
401 );
402 }
403
404 #[tokio::test]
405 async fn from_handles_rejects_duplicate_uris() {
406 let dir = TempDir::new().unwrap();
407 let shared_uri = dir.path().join("shared").to_str().unwrap().to_string();
408 let engine = Arc::new(Omnigraph::init(&shared_uri, TEST_SCHEMA).await.unwrap());
409 let h1 = Arc::new(GraphHandle {
410 key: GraphKey::cluster(GraphId::try_from("alpha").unwrap()),
411 uri: shared_uri.clone(),
412 engine: Arc::clone(&engine),
413 policy: None,
414 });
415 let h2 = Arc::new(GraphHandle {
416 key: GraphKey::cluster(GraphId::try_from("beta").unwrap()),
417 uri: shared_uri,
418 engine,
419 policy: None,
420 });
421 let err = match GraphRegistry::from_handles(vec![h1, h2]) {
422 Ok(_) => panic!("expected DuplicateUri, got Ok"),
423 Err(err) => err,
424 };
425 assert!(
426 matches!(err, InsertError::DuplicateUri(_)),
427 "expected DuplicateUri, got {err}",
428 );
429 }
430
431 #[tokio::test(flavor = "multi_thread")]
438 async fn concurrent_insert_same_key_exactly_one_succeeds() {
439 const N: usize = 8;
440
441 let registry = Arc::new(GraphRegistry::new());
442 let mut handles = Vec::with_capacity(N);
444 let mut dirs = Vec::with_capacity(N);
445 for _ in 0..N {
446 let d = TempDir::new().unwrap();
447 handles.push(build_handle("contested", d.path()).await);
448 dirs.push(d);
449 }
450
451 let barrier = Arc::new(tokio::sync::Barrier::new(N));
452 let mut tasks = Vec::with_capacity(N);
453 for handle in handles {
454 let registry = Arc::clone(®istry);
455 let barrier = Arc::clone(&barrier);
456 tasks.push(tokio::spawn(async move {
457 barrier.wait().await;
458 registry.insert(handle).await
459 }));
460 }
461
462 let mut ok_count = 0usize;
463 let mut dup_count = 0usize;
464 for t in tasks {
465 match t.await.unwrap() {
466 Ok(()) => ok_count += 1,
467 Err(InsertError::DuplicateKey(_)) => dup_count += 1,
468 Err(other) => panic!("unexpected error: {other:?}"),
469 }
470 }
471 assert_eq!(ok_count, 1, "exactly one insert must succeed");
472 assert_eq!(dup_count, N - 1, "the rest must return DuplicateKey");
473 assert_eq!(registry.len(), 1);
474
475 drop(dirs);
477 }
478
479 #[tokio::test(flavor = "multi_thread")]
482 async fn concurrent_insert_distinct_keys_all_succeed() {
483 const N: usize = 8;
484
485 let registry = Arc::new(GraphRegistry::new());
486 let mut handles = Vec::with_capacity(N);
488 let mut dirs = Vec::with_capacity(N);
489 for i in 0..N {
490 let d = TempDir::new().unwrap();
491 handles.push(build_handle(&format!("graph-{i}"), d.path()).await);
492 dirs.push(d);
493 }
494
495 let barrier = Arc::new(tokio::sync::Barrier::new(N));
496 let mut tasks = Vec::with_capacity(N);
497 for handle in handles {
498 let registry = Arc::clone(®istry);
499 let barrier = Arc::clone(&barrier);
500 tasks.push(tokio::spawn(async move {
501 barrier.wait().await;
502 registry.insert(handle).await
503 }));
504 }
505 for t in tasks {
506 t.await.unwrap().unwrap();
507 }
508 assert_eq!(registry.len(), N);
509 drop(dirs);
510 }
511
512 #[tokio::test(flavor = "multi_thread")]
516 async fn concurrent_reads_during_inserts_see_consistent_snapshots() {
517 let dir = TempDir::new().unwrap();
518 let registry = Arc::new(GraphRegistry::new());
519
520 const N_WRITES: usize = 10;
522 let writer_registry = Arc::clone(®istry);
523 let writer_dir = dir.path().to_path_buf();
524 let writer = tokio::spawn(async move {
525 for i in 0..N_WRITES {
526 let h = build_handle(&format!("graph-{i}"), &writer_dir).await;
527 writer_registry.insert(h).await.unwrap();
528 }
529 });
530
531 let reader_registry = Arc::clone(®istry);
535 let reader = tokio::spawn(async move {
536 for _ in 0..200 {
537 let snap = reader_registry.list();
538 assert!(snap.len() <= N_WRITES);
539 for handle in &snap {
540 match reader_registry.get(&handle.key) {
541 RegistryLookup::Ready(found) => {
542 assert!(Arc::ptr_eq(&found, handle));
543 }
544 RegistryLookup::Gone => panic!(
545 "snapshot listed key {} but get() returned Gone",
546 handle.key.graph_id
547 ),
548 }
549 }
550 tokio::task::yield_now().await;
551 }
552 });
553
554 writer.await.unwrap();
555 reader.await.unwrap();
556 assert_eq!(registry.len(), N_WRITES);
557 }
558}