1use async_trait::async_trait;
4use std::collections::HashMap;
5use std::sync::{Arc, RwLock};
6
7use crate::types::{to_hex, Hash};
8
9#[async_trait]
11pub trait Store: Send + Sync {
12 async fn put(&self, hash: Hash, data: Vec<u8>) -> Result<bool, StoreError>;
15
16 async fn get(&self, hash: &Hash) -> Result<Option<Vec<u8>>, StoreError>;
19
20 async fn has(&self, hash: &Hash) -> Result<bool, StoreError>;
22
23 async fn delete(&self, hash: &Hash) -> Result<bool, StoreError>;
26}
27
28#[derive(Debug, thiserror::Error)]
30pub enum StoreError {
31 #[error("IO error: {0}")]
32 Io(#[from] std::io::Error),
33 #[error("Store error: {0}")]
34 Other(String),
35}
36
37#[derive(Debug, Clone, Default)]
40pub struct MemoryStore {
41 data: Arc<RwLock<HashMap<String, Vec<u8>>>>,
42}
43
44impl MemoryStore {
45 pub fn new() -> Self {
46 Self {
47 data: Arc::new(RwLock::new(HashMap::new())),
48 }
49 }
50
51 pub fn size(&self) -> usize {
53 self.data.read().unwrap().len()
54 }
55
56 pub fn total_bytes(&self) -> usize {
58 self.data
59 .read()
60 .unwrap()
61 .values()
62 .map(|v| v.len())
63 .sum()
64 }
65
66 pub fn clear(&self) {
68 self.data.write().unwrap().clear();
69 }
70
71 pub fn keys(&self) -> Vec<Hash> {
73 self.data
74 .read()
75 .unwrap()
76 .keys()
77 .filter_map(|hex| {
78 let bytes = hex::decode(hex).ok()?;
79 if bytes.len() != 32 {
80 return None;
81 }
82 let mut hash = [0u8; 32];
83 hash.copy_from_slice(&bytes);
84 Some(hash)
85 })
86 .collect()
87 }
88}
89
90#[async_trait]
91impl Store for MemoryStore {
92 async fn put(&self, hash: Hash, data: Vec<u8>) -> Result<bool, StoreError> {
93 let key = to_hex(&hash);
94 let mut store = self.data.write().unwrap();
95 if store.contains_key(&key) {
96 return Ok(false);
97 }
98 store.insert(key, data);
100 Ok(true)
101 }
102
103 async fn get(&self, hash: &Hash) -> Result<Option<Vec<u8>>, StoreError> {
104 let key = to_hex(hash);
105 let store = self.data.read().unwrap();
106 Ok(store.get(&key).cloned())
108 }
109
110 async fn has(&self, hash: &Hash) -> Result<bool, StoreError> {
111 let key = to_hex(hash);
112 Ok(self.data.read().unwrap().contains_key(&key))
113 }
114
115 async fn delete(&self, hash: &Hash) -> Result<bool, StoreError> {
116 let key = to_hex(hash);
117 Ok(self.data.write().unwrap().remove(&key).is_some())
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::hash::sha256;
125
126 #[tokio::test]
127 async fn test_put_returns_true_for_new() {
128 let store = MemoryStore::new();
129 let data = vec![1u8, 2, 3];
130 let hash = sha256(&data);
131
132 let result = store.put(hash, data).await.unwrap();
133 assert!(result);
134 }
135
136 #[tokio::test]
137 async fn test_put_returns_false_for_duplicate() {
138 let store = MemoryStore::new();
139 let data = vec![1u8, 2, 3];
140 let hash = sha256(&data);
141
142 store.put(hash, data.clone()).await.unwrap();
143 let result = store.put(hash, data).await.unwrap();
144 assert!(!result);
145 }
146
147 #[tokio::test]
148 async fn test_get_returns_data() {
149 let store = MemoryStore::new();
150 let data = vec![1u8, 2, 3];
151 let hash = sha256(&data);
152
153 store.put(hash, data.clone()).await.unwrap();
154 let result = store.get(&hash).await.unwrap();
155
156 assert_eq!(result, Some(data));
157 }
158
159 #[tokio::test]
160 async fn test_get_returns_none_for_missing() {
161 let store = MemoryStore::new();
162 let hash = [0u8; 32];
163
164 let result = store.get(&hash).await.unwrap();
165 assert!(result.is_none());
166 }
167
168 #[tokio::test]
169 async fn test_has_returns_true() {
170 let store = MemoryStore::new();
171 let data = vec![1u8, 2, 3];
172 let hash = sha256(&data);
173
174 store.put(hash, data).await.unwrap();
175 assert!(store.has(&hash).await.unwrap());
176 }
177
178 #[tokio::test]
179 async fn test_has_returns_false() {
180 let store = MemoryStore::new();
181 let hash = [0u8; 32];
182
183 assert!(!store.has(&hash).await.unwrap());
184 }
185
186 #[tokio::test]
187 async fn test_delete_returns_true() {
188 let store = MemoryStore::new();
189 let data = vec![1u8, 2, 3];
190 let hash = sha256(&data);
191
192 store.put(hash, data).await.unwrap();
193 let result = store.delete(&hash).await.unwrap();
194
195 assert!(result);
196 assert!(!store.has(&hash).await.unwrap());
197 }
198
199 #[tokio::test]
200 async fn test_delete_returns_false() {
201 let store = MemoryStore::new();
202 let hash = [0u8; 32];
203
204 let result = store.delete(&hash).await.unwrap();
205 assert!(!result);
206 }
207
208 #[tokio::test]
209 async fn test_size() {
210 let store = MemoryStore::new();
211 assert_eq!(store.size(), 0);
212
213 let data1 = vec![1u8];
214 let data2 = vec![2u8];
215 let hash1 = sha256(&data1);
216 let hash2 = sha256(&data2);
217
218 store.put(hash1, data1).await.unwrap();
219 store.put(hash2, data2).await.unwrap();
220
221 assert_eq!(store.size(), 2);
222 }
223
224 #[tokio::test]
225 async fn test_total_bytes() {
226 let store = MemoryStore::new();
227 assert_eq!(store.total_bytes(), 0);
228
229 let data1 = vec![1u8, 2, 3];
230 let data2 = vec![4u8, 5];
231 let hash1 = sha256(&data1);
232 let hash2 = sha256(&data2);
233
234 store.put(hash1, data1).await.unwrap();
235 store.put(hash2, data2).await.unwrap();
236
237 assert_eq!(store.total_bytes(), 5);
238 }
239
240 #[tokio::test]
241 async fn test_clear() {
242 let store = MemoryStore::new();
243 let data = vec![1u8, 2, 3];
244 let hash = sha256(&data);
245
246 store.put(hash, data).await.unwrap();
247 store.clear();
248
249 assert_eq!(store.size(), 0);
250 assert!(!store.has(&hash).await.unwrap());
251 }
252
253 #[tokio::test]
254 async fn test_keys() {
255 let store = MemoryStore::new();
256 assert!(store.keys().is_empty());
257
258 let data1 = vec![1u8];
259 let data2 = vec![2u8];
260 let hash1 = sha256(&data1);
261 let hash2 = sha256(&data2);
262
263 store.put(hash1, data1).await.unwrap();
264 store.put(hash2, data2).await.unwrap();
265
266 let keys = store.keys();
267 assert_eq!(keys.len(), 2);
268
269 let mut hex_keys: Vec<_> = keys.iter().map(to_hex).collect();
270 hex_keys.sort();
271 let mut expected: Vec<_> = vec![to_hex(&hash1), to_hex(&hash2)];
272 expected.sort();
273 assert_eq!(hex_keys, expected);
274 }
275}