Skip to main content

dynamo_runtime/storage/
kv.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Interface to a traditional key-value store such as etcd.
5//! "key_value_store" spelt out because in AI land "KV" means something else.
6
7use std::borrow::Cow;
8use std::pin::Pin;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::time::Duration;
12use std::{collections::HashMap, path::PathBuf};
13use std::{env, fmt};
14
15use crate::CancellationToken;
16use crate::transports::etcd as etcd_transport;
17use async_trait::async_trait;
18use futures::StreamExt;
19use percent_encoding::{NON_ALPHANUMERIC, percent_decode_str, percent_encode};
20use serde::{Deserialize, Serialize};
21
22mod mem;
23pub use mem::MemoryStore;
24mod nats;
25pub use nats::NATSStore;
26mod etcd;
27pub use etcd::EtcdStore;
28mod file;
29pub use file::FileStore;
30
31const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(1000);
32
33/// String we use as the Key in a key-value storage operation. Simple String wrapper
34/// that can encode / decode a string.
35#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct Key(String);
37
38impl Key {
39    pub fn new(s: String) -> Key {
40        Key(s)
41    }
42
43    /// Takes a URL-safe percent-encoded string and creates a Key from it by decoding first.
44    /// dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f becomes dynamo/backend/generate/17216e63492ef21f
45    pub fn from_url_safe(s: &str) -> Key {
46        Key(percent_decode_str(s).decode_utf8_lossy().to_string())
47    }
48
49    /// A URL-safe percent-encoded representation of this key.
50    /// e.g.  dynamo/backend/generate/17216e63492ef21f becomes dynamo%2Fbackend%2Fgenerate%2F17216e63492ef21f
51    pub fn url_safe(&self) -> Cow<'_, str> {
52        percent_encode(self.0.as_bytes(), NON_ALPHANUMERIC).into()
53    }
54}
55
56impl From<&str> for Key {
57    fn from(s: &str) -> Key {
58        Key::new(s.to_string())
59    }
60}
61
62impl fmt::Display for Key {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        write!(f, "{}", self.0)
65    }
66}
67
68impl AsRef<str> for Key {
69    fn as_ref(&self) -> &str {
70        &self.0
71    }
72}
73
74impl From<&Key> for String {
75    fn from(k: &Key) -> String {
76        k.0.clone()
77    }
78}
79
80#[derive(Debug, Clone, PartialEq)]
81pub struct KeyValue {
82    key: Key,
83    value: bytes::Bytes,
84}
85
86impl KeyValue {
87    pub fn new(key: Key, value: bytes::Bytes) -> Self {
88        KeyValue { key, value }
89    }
90
91    pub fn key(&self) -> String {
92        self.key.clone().to_string()
93    }
94
95    pub fn key_str(&self) -> &str {
96        self.key.as_ref()
97    }
98
99    pub fn value(&self) -> &[u8] {
100        &self.value
101    }
102
103    pub fn value_str(&self) -> anyhow::Result<&str> {
104        std::str::from_utf8(self.value()).map_err(From::from)
105    }
106}
107
108#[derive(Debug, Clone, PartialEq)]
109pub enum WatchEvent {
110    Put(KeyValue),
111    Delete(Key),
112}
113
114#[async_trait]
115pub trait Store: Send + Sync {
116    type Bucket: Bucket + Send + Sync + 'static;
117
118    async fn get_or_create_bucket(
119        &self,
120        bucket_name: &str,
121        // auto-delete items older than this
122        ttl: Option<Duration>,
123    ) -> Result<Self::Bucket, StoreError>;
124
125    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
126
127    fn connection_id(&self) -> u64;
128
129    fn shutdown(&self);
130}
131
132#[derive(Clone, Debug, Default)]
133pub enum Selector {
134    // Box it because it is significantly bigger than the other variants
135    Etcd(Box<etcd_transport::ClientOptions>),
136    File(PathBuf),
137    #[default]
138    Memory,
139    // Nats not listed because likely we want to remove that impl. It is not currently used and not well tested.
140}
141
142impl fmt::Display for Selector {
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        match self {
145            Selector::Etcd(opts) => {
146                let urls = opts.etcd_url.join(",");
147                write!(f, "Etcd({urls})")
148            }
149            Selector::File(path) => write!(f, "File({})", path.display()),
150            Selector::Memory => write!(f, "Memory"),
151        }
152    }
153}
154
155impl FromStr for Selector {
156    type Err = anyhow::Error;
157
158    fn from_str(s: &str) -> anyhow::Result<Selector> {
159        match s {
160            "etcd" => Ok(Self::Etcd(Box::default())),
161            "file" => {
162                let root = env::var("DYN_FILE_KV")
163                    .map(PathBuf::from)
164                    .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
165                Ok(Self::File(root))
166            }
167            "mem" => Ok(Self::Memory),
168            x => anyhow::bail!("Unknown key-value store type '{x}'"),
169        }
170    }
171}
172
173impl TryFrom<String> for Selector {
174    type Error = anyhow::Error;
175
176    fn try_from(s: String) -> anyhow::Result<Selector> {
177        s.parse()
178    }
179}
180
181#[allow(clippy::large_enum_variant)]
182enum KeyValueStoreEnum {
183    Memory(MemoryStore),
184    Nats(NATSStore),
185    Etcd(EtcdStore),
186    File(FileStore),
187}
188
189impl KeyValueStoreEnum {
190    async fn get_or_create_bucket(
191        &self,
192        bucket_name: &str,
193        // auto-delete items older than this
194        ttl: Option<Duration>,
195    ) -> Result<Box<dyn Bucket>, StoreError> {
196        use KeyValueStoreEnum::*;
197        Ok(match self {
198            Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
199            Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
200            Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
201            File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
202        })
203    }
204
205    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Box<dyn Bucket>>, StoreError> {
206        use KeyValueStoreEnum::*;
207        let maybe_bucket: Option<Box<dyn Bucket>> = match self {
208            Memory(x) => x
209                .get_bucket(bucket_name)
210                .await?
211                .map(|b| Box::new(b) as Box<dyn Bucket>),
212            Nats(x) => x
213                .get_bucket(bucket_name)
214                .await?
215                .map(|b| Box::new(b) as Box<dyn Bucket>),
216            Etcd(x) => x
217                .get_bucket(bucket_name)
218                .await?
219                .map(|b| Box::new(b) as Box<dyn Bucket>),
220            File(x) => x
221                .get_bucket(bucket_name)
222                .await?
223                .map(|b| Box::new(b) as Box<dyn Bucket>),
224        };
225        Ok(maybe_bucket)
226    }
227
228    fn connection_id(&self) -> u64 {
229        use KeyValueStoreEnum::*;
230        match self {
231            Memory(x) => x.connection_id(),
232            Etcd(x) => x.connection_id(),
233            Nats(x) => x.connection_id(),
234            File(x) => x.connection_id(),
235        }
236    }
237
238    fn shutdown(&self) {
239        use KeyValueStoreEnum::*;
240        match self {
241            Memory(x) => x.shutdown(),
242            Etcd(x) => x.shutdown(),
243            Nats(x) => x.shutdown(),
244            File(x) => x.shutdown(),
245        }
246    }
247}
248
249#[derive(Clone)]
250pub struct Manager(Arc<KeyValueStoreEnum>);
251
252impl Default for Manager {
253    fn default() -> Self {
254        Manager::memory()
255    }
256}
257
258impl Manager {
259    /// In-memory KeyValueStoreManager for testing
260    pub fn memory() -> Self {
261        Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
262    }
263
264    pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
265        Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
266    }
267
268    pub fn file<P: Into<PathBuf>>(cancel_token: CancellationToken, root: P) -> Self {
269        Self::new(KeyValueStoreEnum::File(FileStore::new(cancel_token, root)))
270    }
271
272    fn new(s: KeyValueStoreEnum) -> Manager {
273        Manager(Arc::new(s))
274    }
275
276    pub async fn get_or_create_bucket(
277        &self,
278        bucket_name: &str,
279        // auto-delete items older than this
280        ttl: Option<Duration>,
281    ) -> Result<Box<dyn Bucket>, StoreError> {
282        self.0.get_or_create_bucket(bucket_name, ttl).await
283    }
284
285    pub async fn get_bucket(
286        &self,
287        bucket_name: &str,
288    ) -> Result<Option<Box<dyn Bucket>>, StoreError> {
289        self.0.get_bucket(bucket_name).await
290    }
291
292    pub fn connection_id(&self) -> u64 {
293        self.0.connection_id()
294    }
295
296    pub async fn load<T: for<'a> Deserialize<'a>>(
297        &self,
298        bucket: &str,
299        key: &Key,
300    ) -> Result<Option<T>, StoreError> {
301        let Some(bucket) = self.0.get_bucket(bucket).await? else {
302            // No bucket means no cards
303            return Ok(None);
304        };
305        Ok(match bucket.get(key).await? {
306            Some(card_bytes) => {
307                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
308                Some(card)
309            }
310            None => None,
311        })
312    }
313
314    /// Returns a receiver that will receive all the existing keys, and
315    /// then block and receive new keys as they are created.
316    /// Starts a task that runs forever, watches the store.
317    pub fn watch(
318        self: Arc<Self>,
319        bucket_name: &str,
320        bucket_ttl: Option<Duration>,
321        cancel_token: CancellationToken,
322    ) -> (
323        tokio::task::JoinHandle<Result<(), StoreError>>,
324        tokio::sync::mpsc::Receiver<WatchEvent>,
325    ) {
326        let bucket_name = bucket_name.to_string();
327        // Use a larger channel capacity to reduce the likelihood that a slow consumer
328        // during the initial KV-store replay phase triggers send timeouts. Events may
329        // still be dropped if the consumer cannot keep up within `WATCH_SEND_TIMEOUT`.
330        let (tx, rx) = tokio::sync::mpsc::channel(16384);
331        let watch_task = tokio::spawn(async move {
332            // Start listening for changes but don't poll this yet
333            let bucket = self
334                .0
335                .get_or_create_bucket(&bucket_name, bucket_ttl)
336                .await?;
337            let mut stream = bucket.watch().await?;
338
339            // Send all the existing keys
340            for (key, bytes) in bucket.entries().await? {
341                if let Err(err) = tx
342                    .send_timeout(
343                        WatchEvent::Put(KeyValue::new(key, bytes)),
344                        WATCH_SEND_TIMEOUT,
345                    )
346                    .await
347                {
348                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
349                }
350            }
351
352            // Now block waiting for new entries
353            loop {
354                let event = tokio::select! {
355                    _ = cancel_token.cancelled() => break,
356                    result = stream.next() => match result {
357                        Some(event) => event,
358                        None => break,
359                    }
360                };
361                if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
362                    tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
363                }
364            }
365
366            Ok::<(), StoreError>(())
367        });
368        (watch_task, rx)
369    }
370
371    pub async fn publish<T: Serialize + Versioned + Send + Sync>(
372        &self,
373        bucket_name: &str,
374        bucket_ttl: Option<Duration>,
375        key: &Key,
376        obj: &mut T,
377    ) -> anyhow::Result<StoreOutcome> {
378        let obj_json = serde_json::to_vec(obj)?;
379        let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
380
381        let outcome = bucket.insert(key, obj_json.into(), obj.revision()).await?;
382
383        match outcome {
384            StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
385                obj.set_revision(revision);
386            }
387        }
388        Ok(outcome)
389    }
390
391    /// Cleanup any temporary state.
392    /// TODO: Should this be async? Take &mut self?
393    pub fn shutdown(&self) {
394        self.0.shutdown()
395    }
396}
397
398/// An online storage for key-value config values.
399#[async_trait]
400pub trait Bucket: Send + Sync {
401    /// A bucket is a collection of key/value pairs.
402    /// Insert a value into a bucket, if it doesn't exist already
403    /// The Key should be the name of the item, not including the bucket name.
404    async fn insert(
405        &self,
406        key: &Key,
407        value: bytes::Bytes,
408        revision: u64,
409    ) -> Result<StoreOutcome, StoreError>;
410
411    /// Fetch an item from the key-value storage
412    /// The Key should be the name of the item, not including the bucket name.
413    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
414
415    /// Delete an item from the bucket
416    /// The Key should be the name of the item, not including the bucket name.
417    async fn delete(&self, key: &Key) -> Result<(), StoreError>;
418
419    /// A stream of items inserted into the bucket.
420    /// Every time the stream is polled it will either return a newly created entry, or block until
421    /// such time.
422    async fn watch(
423        &self,
424    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
425
426    /// The entries in this bucket.
427    /// The Key includes the full path including the bucket name.
428    /// That means you cannot directory get a Key from `entries` and pass it to `get` or `delete`.
429    async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError>;
430}
431
432#[derive(Debug, Copy, Clone, Eq, PartialEq)]
433pub enum StoreOutcome {
434    /// The operation succeeded and created a new entry with this revision.
435    /// Note that "create" also means update, because each new revision is a "create".
436    Created(u64),
437    /// The operation did not do anything, the value was already present, with this revision.
438    Exists(u64),
439}
440impl fmt::Display for StoreOutcome {
441    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442        match self {
443            StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
444            StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
445        }
446    }
447}
448
449#[derive(thiserror::Error, Debug)]
450pub enum StoreError {
451    #[error("Could not find bucket '{0}'")]
452    MissingBucket(String),
453
454    #[error("Could not find key '{0}'")]
455    MissingKey(String),
456
457    #[error("Internal storage error: '{0}'")]
458    ProviderError(String),
459
460    #[error("Internal NATS error: {0}")]
461    NATSError(String),
462
463    #[error("Internal etcd error: {0}")]
464    EtcdError(String),
465
466    #[error("Internal filesystem error: {0}")]
467    FilesystemError(String),
468
469    #[error("Key Value Error: {0} for bucket '{1}'")]
470    KeyValueError(String, String),
471
472    #[error("Error decoding bytes: {0}")]
473    JSONDecodeError(#[from] serde_json::error::Error),
474
475    #[error("Race condition, retry the call")]
476    Retry,
477}
478
479/// A trait allowing to get/set a revision on an object.
480/// NATS uses this to ensure atomic updates.
481pub trait Versioned {
482    fn revision(&self) -> u64;
483    fn set_revision(&mut self, r: u64);
484}
485
486#[cfg(test)]
487mod tests {
488    use std::sync::Arc;
489
490    use super::*;
491    use futures::{StreamExt, pin_mut};
492
493    const BUCKET_NAME: &str = "v1/mdc";
494
495    /// Convert the value returned by `watch()` into a broadcast stream that multiple
496    /// clients can listen to.
497    #[allow(dead_code)]
498    pub struct TappableStream {
499        tx: tokio::sync::broadcast::Sender<WatchEvent>,
500    }
501
502    #[allow(dead_code)]
503    impl TappableStream {
504        async fn new<T>(stream: T, max_size: usize) -> Self
505        where
506            T: futures::Stream<Item = WatchEvent> + Send + 'static,
507        {
508            let (tx, _) = tokio::sync::broadcast::channel(max_size);
509            let tx2 = tx.clone();
510            tokio::spawn(async move {
511                pin_mut!(stream);
512                while let Some(x) = stream.next().await {
513                    let _ = tx2.send(x);
514                }
515            });
516            TappableStream { tx }
517        }
518
519        fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
520            self.tx.subscribe()
521        }
522    }
523
524    fn init() {
525        crate::logging::init();
526    }
527
528    #[tokio::test]
529    async fn test_memory_storage() -> anyhow::Result<()> {
530        init();
531
532        let s = Arc::new(MemoryStore::new());
533        let s2 = Arc::clone(&s);
534
535        let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
536        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
537        assert_eq!(res, StoreOutcome::Created(0));
538
539        let mut expected = Vec::with_capacity(3);
540        for i in 1..=3 {
541            let item = WatchEvent::Put(KeyValue::new(
542                Key::new(format!("test{i}")),
543                format!("value{i}").into(),
544            ));
545            expected.push(item);
546        }
547
548        let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
549        let ingress = tokio::spawn(async move {
550            let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
551            let mut stream = b2.watch().await?;
552
553            // Put in before starting the watch-all
554            let v = stream.next().await.unwrap();
555            assert_eq!(v, expected[0]);
556
557            got_first_tx.send(()).unwrap();
558
559            // Put in after
560            let v = stream.next().await.unwrap();
561            assert_eq!(v, expected[1]);
562
563            let v = stream.next().await.unwrap();
564            assert_eq!(v, expected[2]);
565
566            Ok::<_, StoreError>(())
567        });
568
569        // MemoryStore uses a HashMap with no inherent ordering, so we must ensure test1 is
570        // fetched before test2 is inserted, otherwise they can come out in any order, and we
571        // wouldn't be testing the watch behavior.
572        got_first_rx.await?;
573
574        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
575        assert_eq!(res, StoreOutcome::Created(0));
576
577        // Repeat a key and revision. Ignored.
578        let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
579        assert_eq!(res, StoreOutcome::Exists(0));
580
581        // Increment revision
582        let res = bucket.insert(&"test2".into(), "value2".into(), 1).await?;
583        assert_eq!(res, StoreOutcome::Created(1));
584
585        let res = bucket.insert(&"test3".into(), "value3".into(), 0).await?;
586        assert_eq!(res, StoreOutcome::Created(0));
587
588        // ingress exits once it has received all values
589        let _ = ingress.await?;
590
591        Ok(())
592    }
593
594    #[tokio::test]
595    async fn test_broadcast_stream() -> anyhow::Result<()> {
596        init();
597
598        let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
599        let bucket: &'static _ =
600            Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
601
602        let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
603        assert_eq!(res, StoreOutcome::Created(0));
604
605        let stream = bucket.watch().await?;
606        let tap = TappableStream::new(stream, 10).await;
607
608        let mut rx1 = tap.subscribe();
609        let mut rx2 = tap.subscribe();
610
611        let item = WatchEvent::Put(KeyValue::new(Key::new("test1".to_string()), "GK".into()));
612        let item_clone = item.clone();
613        let handle1 = tokio::spawn(async move {
614            let b = rx1.recv().await.unwrap();
615            assert_eq!(b, item_clone);
616        });
617        let handle2 = tokio::spawn(async move {
618            let b = rx2.recv().await.unwrap();
619            assert_eq!(b, item);
620        });
621
622        bucket.insert(&"test1".into(), "GK".into(), 1).await?;
623
624        let _ = futures::join!(handle1, handle2);
625        Ok(())
626    }
627}