dynamo_runtime/storage/key_value_store/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{collections::HashMap, pin::Pin, time::Duration};
5
6use crate::{
7    protocols::EndpointId,
8    slug::Slug,
9    storage::key_value_store::{Key, KeyValue, WatchEvent},
10    transports::nats::Client,
11};
12use async_nats::jetstream::kv::Operation;
13use async_trait::async_trait;
14use futures::StreamExt;
15
16use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
17
18#[derive(Clone)]
19pub struct NATSStore {
20    client: Client,
21    endpoint: EndpointId,
22}
23
24pub struct NATSBucket {
25    nats_store: async_nats::jetstream::kv::Store,
26}
27
28#[async_trait]
29impl KeyValueStore for NATSStore {
30    type Bucket = NATSBucket;
31
32    async fn get_or_create_bucket(
33        &self,
34        bucket_name: &str,
35        ttl: Option<Duration>,
36    ) -> Result<Self::Bucket, StoreError> {
37        let name = Slug::slugify(bucket_name);
38        let nats_store = self
39            .get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
40            .await?;
41        Ok(NATSBucket { nats_store })
42    }
43
44    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
45        let name = Slug::slugify(bucket_name);
46        match self.get_key_value(&self.endpoint.namespace, &name).await? {
47            Some(nats_store) => Ok(Some(NATSBucket { nats_store })),
48            None => Ok(None),
49        }
50    }
51
52    fn connection_id(&self) -> u64 {
53        self.client.client().server_info().client_id
54    }
55}
56
57impl NATSStore {
58    pub fn new(client: Client, endpoint: EndpointId) -> Self {
59        NATSStore { client, endpoint }
60    }
61
62    /// Get or create a key-value store (aka bucket) in NATS.
63    ///
64    /// ttl is only used if we are creating the bucket, so if that has
65    /// changed first delete the bucket.
66    async fn get_or_create_key_value(
67        &self,
68        namespace: &str,
69        bucket_name: &Slug,
70        // Delete entries older than this
71        ttl: Option<Duration>,
72    ) -> Result<async_nats::jetstream::kv::Store, StoreError> {
73        if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
74            return Ok(kv);
75        }
76
77        // It doesn't exist, create it
78
79        let bucket_name = single_name(namespace, bucket_name);
80        let js = self.client.jetstream();
81        let create_result = js
82            .create_key_value(
83                // TODO: configure the bucket, probably need to pass some of these values in
84                async_nats::jetstream::kv::Config {
85                    bucket: bucket_name.clone(),
86                    max_age: ttl.unwrap_or_default(),
87                    ..Default::default()
88                },
89            )
90            .await;
91        let nats_store = create_result
92            .map_err(|err| StoreError::KeyValueError(err.to_string(), bucket_name.clone()))?;
93        tracing::debug!("Created bucket {bucket_name}");
94        Ok(nats_store)
95    }
96
97    async fn get_key_value(
98        &self,
99        namespace: &str,
100        bucket_name: &Slug,
101    ) -> Result<Option<async_nats::jetstream::kv::Store>, StoreError> {
102        let bucket_name = single_name(namespace, bucket_name);
103        let js = self.client.jetstream();
104
105        use async_nats::jetstream::context::KeyValueErrorKind;
106        match js.get_key_value(&bucket_name).await {
107            Ok(store) => Ok(Some(store)),
108            Err(err) if err.kind() == KeyValueErrorKind::GetBucket => {
109                // bucket doesn't exist
110                Ok(None)
111            }
112            Err(err) => Err(StoreError::KeyValueError(err.to_string(), bucket_name)),
113        }
114    }
115}
116
117#[async_trait]
118impl KeyValueBucket for NATSBucket {
119    async fn insert(
120        &self,
121        key: &Key,
122        value: &str,
123        revision: u64,
124    ) -> Result<StoreOutcome, StoreError> {
125        if revision == 0 {
126            self.create(key, value).await
127        } else {
128            self.update(key, value, revision).await
129        }
130    }
131
132    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
133        self.nats_store
134            .get(key)
135            .await
136            .map_err(|e| StoreError::NATSError(e.to_string()))
137    }
138
139    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
140        self.nats_store
141            .delete(key)
142            .await
143            .map_err(|e| StoreError::NATSError(e.to_string()))
144    }
145
146    async fn watch(
147        &self,
148    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
149        let watch_stream = self
150            .nats_store
151            .watch_all()
152            .await
153            .map_err(|e| StoreError::NATSError(e.to_string()))?;
154        // Map the `Entry` to `Entry.value` which is Bytes of the stored value.
155        Ok(Box::pin(
156            watch_stream.filter_map(
157                |maybe_entry: Result<
158                    async_nats::jetstream::kv::Entry,
159                    async_nats::error::Error<_>,
160                >| async move {
161                    match maybe_entry {
162                        Ok(entry) => {
163                            let item = KeyValue::new(entry.key, entry.value);
164                            Some(match entry.operation {
165                                Operation::Put => WatchEvent::Put(item),
166                                Operation::Delete => WatchEvent::Delete(item),
167                                // TODO: What is Purge? Not urgent, NATS impl not used
168                                Operation::Purge => WatchEvent::Delete(item),
169                            })
170                        }
171                        Err(e) => {
172                            tracing::error!(error=%e, "watch fatal err");
173                            None
174                        }
175                    }
176                },
177            ),
178        ))
179    }
180
181    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
182        let mut key_stream = self
183            .nats_store
184            .keys()
185            .await
186            .map_err(|e| StoreError::NATSError(e.to_string()))?;
187        let mut out = HashMap::new();
188        while let Some(Ok(key)) = key_stream.next().await {
189            if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
190                out.insert(key, entry.value);
191            }
192        }
193        Ok(out)
194    }
195}
196
197impl NATSBucket {
198    async fn create(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
199        match self.nats_store.create(&key, value.to_string().into()).await {
200            Ok(revision) => Ok(StoreOutcome::Created(revision)),
201            Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
202                // key exists, get the revsion
203                match self.nats_store.entry(key).await {
204                    Ok(Some(entry)) => Ok(StoreOutcome::Exists(entry.revision)),
205                    Ok(None) => {
206                        tracing::error!(
207                            %key,
208                            "Race condition, key deleted between create and fetch. Retry."
209                        );
210                        Err(StoreError::Retry)
211                    }
212                    Err(err) => Err(StoreError::NATSError(err.to_string())),
213                }
214            }
215            Err(err) => Err(StoreError::NATSError(err.to_string())),
216        }
217    }
218
219    async fn update(
220        &self,
221        key: &Key,
222        value: &str,
223        revision: u64,
224    ) -> Result<StoreOutcome, StoreError> {
225        match self
226            .nats_store
227            .update(key, value.to_string().into(), revision)
228            .await
229        {
230            Ok(revision) => Ok(StoreOutcome::Created(revision)),
231            Err(err)
232                if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
233            {
234                tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
235                self.resync_update(key, value).await
236            }
237            Err(err) => Err(StoreError::NATSError(err.to_string())),
238        }
239    }
240
241    /// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
242    /// and try the update again.
243    async fn resync_update(&self, key: &Key, value: &str) -> Result<StoreOutcome, StoreError> {
244        match self.nats_store.entry(key).await {
245            Ok(Some(entry)) => {
246                // Re-try the update with new version number
247                let next_rev = entry.revision + 1;
248                match self
249                    .nats_store
250                    .update(key, value.to_string().into(), next_rev)
251                    .await
252                {
253                    Ok(correct_revision) => Ok(StoreOutcome::Created(correct_revision)),
254                    Err(err) => Err(StoreError::NATSError(format!(
255                        "Error during update of key {key} after resync: {err}"
256                    ))),
257                }
258            }
259            Ok(None) => {
260                tracing::warn!(%key, "Entry does not exist during resync, creating.");
261                self.create(key, value).await
262            }
263            Err(err) => {
264                tracing::error!(%key, %err, "Failed fetching entry during resync");
265                Err(StoreError::NATSError(err.to_string()))
266            }
267        }
268    }
269}
270
271/// async-nats won't let us use a multi-part subject to create KV buckets (and probably many other
272/// things).
273fn single_name(namespace: &str, name: &Slug) -> String {
274    format!("{namespace}_{name}")
275}