Skip to main content

hadb_storage_mem/
lib.rs

1//! `hadb-storage-mem`: in-memory `StorageBackend` implementation.
2//!
3//! Used as the test fixture across the hadb/walrust/turbolite/haqlite
4//! workspaces. Not intended for production: everything lives in a single
5//! `Mutex<HashMap>` and etags are a monotonic counter.
6//!
7//! # Etags
8//!
9//! An etag is an opaque string that survives across the CAS boundary:
10//! `put_if_match(key, data, etag)` only commits if the current stored etag
11//! equals `etag`. We use a monotonic `u64` counter as a string so etags are
12//! unique per-write and totally ordered.
13
14use std::collections::HashMap;
15use std::sync::atomic::{AtomicU64, Ordering};
16
17use anyhow::Result;
18use async_trait::async_trait;
19use tokio::sync::Mutex;
20
21use hadb_storage::{CasResult, StorageBackend};
22
23/// In-memory blob store. Cheap to clone via `Arc`.
24#[derive(Default)]
25pub struct MemStorage {
26    entries: Mutex<HashMap<String, (Vec<u8>, String)>>,
27    etag_counter: AtomicU64,
28}
29
30impl MemStorage {
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    fn next_etag(&self) -> String {
36        let n = self.etag_counter.fetch_add(1, Ordering::SeqCst);
37        // `+ 1` because fetch_add returns the previous value; we want 1-based
38        // etags so "0" never appears (some implementors treat 0 as "missing").
39        format!("{}", n + 1)
40    }
41}
42
43#[async_trait]
44impl StorageBackend for MemStorage {
45    async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
46        Ok(self
47            .entries
48            .lock()
49            .await
50            .get(key)
51            .map(|(bytes, _)| bytes.clone()))
52    }
53
54    async fn put(&self, key: &str, data: &[u8]) -> Result<()> {
55        let etag = self.next_etag();
56        self.entries
57            .lock()
58            .await
59            .insert(key.to_string(), (data.to_vec(), etag));
60        Ok(())
61    }
62
63    async fn delete(&self, key: &str) -> Result<()> {
64        self.entries.lock().await.remove(key);
65        Ok(())
66    }
67
68    async fn list(&self, prefix: &str, after: Option<&str>) -> Result<Vec<String>> {
69        let guard = self.entries.lock().await;
70        let mut keys: Vec<String> = guard
71            .keys()
72            .filter(|k| k.starts_with(prefix))
73            .filter(|k| match after {
74                Some(cursor) => k.as_str() > cursor,
75                None => true,
76            })
77            .cloned()
78            .collect();
79        keys.sort();
80        Ok(keys)
81    }
82
83    async fn exists(&self, key: &str) -> Result<bool> {
84        Ok(self.entries.lock().await.contains_key(key))
85    }
86
87    async fn put_if_absent(&self, key: &str, data: &[u8]) -> Result<CasResult> {
88        let mut guard = self.entries.lock().await;
89        if guard.contains_key(key) {
90            return Ok(CasResult {
91                success: false,
92                etag: None,
93            });
94        }
95        let etag = self.next_etag();
96        guard.insert(key.to_string(), (data.to_vec(), etag.clone()));
97        Ok(CasResult {
98            success: true,
99            etag: Some(etag),
100        })
101    }
102
103    async fn put_if_match(&self, key: &str, data: &[u8], etag: &str) -> Result<CasResult> {
104        let mut guard = self.entries.lock().await;
105        match guard.get(key) {
106            Some((_, current)) if current == etag => {
107                let new_etag = self.next_etag();
108                guard.insert(key.to_string(), (data.to_vec(), new_etag.clone()));
109                Ok(CasResult {
110                    success: true,
111                    etag: Some(new_etag),
112                })
113            }
114            _ => Ok(CasResult {
115                success: false,
116                etag: None,
117            }),
118        }
119    }
120
121    async fn range_get(&self, key: &str, start: u64, len: u32) -> Result<Option<Vec<u8>>> {
122        let guard = self.entries.lock().await;
123        let Some((bytes, _)) = guard.get(key) else {
124            return Ok(None);
125        };
126        let start = start as usize;
127        if start >= bytes.len() {
128            return Ok(Some(Vec::new()));
129        }
130        let end = start.saturating_add(len as usize).min(bytes.len());
131        Ok(Some(bytes[start..end].to_vec()))
132    }
133
134    fn backend_name(&self) -> &str {
135        "mem"
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use std::sync::Arc;
143
144    #[tokio::test]
145    async fn put_then_get_roundtrips() {
146        let s = MemStorage::new();
147        s.put("k", b"v").await.unwrap();
148        assert_eq!(s.get("k").await.unwrap().unwrap(), b"v");
149    }
150
151    #[tokio::test]
152    async fn get_missing_returns_none() {
153        let s = MemStorage::new();
154        assert!(s.get("nope").await.unwrap().is_none());
155    }
156
157    #[tokio::test]
158    async fn put_overwrites() {
159        let s = MemStorage::new();
160        s.put("k", b"first").await.unwrap();
161        s.put("k", b"second").await.unwrap();
162        assert_eq!(s.get("k").await.unwrap().unwrap(), b"second");
163    }
164
165    #[tokio::test]
166    async fn delete_removes_and_is_idempotent() {
167        let s = MemStorage::new();
168        s.put("k", b"v").await.unwrap();
169        s.delete("k").await.unwrap();
170        assert!(s.get("k").await.unwrap().is_none());
171        // Second delete must not error.
172        s.delete("k").await.unwrap();
173    }
174
175    #[tokio::test]
176    async fn exists_reflects_state() {
177        let s = MemStorage::new();
178        assert!(!s.exists("k").await.unwrap());
179        s.put("k", b"").await.unwrap();
180        assert!(s.exists("k").await.unwrap());
181        s.delete("k").await.unwrap();
182        assert!(!s.exists("k").await.unwrap());
183    }
184
185    #[tokio::test]
186    async fn list_filters_by_prefix_and_sorts() {
187        let s = MemStorage::new();
188        s.put("a/1", b"").await.unwrap();
189        s.put("a/3", b"").await.unwrap();
190        s.put("a/2", b"").await.unwrap();
191        s.put("b/1", b"").await.unwrap();
192        assert_eq!(s.list("a/", None).await.unwrap(), vec!["a/1", "a/2", "a/3"]);
193        assert_eq!(s.list("b/", None).await.unwrap(), vec!["b/1"]);
194        assert!(s.list("c/", None).await.unwrap().is_empty());
195    }
196
197    #[tokio::test]
198    async fn list_after_is_exclusive() {
199        let s = MemStorage::new();
200        for k in ["a/1", "a/2", "a/3"] {
201            s.put(k, b"").await.unwrap();
202        }
203        let got = s.list("a/", Some("a/1")).await.unwrap();
204        assert_eq!(got, vec!["a/2", "a/3"]);
205    }
206
207    #[tokio::test]
208    async fn put_if_absent_first_wins() {
209        let s = MemStorage::new();
210        let a = s.put_if_absent("k", b"first").await.unwrap();
211        assert!(a.success);
212        let b = s.put_if_absent("k", b"second").await.unwrap();
213        assert!(!b.success);
214        assert_eq!(s.get("k").await.unwrap().unwrap(), b"first");
215    }
216
217    #[tokio::test]
218    async fn put_if_match_advances_etag() {
219        let s = MemStorage::new();
220        let a = s.put_if_absent("k", b"v1").await.unwrap();
221        let e1 = a.etag.unwrap();
222
223        let b = s.put_if_match("k", b"v2", &e1).await.unwrap();
224        assert!(b.success);
225        let e2 = b.etag.unwrap();
226        assert_ne!(e1, e2);
227
228        // Stale etag now fails.
229        let c = s.put_if_match("k", b"v3", &e1).await.unwrap();
230        assert!(!c.success);
231        assert_eq!(s.get("k").await.unwrap().unwrap(), b"v2");
232    }
233
234    #[tokio::test]
235    async fn put_if_match_on_missing_fails() {
236        let s = MemStorage::new();
237        let r = s.put_if_match("nope", b"x", "any").await.unwrap();
238        assert!(!r.success);
239        assert!(r.etag.is_none());
240    }
241
242    #[tokio::test]
243    async fn concurrent_put_if_absent_exactly_one_wins() {
244        let s = Arc::new(MemStorage::new());
245        let mut handles = Vec::new();
246        for i in 0..16 {
247            let s = Arc::clone(&s);
248            handles.push(tokio::spawn(async move {
249                s.put_if_absent("lease", format!("node-{i}").as_bytes())
250                    .await
251                    .unwrap()
252            }));
253        }
254        let mut wins = 0;
255        for h in handles {
256            if h.await.unwrap().success {
257                wins += 1;
258            }
259        }
260        assert_eq!(wins, 1);
261    }
262
263    #[tokio::test]
264    async fn range_get_slices_correctly() {
265        let s = MemStorage::new();
266        s.put("k", b"abcdefghij").await.unwrap();
267        assert_eq!(s.range_get("k", 0, 3).await.unwrap().unwrap(), b"abc");
268        assert_eq!(s.range_get("k", 2, 4).await.unwrap().unwrap(), b"cdef");
269        assert_eq!(s.range_get("k", 8, 10).await.unwrap().unwrap(), b"ij");
270        assert_eq!(
271            s.range_get("k", 50, 10).await.unwrap().unwrap(),
272            Vec::<u8>::new()
273        );
274    }
275
276    #[tokio::test]
277    async fn range_get_on_missing_returns_none() {
278        let s = MemStorage::new();
279        assert!(s.range_get("nope", 0, 10).await.unwrap().is_none());
280    }
281
282    #[tokio::test]
283    async fn etags_are_unique_per_write() {
284        let s = MemStorage::new();
285        let a = s.put_if_absent("k", b"v1").await.unwrap();
286        let b = s
287            .put_if_match("k", b"v2", a.etag.as_ref().unwrap())
288            .await
289            .unwrap();
290        let c = s
291            .put_if_match("k", b"v3", b.etag.as_ref().unwrap())
292            .await
293            .unwrap();
294        let etags = [a.etag.unwrap(), b.etag.unwrap(), c.etag.unwrap()];
295        assert_eq!(
296            etags.iter().collect::<std::collections::HashSet<_>>().len(),
297            3
298        );
299    }
300
301    // Compile-time: exercises the trait object shape our consumers depend on.
302    #[allow(dead_code)]
303    fn _usable_as_arc_dyn(_: Arc<dyn StorageBackend>) {}
304
305    #[test]
306    fn backend_name_is_stable() {
307        assert_eq!(MemStorage::new().backend_name(), "mem");
308    }
309}