1use std::collections::HashMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use tokio::sync::RwLock;
13
14use crate::error::Error;
15
16#[async_trait]
38pub trait BaseStore<K, V>: Send + Sync
39where
40 K: Send + Sync,
41 V: Send + Sync + Clone,
42{
43 fn mget(&self, keys: &[K]) -> Vec<Option<V>>;
54
55 async fn amget(&self, keys: &[K]) -> Vec<Option<V>>
66 where
67 K: 'static,
68 V: 'static,
69 {
70 self.mget(keys)
71 }
72
73 fn mset(&self, key_value_pairs: &[(K, V)]);
79
80 async fn amset(&self, key_value_pairs: &[(K, V)])
86 where
87 K: 'static,
88 V: 'static,
89 {
90 self.mset(key_value_pairs)
91 }
92
93 fn mdelete(&self, keys: &[K]);
99
100 async fn amdelete(&self, keys: &[K])
106 where
107 K: 'static,
108 V: 'static,
109 {
110 self.mdelete(keys)
111 }
112
113 fn yield_keys(&self, prefix: Option<&str>) -> Vec<String>;
123
124 async fn ayield_keys(&self, prefix: Option<&str>) -> Vec<String>
134 where
135 K: 'static,
136 V: 'static,
137 {
138 self.yield_keys(prefix)
139 }
140}
141
142pub type ByteStore = dyn BaseStore<String, Vec<u8>>;
144
145pub struct InMemoryBaseStore<V>
150where
151 V: Clone + Send + Sync,
152{
153 store: Arc<RwLock<HashMap<String, V>>>,
154}
155
156impl<V> Default for InMemoryBaseStore<V>
157where
158 V: Clone + Send + Sync,
159{
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl<V> InMemoryBaseStore<V>
166where
167 V: Clone + Send + Sync,
168{
169 pub fn new() -> Self {
171 Self {
172 store: Arc::new(RwLock::new(HashMap::new())),
173 }
174 }
175
176 pub fn store(&self) -> &Arc<RwLock<HashMap<String, V>>> {
178 &self.store
179 }
180}
181
182#[async_trait]
183impl<V> BaseStore<String, V> for InMemoryBaseStore<V>
184where
185 V: Clone + Send + Sync + 'static,
186{
187 fn mget(&self, keys: &[String]) -> Vec<Option<V>> {
188 let store = self.store.blocking_read();
189 keys.iter().map(|key| store.get(key).cloned()).collect()
190 }
191
192 async fn amget(&self, keys: &[String]) -> Vec<Option<V>> {
193 let store = self.store.read().await;
194 keys.iter().map(|key| store.get(key).cloned()).collect()
195 }
196
197 fn mset(&self, key_value_pairs: &[(String, V)]) {
198 let mut store = self.store.blocking_write();
199 for (key, value) in key_value_pairs {
200 store.insert(key.clone(), value.clone());
201 }
202 }
203
204 async fn amset(&self, key_value_pairs: &[(String, V)]) {
205 let mut store = self.store.write().await;
206 for (key, value) in key_value_pairs {
207 store.insert(key.clone(), value.clone());
208 }
209 }
210
211 fn mdelete(&self, keys: &[String]) {
212 let mut store = self.store.blocking_write();
213 for key in keys {
214 store.remove(key);
215 }
216 }
217
218 async fn amdelete(&self, keys: &[String]) {
219 let mut store = self.store.write().await;
220 for key in keys {
221 store.remove(key);
222 }
223 }
224
225 fn yield_keys(&self, prefix: Option<&str>) -> Vec<String> {
226 let store = self.store.blocking_read();
227 match prefix {
228 None => store.keys().cloned().collect(),
229 Some(prefix) => store
230 .keys()
231 .filter(|key| key.starts_with(prefix))
232 .cloned()
233 .collect(),
234 }
235 }
236
237 async fn ayield_keys(&self, prefix: Option<&str>) -> Vec<String> {
238 let store = self.store.read().await;
239 match prefix {
240 None => store.keys().cloned().collect(),
241 Some(prefix) => store
242 .keys()
243 .filter(|key| key.starts_with(prefix))
244 .cloned()
245 .collect(),
246 }
247 }
248}
249
250pub type InMemoryStore = InMemoryBaseStore<serde_json::Value>;
277
278pub type InMemoryByteStore = InMemoryBaseStore<Vec<u8>>;
304
305#[derive(Debug, Clone)]
307pub struct InvalidKeyException {
308 pub key: String,
310 pub message: String,
312}
313
314impl InvalidKeyException {
315 pub fn new(key: impl Into<String>, message: impl Into<String>) -> Self {
317 Self {
318 key: key.into(),
319 message: message.into(),
320 }
321 }
322}
323
324impl std::fmt::Display for InvalidKeyException {
325 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326 write!(f, "Invalid key '{}': {}", self.key, self.message)
327 }
328}
329
330impl std::error::Error for InvalidKeyException {}
331
332impl From<InvalidKeyException> for Error {
333 fn from(e: InvalidKeyException) -> Self {
334 Error::Other(e.to_string())
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use serde_json::json;
342
343 #[test]
344 fn test_in_memory_store_mget_mset() {
345 let store = InMemoryStore::new();
346
347 store.mset(&[
348 ("key1".to_string(), json!("value1")),
349 ("key2".to_string(), json!(42)),
350 ]);
351
352 let values = store.mget(&["key1".to_string(), "key2".to_string(), "key3".to_string()]);
353 assert_eq!(values.len(), 3);
354 assert_eq!(values[0], Some(json!("value1")));
355 assert_eq!(values[1], Some(json!(42)));
356 assert_eq!(values[2], None);
357 }
358
359 #[test]
360 fn test_in_memory_store_mdelete() {
361 let store = InMemoryStore::new();
362
363 store.mset(&[
364 ("key1".to_string(), json!("value1")),
365 ("key2".to_string(), json!("value2")),
366 ]);
367
368 store.mdelete(&["key1".to_string()]);
369
370 let values = store.mget(&["key1".to_string(), "key2".to_string()]);
371 assert_eq!(values[0], None);
372 assert_eq!(values[1], Some(json!("value2")));
373 }
374
375 #[test]
376 fn test_in_memory_store_yield_keys() {
377 let store = InMemoryStore::new();
378
379 store.mset(&[
380 ("prefix_a".to_string(), json!("a")),
381 ("prefix_b".to_string(), json!("b")),
382 ("other".to_string(), json!("c")),
383 ]);
384
385 let all_keys = store.yield_keys(None);
386 assert_eq!(all_keys.len(), 3);
387
388 let mut prefix_keys = store.yield_keys(Some("prefix_"));
389 prefix_keys.sort();
390 assert_eq!(prefix_keys, vec!["prefix_a", "prefix_b"]);
391 }
392
393 #[test]
394 fn test_in_memory_byte_store() {
395 let store = InMemoryByteStore::new();
396
397 store.mset(&[
398 ("key1".to_string(), b"bytes1".to_vec()),
399 ("key2".to_string(), b"bytes2".to_vec()),
400 ]);
401
402 let values = store.mget(&["key1".to_string(), "key2".to_string()]);
403 assert_eq!(values[0], Some(b"bytes1".to_vec()));
404 assert_eq!(values[1], Some(b"bytes2".to_vec()));
405 }
406
407 #[tokio::test]
408 async fn test_in_memory_store_async() {
409 let store = InMemoryStore::new();
410
411 store
412 .amset(&[
413 ("key1".to_string(), json!("async_value1")),
414 ("key2".to_string(), json!("async_value2")),
415 ])
416 .await;
417
418 let values = store.amget(&["key1".to_string(), "key2".to_string()]).await;
419 assert_eq!(values[0], Some(json!("async_value1")));
420 assert_eq!(values[1], Some(json!("async_value2")));
421
422 store.amdelete(&["key1".to_string()]).await;
423
424 let keys = store.ayield_keys(None).await;
425 assert_eq!(keys, vec!["key2".to_string()]);
426 }
427
428 #[test]
429 fn test_invalid_key_exception() {
430 let exception = InvalidKeyException::new("bad/key", "keys cannot contain slashes");
431 assert_eq!(
432 exception.to_string(),
433 "Invalid key 'bad/key': keys cannot contain slashes"
434 );
435 }
436}