use crate::store_utils::{DEFAULT_TIMEOUT, get_with_timeout, put_with_timeout};
use anyhow::Result;
use bytes::Bytes;
use object_store::path::Path;
use object_store::{ObjectStore, PutMode, PutOptions, UpdateVersion};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
use uni_common::core::id::{Eid, Vid};
#[derive(Serialize, Deserialize, Default, Clone)]
struct CounterManifest {
next_vid_batch: u64,
next_eid_batch: u64,
}
struct AllocatorState {
manifest: CounterManifest,
manifest_version: Option<String>, current_vid: u64,
current_eid: u64,
}
pub struct IdAllocator {
store: Arc<dyn ObjectStore>,
path: Path,
state: Mutex<AllocatorState>,
batch_size: u64,
}
impl IdAllocator {
pub async fn new(store: Arc<dyn ObjectStore>, path: Path, batch_size: u64) -> Result<Self> {
let (manifest, version) = match get_with_timeout(&store, &path, DEFAULT_TIMEOUT).await {
Ok(get_result) => {
let version = get_result.meta.e_tag.clone();
let bytes = get_result.bytes().await?;
let manifest: CounterManifest = serde_json::from_slice(&bytes)?;
(manifest, version)
}
Err(e) if e.to_string().contains("not found") => (CounterManifest::default(), None),
Err(e) => return Err(e),
};
let current_vid = manifest.next_vid_batch;
let current_eid = manifest.next_eid_batch;
Ok(Self {
store,
path,
state: Mutex::new(AllocatorState {
manifest,
manifest_version: version,
current_vid,
current_eid,
}),
batch_size,
})
}
pub async fn allocate_vid(&self) -> Result<Vid> {
let mut state = self.state.lock().await;
if state.current_vid >= state.manifest.next_vid_batch {
state.manifest.next_vid_batch = state.current_vid + self.batch_size;
self.persist_manifest(&mut state).await?;
}
let vid = Vid::new(state.current_vid);
state.current_vid += 1;
Ok(vid)
}
pub async fn allocate_vids(&self, count: usize) -> Result<Vec<Vid>> {
let mut state = self.state.lock().await;
let needed = count as u64;
if state.current_vid + needed > state.manifest.next_vid_batch {
state.manifest.next_vid_batch = state.current_vid + needed + self.batch_size;
self.persist_manifest(&mut state).await?;
}
let vids: Vec<Vid> = (0..count)
.map(|i| Vid::new(state.current_vid + i as u64))
.collect();
state.current_vid += needed;
Ok(vids)
}
pub async fn allocate_eid(&self) -> Result<Eid> {
let mut state = self.state.lock().await;
if state.current_eid >= state.manifest.next_eid_batch {
state.manifest.next_eid_batch = state.current_eid + self.batch_size;
self.persist_manifest(&mut state).await?;
}
let eid = Eid::new(state.current_eid);
state.current_eid += 1;
Ok(eid)
}
pub async fn allocate_eids(&self, count: usize) -> Result<Vec<Eid>> {
let mut state = self.state.lock().await;
let needed = count as u64;
if state.current_eid + needed > state.manifest.next_eid_batch {
state.manifest.next_eid_batch = state.current_eid + needed + self.batch_size;
self.persist_manifest(&mut state).await?;
}
let eids: Vec<Eid> = (0..count)
.map(|i| Eid::new(state.current_eid + i as u64))
.collect();
state.current_eid += needed;
Ok(eids)
}
pub async fn current_vid(&self) -> u64 {
self.state.lock().await.current_vid
}
pub async fn current_eid(&self) -> u64 {
self.state.lock().await.current_eid
}
async fn persist_manifest(&self, state: &mut AllocatorState) -> Result<()> {
let json = serde_json::to_vec_pretty(&state.manifest)?;
let bytes = Bytes::from(json);
let put_result = if let Some(version) = &state.manifest_version {
let opts: PutOptions = PutMode::Update(UpdateVersion {
e_tag: Some(version.clone()),
version: None,
})
.into();
match tokio::time::timeout(
DEFAULT_TIMEOUT,
self.store.put_opts(&self.path, bytes.clone().into(), opts),
)
.await
{
Ok(Ok(result)) => result,
Ok(Err(e))
if e.to_string().contains("not yet implemented")
|| e.to_string().contains("not supported") =>
{
put_with_timeout(&self.store, &self.path, bytes, DEFAULT_TIMEOUT).await?
}
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
return Err(anyhow::anyhow!(
"Object store put_opts timed out after {:?}",
DEFAULT_TIMEOUT
));
}
}
} else {
let opts: PutOptions = PutMode::Create.into();
match tokio::time::timeout(
DEFAULT_TIMEOUT,
self.store.put_opts(&self.path, bytes.clone().into(), opts),
)
.await
{
Ok(Ok(result)) => result,
Ok(Err(object_store::Error::AlreadyExists { .. })) => {
put_with_timeout(&self.store, &self.path, bytes, DEFAULT_TIMEOUT).await?
}
Ok(Err(e)) if e.to_string().contains("not yet implemented") => {
put_with_timeout(&self.store, &self.path, bytes, DEFAULT_TIMEOUT).await?
}
Ok(Err(e)) => return Err(e.into()),
Err(_) => {
return Err(anyhow::anyhow!(
"Object store put_opts timed out after {:?}",
DEFAULT_TIMEOUT
));
}
}
};
state.manifest_version = put_result.e_tag;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use object_store::memory::InMemory;
#[tokio::test]
async fn test_allocate_vid() {
let store = Arc::new(InMemory::new());
let path = Path::from("id_counters.json");
let allocator = IdAllocator::new(store, path, 100).await.unwrap();
let vid1 = allocator.allocate_vid().await.unwrap();
let vid2 = allocator.allocate_vid().await.unwrap();
let vid3 = allocator.allocate_vid().await.unwrap();
assert_eq!(vid1.as_u64(), 0);
assert_eq!(vid2.as_u64(), 1);
assert_eq!(vid3.as_u64(), 2);
}
#[tokio::test]
async fn test_allocate_eid() {
let store = Arc::new(InMemory::new());
let path = Path::from("id_counters.json");
let allocator = IdAllocator::new(store, path, 100).await.unwrap();
let eid1 = allocator.allocate_eid().await.unwrap();
let eid2 = allocator.allocate_eid().await.unwrap();
assert_eq!(eid1.as_u64(), 0);
assert_eq!(eid2.as_u64(), 1);
}
#[tokio::test]
async fn test_allocate_many() {
let store = Arc::new(InMemory::new());
let path = Path::from("id_counters.json");
let allocator = IdAllocator::new(store, path, 100).await.unwrap();
let vids = allocator.allocate_vids(5).await.unwrap();
assert_eq!(vids.len(), 5);
for (i, vid) in vids.iter().enumerate() {
assert_eq!(vid.as_u64(), i as u64);
}
let next = allocator.allocate_vid().await.unwrap();
assert_eq!(next.as_u64(), 5);
}
#[tokio::test]
async fn test_persistence() {
let store = Arc::new(InMemory::new());
let path = Path::from("id_counters.json");
{
let allocator = IdAllocator::new(store.clone(), path.clone(), 10)
.await
.unwrap();
for _ in 0..15 {
allocator.allocate_vid().await.unwrap();
}
}
{
let allocator = IdAllocator::new(store, path, 10).await.unwrap();
let vid = allocator.allocate_vid().await.unwrap();
assert_eq!(vid.as_u64(), 20);
}
}
}