1use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17use anyhow::Result;
18use async_trait::async_trait;
19use tokio::sync::Mutex;
20
21use hadb_storage::{CasResult, StorageBackend};
22
23#[derive(Default)]
25pub struct MemStorage {
26 entries: Mutex<HashMap<String, (Vec<u8>, String)>>,
27 etag_counter: AtomicU64,
28}
29
30impl MemStorage {
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 fn next_etag(&self) -> String {
36 let n = self.etag_counter.fetch_add(1, Ordering::SeqCst);
37 format!("{}", n + 1)
40 }
41}
42
43#[async_trait]
44impl StorageBackend for MemStorage {
45 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
46 Ok(self
47 .entries
48 .lock()
49 .await
50 .get(key)
51 .map(|(bytes, _)| bytes.clone()))
52 }
53
54 async fn put(&self, key: &str, data: &[u8]) -> Result<()> {
55 let etag = self.next_etag();
56 self.entries
57 .lock()
58 .await
59 .insert(key.to_string(), (data.to_vec(), etag));
60 Ok(())
61 }
62
63 async fn delete(&self, key: &str) -> Result<()> {
64 self.entries.lock().await.remove(key);
65 Ok(())
66 }
67
68 async fn list(&self, prefix: &str, after: Option<&str>) -> Result<Vec<String>> {
69 let guard = self.entries.lock().await;
70 let mut keys: Vec<String> = guard
71 .keys()
72 .filter(|k| k.starts_with(prefix))
73 .filter(|k| match after {
74 Some(cursor) => k.as_str() > cursor,
75 None => true,
76 })
77 .cloned()
78 .collect();
79 keys.sort();
80 Ok(keys)
81 }
82
83 async fn exists(&self, key: &str) -> Result<bool> {
84 Ok(self.entries.lock().await.contains_key(key))
85 }
86
87 async fn put_if_absent(&self, key: &str, data: &[u8]) -> Result<CasResult> {
88 let mut guard = self.entries.lock().await;
89 if guard.contains_key(key) {
90 return Ok(CasResult {
91 success: false,
92 etag: None,
93 });
94 }
95 let etag = self.next_etag();
96 guard.insert(key.to_string(), (data.to_vec(), etag.clone()));
97 Ok(CasResult {
98 success: true,
99 etag: Some(etag),
100 })
101 }
102
103 async fn put_if_match(&self, key: &str, data: &[u8], etag: &str) -> Result<CasResult> {
104 let mut guard = self.entries.lock().await;
105 match guard.get(key) {
106 Some((_, current)) if current == etag => {
107 let new_etag = self.next_etag();
108 guard.insert(key.to_string(), (data.to_vec(), new_etag.clone()));
109 Ok(CasResult {
110 success: true,
111 etag: Some(new_etag),
112 })
113 }
114 _ => Ok(CasResult {
115 success: false,
116 etag: None,
117 }),
118 }
119 }
120
121 async fn range_get(&self, key: &str, start: u64, len: u32) -> Result<Option<Vec<u8>>> {
122 let guard = self.entries.lock().await;
123 let Some((bytes, _)) = guard.get(key) else {
124 return Ok(None);
125 };
126 let start = start as usize;
127 if start >= bytes.len() {
128 return Ok(Some(Vec::new()));
129 }
130 let end = start.saturating_add(len as usize).min(bytes.len());
131 Ok(Some(bytes[start..end].to_vec()))
132 }
133
134 fn backend_name(&self) -> &str {
135 "mem"
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use std::sync::Arc;
143
144 #[tokio::test]
145 async fn put_then_get_roundtrips() {
146 let s = MemStorage::new();
147 s.put("k", b"v").await.unwrap();
148 assert_eq!(s.get("k").await.unwrap().unwrap(), b"v");
149 }
150
151 #[tokio::test]
152 async fn get_missing_returns_none() {
153 let s = MemStorage::new();
154 assert!(s.get("nope").await.unwrap().is_none());
155 }
156
157 #[tokio::test]
158 async fn put_overwrites() {
159 let s = MemStorage::new();
160 s.put("k", b"first").await.unwrap();
161 s.put("k", b"second").await.unwrap();
162 assert_eq!(s.get("k").await.unwrap().unwrap(), b"second");
163 }
164
165 #[tokio::test]
166 async fn delete_removes_and_is_idempotent() {
167 let s = MemStorage::new();
168 s.put("k", b"v").await.unwrap();
169 s.delete("k").await.unwrap();
170 assert!(s.get("k").await.unwrap().is_none());
171 s.delete("k").await.unwrap();
173 }
174
175 #[tokio::test]
176 async fn exists_reflects_state() {
177 let s = MemStorage::new();
178 assert!(!s.exists("k").await.unwrap());
179 s.put("k", b"").await.unwrap();
180 assert!(s.exists("k").await.unwrap());
181 s.delete("k").await.unwrap();
182 assert!(!s.exists("k").await.unwrap());
183 }
184
185 #[tokio::test]
186 async fn list_filters_by_prefix_and_sorts() {
187 let s = MemStorage::new();
188 s.put("a/1", b"").await.unwrap();
189 s.put("a/3", b"").await.unwrap();
190 s.put("a/2", b"").await.unwrap();
191 s.put("b/1", b"").await.unwrap();
192 assert_eq!(s.list("a/", None).await.unwrap(), vec!["a/1", "a/2", "a/3"]);
193 assert_eq!(s.list("b/", None).await.unwrap(), vec!["b/1"]);
194 assert!(s.list("c/", None).await.unwrap().is_empty());
195 }
196
197 #[tokio::test]
198 async fn list_after_is_exclusive() {
199 let s = MemStorage::new();
200 for k in ["a/1", "a/2", "a/3"] {
201 s.put(k, b"").await.unwrap();
202 }
203 let got = s.list("a/", Some("a/1")).await.unwrap();
204 assert_eq!(got, vec!["a/2", "a/3"]);
205 }
206
207 #[tokio::test]
208 async fn put_if_absent_first_wins() {
209 let s = MemStorage::new();
210 let a = s.put_if_absent("k", b"first").await.unwrap();
211 assert!(a.success);
212 let b = s.put_if_absent("k", b"second").await.unwrap();
213 assert!(!b.success);
214 assert_eq!(s.get("k").await.unwrap().unwrap(), b"first");
215 }
216
217 #[tokio::test]
218 async fn put_if_match_advances_etag() {
219 let s = MemStorage::new();
220 let a = s.put_if_absent("k", b"v1").await.unwrap();
221 let e1 = a.etag.unwrap();
222
223 let b = s.put_if_match("k", b"v2", &e1).await.unwrap();
224 assert!(b.success);
225 let e2 = b.etag.unwrap();
226 assert_ne!(e1, e2);
227
228 let c = s.put_if_match("k", b"v3", &e1).await.unwrap();
230 assert!(!c.success);
231 assert_eq!(s.get("k").await.unwrap().unwrap(), b"v2");
232 }
233
234 #[tokio::test]
235 async fn put_if_match_on_missing_fails() {
236 let s = MemStorage::new();
237 let r = s.put_if_match("nope", b"x", "any").await.unwrap();
238 assert!(!r.success);
239 assert!(r.etag.is_none());
240 }
241
242 #[tokio::test]
243 async fn concurrent_put_if_absent_exactly_one_wins() {
244 let s = Arc::new(MemStorage::new());
245 let mut handles = Vec::new();
246 for i in 0..16 {
247 let s = Arc::clone(&s);
248 handles.push(tokio::spawn(async move {
249 s.put_if_absent("lease", format!("node-{i}").as_bytes())
250 .await
251 .unwrap()
252 }));
253 }
254 let mut wins = 0;
255 for h in handles {
256 if h.await.unwrap().success {
257 wins += 1;
258 }
259 }
260 assert_eq!(wins, 1);
261 }
262
263 #[tokio::test]
264 async fn range_get_slices_correctly() {
265 let s = MemStorage::new();
266 s.put("k", b"abcdefghij").await.unwrap();
267 assert_eq!(s.range_get("k", 0, 3).await.unwrap().unwrap(), b"abc");
268 assert_eq!(s.range_get("k", 2, 4).await.unwrap().unwrap(), b"cdef");
269 assert_eq!(s.range_get("k", 8, 10).await.unwrap().unwrap(), b"ij");
270 assert_eq!(
271 s.range_get("k", 50, 10).await.unwrap().unwrap(),
272 Vec::<u8>::new()
273 );
274 }
275
276 #[tokio::test]
277 async fn range_get_on_missing_returns_none() {
278 let s = MemStorage::new();
279 assert!(s.range_get("nope", 0, 10).await.unwrap().is_none());
280 }
281
282 #[tokio::test]
283 async fn etags_are_unique_per_write() {
284 let s = MemStorage::new();
285 let a = s.put_if_absent("k", b"v1").await.unwrap();
286 let b = s
287 .put_if_match("k", b"v2", a.etag.as_ref().unwrap())
288 .await
289 .unwrap();
290 let c = s
291 .put_if_match("k", b"v3", b.etag.as_ref().unwrap())
292 .await
293 .unwrap();
294 let etags = [a.etag.unwrap(), b.etag.unwrap(), c.etag.unwrap()];
295 assert_eq!(
296 etags.iter().collect::<std::collections::HashSet<_>>().len(),
297 3
298 );
299 }
300
301 #[allow(dead_code)]
303 fn _usable_as_arc_dyn(_: Arc<dyn StorageBackend>) {}
304
305 #[test]
306 fn backend_name_is_stable() {
307 assert_eq!(MemStorage::new().backend_name(), "mem");
308 }
309}