agent_chain_core/
stores.rs

1//! **Store** implements the key-value stores and storage helpers.
2//!
3//! Module provides implementations of various key-value stores that conform
4//! to a simple key-value interface.
5//!
6//! The primary goal of these storages is to support implementation of caching.
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use tokio::sync::RwLock;
13
14use crate::error::Error;
15
16/// Abstract interface for a key-value store.
17///
18/// This is an interface that's meant to abstract away the details of
19/// different key-value stores. It provides a simple interface for
20/// getting, setting, and deleting key-value pairs.
21///
22/// The basic methods are `mget`, `mset`, and `mdelete` for getting,
23/// setting, and deleting multiple key-value pairs at once. The `yield_keys`
24/// method is used to iterate over keys that match a given prefix.
25///
26/// The async versions of these methods are also provided, which are
27/// meant to be used in async contexts. The async methods have the same names
28/// but return futures.
29///
30/// By default, the async methods are implemented using the synchronous methods
31/// wrapped in tokio's spawn_blocking. If the store can natively support async
32/// operations, it should override these methods.
33///
34/// By design the methods only accept batches of keys and values, and not
35/// single keys or values. This is done to force user code to work with batches
36/// which will usually be more efficient by saving on round trips to the store.
37#[async_trait]
38pub trait BaseStore<K, V>: Send + Sync
39where
40    K: Send + Sync,
41    V: Send + Sync + Clone,
42{
43    /// Get the values associated with the given keys.
44    ///
45    /// # Arguments
46    ///
47    /// * `keys` - A sequence of keys.
48    ///
49    /// # Returns
50    ///
51    /// A sequence of optional values associated with the keys.
52    /// If a key is not found, the corresponding value will be `None`.
53    fn mget(&self, keys: &[K]) -> Vec<Option<V>>;
54
55    /// Async get the values associated with the given keys.
56    ///
57    /// # Arguments
58    ///
59    /// * `keys` - A sequence of keys.
60    ///
61    /// # Returns
62    ///
63    /// A sequence of optional values associated with the keys.
64    /// If a key is not found, the corresponding value will be `None`.
65    async fn amget(&self, keys: &[K]) -> Vec<Option<V>>
66    where
67        K: 'static,
68        V: 'static,
69    {
70        self.mget(keys)
71    }
72
73    /// Set the values for the given keys.
74    ///
75    /// # Arguments
76    ///
77    /// * `key_value_pairs` - A sequence of key-value pairs.
78    fn mset(&self, key_value_pairs: &[(K, V)]);
79
80    /// Async set the values for the given keys.
81    ///
82    /// # Arguments
83    ///
84    /// * `key_value_pairs` - A sequence of key-value pairs.
85    async fn amset(&self, key_value_pairs: &[(K, V)])
86    where
87        K: 'static,
88        V: 'static,
89    {
90        self.mset(key_value_pairs)
91    }
92
93    /// Delete the given keys and their associated values.
94    ///
95    /// # Arguments
96    ///
97    /// * `keys` - A sequence of keys to delete.
98    fn mdelete(&self, keys: &[K]);
99
100    /// Async delete the given keys and their associated values.
101    ///
102    /// # Arguments
103    ///
104    /// * `keys` - A sequence of keys to delete.
105    async fn amdelete(&self, keys: &[K])
106    where
107        K: 'static,
108        V: 'static,
109    {
110        self.mdelete(keys)
111    }
112
113    /// Get an iterator over keys that match the given prefix.
114    ///
115    /// # Arguments
116    ///
117    /// * `prefix` - The prefix to match.
118    ///
119    /// # Returns
120    ///
121    /// A vector of keys that match the given prefix.
122    fn yield_keys(&self, prefix: Option<&str>) -> Vec<String>;
123
124    /// Async get an iterator over keys that match the given prefix.
125    ///
126    /// # Arguments
127    ///
128    /// * `prefix` - The prefix to match.
129    ///
130    /// # Returns
131    ///
132    /// A vector of keys that match the given prefix.
133    async fn ayield_keys(&self, prefix: Option<&str>) -> Vec<String>
134    where
135        K: 'static,
136        V: 'static,
137    {
138        self.yield_keys(prefix)
139    }
140}
141
142/// Type alias for a store with string keys and byte values.
143pub type ByteStore = dyn BaseStore<String, Vec<u8>>;
144
145/// In-memory implementation of the BaseStore using a dictionary.
146///
147/// This implementation uses an `Arc<RwLock<HashMap>>` internally to allow
148/// for concurrent access and mutation.
149pub struct InMemoryBaseStore<V>
150where
151    V: Clone + Send + Sync,
152{
153    store: Arc<RwLock<HashMap<String, V>>>,
154}
155
156impl<V> Default for InMemoryBaseStore<V>
157where
158    V: Clone + Send + Sync,
159{
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl<V> InMemoryBaseStore<V>
166where
167    V: Clone + Send + Sync,
168{
169    /// Initialize an empty store.
170    pub fn new() -> Self {
171        Self {
172            store: Arc::new(RwLock::new(HashMap::new())),
173        }
174    }
175
176    /// Get a reference to the internal store for direct access.
177    pub fn store(&self) -> &Arc<RwLock<HashMap<String, V>>> {
178        &self.store
179    }
180}
181
182#[async_trait]
183impl<V> BaseStore<String, V> for InMemoryBaseStore<V>
184where
185    V: Clone + Send + Sync + 'static,
186{
187    fn mget(&self, keys: &[String]) -> Vec<Option<V>> {
188        let store = self.store.blocking_read();
189        keys.iter().map(|key| store.get(key).cloned()).collect()
190    }
191
192    async fn amget(&self, keys: &[String]) -> Vec<Option<V>> {
193        let store = self.store.read().await;
194        keys.iter().map(|key| store.get(key).cloned()).collect()
195    }
196
197    fn mset(&self, key_value_pairs: &[(String, V)]) {
198        let mut store = self.store.blocking_write();
199        for (key, value) in key_value_pairs {
200            store.insert(key.clone(), value.clone());
201        }
202    }
203
204    async fn amset(&self, key_value_pairs: &[(String, V)]) {
205        let mut store = self.store.write().await;
206        for (key, value) in key_value_pairs {
207            store.insert(key.clone(), value.clone());
208        }
209    }
210
211    fn mdelete(&self, keys: &[String]) {
212        let mut store = self.store.blocking_write();
213        for key in keys {
214            store.remove(key);
215        }
216    }
217
218    async fn amdelete(&self, keys: &[String]) {
219        let mut store = self.store.write().await;
220        for key in keys {
221            store.remove(key);
222        }
223    }
224
225    fn yield_keys(&self, prefix: Option<&str>) -> Vec<String> {
226        let store = self.store.blocking_read();
227        match prefix {
228            None => store.keys().cloned().collect(),
229            Some(prefix) => store
230                .keys()
231                .filter(|key| key.starts_with(prefix))
232                .cloned()
233                .collect(),
234        }
235    }
236
237    async fn ayield_keys(&self, prefix: Option<&str>) -> Vec<String> {
238        let store = self.store.read().await;
239        match prefix {
240            None => store.keys().cloned().collect(),
241            Some(prefix) => store
242                .keys()
243                .filter(|key| key.starts_with(prefix))
244                .cloned()
245                .collect(),
246        }
247    }
248}
249
250/// In-memory store for any type of data.
251///
252/// This is a type alias for `InMemoryBaseStore<serde_json::Value>` which
253/// can store any JSON-serializable value.
254///
255/// # Examples
256///
257/// ```
258/// use agent_chain_core::stores::InMemoryStore;
259/// use agent_chain_core::stores::BaseStore;
260/// use serde_json::json;
261///
262/// let store = InMemoryStore::new();
263/// store.mset(&[
264///     ("key1".to_string(), json!("value1")),
265///     ("key2".to_string(), json!("value2")),
266/// ]);
267///
268/// let values = store.mget(&["key1".to_string(), "key2".to_string()]);
269/// assert_eq!(values[0], Some(json!("value1")));
270/// assert_eq!(values[1], Some(json!("value2")));
271///
272/// store.mdelete(&["key1".to_string()]);
273/// let keys: Vec<String> = store.yield_keys(None);
274/// assert_eq!(keys, vec!["key2".to_string()]);
275/// ```
276pub type InMemoryStore = InMemoryBaseStore<serde_json::Value>;
277
278/// In-memory store for bytes.
279///
280/// This is a type alias for `InMemoryBaseStore<Vec<u8>>` which stores
281/// byte vectors.
282///
283/// # Examples
284///
285/// ```
286/// use agent_chain_core::stores::InMemoryByteStore;
287/// use agent_chain_core::stores::BaseStore;
288///
289/// let store = InMemoryByteStore::new();
290/// store.mset(&[
291///     ("key1".to_string(), b"value1".to_vec()),
292///     ("key2".to_string(), b"value2".to_vec()),
293/// ]);
294///
295/// let values = store.mget(&["key1".to_string(), "key2".to_string()]);
296/// assert_eq!(values[0], Some(b"value1".to_vec()));
297/// assert_eq!(values[1], Some(b"value2".to_vec()));
298///
299/// store.mdelete(&["key1".to_string()]);
300/// let keys: Vec<String> = store.yield_keys(None);
301/// assert_eq!(keys, vec!["key2".to_string()]);
302/// ```
303pub type InMemoryByteStore = InMemoryBaseStore<Vec<u8>>;
304
305/// Error raised when a key is invalid; e.g., uses incorrect characters.
306#[derive(Debug, Clone)]
307pub struct InvalidKeyException {
308    /// The invalid key that caused the error.
309    pub key: String,
310    /// A message describing why the key is invalid.
311    pub message: String,
312}
313
314impl InvalidKeyException {
315    /// Create a new InvalidKeyException.
316    pub fn new(key: impl Into<String>, message: impl Into<String>) -> Self {
317        Self {
318            key: key.into(),
319            message: message.into(),
320        }
321    }
322}
323
324impl std::fmt::Display for InvalidKeyException {
325    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
326        write!(f, "Invalid key '{}': {}", self.key, self.message)
327    }
328}
329
330impl std::error::Error for InvalidKeyException {}
331
332impl From<InvalidKeyException> for Error {
333    fn from(e: InvalidKeyException) -> Self {
334        Error::Other(e.to_string())
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use serde_json::json;
342
343    #[test]
344    fn test_in_memory_store_mget_mset() {
345        let store = InMemoryStore::new();
346
347        store.mset(&[
348            ("key1".to_string(), json!("value1")),
349            ("key2".to_string(), json!(42)),
350        ]);
351
352        let values = store.mget(&["key1".to_string(), "key2".to_string(), "key3".to_string()]);
353        assert_eq!(values.len(), 3);
354        assert_eq!(values[0], Some(json!("value1")));
355        assert_eq!(values[1], Some(json!(42)));
356        assert_eq!(values[2], None);
357    }
358
359    #[test]
360    fn test_in_memory_store_mdelete() {
361        let store = InMemoryStore::new();
362
363        store.mset(&[
364            ("key1".to_string(), json!("value1")),
365            ("key2".to_string(), json!("value2")),
366        ]);
367
368        store.mdelete(&["key1".to_string()]);
369
370        let values = store.mget(&["key1".to_string(), "key2".to_string()]);
371        assert_eq!(values[0], None);
372        assert_eq!(values[1], Some(json!("value2")));
373    }
374
375    #[test]
376    fn test_in_memory_store_yield_keys() {
377        let store = InMemoryStore::new();
378
379        store.mset(&[
380            ("prefix_a".to_string(), json!("a")),
381            ("prefix_b".to_string(), json!("b")),
382            ("other".to_string(), json!("c")),
383        ]);
384
385        let all_keys = store.yield_keys(None);
386        assert_eq!(all_keys.len(), 3);
387
388        let mut prefix_keys = store.yield_keys(Some("prefix_"));
389        prefix_keys.sort();
390        assert_eq!(prefix_keys, vec!["prefix_a", "prefix_b"]);
391    }
392
393    #[test]
394    fn test_in_memory_byte_store() {
395        let store = InMemoryByteStore::new();
396
397        store.mset(&[
398            ("key1".to_string(), b"bytes1".to_vec()),
399            ("key2".to_string(), b"bytes2".to_vec()),
400        ]);
401
402        let values = store.mget(&["key1".to_string(), "key2".to_string()]);
403        assert_eq!(values[0], Some(b"bytes1".to_vec()));
404        assert_eq!(values[1], Some(b"bytes2".to_vec()));
405    }
406
407    #[tokio::test]
408    async fn test_in_memory_store_async() {
409        let store = InMemoryStore::new();
410
411        store
412            .amset(&[
413                ("key1".to_string(), json!("async_value1")),
414                ("key2".to_string(), json!("async_value2")),
415            ])
416            .await;
417
418        let values = store.amget(&["key1".to_string(), "key2".to_string()]).await;
419        assert_eq!(values[0], Some(json!("async_value1")));
420        assert_eq!(values[1], Some(json!("async_value2")));
421
422        store.amdelete(&["key1".to_string()]).await;
423
424        let keys = store.ayield_keys(None).await;
425        assert_eq!(keys, vec!["key2".to_string()]);
426    }
427
428    #[test]
429    fn test_invalid_key_exception() {
430        let exception = InvalidKeyException::new("bad/key", "keys cannot contain slashes");
431        assert_eq!(
432            exception.to_string(),
433            "Invalid key 'bad/key': keys cannot contain slashes"
434        );
435    }
436}