use std::collections::BTreeMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use bytes::Bytes;
use dashmap::DashMap;
use object_store::ObjectStore;
use object_store::path::Path as ObjectStorePath;
use tokio::sync::Mutex as AsyncMutex;
use tracing::{debug, instrument, warn};
use uni_common::api::error::UniError;
use uni_common::core::fork::{ForkId, ForkInfo, ForkRegistryFile, ForkStatus, SchemaDelta};
use crate::store_utils::{
DEFAULT_TIMEOUT, delete_with_timeout, get_with_timeout, list_with_timeout, put_with_timeout,
};
#[derive(Clone)]
pub struct ForkRegistryHandle {
inner: Arc<ForkRegistryInner>,
}
impl std::fmt::Debug for ForkRegistryHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ForkRegistryHandle")
.field("name_locks", &self.inner.name_locks.len())
.field("holders", &self.inner.holders.len())
.finish_non_exhaustive()
}
}
struct ForkRegistryInner {
store: Arc<dyn ObjectStore>,
cache: AsyncMutex<ForkRegistryFile>,
name_locks: DashMap<String, Arc<AsyncMutex<()>>>,
holders: DashMap<ForkId, Arc<AtomicUsize>>,
max_forks: AsyncMutex<Option<usize>>,
}
pub struct ForkHolderGuard {
counter: Arc<AtomicUsize>,
}
impl ForkHolderGuard {
fn new(counter: Arc<AtomicUsize>) -> Self {
counter.fetch_add(1, Ordering::AcqRel);
Self { counter }
}
}
impl Drop for ForkHolderGuard {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::AcqRel);
}
}
fn registry_path() -> ObjectStorePath {
ObjectStorePath::from("catalog/fork_registry.json")
}
fn schema_overlay_path(id: &ForkId) -> ObjectStorePath {
ObjectStorePath::from(format!("catalog/fork_schemas/{id}.json"))
}
fn tombstone_path(id: &ForkId) -> ObjectStorePath {
ObjectStorePath::from(format!("catalog/fork_tombstones/{id}.json"))
}
fn lifecycle<E>(name: &str, stage: &'static str, source: E) -> UniError
where
E: std::error::Error + Send + Sync + 'static,
{
UniError::ForkLifecycle {
name: name.to_string(),
stage,
source: Box::new(source),
}
}
fn lifecycle_anyhow(name: &str, stage: &'static str, source: anyhow::Error) -> UniError {
UniError::ForkLifecycle {
name: name.to_string(),
stage,
source: source.into(),
}
}
impl ForkRegistryHandle {
#[instrument(skip(store), level = "info")]
pub async fn load(store: Arc<dyn ObjectStore>) -> Result<Self, UniError> {
let path = registry_path();
let cache = match get_with_timeout(&store, &path, DEFAULT_TIMEOUT).await {
Ok(result) => {
let bytes = result
.bytes()
.await
.map_err(|e| lifecycle("<registry>", "load", e))?;
serde_json::from_slice::<ForkRegistryFile>(&bytes).map_err(|e| {
UniError::ForkCorruptRegistry {
message: format!("parse fork_registry.json: {e}"),
}
})?
}
Err(_) => {
ForkRegistryFile::default()
}
};
Ok(Self {
inner: Arc::new(ForkRegistryInner {
store,
cache: AsyncMutex::new(cache),
name_locks: DashMap::new(),
holders: DashMap::new(),
max_forks: AsyncMutex::new(None),
}),
})
}
pub async fn set_max_forks(&self, cap: Option<usize>) {
*self.inner.max_forks.lock().await = cap;
}
pub async fn snapshot(&self) -> ForkRegistryFile {
self.inner.cache.lock().await.clone()
}
pub async fn list_active(&self) -> Vec<ForkInfo> {
let cache = self.inner.cache.lock().await;
cache
.forks
.values()
.filter(|f| f.status == ForkStatus::Active)
.cloned()
.collect()
}
pub async fn list_expired(&self, now: chrono::DateTime<chrono::Utc>) -> Vec<ForkInfo> {
let cache = self.inner.cache.lock().await;
cache
.forks
.values()
.filter(|f| {
f.status == ForkStatus::Active
&& f.ttl_expires_at.map(|t| t <= now).unwrap_or(false)
})
.cloned()
.collect()
}
pub async fn list_children(&self, parent_id: ForkId) -> Vec<ForkInfo> {
let cache = self.inner.cache.lock().await;
cache
.forks
.values()
.filter(|f| f.parent_fork_id == Some(parent_id))
.cloned()
.collect()
}
pub async fn get_by_id(&self, id: ForkId) -> Result<ForkInfo, UniError> {
let cache = self.inner.cache.lock().await;
cache
.forks
.values()
.find(|f| f.id == id)
.cloned()
.ok_or_else(|| UniError::ForkNotFound {
name: id.to_string(),
})
}
pub async fn get(&self, name: &str) -> Result<ForkInfo, UniError> {
let cache = self.inner.cache.lock().await;
cache
.forks
.get(name)
.cloned()
.ok_or_else(|| UniError::ForkNotFound {
name: name.to_string(),
})
}
pub async fn name_lock(&self, name: &str) -> Arc<AsyncMutex<()>> {
self.inner
.name_locks
.entry(name.to_string())
.or_insert_with(|| Arc::new(AsyncMutex::new(())))
.clone()
}
pub fn register_holder(&self, fork_id: ForkId) -> ForkHolderGuard {
let counter = self
.inner
.holders
.entry(fork_id)
.or_insert_with(|| Arc::new(AtomicUsize::new(0)))
.clone();
ForkHolderGuard::new(counter)
}
pub async fn holder_count_for(&self, fork_id: ForkId) -> usize {
self.holder_count(fork_id)
}
fn holder_count(&self, fork_id: ForkId) -> usize {
self.inner
.holders
.get(&fork_id)
.map(|c| c.load(Ordering::Acquire))
.unwrap_or(0)
}
async fn put_registry(
&self,
cache: &ForkRegistryFile,
name: &str,
stage: &'static str,
) -> Result<(), UniError> {
let json = serde_json::to_vec_pretty(cache).map_err(|e| lifecycle(name, stage, e))?;
put_with_timeout(
&self.inner.store,
®istry_path(),
Bytes::from(json),
DEFAULT_TIMEOUT,
)
.await
.map_err(|e| lifecycle_anyhow(name, stage, e))?;
Ok(())
}
async fn put_schema_overlay(
&self,
id: &ForkId,
delta: &SchemaDelta,
name: &str,
) -> Result<(), UniError> {
let json =
serde_json::to_vec_pretty(delta).map_err(|e| lifecycle(name, "schema_overlay", e))?;
put_with_timeout(
&self.inner.store,
&schema_overlay_path(id),
Bytes::from(json),
DEFAULT_TIMEOUT,
)
.await
.map_err(|e| lifecycle_anyhow(name, "schema_overlay", e))?;
Ok(())
}
async fn put_tombstone(&self, info: &ForkInfo) -> Result<(), UniError> {
let json =
serde_json::to_vec_pretty(info).map_err(|e| lifecycle(&info.name, "tombstone", e))?;
put_with_timeout(
&self.inner.store,
&tombstone_path(&info.id),
Bytes::from(json),
DEFAULT_TIMEOUT,
)
.await
.map_err(|e| lifecycle_anyhow(&info.name, "tombstone", e))?;
Ok(())
}
async fn delete_tombstone(&self, id: &ForkId, _name: &str) -> Result<(), UniError> {
if let Err(e) =
delete_with_timeout(&self.inner.store, &tombstone_path(id), DEFAULT_TIMEOUT).await
{
warn!(fork_id = %id, "delete tombstone returned {e}");
}
Ok(())
}
async fn delete_schema_overlay(&self, id: &ForkId) {
let _ =
delete_with_timeout(&self.inner.store, &schema_overlay_path(id), DEFAULT_TIMEOUT).await;
}
pub async fn register_dataset_branch(
&self,
fork_id: ForkId,
dataset: &str,
branch: &str,
) -> Result<(), UniError> {
let mut cache = self.inner.cache.lock().await;
let entry = cache
.forks
.values_mut()
.find(|f| f.id == fork_id)
.ok_or_else(|| UniError::ForkNotFound {
name: format!("<fork:{fork_id}>"),
})?;
match entry.datasets.get(dataset) {
Some(existing) if existing == branch => return Ok(()),
Some(existing) => {
return Err(UniError::ForkCorruptRegistry {
message: format!(
"register_dataset_branch: dataset '{dataset}' already \
maps to '{existing}', refusing to overwrite with \
'{branch}'"
),
});
}
None => {}
}
entry
.datasets
.insert(dataset.to_string(), branch.to_string());
let name = entry.name.clone();
self.put_registry(&cache, &name, "registry_dynamic_branch")
.await?;
Ok(())
}
pub async fn update_schema_overlay(
&self,
id: &ForkId,
delta: &SchemaDelta,
) -> Result<(), UniError> {
let name = {
let cache = self.inner.cache.lock().await;
cache
.forks
.values()
.find(|f| f.id == *id)
.map(|f| f.name.clone())
.unwrap_or_else(|| format!("<fork:{id}>"))
};
self.put_schema_overlay(id, delta, &name).await
}
pub async fn load_schema_overlay(&self, id: &ForkId) -> Result<SchemaDelta, UniError> {
match get_with_timeout(&self.inner.store, &schema_overlay_path(id), DEFAULT_TIMEOUT).await {
Ok(result) => {
let bytes = result
.bytes()
.await
.map_err(|e| lifecycle("<schema-overlay>", "schema_overlay", e))?;
serde_json::from_slice(&bytes).map_err(|e| UniError::ForkCorruptRegistry {
message: format!("parse fork_schemas/{id}.json: {e}"),
})
}
Err(_) => Ok(SchemaDelta::empty()),
}
}
pub async fn begin_create(&self, info: ForkInfo) -> Result<(), UniError> {
debug_assert_eq!(info.status, ForkStatus::Pending);
let mut cache = self.inner.cache.lock().await;
if cache.forks.contains_key(&info.name) {
return Err(UniError::ForkAlreadyExists {
name: info.name.clone(),
});
}
if let Some(cap) = *self.inner.max_forks.lock().await {
let current = cache.forks.len();
if current >= cap {
return Err(UniError::ForkBudgetExceeded { current, max: cap });
}
}
let name = info.name.clone();
cache.forks.insert(name.clone(), info);
self.put_registry(&cache, &name, "registry_pending").await?;
Ok(())
}
pub async fn finish_create(
&self,
name: &str,
datasets: BTreeMap<String, String>,
) -> Result<ForkInfo, UniError> {
let id = {
let mut cache = self.inner.cache.lock().await;
let entry = cache
.forks
.get_mut(name)
.ok_or_else(|| UniError::ForkNotFound {
name: name.to_string(),
})?;
entry.datasets = datasets;
entry.status = ForkStatus::Active;
let id = entry.id;
self.put_registry(&cache, name, "registry_active").await?;
id
};
self.put_schema_overlay(&id, &SchemaDelta::empty(), name)
.await?;
let info = self.get(name).await?;
debug!(fork_id = %id, fork_name = %name, "fork active");
Ok(info)
}
pub async fn rollback_create(&self, name: &str) -> Result<(), UniError> {
let removed_id = {
let mut cache = self.inner.cache.lock().await;
let id = cache.forks.remove(name).map(|info| info.id);
self.put_registry(&cache, name, "registry_pending").await?;
id
};
if let Some(id) = removed_id {
self.delete_schema_overlay(&id).await;
}
Ok(())
}
pub async fn begin_drop(&self, name: &str) -> Result<ForkInfo, UniError> {
let info = {
let cache = self.inner.cache.lock().await;
cache
.forks
.get(name)
.cloned()
.ok_or_else(|| UniError::ForkNotFound {
name: name.to_string(),
})?
};
let holders = self.holder_count(info.id);
if holders > 0 {
return Err(UniError::ForkInUse {
name: info.name.clone(),
holder_count: holders,
});
}
self.put_tombstone(&info).await?;
let mut cache = self.inner.cache.lock().await;
if let Some(entry) = cache.forks.get_mut(name) {
entry.status = ForkStatus::Tombstoned;
self.put_registry(&cache, name, "tombstone").await?;
}
Ok(info)
}
pub async fn finish_drop(&self, info: &ForkInfo) -> Result<(), UniError> {
{
let mut cache = self.inner.cache.lock().await;
cache.forks.remove(&info.name);
self.put_registry(&cache, &info.name, "registry_clear")
.await?;
}
self.delete_tombstone(&info.id, &info.name).await?;
self.delete_schema_overlay(&info.id).await;
Ok(())
}
pub async fn list_tombstones(&self) -> Result<Vec<ForkInfo>, UniError> {
let prefix = ObjectStorePath::from("catalog/fork_tombstones");
let metas = list_with_timeout(&self.inner.store, Some(&prefix), DEFAULT_TIMEOUT)
.await
.map_err(|e| lifecycle_anyhow("<tombstones>", "recovery", e))?;
let mut out = Vec::new();
for meta in metas {
let result = get_with_timeout(&self.inner.store, &meta.location, DEFAULT_TIMEOUT)
.await
.map_err(|e| lifecycle_anyhow("<tombstones>", "recovery", e))?;
let bytes = result
.bytes()
.await
.map_err(|e| lifecycle("<tombstones>", "recovery", e))?;
let info: ForkInfo =
serde_json::from_slice(&bytes).map_err(|e| UniError::ForkCorruptRegistry {
message: format!("parse tombstone {}: {e}", meta.location),
})?;
out.push(info);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
use object_store::ObjectStoreExt;
use object_store::local::LocalFileSystem;
use tempfile::TempDir;
use uni_common::core::fork::ForkId;
async fn fresh_handle() -> (TempDir, ForkRegistryHandle) {
let dir = TempDir::new().unwrap();
let store: Arc<dyn ObjectStore> =
Arc::new(LocalFileSystem::new_with_prefix(dir.path()).unwrap());
let handle = ForkRegistryHandle::load(store).await.unwrap();
(dir, handle)
}
#[tokio::test]
async fn empty_registry_loads_clean() {
let (_dir, h) = fresh_handle().await;
assert!(h.snapshot().await.forks.is_empty());
}
#[tokio::test]
async fn begin_create_persists_and_rejects_duplicate() {
let (_dir, h) = fresh_handle().await;
let info = ForkInfo::new_pending(ForkId::new(), "scenario_a", "snap-1", 17);
h.begin_create(info.clone()).await.unwrap();
let store = h.inner.store.clone();
let h2 = ForkRegistryHandle::load(store).await.unwrap();
let snap = h2.snapshot().await;
assert_eq!(snap.forks.len(), 1);
assert_eq!(snap.forks["scenario_a"].status, ForkStatus::Pending);
let dup = ForkInfo::new_pending(ForkId::new(), "scenario_a", "snap-1", 17);
let err = h.begin_create(dup).await.unwrap_err();
assert!(matches!(err, UniError::ForkAlreadyExists { .. }));
}
#[tokio::test]
async fn finish_create_promotes_and_writes_overlay() {
let (_dir, h) = fresh_handle().await;
let info = ForkInfo::new_pending(ForkId::new(), "scenario_b", "snap-1", 1);
h.begin_create(info).await.unwrap();
let mut datasets = BTreeMap::new();
datasets.insert("vertices_Person".into(), "fork-b__v_Person".into());
let promoted = h.finish_create("scenario_b", datasets).await.unwrap();
assert_eq!(promoted.status, ForkStatus::Active);
assert_eq!(promoted.datasets.len(), 1);
let overlay = h.load_schema_overlay(&promoted.id).await.unwrap();
assert!(overlay.is_empty());
}
#[tokio::test]
async fn drop_2pc_clears_registry_and_files() {
let (_dir, h) = fresh_handle().await;
let info = ForkInfo::new_pending(ForkId::new(), "scenario_c", "snap-1", 1);
h.begin_create(info).await.unwrap();
let info = h
.finish_create("scenario_c", BTreeMap::new())
.await
.unwrap();
let tomb = h.begin_drop("scenario_c").await.unwrap();
assert_eq!(tomb.id, info.id);
let snap = h.snapshot().await;
assert_eq!(snap.forks["scenario_c"].status, ForkStatus::Tombstoned);
let tombs = h.list_tombstones().await.unwrap();
assert_eq!(tombs.len(), 1);
h.finish_drop(&info).await.unwrap();
let snap = h.snapshot().await;
assert!(!snap.forks.contains_key("scenario_c"));
let tombs = h.list_tombstones().await.unwrap();
assert!(tombs.is_empty());
}
#[tokio::test]
async fn drop_blocked_when_holders_alive() {
let (_dir, h) = fresh_handle().await;
let info = ForkInfo::new_pending(ForkId::new(), "scenario_d", "snap-1", 1);
h.begin_create(info).await.unwrap();
let info = h
.finish_create("scenario_d", BTreeMap::new())
.await
.unwrap();
let _holder = h.register_holder(info.id);
let err = h.begin_drop("scenario_d").await.unwrap_err();
match err {
UniError::ForkInUse { holder_count, .. } => assert_eq!(holder_count, 1),
other => panic!("expected ForkInUse, got {other:?}"),
}
drop(_holder);
h.begin_drop("scenario_d").await.unwrap();
}
#[tokio::test]
async fn concurrent_distinct_creates_serialize_only_per_name() {
let (_dir, h) = fresh_handle().await;
let h1 = h.clone();
let h2 = h.clone();
let t1 = tokio::spawn(async move {
h1.begin_create(ForkInfo::new_pending(ForkId::new(), "fork-a", "snap-1", 1))
.await
});
let t2 = tokio::spawn(async move {
h2.begin_create(ForkInfo::new_pending(ForkId::new(), "fork-b", "snap-1", 1))
.await
});
t1.await.unwrap().unwrap();
t2.await.unwrap().unwrap();
let snap = h.snapshot().await;
assert_eq!(snap.forks.len(), 2);
}
#[tokio::test]
async fn name_lock_grants_exclusive_per_name() {
let (_dir, h) = fresh_handle().await;
let lock = h.name_lock("scenario_e").await;
let g1 = lock.try_lock();
assert!(g1.is_ok());
let same = h.name_lock("scenario_e").await;
assert!(Arc::ptr_eq(&lock, &same));
}
#[tokio::test]
async fn corrupt_registry_surfaces_typed_error() {
let dir = TempDir::new().unwrap();
let store: Arc<dyn ObjectStore> =
Arc::new(LocalFileSystem::new_with_prefix(dir.path()).unwrap());
put_with_timeout(
&store,
®istry_path(),
Bytes::from_static(b"{ not json"),
DEFAULT_TIMEOUT,
)
.await
.unwrap();
let err = ForkRegistryHandle::load(store).await.unwrap_err();
assert!(matches!(err, UniError::ForkCorruptRegistry { .. }));
}
#[tokio::test]
async fn rollback_create_after_finish_cleans_overlay_file() {
let (_dir, h) = fresh_handle().await;
let info = ForkInfo::new_pending(ForkId::new(), "scenario_rb", "snap-1", 1);
h.begin_create(info).await.unwrap();
let info = h
.finish_create("scenario_rb", BTreeMap::new())
.await
.unwrap();
let exists_before = h
.inner
.store
.head(&schema_overlay_path(&info.id))
.await
.is_ok();
assert!(exists_before);
h.rollback_create("scenario_rb").await.unwrap();
assert!(!h.snapshot().await.forks.contains_key("scenario_rb"));
let exists_after = h
.inner
.store
.head(&schema_overlay_path(&info.id))
.await
.is_ok();
assert!(
!exists_after,
"overlay file at {} should be deleted after rollback",
schema_overlay_path(&info.id)
);
}
#[tokio::test]
async fn restart_preserves_active_forks() {
let dir = TempDir::new().unwrap();
let store: Arc<dyn ObjectStore> =
Arc::new(LocalFileSystem::new_with_prefix(dir.path()).unwrap());
{
let h = ForkRegistryHandle::load(store.clone()).await.unwrap();
let info = ForkInfo::new_pending(ForkId::new(), "persist_me", "snap-1", 1);
h.begin_create(info).await.unwrap();
h.finish_create("persist_me", BTreeMap::new())
.await
.unwrap();
}
let h2 = ForkRegistryHandle::load(store).await.unwrap();
let snap = h2.snapshot().await;
assert_eq!(snap.forks.len(), 1);
assert_eq!(snap.forks["persist_me"].status, ForkStatus::Active);
}
#[tokio::test]
async fn holder_count_round_trips_under_concurrent_register_drop() {
let (_dir, h) = fresh_handle().await;
let info = ForkInfo::new_pending(ForkId::new(), "concurrent", "snap-1", 1);
h.begin_create(info).await.unwrap();
let info = h
.finish_create("concurrent", BTreeMap::new())
.await
.unwrap();
let mut handles = Vec::new();
for _ in 0..100 {
let h_clone = h.clone();
let id = info.id;
handles.push(tokio::spawn(async move {
let _g = h_clone.register_holder(id);
tokio::task::yield_now().await;
}));
}
for jh in handles {
jh.await.unwrap();
}
h.begin_drop("concurrent").await.unwrap();
}
}