1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::{Arc, RwLock};
4use uuid::Uuid;
5
6use super::provider::{BucketLogError, BucketLogProvider};
7use crate::linked_data::Link;
8
9#[derive(Debug, Clone)]
11pub struct MemoryBucketLogProvider {
12 inner: Arc<RwLock<MemoryBucketLogProviderInner>>,
13}
14
15#[derive(Debug, Default)]
16struct MemoryBucketLogProviderInner {
17 entries: HashMap<Uuid, HashMap<u64, Vec<Link>>>,
20 max_heights: HashMap<Uuid, u64>,
22 link_index: HashMap<Uuid, HashMap<Link, Vec<u64>>>,
24 names: HashMap<Uuid, String>,
26}
27
28#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
29pub enum MemoryBucketLogProviderError {
30 #[error("memory provider error: {0}")]
31 Internal(String),
32}
33
34impl MemoryBucketLogProvider {
35 pub fn new() -> Self {
36 Self {
37 inner: Arc::new(RwLock::new(MemoryBucketLogProviderInner::default())),
38 }
39 }
40}
41
42impl Default for MemoryBucketLogProvider {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48#[async_trait]
49impl BucketLogProvider for MemoryBucketLogProvider {
50 type Error = MemoryBucketLogProviderError;
51
52 async fn exists(&self, id: Uuid) -> Result<bool, BucketLogError<Self::Error>> {
53 let inner = self.inner.read().map_err(|e| {
54 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
55 "failed to acquire read lock: {}",
56 e
57 )))
58 })?;
59
60 Ok(inner.entries.contains_key(&id))
61 }
62
63 async fn heads(&self, id: Uuid, height: u64) -> Result<Vec<Link>, BucketLogError<Self::Error>> {
64 let inner = self.inner.read().map_err(|e| {
65 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
66 "failed to acquire read lock: {}",
67 e
68 )))
69 })?;
70
71 Ok(inner
72 .entries
73 .get(&id)
74 .and_then(|heights| heights.get(&height))
75 .cloned()
76 .unwrap_or_default())
77 }
78
79 async fn append(
80 &self,
81 id: Uuid,
82 name: String,
83 current: Link,
84 previous: Option<Link>,
85 height: u64,
86 ) -> Result<(), BucketLogError<Self::Error>> {
87 let mut inner = self.inner.write().map_err(|e| {
88 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
89 "failed to acquire write lock: {}",
90 e
91 )))
92 })?;
93
94 inner.names.insert(id, name);
96
97 let bucket_entries = inner.entries.entry(id).or_insert_with(HashMap::new);
99
100 if let Some(existing_links) = bucket_entries.get(&height) {
102 if existing_links.contains(¤t) {
103 return Err(BucketLogError::Conflict);
104 }
105 }
106
107 if let Some(prev_link) = &previous {
109 if height == 0 {
111 return Err(BucketLogError::InvalidAppend(
112 current,
113 prev_link.clone(),
114 height,
115 ));
116 }
117
118 let expected_prev_height = height - 1;
119 let prev_exists = bucket_entries
120 .get(&expected_prev_height)
121 .map(|links| links.contains(prev_link))
122 .unwrap_or(false);
123
124 if !prev_exists {
125 return Err(BucketLogError::InvalidAppend(
126 current,
127 prev_link.clone(),
128 expected_prev_height,
129 ));
130 }
131 } else {
132 if height != 0 {
134 return Err(BucketLogError::InvalidAppend(
136 current,
137 Link::default(), height,
139 ));
140 }
141 }
142
143 bucket_entries
145 .entry(height)
146 .or_insert_with(Vec::new)
147 .push(current.clone());
148
149 let current_max = inner.max_heights.get(&id).copied();
151 if current_max.is_none() || height > current_max.unwrap() {
152 inner.max_heights.insert(id, height);
153 }
154
155 inner
157 .link_index
158 .entry(id)
159 .or_insert_with(HashMap::new)
160 .entry(current)
161 .or_insert_with(Vec::new)
162 .push(height);
163
164 Ok(())
165 }
166
167 async fn height(&self, id: Uuid) -> Result<u64, BucketLogError<Self::Error>> {
168 let inner = self.inner.read().map_err(|e| {
169 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
170 "failed to acquire read lock: {}",
171 e
172 )))
173 })?;
174
175 inner
176 .max_heights
177 .get(&id)
178 .copied()
179 .ok_or(BucketLogError::HeadNotFound(0))
180 }
181
182 async fn has(&self, id: Uuid, link: Link) -> Result<Vec<u64>, BucketLogError<Self::Error>> {
183 let inner = self.inner.read().map_err(|e| {
184 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
185 "failed to acquire read lock: {}",
186 e
187 )))
188 })?;
189
190 Ok(inner
191 .link_index
192 .get(&id)
193 .and_then(|links| links.get(&link))
194 .cloned()
195 .unwrap_or_default())
196 }
197
198 async fn list_buckets(&self) -> Result<Vec<Uuid>, BucketLogError<Self::Error>> {
199 let inner = self.inner.read().map_err(|e| {
200 BucketLogError::Provider(MemoryBucketLogProviderError::Internal(format!(
201 "failed to acquire read lock: {}",
202 e
203 )))
204 })?;
205
206 Ok(inner.entries.keys().copied().collect())
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use iroh_blobs::Hash;
214
215 #[tokio::test]
216 async fn test_genesis_append() {
217 let provider = MemoryBucketLogProvider::new();
218 let id = Uuid::new_v4();
219 let link = Link::new(0x55, Hash::from_bytes([1; 32]));
220
221 let result = provider
223 .append(id, "test".to_string(), link.clone(), None, 0)
224 .await;
225 assert!(result.is_ok());
226
227 let height = provider.height(id).await.unwrap();
229 assert_eq!(height, 0);
230
231 let heads = provider.heads(id, 0).await.unwrap();
233 assert_eq!(heads, vec![link]);
234 }
235
236 #[tokio::test]
237 async fn test_conflict() {
238 let provider = MemoryBucketLogProvider::new();
239 let id = Uuid::new_v4();
240 let link = Link::new(0x55, Hash::from_bytes([1; 32]));
241
242 provider
244 .append(id, "test".to_string(), link.clone(), None, 0)
245 .await
246 .unwrap();
247
248 let result = provider.append(id, "test".to_string(), link, None, 0).await;
250 assert!(matches!(result, Err(BucketLogError::Conflict)));
251 }
252
253 #[tokio::test]
254 async fn test_invalid_append() {
255 let provider = MemoryBucketLogProvider::new();
256 let id = Uuid::new_v4();
257 let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
258 let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
259
260 provider
262 .append(id, "test".to_string(), link1, None, 0)
263 .await
264 .unwrap();
265
266 let result = provider
268 .append(id, "test".to_string(), link2.clone(), Some(link2), 1)
269 .await;
270 assert!(matches!(
271 result,
272 Err(BucketLogError::InvalidAppend(_, _, _))
273 ));
274 }
275
276 #[tokio::test]
277 async fn test_valid_chain() {
278 let provider = MemoryBucketLogProvider::new();
279 let id = Uuid::new_v4();
280 let link1 = Link::new(0x55, Hash::from_bytes([1; 32]));
281 let link2 = Link::new(0x55, Hash::from_bytes([2; 32]));
282
283 provider
285 .append(id, "test".to_string(), link1.clone(), None, 0)
286 .await
287 .unwrap();
288
289 provider
291 .append(id, "test".to_string(), link2.clone(), Some(link1), 1)
292 .await
293 .unwrap();
294
295 let height = provider.height(id).await.unwrap();
297 assert_eq!(height, 1);
298
299 let heights = provider.has(id, link2).await.unwrap();
301 assert_eq!(heights, vec![1]);
302 }
303}