use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
use super::provider::{BucketLogError, BucketLogProvider};
use crate::linked_data::Link;
#[derive(Debug, Clone)]
pub struct MemoryBucketLogProvider {
inner: Arc<RwLock<MemoryBucketLogProviderInner>>,
}
#[derive(Debug, Default)]
struct MemoryBucketLogProviderInner {
entries: HashMap<Uuid, HashMap<u64, Vec<Link>>>,
max_heights: HashMap<Uuid, u64>,
link_index: HashMap<Uuid, HashMap<Link, Vec<u64>>>,
names: HashMap<Uuid, String>,
published: HashMap<Uuid, HashMap<Link, bool>>,
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum MemoryBucketLogProviderError {
#[error("memory provider error: {0}")]
Internal(String),
}
impl MemoryBucketLogProvider {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(MemoryBucketLogProviderInner::default())),
}
}
}
impl Default for MemoryBucketLogProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BucketLogProvider for MemoryBucketLogProvider {
type Error = MemoryBucketLogProviderError;
async fn exists(&self, id: Uuid) -> Result<bool, BucketLogError<Self::Error>> {
let inner = self.inner.read().map_err(|e| {
BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
"failed to acquire read lock: {}",
e
)))
})?;
Ok(inner.entries.contains_key(&id))
}
async fn heads(&self, id: Uuid, height: u64) -> Result<Vec<Link>, BucketLogError<Self::Error>> {
let inner = self.inner.read().map_err(|e| {
BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
"failed to acquire read lock: {}",
e
)))
})?;
Ok(inner
.entries
.get(&id)
.and_then(|heights| heights.get(&height))
.cloned()
.unwrap_or_default())
}
async fn append(
&self,
id: Uuid,
name: String,
current: Link,
previous: Option<Link>,
height: u64,
published: bool,
) -> Result<(), BucketLogError<Self::Error>> {
let mut inner = self.inner.write().map_err(|e| {
BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
"failed to acquire write lock: {}",
e
)))
})?;
inner.names.insert(id, name);
let bucket_entries = inner.entries.entry(id).or_insert_with(HashMap::new);
if let Some(existing_links) = bucket_entries.get(&height) {
if existing_links.contains(¤t) {
return Err(BucketLogError::Conflict);
}
}
if let Some(prev_link) = &previous {
if height == 0 {
return Err(BucketLogError::InvalidAppend(
current,
prev_link.clone(),
height,
));
}
let expected_prev_height = height - 1;
let prev_exists = bucket_entries
.get(&expected_prev_height)
.map(|links| links.contains(prev_link))
.unwrap_or(false);
if !prev_exists {
return Err(BucketLogError::InvalidAppend(
current,
prev_link.clone(),
expected_prev_height,
));
}
} else {
if height != 0 {
return Err(BucketLogError::InvalidAppend(
current,
Link::default(), height,
));
}
}
bucket_entries
.entry(height)
.or_insert_with(Vec::new)
.push(current.clone());
let current_max = inner.max_heights.get(&id).copied();
if current_max.is_none() || height > current_max.unwrap() {
inner.max_heights.insert(id, height);
}
inner
.link_index
.entry(id)
.or_insert_with(HashMap::new)
.entry(current.clone())
.or_insert_with(Vec::new)
.push(height);
inner
.published
.entry(id)
.or_insert_with(HashMap::new)
.insert(current, published);
Ok(())
}
async fn height(&self, id: Uuid) -> Result<u64, BucketLogError<Self::Error>> {
let inner = self.inner.read().map_err(|e| {
BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
"failed to acquire read lock: {}",
e
)))
})?;
inner
.max_heights
.get(&id)
.copied()
.ok_or(BucketLogError::HeadNotFound(0))
}
async fn has(&self, id: Uuid, link: Link) -> Result<Vec<u64>, BucketLogError<Self::Error>> {
let inner = self.inner.read().map_err(|e| {
BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
"failed to acquire read lock: {}",
e
)))
})?;
Ok(inner
.link_index
.get(&id)
.and_then(|links| links.get(&link))
.cloned()
.unwrap_or_default())
}
async fn list_buckets(&self) -> Result<Vec<Uuid>, BucketLogError<Self::Error>> {
let inner = self.inner.read().map_err(|e| {
BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
"failed to acquire read lock: {}",
e
)))
})?;
Ok(inner.entries.keys().copied().collect())
}
async fn latest_published(
&self,
id: Uuid,
) -> Result<Option<(Link, u64)>, BucketLogError<Self::Error>> {
let inner = self.inner.read().map_err(|e| {
BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
"failed to acquire read lock: {}",
e
)))
})?;
let Some(published_map) = inner.published.get(&id) else {
return Ok(None);
};
let Some(entries) = inner.entries.get(&id) else {
return Ok(None);
};
let mut best: Option<(Link, u64)> = None;
for (height, links) in entries.iter() {
for link in links {
if published_map.get(link).copied().unwrap_or(false)
&& (best.is_none() || *height > best.as_ref().unwrap().1)
{
best = Some((link.clone(), *height));
}
}
}
Ok(best)
}
}
#[cfg(test)]
mod tests {
use super::*;
use iroh_blobs::Hash;
#[tokio::test]
async fn test_genesis_append() {
let provider = MemoryBucketLogProvider::new();
let id = Uuid::new_v4();
let link = Link::new(0x55, Hash::from_bytes([1; 32]));
let result = provider
.append(id, "test".to_string(), link.clone(), None, 0, false)
.await;
assert!(result.is_ok());
let height = provider.height(id).await.unwrap();
assert_eq!(height, 0);
let heads = provider.heads(id, 0).await.unwrap();
assert_eq!(heads, vec![link]);
}
#[tokio::test]
async fn test_conflict() {
let provider = MemoryBucketLogProvider::new();
let id = Uuid::new_v4();
let link = Link::new(0x55, Hash::from_bytes([1; 32]));
provider
.append(id, "test".to_string(), link.clone(), None, 0, false)
.await
.unwrap();
let result = provider
.append(id, "test".to_string(), link, None, 0, false)
.await;
assert!(matches!(result, Err(BucketLogError::Conflict)));
}
#[tokio::test]
async fn test_invalid_append() {
let provider = MemoryBucketLogProvider::new();
let id = Uuid::new_v4();
let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
provider
.append(id, "test".to_string(), link1, None, 0, false)
.await
.unwrap();
let result = provider
.append(id, "test".to_string(), link2.clone(), Some(link2), 1, false)
.await;
assert!(matches!(
result,
Err(BucketLogError::InvalidAppend(_, _, _))
));
}
#[tokio::test]
async fn test_valid_chain() {
let provider = MemoryBucketLogProvider::new();
let id = Uuid::new_v4();
let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
provider
.append(id, "test".to_string(), link1.clone(), None, 0, false)
.await
.unwrap();
provider
.append(id, "test".to_string(), link2.clone(), Some(link1), 1, false)
.await
.unwrap();
let height = provider.height(id).await.unwrap();
assert_eq!(height, 1);
let heights = provider.has(id, link2).await.unwrap();
assert_eq!(heights, vec![1]);
}
#[tokio::test]
async fn test_latest_published() {
let provider = MemoryBucketLogProvider::new();
let id = Uuid::new_v4();
let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
let link3 = Link::new(0x55, Hash::from_bytes([3; 32]));
provider
.append(id, "test".to_string(), link1.clone(), None, 0, false)
.await
.unwrap();
assert!(provider.latest_published(id).await.unwrap().is_none());
provider
.append(
id,
"test".to_string(),
link2.clone(),
Some(link1.clone()),
1,
true,
)
.await
.unwrap();
let (link, height) = provider.latest_published(id).await.unwrap().unwrap();
assert_eq!(link, link2);
assert_eq!(height, 1);
provider
.append(
id,
"test".to_string(),
link3.clone(),
Some(link2.clone()),
2,
false,
)
.await
.unwrap();
let (link, height) = provider.latest_published(id).await.unwrap().unwrap();
assert_eq!(link, link2);
assert_eq!(height, 1);
}
}