dynamo_runtime/storage/
key_value_store.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Interface to a traditional key-value store such as etcd.
5//! "key_value_store" spelt out because in AI land "KV" means something else.
6
7use std::collections::HashMap;
8use std::fmt;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::CancellationToken;
14use crate::slug::Slug;
15use async_trait::async_trait;
16use futures::StreamExt;
17use serde::{Deserialize, Serialize};
18
19mod mem;
20pub use mem::MemoryStore;
21mod nats;
22pub use nats::NATSStore;
23mod etcd;
24pub use etcd::EtcdStore;
25
26const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
27
28/// A key that is safe to use directly in the KV store.
29#[derive(Debug, Clone, PartialEq)]
30pub struct Key(String);
31
32impl Key {
33    pub fn new(s: &str) -> Key {
34        Key(Slug::slugify(s).to_string())
35    }
36
37    /// Create a Key without changing the string, it is assumed already KV store safe.
38    pub fn from_raw(s: String) -> Key {
39        Key(s)
40    }
41}
42
43impl From<&str> for Key {
44    fn from(s: &str) -> Key {
45        Key::new(s)
46    }
47}
48
49impl fmt::Display for Key {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        write!(f, "{}", self.0)
52    }
53}
54
55impl AsRef<str> for Key {
56    fn as_ref(&self) -> &str {
57        &self.0
58    }
59}
60
61impl From<&Key> for String {
62    fn from(k: &Key) -> String {
63        k.0.clone()
64    }
65}
66
67#[derive(Debug, Clone, PartialEq)]
68pub struct KeyValue {
69    key: String,
70    value: bytes::Bytes,
71}
72
73impl KeyValue {
74    pub fn new(key: String, value: bytes::Bytes) -> Self {
75        KeyValue { key, value }
76    }
77
78    pub fn key(&self) -> String {
79        self.key.clone()
80    }
81
82    pub fn key_str(&self) -> &str {
83        &self.key
84    }
85
86    pub fn value(&self) -> &[u8] {
87        &self.value
88    }
89
90    pub fn value_str(&self) -> anyhow::Result<&str> {
91        std::str::from_utf8(self.value()).map_err(From::from)
92    }
93}
94
95#[derive(Debug, Clone, PartialEq)]
96pub enum WatchEvent {
97    Put(KeyValue),
98    Delete(KeyValue),
99}
100
101#[async_trait]
102pub trait KeyValueStore: Send + Sync {
103    type Bucket: KeyValueBucket + Send + Sync + 'static;
104
105    async fn get_or_create_bucket(
106        &self,
107        bucket_name: &str,
108        // auto-delete items older than this
109        ttl: Option<Duration>,
110    ) -> Result<Self::Bucket, StoreError>;
111
112    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
113
114    fn connection_id(&self) -> u64;
115}
116
117#[allow(clippy::large_enum_variant)]
118pub enum KeyValueStoreEnum {
119    Memory(MemoryStore),
120    Nats(NATSStore),
121    Etcd(EtcdStore),
122}
123
124impl KeyValueStoreEnum {
125    async fn get_or_create_bucket(
126        &self,
127        bucket_name: &str,
128        // auto-delete items older than this
129        ttl: Option<Duration>,
130    ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
131        use KeyValueStoreEnum::*;
132        Ok(match self {
133            Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
134            Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
135            Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
136        })
137    }
138
139    async fn get_bucket(
140        &self,
141        bucket_name: &str,
142    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
143        use KeyValueStoreEnum::*;
144        let maybe_bucket: Option<Box<dyn KeyValueBucket>> = match self {
145            Memory(x) => x
146                .get_bucket(bucket_name)
147                .await?
148                .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
149            Nats(x) => x
150                .get_bucket(bucket_name)
151                .await?
152                .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
153            Etcd(x) => x
154                .get_bucket(bucket_name)
155                .await?
156                .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
157        };
158        Ok(maybe_bucket)
159    }
160
161    fn connection_id(&self) -> u64 {
162        use KeyValueStoreEnum::*;
163        match self {
164            Memory(x) => x.connection_id(),
165            Etcd(x) => x.connection_id(),
166            Nats(x) => x.connection_id(),
167        }
168    }
169}
170
171#[derive(Clone)]
172pub struct KeyValueStoreManager(Arc<KeyValueStoreEnum>);
173
174impl Default for KeyValueStoreManager {
175    fn default() -> Self {
176        KeyValueStoreManager::memory()
177    }
178}
179
180impl KeyValueStoreManager {
181    /// In-memory KeyValueStoreManager for testing
182    pub fn memory() -> Self {
183        Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
184    }
185
186    pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
187        Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
188    }
189
190    fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager {
191        KeyValueStoreManager(Arc::new(s))
192    }
193
194    pub async fn get_or_create_bucket(
195        &self,
196        bucket_name: &str,
197        // auto-delete items older than this
198        ttl: Option<Duration>,
199    ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
200        self.0.get_or_create_bucket(bucket_name, ttl).await
201    }
202
203    pub async fn get_bucket(
204        &self,
205        bucket_name: &str,
206    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
207        self.0.get_bucket(bucket_name).await
208    }
209
210    pub fn connection_id(&self) -> u64 {
211        self.0.connection_id()
212    }
213
214    pub async fn load<T: for<'a> Deserialize<'a>>(
215        &self,
216        bucket: &str,
217        key: &Key,
218    ) -> Result<Option<T>, StoreError> {
219        let Some(bucket) = self.0.get_bucket(bucket).await? else {
220            // No bucket means no cards
221            return Ok(None);
222        };
223        Ok(match bucket.get(key).await? {
224            Some(card_bytes) => {
225                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
226                Some(card)
227            }
228            None => None,
229        })
230    }
231
232    /// Returns a receiver that will receive all the existing keys, and
233    /// then block and receive new keys as they are created.
234    /// Starts a task that runs forever, watches the store.
235    pub fn watch(
236        self: Arc<Self>,
237        bucket_name: &str,
238        bucket_ttl: Option<Duration>,
239        cancel_token: CancellationToken,
240    ) -> (
241        tokio::task::JoinHandle<Result<(), StoreError>>,
242        tokio::sync::mpsc::Receiver<WatchEvent>,
243    ) {
244        let bucket_name = bucket_name.to_string();
245        let (tx, rx) = tokio::sync::mpsc::channel(128);
246        let watch_task = tokio::spawn(async move {
247            // Start listening for changes but don't poll this yet
248            let bucket = self
249                .0
250                .get_or_create_bucket(&bucket_name, bucket_ttl)
251                .await?;
252            let mut stream = bucket.watch().await?;
253
254            // Send all the existing keys
255            for (key, bytes) in bucket.entries().await? {
256                if let Err(err) = tx
257                    .send_timeout(
258                        WatchEvent::Put(KeyValue::new(key, bytes)),
259                        WATCH_SEND_TIMEOUT,
260                    )
261                    .await
262                {
263                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
264                }
265            }
266
267            // Now block waiting for new entries
268            loop {
269                let event = tokio::select! {
270                    _ = cancel_token.cancelled() => break,
271                    result = stream.next() => match result {
272                        Some(event) => event,
273                        None => break,
274                    }
275                };
276                if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
277                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
278                }
279            }
280
281            Ok::<(), StoreError>(())
282        });
283        (watch_task, rx)
284    }
285
286    pub async fn publish<T: Serialize + Versioned + Send + Sync>(
287        &self,
288        bucket_name: &str,
289        bucket_ttl: Option<Duration>,
290        key: &Key,
291        obj: &mut T,
292    ) -> anyhow::Result<StoreOutcome> {
293        let obj_json = serde_json::to_string(obj)?;
294        let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
295
296        let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
297
298        match outcome {
299            StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
300                obj.set_revision(revision);
301            }
302        }
303        Ok(outcome)
304    }
305}
306
307/// An online storage for key-value config values.
308/// Usually backed by `nats-server`.
309#[async_trait]
310pub trait KeyValueBucket: Send + Sync {
311    /// A bucket is a collection of key/value pairs.
312    /// Insert a value into a bucket, if it doesn't exist already
313    async fn insert(
314        &self,
315        key: &Key,
316        value: &str,
317        revision: u64,
318    ) -> Result<StoreOutcome, StoreError>;
319
320    /// Fetch an item from the key-value storage
321    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
322
323    /// Delete an item from the bucket
324    async fn delete(&self, key: &Key) -> Result<(), StoreError>;
325
326    /// A stream of items inserted into the bucket.
327    /// Every time the stream is polled it will either return a newly created entry, or block until
328    /// such time.
329    async fn watch(
330        &self,
331    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
332
333    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
334}
335
336#[derive(Debug, Copy, Clone, Eq, PartialEq)]
337pub enum StoreOutcome {
338    /// The operation succeeded and created a new entry with this revision.
339    /// Note that "create" also means update, because each new revision is a "create".
340    Created(u64),
341    /// The operation did not do anything, the value was already present, with this revision.
342    Exists(u64),
343}
344impl fmt::Display for StoreOutcome {
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        match self {
347            StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
348            StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
349        }
350    }
351}
352
353#[derive(thiserror::Error, Debug)]
354pub enum StoreError {
355    #[error("Could not find bucket '{0}'")]
356    MissingBucket(String),
357
358    #[error("Could not find key '{0}'")]
359    MissingKey(String),
360
361    #[error("Internal storage error: '{0}'")]
362    ProviderError(String),
363
364    #[error("Internal NATS error: {0}")]
365    NATSError(String),
366
367    #[error("Internal etcd error: {0}")]
368    EtcdError(String),
369
370    #[error("Key Value Error: {0} for bucket '{1}'")]
371    KeyValueError(String, String),
372
373    #[error("Error decoding bytes: {0}")]
374    JSONDecodeError(#[from] serde_json::error::Error),
375
376    #[error("Race condition, retry the call")]
377    Retry,
378}
379
380/// A trait allowing to get/set a revision on an object.
381/// NATS uses this to ensure atomic updates.
382pub trait Versioned {
383    fn revision(&self) -> u64;
384    fn set_revision(&mut self, r: u64);
385}
386
387#[cfg(test)]
388mod tests {
389    use std::sync::Arc;
390
391    use super::*;
392    use futures::{StreamExt, pin_mut};
393
394    const BUCKET_NAME: &str = "v1/mdc";
395
396    /// Convert the value returned by `watch()` into a broadcast stream that multiple
397    /// clients can listen to.
398    #[allow(dead_code)]
399    pub struct TappableStream {
400        tx: tokio::sync::broadcast::Sender<WatchEvent>,
401    }
402
403    #[allow(dead_code)]
404    impl TappableStream {
405        async fn new<T>(stream: T, max_size: usize) -> Self
406        where
407            T: futures::Stream<Item = WatchEvent> + Send + 'static,
408        {
409            let (tx, _) = tokio::sync::broadcast::channel(max_size);
410            let tx2 = tx.clone();
411            tokio::spawn(async move {
412                pin_mut!(stream);
413                while let Some(x) = stream.next().await {
414                    let _ = tx2.send(x);
415                }
416            });
417            TappableStream { tx }
418        }
419
420        fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
421            self.tx.subscribe()
422        }
423    }
424
425    fn init() {
426        crate::logging::init();
427    }
428
429    #[tokio::test]
430    async fn test_memory_storage() -> anyhow::Result<()> {
431        init();
432
433        let s = Arc::new(MemoryStore::new());
434        let s2 = Arc::clone(&s);
435
436        let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
437        let res = bucket.insert(&"test1".into(), "value1", 0).await?;
438        assert_eq!(res, StoreOutcome::Created(0));
439
440        let mut expected = Vec::with_capacity(3);
441        for i in 1..=3 {
442            let item = WatchEvent::Put(KeyValue::new(
443                format!("test{i}"),
444                bytes::Bytes::from(format!("value{i}").into_bytes()),
445            ));
446            expected.push(item);
447        }
448
449        let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
450        let ingress = tokio::spawn(async move {
451            let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
452            let mut stream = b2.watch().await?;
453
454            // Put in before starting the watch-all
455            let v = stream.next().await.unwrap();
456            assert_eq!(v, expected[0]);
457
458            got_first_tx.send(()).unwrap();
459
460            // Put in after
461            let v = stream.next().await.unwrap();
462            assert_eq!(v, expected[1]);
463
464            let v = stream.next().await.unwrap();
465            assert_eq!(v, expected[2]);
466
467            Ok::<_, StoreError>(())
468        });
469
470        // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
471        // fetched before test2 is inserted, otherwise they can come out in any order, and we
472        // wouldn't be testing the watch behavior.
473        got_first_rx.await?;
474
475        let res = bucket.insert(&"test2".into(), "value2", 0).await?;
476        assert_eq!(res, StoreOutcome::Created(0));
477
478        // Repeat a key and revision. Ignored.
479        let res = bucket.insert(&"test2".into(), "value2", 0).await?;
480        assert_eq!(res, StoreOutcome::Exists(0));
481
482        // Increment revision
483        let res = bucket.insert(&"test2".into(), "value2", 1).await?;
484        assert_eq!(res, StoreOutcome::Created(1));
485
486        let res = bucket.insert(&"test3".into(), "value3", 0).await?;
487        assert_eq!(res, StoreOutcome::Created(0));
488
489        // ingress exits once it has received all values
490        let _ = ingress.await?;
491
492        Ok(())
493    }
494
495    #[tokio::test]
496    async fn test_broadcast_stream() -> anyhow::Result<()> {
497        init();
498
499        let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
500        let bucket: &'static _ =
501            Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
502
503        let res = bucket.insert(&"test1".into(), "value1", 0).await?;
504        assert_eq!(res, StoreOutcome::Created(0));
505
506        let stream = bucket.watch().await?;
507        let tap = TappableStream::new(stream, 10).await;
508
509        let mut rx1 = tap.subscribe();
510        let mut rx2 = tap.subscribe();
511
512        let item = WatchEvent::Put(KeyValue::new(
513            "test1".to_string(),
514            bytes::Bytes::from(b"GK".as_slice()),
515        ));
516        let item_clone = item.clone();
517        let handle1 = tokio::spawn(async move {
518            let b = rx1.recv().await.unwrap();
519            assert_eq!(b, item_clone);
520        });
521        let handle2 = tokio::spawn(async move {
522            let b = rx2.recv().await.unwrap();
523            assert_eq!(b, item);
524        });
525
526        bucket.insert(&"test1".into(), "GK", 1).await?;
527
528        let _ = futures::join!(handle1, handle2);
529        Ok(())
530    }
531}