dynamo_runtime/storage/key_value_store/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{collections::HashMap, pin::Pin, time::Duration};
5
6use crate::{
7    protocols::EndpointId, slug::Slug, storage::key_value_store::Key, transports::nats::Client,
8};
9use async_trait::async_trait;
10use futures::StreamExt;
11
12use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
13
14#[derive(Clone)]
15pub struct NATSStore {
16    client: Client,
17    endpoint: EndpointId,
18}
19
20pub struct NATSBucket {
21    nats_store: async_nats::jetstream::kv::Store,
22}
23
24#[async_trait]
25impl KeyValueStore for NATSStore {
26    async fn get_or_create_bucket(
27        &self,
28        bucket_name: &str,
29        ttl: Option<Duration>,
30    ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
31        let name = Slug::slugify(bucket_name);
32        let nats_store = self
33            .get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
34            .await?;
35        Ok(Box::new(NATSBucket { nats_store }))
36    }
37
38    async fn get_bucket(
39        &self,
40        bucket_name: &str,
41    ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
42        let name = Slug::slugify(bucket_name);
43        match self.get_key_value(&self.endpoint.namespace, &name).await? {
44            Some(nats_store) => Ok(Some(Box::new(NATSBucket { nats_store }))),
45            None => Ok(None),
46        }
47    }
48
49    fn connection_id(&self) -> u64 {
50        self.client.client().server_info().client_id
51    }
52}
53
54impl NATSStore {
55    pub fn new(client: Client, endpoint: EndpointId) -> Self {
56        NATSStore { client, endpoint }
57    }
58
59    /// Get or create a key-value store (aka bucket) in NATS.
60    ///
61    /// ttl is only used if we are creating the bucket, so if that has
62    /// changed first delete the bucket.
63    async fn get_or_create_key_value(
64        &self,
65        namespace: &str,
66        bucket_name: &Slug,
67        // Delete entries older than this
68        ttl: Option<Duration>,
69    ) -> Result<async_nats::jetstream::kv::Store, StoreError> {
70        if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
71            return Ok(kv);
72        }
73
74        // It doesn't exist, create it
75
76        let bucket_name = single_name(namespace, bucket_name);
77        let js = self.client.jetstream();
78        let create_result = js
79            .create_key_value(
80                // TODO: configure the bucket, probably need to pass some of these values in
81                async_nats::jetstream::kv::Config {
82                    bucket: bucket_name.clone(),
83                    max_age: ttl.unwrap_or_default(),
84                    ..Default::default()
85                },
86            )
87            .await;
88        let nats_store = create_result
89            .map_err(|err| StoreError::KeyValueError(err.to_string(), bucket_name.clone()))?;
90        tracing::debug!("Created bucket {bucket_name}");
91        Ok(nats_store)
92    }
93
94    async fn get_key_value(
95        &self,
96        namespace: &str,
97        bucket_name: &Slug,
98    ) -> Result<Option<async_nats::jetstream::kv::Store>, StoreError> {
99        let bucket_name = single_name(namespace, bucket_name);
100        let js = self.client.jetstream();
101
102        use async_nats::jetstream::context::KeyValueErrorKind;
103        match js.get_key_value(&bucket_name).await {
104            Ok(store) => Ok(Some(store)),
105            Err(err) if err.kind() == KeyValueErrorKind::GetBucket => {
106                // bucket doesn't exist
107                Ok(None)
108            }
109            Err(err) => Err(StoreError::KeyValueError(err.to_string(), bucket_name)),
110        }
111    }
112}
113
114#[async_trait]
115impl KeyValueBucket for NATSBucket {
116    async fn insert(
117        &self,
118        key: &Key,
119        value: &str,
120        revision: u64,
121    ) -> Result<StoreOutcome, StoreError> {
122        if revision == 0 {
123            self.create(key, value).await
124        } else {
125            self.update(key, value, revision).await
126        }
127    }
128
129    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
130        self.nats_store
131            .get(key)
132            .await
133            .map_err(|e| StoreError::NATSError(e.to_string()))
134    }
135
136    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
137        self.nats_store
138            .delete(key)
139            .await
140            .map_err(|e| StoreError::NATSError(e.to_string()))
141    }
142
143    async fn watch(
144        &self,
145    ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>
146    {
147        let watch_stream = self
148            .nats_store
149            .watch_all()
150            .await
151            .map_err(|e| StoreError::NATSError(e.to_string()))?;
152        // Map the `Entry` to `Entry.value` which is Bytes of the stored value.
153        Ok(Box::pin(
154            watch_stream.filter_map(
155                |maybe_entry: Result<
156                    async_nats::jetstream::kv::Entry,
157                    async_nats::error::Error<_>,
158                >| async move {
159                    match maybe_entry {
160                        Ok(entry) => Some(entry.value),
161                        Err(e) => {
162                            tracing::error!(error=%e, "watch fatal err");
163                            None
164                        }
165                    }
166                },
167            ),
168        ))
169    }
170
171    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
172        let mut key_stream = self
173            .nats_store
174            .keys()
175            .await
176            .map_err(|e| StoreError::NATSError(e.to_string()))?;
177        let mut out = HashMap::new();
178        while let Some(Ok(key)) = key_stream.next().await {
179            if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
180                out.insert(key, entry.value);
181            }
182        }
183        Ok(out)
184    }
185}
186
187impl NATSBucket {
188    async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
189        match self.nats_store.create(&key, value.to_string().into()).await {
190            Ok(revision) => Ok(StoreOutcome::Created(revision)),
191            Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
192                // key exists, get the revsion
193                match self.nats_store.entry(key).await {
194                    Ok(Some(entry)) => Ok(StoreOutcome::Exists(entry.revision)),
195                    Ok(None) => {
196                        tracing::error!(
197                            %key,
198                            "Race condition, key deleted between create and fetch. Retry."
199                        );
200                        Err(StoreError::Retry)
201                    }
202                    Err(err) => Err(StoreError::NATSError(err.to_string())),
203                }
204            }
205            Err(err) => Err(StoreError::NATSError(err.to_string())),
206        }
207    }
208
209    async fn update(
210        &self,
211        key: &Key,
212        value: &str,
213        revision: u64,
214    ) -> Result<StoreOutcome, StoreError> {
215        match self
216            .nats_store
217            .update(key, value.to_string().into(), revision)
218            .await
219        {
220            Ok(revision) => Ok(StoreOutcome::Created(revision)),
221            Err(err)
222                if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
223            {
224                tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
225                self.resync_update(key, value).await
226            }
227            Err(err) => Err(StoreError::NATSError(err.to_string())),
228        }
229    }
230
231    /// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
232    /// and try the update again.
233    async fn resync_update(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
234        match self.nats_store.entry(key).await {
235            Ok(Some(entry)) => {
236                // Re-try the update with new version number
237                let next_rev = entry.revision + 1;
238                match self
239                    .nats_store
240                    .update(key, value.to_string().into(), next_rev)
241                    .await
242                {
243                    Ok(correct_revision) => Ok(StoreOutcome::Created(correct_revision)),
244                    Err(err) => Err(StoreError::NATSError(format!(
245                        "Error during update of key {key} after resync: {err}"
246                    ))),
247                }
248            }
249            Ok(None) => {
250                tracing::warn!(%key, "Entry does not exist during resync, creating.");
251                self.create(key, value).await
252            }
253            Err(err) => {
254                tracing::error!(%key, %err, "Failed fetching entry during resync");
255                Err(StoreError::NATSError(err.to_string()))
256            }
257        }
258    }
259}
260
261/// async-nats won't let us use a multi-part subject to create KV buckets (and probably many other
262/// things).
263fn single_name(namespace: &str, name: &Slug) -> String {
264    format!("{namespace}_{name}")
265}