1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4use uuid::Uuid;
5
6use super::provider::{BucketLogError, BucketLogProvider};
7use crate::linked_data::Link;
8
9#[derive(Debug, Clone)]
11pub struct MemoryBucketLogProvider {
12 inner: Arc<RwLock<MemoryBucketLogProviderInner>>,
13}
14
15#[derive(Debug, Default)]
16struct MemoryBucketLogProviderInner {
17 entries: HashMap<Uuid, HashMap<u64, Vec<Link>>>,
20 max_heights: HashMap<Uuid, u64>,
22 link_index: HashMap<Uuid, HashMap<Link, Vec<u64>>>,
24 names: HashMap<Uuid, String>,
26 published: HashMap<Uuid, HashMap<Link, bool>>,
28}
29
30#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
31pub enum MemoryBucketLogProviderError {
32 #[error("memory provider error: {0}")]
33 Internal(String),
34}
35
36impl MemoryBucketLogProvider {
37 pub fn new() -> Self {
38 Self {
39 inner: Arc::new(RwLock::new(MemoryBucketLogProviderInner::default())),
40 }
41 }
42}
43
44impl Default for MemoryBucketLogProvider {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50#[async_trait]
51impl BucketLogProvider for MemoryBucketLogProvider {
52 type Error = MemoryBucketLogProviderError;
53
54 async fn exists(&self, id: Uuid) -> Result<bool, BucketLogError<Self::Error>> {
55 let inner = self.inner.read().map_err(|e| {
56 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
57 "failed to acquire read lock: {}",
58 e
59 )))
60 })?;
61
62 Ok(inner.entries.contains_key(&id))
63 }
64
65 async fn heads(&self, id: Uuid, height: u64) -> Result<Vec<Link>, BucketLogError<Self::Error>> {
66 let inner = self.inner.read().map_err(|e| {
67 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
68 "failed to acquire read lock: {}",
69 e
70 )))
71 })?;
72
73 Ok(inner
74 .entries
75 .get(&id)
76 .and_then(|heights| heights.get(&height))
77 .cloned()
78 .unwrap_or_default())
79 }
80
81 async fn append(
82 &self,
83 id: Uuid,
84 name: String,
85 current: Link,
86 previous: Option<Link>,
87 height: u64,
88 published: bool,
89 ) -> Result<(), BucketLogError<Self::Error>> {
90 let mut inner = self.inner.write().map_err(|e| {
91 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
92 "failed to acquire write lock: {}",
93 e
94 )))
95 })?;
96
97 inner.names.insert(id, name);
99
100 let bucket_entries = inner.entries.entry(id).or_insert_with(HashMap::new);
102
103 if let Some(existing_links) = bucket_entries.get(&height) {
105 if existing_links.contains(¤t) {
106 return Err(BucketLogError::Conflict);
107 }
108 }
109
110 if let Some(prev_link) = &previous {
112 if height == 0 {
114 return Err(BucketLogError::InvalidAppend(
115 current,
116 prev_link.clone(),
117 height,
118 ));
119 }
120
121 let expected_prev_height = height - 1;
122 let prev_exists = bucket_entries
123 .get(&expected_prev_height)
124 .map(|links| links.contains(prev_link))
125 .unwrap_or(false);
126
127 if !prev_exists {
128 return Err(BucketLogError::InvalidAppend(
129 current,
130 prev_link.clone(),
131 expected_prev_height,
132 ));
133 }
134 } else {
135 if height != 0 {
137 return Err(BucketLogError::InvalidAppend(
139 current,
140 Link::default(), height,
142 ));
143 }
144 }
145
146 bucket_entries
148 .entry(height)
149 .or_insert_with(Vec::new)
150 .push(current.clone());
151
152 let current_max = inner.max_heights.get(&id).copied();
154 if current_max.is_none() || height > current_max.unwrap() {
155 inner.max_heights.insert(id, height);
156 }
157
158 inner
160 .link_index
161 .entry(id)
162 .or_insert_with(HashMap::new)
163 .entry(current.clone())
164 .or_insert_with(Vec::new)
165 .push(height);
166
167 inner
169 .published
170 .entry(id)
171 .or_insert_with(HashMap::new)
172 .insert(current, published);
173
174 Ok(())
175 }
176
177 async fn height(&self, id: Uuid) -> Result<u64, BucketLogError<Self::Error>> {
178 let inner = self.inner.read().map_err(|e| {
179 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
180 "failed to acquire read lock: {}",
181 e
182 )))
183 })?;
184
185 inner
186 .max_heights
187 .get(&id)
188 .copied()
189 .ok_or(BucketLogError::HeadNotFound(0))
190 }
191
192 async fn has(&self, id: Uuid, link: Link) -> Result<Vec<u64>, BucketLogError<Self::Error>> {
193 let inner = self.inner.read().map_err(|e| {
194 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
195 "failed to acquire read lock: {}",
196 e
197 )))
198 })?;
199
200 Ok(inner
201 .link_index
202 .get(&id)
203 .and_then(|links| links.get(&link))
204 .cloned()
205 .unwrap_or_default())
206 }
207
208 async fn list_buckets(&self) -> Result<Vec<Uuid>, BucketLogError<Self::Error>> {
209 let inner = self.inner.read().map_err(|e| {
210 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
211 "failed to acquire read lock: {}",
212 e
213 )))
214 })?;
215
216 Ok(inner.entries.keys().copied().collect())
217 }
218
219 async fn latest_published(
220 &self,
221 id: Uuid,
222 ) -> Result<Option<(Link, u64)>, BucketLogError<Self::Error>> {
223 let inner = self.inner.read().map_err(|e| {
224 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
225 "failed to acquire read lock: {}",
226 e
227 )))
228 })?;
229
230 let Some(published_map) = inner.published.get(&id) else {
232 return Ok(None);
233 };
234
235 let Some(entries) = inner.entries.get(&id) else {
237 return Ok(None);
238 };
239
240 let mut best: Option<(Link, u64)> = None;
242 for (height, links) in entries.iter() {
243 for link in links {
244 if published_map.get(link).copied().unwrap_or(false)
245 && (best.is_none() || *height > best.as_ref().unwrap().1)
246 {
247 best = Some((link.clone(), *height));
248 }
249 }
250 }
251
252 Ok(best)
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use iroh_blobs::Hash;
260
261 #[tokio::test]
262 async fn test_genesis_append() {
263 let provider = MemoryBucketLogProvider::new();
264 let id = Uuid::new_v4();
265 let link = Link::new(0x55, Hash::from_bytes([1; 32]));
266
267 let result = provider
269 .append(id, "test".to_string(), link.clone(), None, 0, false)
270 .await;
271 assert!(result.is_ok());
272
273 let height = provider.height(id).await.unwrap();
275 assert_eq!(height, 0);
276
277 let heads = provider.heads(id, 0).await.unwrap();
279 assert_eq!(heads, vec![link]);
280 }
281
282 #[tokio::test]
283 async fn test_conflict() {
284 let provider = MemoryBucketLogProvider::new();
285 let id = Uuid::new_v4();
286 let link = Link::new(0x55, Hash::from_bytes([1; 32]));
287
288 provider
290 .append(id, "test".to_string(), link.clone(), None, 0, false)
291 .await
292 .unwrap();
293
294 let result = provider
296 .append(id, "test".to_string(), link, None, 0, false)
297 .await;
298 assert!(matches!(result, Err(BucketLogError::Conflict)));
299 }
300
301 #[tokio::test]
302 async fn test_invalid_append() {
303 let provider = MemoryBucketLogProvider::new();
304 let id = Uuid::new_v4();
305 let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
306 let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
307
308 provider
310 .append(id, "test".to_string(), link1, None, 0, false)
311 .await
312 .unwrap();
313
314 let result = provider
316 .append(id, "test".to_string(), link2.clone(), Some(link2), 1, false)
317 .await;
318 assert!(matches!(
319 result,
320 Err(BucketLogError::InvalidAppend(_, _, _))
321 ));
322 }
323
324 #[tokio::test]
325 async fn test_valid_chain() {
326 let provider = MemoryBucketLogProvider::new();
327 let id = Uuid::new_v4();
328 let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
329 let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
330
331 provider
333 .append(id, "test".to_string(), link1.clone(), None, 0, false)
334 .await
335 .unwrap();
336
337 provider
339 .append(id, "test".to_string(), link2.clone(), Some(link1), 1, false)
340 .await
341 .unwrap();
342
343 let height = provider.height(id).await.unwrap();
345 assert_eq!(height, 1);
346
347 let heights = provider.has(id, link2).await.unwrap();
349 assert_eq!(heights, vec![1]);
350 }
351
352 #[tokio::test]
353 async fn test_latest_published() {
354 let provider = MemoryBucketLogProvider::new();
355 let id = Uuid::new_v4();
356 let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
357 let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
358 let link3 = Link::new(0x55, Hash::from_bytes([3; 32]));
359
360 provider
362 .append(id, "test".to_string(), link1.clone(), None, 0, false)
363 .await
364 .unwrap();
365
366 assert!(provider.latest_published(id).await.unwrap().is_none());
368
369 provider
371 .append(
372 id,
373 "test".to_string(),
374 link2.clone(),
375 Some(link1.clone()),
376 1,
377 true,
378 )
379 .await
380 .unwrap();
381
382 let (link, height) = provider.latest_published(id).await.unwrap().unwrap();
384 assert_eq!(link, link2);
385 assert_eq!(height, 1);
386
387 provider
389 .append(
390 id,
391 "test".to_string(),
392 link3.clone(),
393 Some(link2.clone()),
394 2,
395 false,
396 )
397 .await
398 .unwrap();
399
400 let (link, height) = provider.latest_published(id).await.unwrap().unwrap();
402 assert_eq!(link, link2);
403 assert_eq!(height, 1);
404 }
405}