dynamo_runtime/storage/
key_value_store.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Interface to a traditional key-value store such as etcd.
5//! "key_value_store" spelt out because in AI land "KV" means something else.
6
7use std::collections::HashMap;
8use std::fmt;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::CancellationToken;
14use crate::slug::Slug;
15use async_trait::async_trait;
16use futures::StreamExt;
17use serde::{Deserialize, Serialize};
18
19mod mem;
20pub use mem::MemoryStore;
21mod nats;
22pub use nats::NATSStore;
23mod etcd;
24pub use etcd::EtcdStore;
25
26/// A key that is safe to use directly in the KV store.
27#[derive(Debug, Clone, PartialEq)]
28pub struct Key(String);
29
30impl Key {
31    pub fn new(s: &str) -> Key {
32        Key(Slug::slugify(s).to_string())
33    }
34
35    /// Create a Key without changing the string, it is assumed already KV store safe.
36    pub fn from_raw(s: String) -> Key {
37        Key(s)
38    }
39}
40
41impl From<&str> for Key {
42    fn from(s: &str) -> Key {
43        Key::new(s)
44    }
45}
46
47impl fmt::Display for Key {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        write!(f, "{}", self.0)
50    }
51}
52
53impl AsRef<str> for Key {
54    fn as_ref(&self) -> &str {
55        &self.0
56    }
57}
58
59impl From<&Key> for String {
60    fn from(k: &Key) -> String {
61        k.0.clone()
62    }
63}
64
65#[async_trait]
66pub trait KeyValueStore: Send + Sync {
67    async fn get_or_create_bucket(
68        &self,
69        bucket_name: &str,
70        // auto-delete items older than this
71        ttl: Option<Duration>,
72    ) -> Result<Box<dyn KeyValueBucket>, StoreError>;
73
74    async fn get_bucket(
75        &self,
76        bucket_name: &str,
77    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError>;
78
79    fn connection_id(&self) -> u64;
80}
81
82pub struct KeyValueStoreManager(Box<dyn KeyValueStore>);
83
84impl KeyValueStoreManager {
85    pub fn new(s: Box<dyn KeyValueStore>) -> KeyValueStoreManager {
86        KeyValueStoreManager(s)
87    }
88
89    pub async fn load<T: for<'a> Deserialize<'a>>(
90        &self,
91        bucket: &str,
92        key: &Key,
93    ) -> Result<Option<T>, StoreError> {
94        let Some(bucket) = self.0.get_bucket(bucket).await? else {
95            // No bucket means no cards
96            return Ok(None);
97        };
98        match bucket.get(key).await {
99            Ok(Some(card_bytes)) => {
100                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
101                Ok(Some(card))
102            }
103            Ok(None) => Ok(None),
104            Err(err) => {
105                // TODO look at what errors NATS can give us and make more specific wrappers
106                Err(StoreError::NATSError(err.to_string()))
107            }
108        }
109    }
110
111    /// Returns a receiver that will receive all the existing keys, and
112    /// then block and receive new keys as they are created.
113    /// Starts a task that runs forever, watches the store.
114    pub fn watch<T: for<'a> Deserialize<'a> + Send + 'static>(
115        self: Arc<Self>,
116        bucket_name: &str,
117        bucket_ttl: Option<Duration>,
118    ) -> (
119        tokio::task::JoinHandle<Result<(), StoreError>>,
120        tokio::sync::mpsc::UnboundedReceiver<T>,
121    ) {
122        let bucket_name = bucket_name.to_string();
123        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
124        let watch_task = tokio::spawn(async move {
125            // Start listening for changes but don't poll this yet
126            let bucket = self
127                .0
128                .get_or_create_bucket(&bucket_name, bucket_ttl)
129                .await?;
130            let mut stream = bucket.watch().await?;
131
132            // Send all the existing keys
133            for (_, card_bytes) in bucket.entries().await? {
134                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
135                let _ = tx.send(card);
136            }
137
138            // Now block waiting for new entries
139            while let Some(card_bytes) = stream.next().await {
140                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
141                let _ = tx.send(card);
142            }
143
144            Ok::<(), StoreError>(())
145        });
146        (watch_task, rx)
147    }
148
149    pub async fn publish<T: Serialize + Versioned + Send + Sync>(
150        &self,
151        bucket_name: &str,
152        bucket_ttl: Option<Duration>,
153        key: &Key,
154        obj: &mut T,
155    ) -> anyhow::Result<StoreOutcome> {
156        let obj_json = serde_json::to_string(obj)?;
157        let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
158
159        let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
160
161        match outcome {
162            StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
163                obj.set_revision(revision);
164            }
165        }
166        Ok(outcome)
167    }
168}
169
170/// An online storage for key-value config values.
171/// Usually backed by `nats-server`.
172#[async_trait]
173pub trait KeyValueBucket: Send {
174    /// A bucket is a collection of key/value pairs.
175    /// Insert a value into a bucket, if it doesn't exist already
176    async fn insert(
177        &self,
178        key: &Key,
179        value: &str,
180        revision: u64,
181    ) -> Result<StoreOutcome, StoreError>;
182
183    /// Fetch an item from the key-value storage
184    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
185
186    /// Delete an item from the bucket
187    async fn delete(&self, key: &Key) -> Result<(), StoreError>;
188
189    /// A stream of items inserted into the bucket.
190    /// Every time the stream is polled it will either return a newly created entry, or block until
191    /// such time.
192    async fn watch(
193        &self,
194    ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>;
195
196    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
197}
198
199#[derive(Debug, Copy, Clone, Eq, PartialEq)]
200pub enum StoreOutcome {
201    /// The operation succeeded and created a new entry with this revision.
202    /// Note that "create" also means update, because each new revision is a "create".
203    Created(u64),
204    /// The operation did not do anything, the value was already present, with this revision.
205    Exists(u64),
206}
207impl fmt::Display for StoreOutcome {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        match self {
210            StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
211            StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
212        }
213    }
214}
215
216#[derive(thiserror::Error, Debug)]
217pub enum StoreError {
218    #[error("Could not find bucket '{0}'")]
219    MissingBucket(String),
220
221    #[error("Could not find key '{0}'")]
222    MissingKey(String),
223
224    #[error("Internal storage error: '{0}'")]
225    ProviderError(String),
226
227    #[error("Internal NATS error: {0}")]
228    NATSError(String),
229
230    #[error("Internal etcd error: {0}")]
231    EtcdError(String),
232
233    #[error("Key Value Error: {0} for bucket '{1}")]
234    KeyValueError(String, String),
235
236    #[error("Error decoding bytes: {0}")]
237    JSONDecodeError(#[from] serde_json::error::Error),
238
239    #[error("Race condition, retry the call")]
240    Retry,
241}
242
243/// A trait allowing to get/set a revision on an object.
244/// NATS uses this to ensure atomic updates.
245pub trait Versioned {
246    fn revision(&self) -> u64;
247    fn set_revision(&mut self, r: u64);
248}
249
250#[cfg(test)]
251mod tests {
252    use std::sync::Arc;
253
254    use super::*;
255    use futures::{StreamExt, pin_mut};
256
257    const BUCKET_NAME: &str = "v1/mdc";
258
259    /// Convert the value returned by `watch()` into a broadcast stream that multiple
260    /// clients can listen to.
261    #[allow(dead_code)]
262    pub struct TappableStream {
263        tx: tokio::sync::broadcast::Sender<bytes::Bytes>,
264    }
265
266    #[allow(dead_code)]
267    impl TappableStream {
268        async fn new<T>(stream: T, max_size: usize) -> Self
269        where
270            T: futures::Stream<Item = bytes::Bytes> + Send + 'static,
271        {
272            let (tx, _) = tokio::sync::broadcast::channel(max_size);
273            let tx2 = tx.clone();
274            tokio::spawn(async move {
275                pin_mut!(stream);
276                while let Some(x) = stream.next().await {
277                    let _ = tx2.send(x);
278                }
279            });
280            TappableStream { tx }
281        }
282
283        fn subscribe(&self) -> tokio::sync::broadcast::Receiver<bytes::Bytes> {
284            self.tx.subscribe()
285        }
286    }
287
288    fn init() {
289        crate::logging::init();
290    }
291
292    #[tokio::test]
293    async fn test_memory_storage() -> anyhow::Result<()> {
294        init();
295
296        let s = Arc::new(MemoryStore::new());
297        let s2 = Arc::clone(&s);
298
299        let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
300        let res = bucket.insert(&"test1".into(), "value1", 0).await?;
301        assert_eq!(res, StoreOutcome::Created(0));
302
303        let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
304        let ingress = tokio::spawn(async move {
305            let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
306            let mut stream = b2.watch().await?;
307
308            // Put in before starting the watch-all
309            let v = stream.next().await.unwrap();
310            assert_eq!(v, "value1".as_bytes());
311
312            got_first_tx.send(()).unwrap();
313
314            // Put in after
315            let v = stream.next().await.unwrap();
316            assert_eq!(v, "value2".as_bytes());
317            let v = stream.next().await.unwrap();
318            assert_eq!(v, "value3".as_bytes());
319
320            Ok::<_, StoreError>(())
321        });
322
323        // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
324        // fetched before test2 is inserted, otherwise they can come out in any order, and we
325        // wouldn't be testing the watch behavior.
326        got_first_rx.await?;
327
328        let res = bucket.insert(&"test2".into(), "value2", 0).await?;
329        assert_eq!(res, StoreOutcome::Created(0));
330
331        // Repeat a key and revision. Ignored.
332        let res = bucket.insert(&"test2".into(), "value2", 0).await?;
333        assert_eq!(res, StoreOutcome::Exists(0));
334
335        // Increment revision
336        let res = bucket.insert(&"test2".into(), "value2", 1).await?;
337        assert_eq!(res, StoreOutcome::Created(1));
338
339        let res = bucket.insert(&"test3".into(), "value3", 0).await?;
340        assert_eq!(res, StoreOutcome::Created(0));
341
342        // ingress exits once it has received all values
343        let _ = ingress.await?;
344
345        Ok(())
346    }
347
348    #[tokio::test]
349    async fn test_broadcast_stream() -> anyhow::Result<()> {
350        init();
351
352        let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
353        let bucket: &'static _ =
354            Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
355
356        let res = bucket.insert(&"test1".into(), "value1", 0).await?;
357        assert_eq!(res, StoreOutcome::Created(0));
358
359        let stream = bucket.watch().await?;
360        let tap = TappableStream::new(stream, 10).await;
361
362        let mut rx1 = tap.subscribe();
363        let mut rx2 = tap.subscribe();
364
365        let handle1 = tokio::spawn(async move {
366            let b = rx1.recv().await.unwrap();
367            assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
368        });
369        let handle2 = tokio::spawn(async move {
370            let b = rx2.recv().await.unwrap();
371            assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
372        });
373
374        bucket.insert(&"test1".into(), "GK", 1).await?;
375
376        let _ = futures::join!(handle1, handle2);
377        Ok(())
378    }
379}