dynamo_llm/key_value_store/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{collections::HashMap, pin::Pin, time::Duration};
17
18use async_trait::async_trait;
19use dynamo_runtime::{protocols::Endpoint, slug::Slug, transports::nats::Client};
20use futures::StreamExt;
21
22use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
23
24#[derive(Clone)]
25pub struct NATSStorage {
26    client: Client,
27    endpoint: Endpoint,
28}
29
30pub struct NATSBucket {
31    nats_store: async_nats::jetstream::kv::Store,
32}
33
34#[async_trait]
35impl KeyValueStore for NATSStorage {
36    async fn get_or_create_bucket(
37        &self,
38        bucket_name: &str,
39        ttl: Option<Duration>,
40    ) -> Result<Box<dyn KeyValueBucket>, StorageError> {
41        let name = Slug::slugify(bucket_name);
42        let nats_store = self
43            .get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
44            .await?;
45        Ok(Box::new(NATSBucket { nats_store }))
46    }
47
48    async fn get_bucket(
49        &self,
50        bucket_name: &str,
51    ) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
52        let name = Slug::slugify(bucket_name);
53        match self.get_key_value(&self.endpoint.namespace, &name).await? {
54            Some(nats_store) => Ok(Some(Box::new(NATSBucket { nats_store }))),
55            None => Ok(None),
56        }
57    }
58}
59
60impl NATSStorage {
61    pub fn new(client: Client, endpoint: Endpoint) -> Self {
62        NATSStorage { client, endpoint }
63    }
64
65    /// Get or create a key-value store (aka bucket) in NATS.
66    ///
67    /// ttl is only used if we are creating the bucket, so if that has
68    /// changed first delete the bucket.
69    async fn get_or_create_key_value(
70        &self,
71        namespace: &str,
72        bucket_name: &Slug,
73        // Delete entries older than this
74        ttl: Option<Duration>,
75    ) -> Result<async_nats::jetstream::kv::Store, StorageError> {
76        if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
77            return Ok(kv);
78        }
79
80        // It doesn't exist, create it
81
82        let bucket_name = single_name(namespace, bucket_name);
83        let js = self.client.jetstream();
84        let create_result = js
85            .create_key_value(
86                // TODO: configure the bucket, probably need to pass some of these values in
87                async_nats::jetstream::kv::Config {
88                    bucket: bucket_name.clone(),
89                    max_age: ttl.unwrap_or_default(),
90                    ..Default::default()
91                },
92            )
93            .await;
94        tracing::debug!("Created bucket {bucket_name}");
95        create_result.map_err(|err| StorageError::KeyValueError(err.to_string(), bucket_name))
96    }
97
98    async fn get_key_value(
99        &self,
100        namespace: &str,
101        bucket_name: &Slug,
102    ) -> Result<Option<async_nats::jetstream::kv::Store>, StorageError> {
103        let bucket_name = single_name(namespace, bucket_name);
104        let js = self.client.jetstream();
105
106        use async_nats::jetstream::context::KeyValueErrorKind;
107        match js.get_key_value(&bucket_name).await {
108            Ok(store) => Ok(Some(store)),
109            Err(err) if err.kind() == KeyValueErrorKind::GetBucket => {
110                // bucket doesn't exist
111                Ok(None)
112            }
113            Err(err) => Err(StorageError::KeyValueError(err.to_string(), bucket_name)),
114        }
115    }
116}
117
118#[async_trait]
119impl KeyValueBucket for NATSBucket {
120    async fn insert(
121        &self,
122        key: String,
123        value: String,
124        revision: u64,
125    ) -> Result<StorageOutcome, StorageError> {
126        if revision == 0 {
127            self.create(key, value).await
128        } else {
129            self.update(key, value, revision).await
130        }
131    }
132
133    async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
134        self.nats_store
135            .get(key)
136            .await
137            .map_err(|e| StorageError::NATSError(e.to_string()))
138    }
139
140    async fn delete(&self, key: &str) -> Result<(), StorageError> {
141        self.nats_store
142            .delete(key)
143            .await
144            .map_err(|e| StorageError::NATSError(e.to_string()))
145    }
146
147    async fn watch(
148        &self,
149    ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
150    {
151        let watch_stream = self
152            .nats_store
153            .watch_all()
154            .await
155            .map_err(|e| StorageError::NATSError(e.to_string()))?;
156        // Map the `Entry` to `Entry.value` which is Bytes of the stored value.
157        Ok(Box::pin(
158            watch_stream.filter_map(
159                |maybe_entry: Result<
160                    async_nats::jetstream::kv::Entry,
161                    async_nats::error::Error<_>,
162                >| async move {
163                    match maybe_entry {
164                        Ok(entry) => Some(entry.value),
165                        Err(e) => {
166                            tracing::error!(error=%e, "watch fatal err");
167                            None
168                        }
169                    }
170                },
171            ),
172        ))
173    }
174
175    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
176        let mut key_stream = self
177            .nats_store
178            .keys()
179            .await
180            .map_err(|e| StorageError::NATSError(e.to_string()))?;
181        let mut out = HashMap::new();
182        while let Some(Ok(key)) = key_stream.next().await {
183            if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
184                out.insert(key, entry.value);
185            }
186        }
187        Ok(out)
188    }
189}
190
191impl NATSBucket {
192    async fn create(&self, key: String, value: String) -> Result<StorageOutcome, StorageError> {
193        match self.nats_store.create(&key, value.into()).await {
194            Ok(revision) => Ok(StorageOutcome::Created(revision)),
195            Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
196                // key exists, get the revsion
197                match self.nats_store.entry(&key).await {
198                    Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)),
199                    Ok(None) => {
200                        tracing::error!(
201                            key,
202                            "Race condition, key deleted between create and fetch. Retry."
203                        );
204                        Err(StorageError::Retry)
205                    }
206                    Err(err) => Err(StorageError::NATSError(err.to_string())),
207                }
208            }
209            Err(err) => Err(StorageError::NATSError(err.to_string())),
210        }
211    }
212
213    async fn update(
214        &self,
215        key: String,
216        value: String,
217        revision: u64,
218    ) -> Result<StorageOutcome, StorageError> {
219        match self
220            .nats_store
221            .update(key.clone(), value.clone().into(), revision)
222            .await
223        {
224            Ok(revision) => Ok(StorageOutcome::Created(revision)),
225            Err(err)
226                if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
227            {
228                tracing::warn!(revision, key, "Update WrongLastRevision, resync");
229                self.resync_update(key, value).await
230            }
231            Err(err) => Err(StorageError::NATSError(err.to_string())),
232        }
233    }
234
235    /// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
236    /// and try the update again.
237    async fn resync_update(
238        &self,
239        key: String,
240        value: String,
241    ) -> Result<StorageOutcome, StorageError> {
242        match self.nats_store.entry(&key).await {
243            Ok(Some(entry)) => {
244                // Re-try the update with new version number
245                let next_rev = entry.revision + 1;
246                match self
247                    .nats_store
248                    .update(key.clone(), value.into(), next_rev)
249                    .await
250                {
251                    Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)),
252                    Err(err) => Err(StorageError::NATSError(format!(
253                        "Error during update of key {key} after resync: {err}"
254                    ))),
255                }
256            }
257            Ok(None) => {
258                tracing::warn!(key, "Entry does not exist during resync, creating.");
259                self.create(key, value).await
260            }
261            Err(err) => {
262                tracing::error!(key, %err, "Failed fetching entry during resync");
263                Err(StorageError::NATSError(err.to_string()))
264            }
265        }
266    }
267}
268
269/// async-nats won't let us use a multi-part subject to create KV buckets (and probably many other
270/// things).
271fn single_name(namespace: &str, name: &Slug) -> String {
272    format!("{namespace}_{name}")
273}