dynamo_runtime/storage/
key_value_store.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Interface to a traditional key-value store such as etcd.
5//! "key_value_store" spelt out because in AI land "KV" means something else.
6
7use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::time::Duration;
11use std::{collections::HashMap, path::PathBuf};
12use std::{env, fmt};
13
14use crate::CancellationToken;
15use crate::slug::Slug;
16use crate::transports::etcd as etcd_transport;
17use async_trait::async_trait;
18use futures::StreamExt;
19use serde::{Deserialize, Serialize};
20
21mod mem;
22pub use mem::MemoryStore;
23mod nats;
24pub use nats::NATSStore;
25mod etcd;
26pub use etcd::EtcdStore;
27mod file;
28pub use file::FileStore;
29
30const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
31
32/// A key that is safe to use directly in the KV store.
33///
34/// TODO: Need to re-think this. etcd uses slash separators, so we often use from_raw
35/// to avoid the slug. But other impl's, particularly file, need a real slug.
36#[derive(Debug, Clone, PartialEq)]
37pub struct Key(String);
38
39impl Key {
40    pub fn new(s: &str) -> Key {
41        Key(Slug::slugify(s).to_string())
42    }
43
44    /// Create a Key without changing the string, it is assumed already KV store safe.
45    pub fn from_raw(s: String) -> Key {
46        Key(s)
47    }
48}
49
50impl From<&str> for Key {
51    fn from(s: &str) -> Key {
52        Key::new(s)
53    }
54}
55
56impl fmt::Display for Key {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        write!(f, "{}", self.0)
59    }
60}
61
62impl AsRef<str> for Key {
63    fn as_ref(&self) -> &str {
64        &self.0
65    }
66}
67
68impl From<&Key> for String {
69    fn from(k: &Key) -> String {
70        k.0.clone()
71    }
72}
73
74#[derive(Debug, Clone, PartialEq)]
75pub struct KeyValue {
76    key: String,
77    value: bytes::Bytes,
78}
79
80impl KeyValue {
81    pub fn new(key: String, value: bytes::Bytes) -> Self {
82        KeyValue { key, value }
83    }
84
85    pub fn key(&self) -> String {
86        self.key.clone()
87    }
88
89    pub fn key_str(&self) -> &str {
90        &self.key
91    }
92
93    pub fn value(&self) -> &[u8] {
94        &self.value
95    }
96
97    pub fn value_str(&self) -> anyhow::Result<&str> {
98        std::str::from_utf8(self.value()).map_err(From::from)
99    }
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub enum WatchEvent {
104    Put(KeyValue),
105    Delete(Key),
106}
107
108#[async_trait]
109pub trait KeyValueStore: Send + Sync {
110    type Bucket: KeyValueBucket + Send + Sync + 'static;
111
112    async fn get_or_create_bucket(
113        &self,
114        bucket_name: &str,
115        // auto-delete items older than this
116        ttl: Option<Duration>,
117    ) -> Result<Self::Bucket, StoreError>;
118
119    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
120
121    fn connection_id(&self) -> u64;
122
123    fn shutdown(&self);
124}
125
126#[derive(Clone, Debug, Default)]
127pub enum KeyValueStoreSelect {
128    // Box it because it is significantly bigger than the other variants
129    Etcd(Box<etcd_transport::ClientOptions>),
130    File(PathBuf),
131    #[default]
132    Memory,
133    // Nats not listed because likely we want to remove that impl. It is not currently used and not well tested.
134}
135
136impl fmt::Display for KeyValueStoreSelect {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        match self {
139            KeyValueStoreSelect::Etcd(opts) => {
140                let urls = opts.etcd_url.join(",");
141                write!(f, "Etcd({urls})")
142            }
143            KeyValueStoreSelect::File(path) => write!(f, "File({})", path.display()),
144            KeyValueStoreSelect::Memory => write!(f, "Memory"),
145        }
146    }
147}
148
149impl FromStr for KeyValueStoreSelect {
150    type Err = anyhow::Error;
151
152    fn from_str(s: &str) -> anyhow::Result<KeyValueStoreSelect> {
153        match s {
154            "etcd" => Ok(Self::Etcd(Box::default())),
155            "file" => {
156                let root = env::var("DYN_FILE_KV")
157                    .map(PathBuf::from)
158                    .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
159                Ok(Self::File(root))
160            }
161            "mem" => Ok(Self::Memory),
162            x => anyhow::bail!("Unknown key-value store type '{x}'"),
163        }
164    }
165}
166
167impl TryFrom<String> for KeyValueStoreSelect {
168    type Error = anyhow::Error;
169
170    fn try_from(s: String) -> anyhow::Result<KeyValueStoreSelect> {
171        s.parse()
172    }
173}
174
175#[allow(clippy::large_enum_variant)]
176pub enum KeyValueStoreEnum {
177    Memory(MemoryStore),
178    Nats(NATSStore),
179    Etcd(EtcdStore),
180    File(FileStore),
181}
182
183impl KeyValueStoreEnum {
184    async fn get_or_create_bucket(
185        &self,
186        bucket_name: &str,
187        // auto-delete items older than this
188        ttl: Option<Duration>,
189    ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
190        use KeyValueStoreEnum::*;
191        Ok(match self {
192            Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
193            Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
194            Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
195            File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
196        })
197    }
198
199    async fn get_bucket(
200        &self,
201        bucket_name: &str,
202    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
203        use KeyValueStoreEnum::*;
204        let maybe_bucket: Option<Box<dyn KeyValueBucket>> = match self {
205            Memory(x) => x
206                .get_bucket(bucket_name)
207                .await?
208                .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
209            Nats(x) => x
210                .get_bucket(bucket_name)
211                .await?
212                .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
213            Etcd(x) => x
214                .get_bucket(bucket_name)
215                .await?
216                .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
217            File(x) => x
218                .get_bucket(bucket_name)
219                .await?
220                .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
221        };
222        Ok(maybe_bucket)
223    }
224
225    fn connection_id(&self) -> u64 {
226        use KeyValueStoreEnum::*;
227        match self {
228            Memory(x) => x.connection_id(),
229            Etcd(x) => x.connection_id(),
230            Nats(x) => x.connection_id(),
231            File(x) => x.connection_id(),
232        }
233    }
234
235    fn shutdown(&self) {
236        use KeyValueStoreEnum::*;
237        match self {
238            Memory(x) => x.shutdown(),
239            Etcd(x) => x.shutdown(),
240            Nats(x) => x.shutdown(),
241            File(x) => x.shutdown(),
242        }
243    }
244}
245
246#[derive(Clone)]
247pub struct KeyValueStoreManager(pub Arc<KeyValueStoreEnum>);
248
249impl Default for KeyValueStoreManager {
250    fn default() -> Self {
251        KeyValueStoreManager::memory()
252    }
253}
254
255impl KeyValueStoreManager {
256    /// In-memory KeyValueStoreManager for testing
257    pub fn memory() -> Self {
258        Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
259    }
260
261    pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
262        Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
263    }
264
265    pub fn file<P: Into<PathBuf>>(root: P) -> Self {
266        Self::new(KeyValueStoreEnum::File(FileStore::new(root)))
267    }
268
269    fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager {
270        KeyValueStoreManager(Arc::new(s))
271    }
272
273    pub async fn get_or_create_bucket(
274        &self,
275        bucket_name: &str,
276        // auto-delete items older than this
277        ttl: Option<Duration>,
278    ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
279        self.0.get_or_create_bucket(bucket_name, ttl).await
280    }
281
282    pub async fn get_bucket(
283        &self,
284        bucket_name: &str,
285    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
286        self.0.get_bucket(bucket_name).await
287    }
288
289    pub fn connection_id(&self) -> u64 {
290        self.0.connection_id()
291    }
292
293    pub async fn load<T: for<'a> Deserialize<'a>>(
294        &self,
295        bucket: &str,
296        key: &Key,
297    ) -> Result<Option<T>, StoreError> {
298        let Some(bucket) = self.0.get_bucket(bucket).await? else {
299            // No bucket means no cards
300            return Ok(None);
301        };
302        Ok(match bucket.get(key).await? {
303            Some(card_bytes) => {
304                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
305                Some(card)
306            }
307            None => None,
308        })
309    }
310
311    /// Returns a receiver that will receive all the existing keys, and
312    /// then block and receive new keys as they are created.
313    /// Starts a task that runs forever, watches the store.
314    pub fn watch(
315        self: Arc<Self>,
316        bucket_name: &str,
317        bucket_ttl: Option<Duration>,
318        cancel_token: CancellationToken,
319    ) -> (
320        tokio::task::JoinHandle<Result<(), StoreError>>,
321        tokio::sync::mpsc::Receiver<WatchEvent>,
322    ) {
323        let bucket_name = bucket_name.to_string();
324        let (tx, rx) = tokio::sync::mpsc::channel(128);
325        let watch_task = tokio::spawn(async move {
326            // Start listening for changes but don't poll this yet
327            let bucket = self
328                .0
329                .get_or_create_bucket(&bucket_name, bucket_ttl)
330                .await?;
331            let mut stream = bucket.watch().await?;
332
333            // Send all the existing keys
334            for (key, bytes) in bucket.entries().await? {
335                if let Err(err) = tx
336                    .send_timeout(
337                        WatchEvent::Put(KeyValue::new(key, bytes)),
338                        WATCH_SEND_TIMEOUT,
339                    )
340                    .await
341                {
342                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
343                }
344            }
345
346            // Now block waiting for new entries
347            loop {
348                let event = tokio::select! {
349                    _ = cancel_token.cancelled() => break,
350                    result = stream.next() => match result {
351                        Some(event) => event,
352                        None => break,
353                    }
354                };
355                if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
356                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
357                }
358            }
359
360            Ok::<(), StoreError>(())
361        });
362        (watch_task, rx)
363    }
364
365    pub async fn publish<T: Serialize + Versioned + Send + Sync>(
366        &self,
367        bucket_name: &str,
368        bucket_ttl: Option<Duration>,
369        key: &Key,
370        obj: &mut T,
371    ) -> anyhow::Result<StoreOutcome> {
372        let obj_json = serde_json::to_vec(obj)?;
373        let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
374
375        let outcome = bucket.insert(key, obj_json.into(), obj.revision()).await?;
376
377        match outcome {
378            StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
379                obj.set_revision(revision);
380            }
381        }
382        Ok(outcome)
383    }
384
385    /// Cleanup any temporary state.
386    /// TODO: Should this be async? Take &mut self?
387    pub fn shutdown(&self) {
388        self.0.shutdown()
389    }
390}
391
392/// An online storage for key-value config values.
393#[async_trait]
394pub trait KeyValueBucket: Send + Sync {
395    /// A bucket is a collection of key/value pairs.
396    /// Insert a value into a bucket, if it doesn't exist already
397    async fn insert(
398        &self,
399        key: &Key,
400        value: bytes::Bytes,
401        revision: u64,
402    ) -> Result<StoreOutcome, StoreError>;
403
404    /// Fetch an item from the key-value storage
405    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
406
407    /// Delete an item from the bucket
408    async fn delete(&self, key: &Key) -> Result<(), StoreError>;
409
410    /// A stream of items inserted into the bucket.
411    /// Every time the stream is polled it will either return a newly created entry, or block until
412    /// such time.
413    async fn watch(
414        &self,
415    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
416
417    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
418}
419
420#[derive(Debug, Copy, Clone, Eq, PartialEq)]
421pub enum StoreOutcome {
422    /// The operation succeeded and created a new entry with this revision.
423    /// Note that "create" also means update, because each new revision is a "create".
424    Created(u64),
425    /// The operation did not do anything, the value was already present, with this revision.
426    Exists(u64),
427}
428impl fmt::Display for StoreOutcome {
429    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430        match self {
431            StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
432            StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
433        }
434    }
435}
436
437#[derive(thiserror::Error, Debug)]
438pub enum StoreError {
439    #[error("Could not find bucket '{0}'")]
440    MissingBucket(String),
441
442    #[error("Could not find key '{0}'")]
443    MissingKey(String),
444
445    #[error("Internal storage error: '{0}'")]
446    ProviderError(String),
447
448    #[error("Internal NATS error: {0}")]
449    NATSError(String),
450
451    #[error("Internal etcd error: {0}")]
452    EtcdError(String),
453
454    #[error("Internal filesystem error: {0}")]
455    FilesystemError(String),
456
457    #[error("Key Value Error: {0} for bucket '{1}'")]
458    KeyValueError(String, String),
459
460    #[error("Error decoding bytes: {0}")]
461    JSONDecodeError(#[from] serde_json::error::Error),
462
463    #[error("Race condition, retry the call")]
464    Retry,
465}
466
467/// A trait allowing to get/set a revision on an object.
468/// NATS uses this to ensure atomic updates.
469pub trait Versioned {
470    fn revision(&self) -> u64;
471    fn set_revision(&mut self, r: u64);
472}
473
474#[cfg(test)]
475mod tests {
476    use std::sync::Arc;
477
478    use super::*;
479    use futures::{StreamExt, pin_mut};
480
481    const BUCKET_NAME: &str = "v1/mdc";
482
483    /// Convert the value returned by `watch()` into a broadcast stream that multiple
484    /// clients can listen to.
485    #[allow(dead_code)]
486    pub struct TappableStream {
487        tx: tokio::sync::broadcast::Sender<WatchEvent>,
488    }
489
490    #[allow(dead_code)]
491    impl TappableStream {
492        async fn new<T>(stream: T, max_size: usize) -> Self
493        where
494            T: futures::Stream<Item = WatchEvent> + Send + 'static,
495        {
496            let (tx, _) = tokio::sync::broadcast::channel(max_size);
497            let tx2 = tx.clone();
498            tokio::spawn(async move {
499                pin_mut!(stream);
500                while let Some(x) = stream.next().await {
501                    let _ = tx2.send(x);
502                }
503            });
504            TappableStream { tx }
505        }
506
507        fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
508            self.tx.subscribe()
509        }
510    }
511
512    fn init() {
513        crate::logging::init();
514    }
515
516    #[tokio::test]
517    async fn test_memory_storage() -> anyhow::Result<()> {
518        init();
519
520        let s = Arc::new(MemoryStore::new());
521        let s2 = Arc::clone(&s);
522
523        let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
524        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
525        assert_eq!(res, StoreOutcome::Created(0));
526
527        let mut expected = Vec::with_capacity(3);
528        for i in 1..=3 {
529            let item = WatchEvent::Put(KeyValue::new(
530                format!("test{i}"),
531                format!("value{i}").into(),
532            ));
533            expected.push(item);
534        }
535
536        let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
537        let ingress = tokio::spawn(async move {
538            let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
539            let mut stream = b2.watch().await?;
540
541            // Put in before starting the watch-all
542            let v = stream.next().await.unwrap();
543            assert_eq!(v, expected[0]);
544
545            got_first_tx.send(()).unwrap();
546
547            // Put in after
548            let v = stream.next().await.unwrap();
549            assert_eq!(v, expected[1]);
550
551            let v = stream.next().await.unwrap();
552            assert_eq!(v, expected[2]);
553
554            Ok::<_, StoreError>(())
555        });
556
557        // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
558        // fetched before test2 is inserted, otherwise they can come out in any order, and we
559        // wouldn't be testing the watch behavior.
560        got_first_rx.await?;
561
562        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
563        assert_eq!(res, StoreOutcome::Created(0));
564
565        // Repeat a key and revision. Ignored.
566        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
567        assert_eq!(res, StoreOutcome::Exists(0));
568
569        // Increment revision
570        let res = bucket.insert(&"test2".into(), "value2".into(), 1).await?;
571        assert_eq!(res, StoreOutcome::Created(1));
572
573        let res = bucket.insert(&"test3".into(), "value3".into(), 0).await?;
574        assert_eq!(res, StoreOutcome::Created(0));
575
576        // ingress exits once it has received all values
577        let _ = ingress.await?;
578
579        Ok(())
580    }
581
582    #[tokio::test]
583    async fn test_broadcast_stream() -> anyhow::Result<()> {
584        init();
585
586        let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
587        let bucket: &'static _ =
588            Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
589
590        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
591        assert_eq!(res, StoreOutcome::Created(0));
592
593        let stream = bucket.watch().await?;
594        let tap = TappableStream::new(stream, 10).await;
595
596        let mut rx1 = tap.subscribe();
597        let mut rx2 = tap.subscribe();
598
599        let item = WatchEvent::Put(KeyValue::new("test1".to_string(), "GK".into()));
600        let item_clone = item.clone();
601        let handle1 = tokio::spawn(async move {
602            let b = rx1.recv().await.unwrap();
603            assert_eq!(b, item_clone);
604        });
605        let handle2 = tokio::spawn(async move {
606            let b = rx2.recv().await.unwrap();
607            assert_eq!(b, item);
608        });
609
610        bucket.insert(&"test1".into(), "GK".into(), 1).await?;
611
612        let _ = futures::join!(handle1, handle2);
613        Ok(())
614    }
615}