dynamo_llm/
key_value_store.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16//! Interface to a traditional key-value store such as etcd.
17//! "key_value_store" spelt out because in AI land "KV" means something else.
18
19use std::collections::HashMap;
20use std::fmt;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dynamo_runtime::slug::Slug;
27use dynamo_runtime::CancellationToken;
28use futures::StreamExt;
29use serde::{Deserialize, Serialize};
30
31mod mem;
32pub use mem::MemoryStorage;
33mod nats;
34pub use nats::NATSStorage;
35
36#[async_trait]
37pub trait KeyValueStore: Send + Sync {
38    async fn get_or_create_bucket(
39        &self,
40        bucket_name: &str,
41        // auto-delete items older than this
42        ttl: Option<Duration>,
43    ) -> Result<Box<dyn KeyValueBucket>, StorageError>;
44
45    async fn get_bucket(
46        &self,
47        bucket_name: &str,
48    ) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError>;
49}
50
51pub struct KeyValueStoreManager(Box<dyn KeyValueStore>);
52
53impl KeyValueStoreManager {
54    pub fn new(s: Box<dyn KeyValueStore>) -> KeyValueStoreManager {
55        KeyValueStoreManager(s)
56    }
57
58    pub async fn load<T: for<'a> Deserialize<'a>>(
59        &self,
60        bucket: &str,
61        key: &Slug,
62    ) -> Result<Option<T>, StorageError> {
63        let Some(bucket) = self.0.get_bucket(bucket).await? else {
64            // No bucket means no cards
65            return Ok(None);
66        };
67        match bucket.get(key.as_ref()).await {
68            Ok(Some(card_bytes)) => {
69                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
70                Ok(Some(card))
71            }
72            Ok(None) => Ok(None),
73            Err(err) => {
74                // TODO look at what errors NATS can give us and make more specific wrappers
75                Err(StorageError::NATSError(err.to_string()))
76            }
77        }
78    }
79
80    /// Returns a receiver that will receive all the existing keys, and
81    /// then block and receive new keys as they are created.
82    /// Starts a task that runs forever, watches the store.
83    pub fn watch<T: for<'a> Deserialize<'a> + Send + 'static>(
84        self: Arc<Self>,
85        bucket_name: &str,
86        bucket_ttl: Option<Duration>,
87    ) -> (
88        tokio::task::JoinHandle<Result<(), StorageError>>,
89        tokio::sync::mpsc::UnboundedReceiver<T>,
90    ) {
91        let bucket_name = bucket_name.to_string();
92        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
93        let watch_task = tokio::spawn(async move {
94            // Start listening for changes but don't poll this yet
95            let bucket = self
96                .0
97                .get_or_create_bucket(&bucket_name, bucket_ttl)
98                .await?;
99            let mut stream = bucket.watch().await?;
100
101            // Send all the existing keys
102            for (_, card_bytes) in bucket.entries().await? {
103                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
104                let _ = tx.send(card);
105            }
106
107            // Now block waiting for new entries
108            while let Some(card_bytes) = stream.next().await {
109                let card: T = serde_json::from_slice(card_bytes.as_ref())?;
110                let _ = tx.send(card);
111            }
112
113            Ok::<(), StorageError>(())
114        });
115        (watch_task, rx)
116    }
117
118    pub async fn publish<T: Serialize + Versioned + Send + Sync>(
119        &self,
120        bucket_name: &str,
121        bucket_ttl: Option<Duration>,
122        key: &str,
123        obj: &mut T,
124    ) -> anyhow::Result<StorageOutcome> {
125        let obj_json = serde_json::to_string(obj)?;
126        let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
127
128        let outcome = bucket
129            .insert(key.to_string(), obj_json, obj.revision())
130            .await?;
131
132        match outcome {
133            StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => {
134                obj.set_revision(revision);
135            }
136        }
137        Ok(outcome)
138    }
139
140    /// Re-publish the model card to the store regularly. Spawns a task and returns.
141    /// Takes most arguments by value because it will hold on to them in the publish task.
142    /// Deletes the card on cancellation.
143    pub fn publish_until_cancelled<T: Serialize + Versioned + Send + Sync + 'static>(
144        self: Arc<Self>,
145        cancel_token: CancellationToken,
146        bucket_name: String,
147        bucket_ttl: Option<Duration>,
148        publish_interval: Duration,
149        key: String,
150        mut obj: T,
151    ) {
152        tokio::spawn(async move {
153            loop {
154                let publish_result = self
155                    .clone()
156                    .publish(&bucket_name, bucket_ttl, &key, &mut obj)
157                    .await;
158                if let Err(err) = publish_result {
159                    tracing::error!(
160                        model = key,
161                        error = %err,
162                        "Failed publishing to KV storage. Ending publish task.",
163                    );
164                }
165                tokio::select! {
166                    _ = tokio::time::sleep(publish_interval) => {},
167                    _ = cancel_token.cancelled() => {
168                        tracing::trace!(model_service_name = key, "Publish loop cancelled");
169                        match self.0.get_bucket(&bucket_name).await {
170                            Ok(Some(bucket)) => {
171                                if let Err(err) = bucket.delete(&key).await {
172                                    // This is usually expected, our NATS connection is closed
173                                    tracing::trace!(bucket_name, key, %err, "Error delete published card from NATS on publish stop");
174                                }
175
176                                tracing::trace!(bucket_name, key, "Deleted Model Deployment Card from NATS");
177                            }
178                            Ok(None) => {
179                                tracing::trace!(bucket_name, key, "Bucket does not exist");
180                            }
181                            Err(err) => {
182                                tracing::trace!(bucket_name, %err, "publish_until_cancelled shutdown error");
183                            }
184                        }
185                        // Stop publishing
186                        break;
187                    }
188                }
189            }
190        });
191    }
192}
193
194/// An online storage for key-value config values.
195/// Usually backed by `nats-server`.
196#[async_trait]
197pub trait KeyValueBucket: Send {
198    /// A bucket is a collection of key/value pairs.
199    /// Insert a value into a bucket, if it doesn't exist already
200    async fn insert(
201        &self,
202        key: String,
203        value: String,
204        revision: u64,
205    ) -> Result<StorageOutcome, StorageError>;
206
207    /// Fetch an item from the key-value storage
208    async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError>;
209
210    /// Delete an item from the bucket
211    async fn delete(&self, key: &str) -> Result<(), StorageError>;
212
213    /// A stream of items inserted into the bucket.
214    /// Every time the stream is polled it will either return a newly created entry, or block until
215    /// such time.
216    async fn watch(
217        &self,
218    ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>;
219
220    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError>;
221}
222
223#[derive(Debug, Copy, Clone, Eq, PartialEq)]
224pub enum StorageOutcome {
225    /// The operation succeeded and created a new entry with this revision.
226    /// Note that "create" also means update, because each new revision is a "create".
227    Created(u64),
228    /// The operation did not do anything, the value was already present, with this revision.
229    Exists(u64),
230}
231impl fmt::Display for StorageOutcome {
232    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233        match self {
234            StorageOutcome::Created(revision) => write!(f, "Created at {revision}"),
235            StorageOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
236        }
237    }
238}
239
240#[derive(thiserror::Error, Debug)]
241pub enum StorageError {
242    #[error("Could not find bucket '{0}'")]
243    MissingBucket(String),
244
245    #[error("Could not find key '{0}'")]
246    MissingKey(String),
247
248    #[error("Internal storage error: '{0}'")]
249    ProviderError(String),
250
251    #[error("Internal NATS error: {0}")]
252    NATSError(String),
253
254    #[error("Internal etcd error: {0}")]
255    EtcdError(String),
256
257    #[error("Key Value Error: {0} for bucket '{1}")]
258    KeyValueError(String, String),
259
260    #[error("Error decoding bytes: {0}")]
261    JSONDecodeError(#[from] serde_json::error::Error),
262
263    #[error("Race condition, retry the call")]
264    Retry,
265}
266
267/// A trait allowing to get/set a revision on an object.
268/// NATS uses this to ensure atomic updates.
269pub trait Versioned {
270    fn revision(&self) -> u64;
271    fn set_revision(&mut self, r: u64);
272}
273
274#[cfg(test)]
275mod tests {
276    use std::sync::Arc;
277
278    use super::*;
279    use futures::{pin_mut, StreamExt};
280
281    const BUCKET_NAME: &str = "mdc";
282
283    /// Convert the value returned by `watch()` into a broadcast stream that multiple
284    /// clients can listen to.
285    #[allow(dead_code)]
286    pub struct TappableStream {
287        tx: tokio::sync::broadcast::Sender<bytes::Bytes>,
288    }
289
290    #[allow(dead_code)]
291    impl TappableStream {
292        async fn new<T>(stream: T, max_size: usize) -> Self
293        where
294            T: futures::Stream<Item = bytes::Bytes> + Send + 'static,
295        {
296            let (tx, _) = tokio::sync::broadcast::channel(max_size);
297            let tx2 = tx.clone();
298            tokio::spawn(async move {
299                pin_mut!(stream);
300                while let Some(x) = stream.next().await {
301                    let _ = tx2.send(x);
302                }
303            });
304            TappableStream { tx }
305        }
306
307        fn subscribe(&self) -> tokio::sync::broadcast::Receiver<bytes::Bytes> {
308            self.tx.subscribe()
309        }
310    }
311
312    fn init() {
313        dynamo_runtime::logging::init();
314    }
315
316    #[tokio::test]
317    async fn test_memory_storage() -> anyhow::Result<()> {
318        init();
319
320        let s = Arc::new(MemoryStorage::new());
321        let s2 = Arc::clone(&s);
322
323        let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
324        let res = bucket
325            .insert("test1".to_string(), "value1".to_string(), 0)
326            .await?;
327        assert_eq!(res, StorageOutcome::Created(0));
328
329        let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
330        let ingress = tokio::spawn(async move {
331            let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
332            let mut stream = b2.watch().await?;
333
334            // Put in before starting the watch-all
335            let v = stream.next().await.unwrap();
336            assert_eq!(v, "value1".as_bytes());
337
338            got_first_tx.send(()).unwrap();
339
340            // Put in after
341            let v = stream.next().await.unwrap();
342            assert_eq!(v, "value2".as_bytes());
343            let v = stream.next().await.unwrap();
344            assert_eq!(v, "value3".as_bytes());
345
346            Ok::<_, StorageError>(())
347        });
348
349        // MemoryStorage uses a HashMap with no inherent ordering, so we must ensure test1 is
350        // fetched before test2 is inserted, otherwise they can come out in any order, and we
351        // wouldn't be testing the watch behavior.
352        got_first_rx.await?;
353
354        let res = bucket
355            .insert("test2".to_string(), "value2".to_string(), 0)
356            .await?;
357        assert_eq!(res, StorageOutcome::Created(0));
358
359        // Repeat a key and revision. Ignored.
360        let res = bucket
361            .insert("test2".to_string(), "value2".to_string(), 0)
362            .await?;
363        assert_eq!(res, StorageOutcome::Exists(0));
364
365        // Increment revision
366        let res = bucket
367            .insert("test2".to_string(), "value2".to_string(), 1)
368            .await?;
369        assert_eq!(res, StorageOutcome::Created(1));
370
371        let res = bucket
372            .insert("test3".to_string(), "value3".to_string(), 0)
373            .await?;
374        assert_eq!(res, StorageOutcome::Created(0));
375
376        // ingress exits once it has received all values
377        let _ = ingress.await?;
378
379        Ok(())
380    }
381
382    #[tokio::test]
383    async fn test_broadcast_stream() -> anyhow::Result<()> {
384        init();
385
386        let s: &'static _ = Box::leak(Box::new(MemoryStorage::new()));
387        let bucket: &'static _ =
388            Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
389
390        let res = bucket
391            .insert("test1".to_string(), "value1".to_string(), 0)
392            .await?;
393        assert_eq!(res, StorageOutcome::Created(0));
394
395        let stream = bucket.watch().await?;
396        let tap = TappableStream::new(stream, 10).await;
397
398        let mut rx1 = tap.subscribe();
399        let mut rx2 = tap.subscribe();
400
401        let handle1 = tokio::spawn(async move {
402            let b = rx1.recv().await.unwrap();
403            assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
404        });
405        let handle2 = tokio::spawn(async move {
406            let b = rx2.recv().await.unwrap();
407            assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
408        });
409
410        bucket
411            .insert("test1".to_string(), "GK".to_string(), 1)
412            .await?;
413
414        let _ = futures::join!(handle1, handle2);
415        Ok(())
416    }
417}