dynamo_runtime/storage/key_value_store/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::{collections::HashMap, pin::Pin, time::Duration};
5
6use crate::{
7    protocols::EndpointId,
8    slug::Slug,
9    storage::key_value_store::{Key, KeyValue, WatchEvent},
10    transports::nats::Client,
11};
12use async_nats::jetstream::kv::Operation;
13use async_trait::async_trait;
14use futures::StreamExt;
15
16use super::{KeyValueBucket, KeyValueStore, StoreError, StoreOutcome};
17
18#[derive(Clone)]
19pub struct NATSStore {
20    client: Client,
21    endpoint: EndpointId,
22}
23
24pub struct NATSBucket {
25    nats_store: async_nats::jetstream::kv::Store,
26}
27
28#[async_trait]
29impl KeyValueStore for NATSStore {
30    type Bucket = NATSBucket;
31
32    async fn get_or_create_bucket(
33        &self,
34        bucket_name: &str,
35        ttl: Option<Duration>,
36    ) -> Result<Self::Bucket, StoreError> {
37        let name = Slug::slugify(bucket_name);
38        let nats_store = self
39            .get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
40            .await?;
41        Ok(NATSBucket { nats_store })
42    }
43
44    async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError> {
45        let name = Slug::slugify(bucket_name);
46        match self.get_key_value(&self.endpoint.namespace, &name).await? {
47            Some(nats_store) => Ok(Some(NATSBucket { nats_store })),
48            None => Ok(None),
49        }
50    }
51
52    fn connection_id(&self) -> u64 {
53        self.client.client().server_info().client_id
54    }
55
56    fn shutdown(&self) {
57        // TODO: Track and delete any owned keys
58        // The TTL should ensure NATS does it, but best we do it immediately
59    }
60}
61
62impl NATSStore {
63    pub fn new(client: Client, endpoint: EndpointId) -> Self {
64        NATSStore { client, endpoint }
65    }
66
67    /// Get or create a key-value store (aka bucket) in NATS.
68    ///
69    /// ttl is only used if we are creating the bucket, so if that has
70    /// changed first delete the bucket.
71    async fn get_or_create_key_value(
72        &self,
73        namespace: &str,
74        bucket_name: &Slug,
75        // Delete entries older than this
76        ttl: Option<Duration>,
77    ) -> Result<async_nats::jetstream::kv::Store, StoreError> {
78        if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
79            return Ok(kv);
80        }
81
82        // It doesn't exist, create it
83
84        let bucket_name = single_name(namespace, bucket_name);
85        let js = self.client.jetstream();
86        let create_result = js
87            .create_key_value(
88                // TODO: configure the bucket, probably need to pass some of these values in
89                async_nats::jetstream::kv::Config {
90                    bucket: bucket_name.clone(),
91                    max_age: ttl.unwrap_or_default(),
92                    ..Default::default()
93                },
94            )
95            .await;
96        let nats_store = create_result
97            .map_err(|err| StoreError::KeyValueError(err.to_string(), bucket_name.clone()))?;
98        tracing::debug!("Created bucket {bucket_name}");
99        Ok(nats_store)
100    }
101
102    async fn get_key_value(
103        &self,
104        namespace: &str,
105        bucket_name: &Slug,
106    ) -> Result<Option<async_nats::jetstream::kv::Store>, StoreError> {
107        let bucket_name = single_name(namespace, bucket_name);
108        let js = self.client.jetstream();
109
110        use async_nats::jetstream::context::KeyValueErrorKind;
111        match js.get_key_value(&bucket_name).await {
112            Ok(store) => Ok(Some(store)),
113            Err(err) if err.kind() == KeyValueErrorKind::GetBucket => {
114                // bucket doesn't exist
115                Ok(None)
116            }
117            Err(err) => Err(StoreError::KeyValueError(err.to_string(), bucket_name)),
118        }
119    }
120}
121
122#[async_trait]
123impl KeyValueBucket for NATSBucket {
124    async fn insert(
125        &self,
126        key: &Key,
127        value: bytes::Bytes,
128        revision: u64,
129    ) -> Result<StoreOutcome, StoreError> {
130        if revision == 0 {
131            self.create(key, value).await
132        } else {
133            self.update(key, value, revision).await
134        }
135    }
136
137    async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError> {
138        self.nats_store
139            .get(key)
140            .await
141            .map_err(|e| StoreError::NATSError(e.to_string()))
142    }
143
144    async fn delete(&self, key: &Key) -> Result<(), StoreError> {
145        self.nats_store
146            .delete(key)
147            .await
148            .map_err(|e| StoreError::NATSError(e.to_string()))
149    }
150
151    async fn watch(
152        &self,
153    ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + 'life0>>, StoreError> {
154        let watch_stream = self
155            .nats_store
156            .watch_all()
157            .await
158            .map_err(|e| StoreError::NATSError(e.to_string()))?;
159        // Map the `Entry` to `Entry.value` which is Bytes of the stored value.
160        Ok(Box::pin(
161            watch_stream.filter_map(
162                |maybe_entry: Result<
163                    async_nats::jetstream::kv::Entry,
164                    async_nats::error::Error<_>,
165                >| async move {
166                    match maybe_entry {
167                        Ok(entry) => {
168                            Some(match entry.operation {
169                                Operation::Put => {
170                                    let item = KeyValue::new(entry.key, entry.value);
171                                    WatchEvent::Put(item)
172                                }
173                                Operation::Delete => WatchEvent::Delete(Key::from_raw(entry.key)),
174                                // TODO: What is Purge? Not urgent, NATS impl not used
175                                Operation::Purge => WatchEvent::Delete(Key::from_raw(entry.key)),
176                            })
177                        }
178                        Err(e) => {
179                            tracing::error!(error=%e, "watch fatal err");
180                            None
181                        }
182                    }
183                },
184            ),
185        ))
186    }
187
188    async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError> {
189        let mut key_stream = self
190            .nats_store
191            .keys()
192            .await
193            .map_err(|e| StoreError::NATSError(e.to_string()))?;
194        let mut out = HashMap::new();
195        while let Some(Ok(key)) = key_stream.next().await {
196            if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
197                out.insert(key, entry.value);
198            }
199        }
200        Ok(out)
201    }
202}
203
204impl NATSBucket {
205    async fn create(&self, key: &Key, value: bytes::Bytes) -> Result<StoreOutcome, StoreError> {
206        match self.nats_store.create(&key, value).await {
207            Ok(revision) => Ok(StoreOutcome::Created(revision)),
208            Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
209                // key exists, get the revsion
210                match self.nats_store.entry(key).await {
211                    Ok(Some(entry)) => Ok(StoreOutcome::Exists(entry.revision)),
212                    Ok(None) => {
213                        tracing::error!(
214                            %key,
215                            "Race condition, key deleted between create and fetch. Retry."
216                        );
217                        Err(StoreError::Retry)
218                    }
219                    Err(err) => Err(StoreError::NATSError(err.to_string())),
220                }
221            }
222            Err(err) => Err(StoreError::NATSError(err.to_string())),
223        }
224    }
225
226    async fn update(
227        &self,
228        key: &Key,
229        value: bytes::Bytes,
230        revision: u64,
231    ) -> Result<StoreOutcome, StoreError> {
232        match self.nats_store.update(key, value.clone(), revision).await {
233            Ok(revision) => Ok(StoreOutcome::Created(revision)),
234            Err(err)
235                if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
236            {
237                tracing::warn!(revision, %key, "Update WrongLastRevision, resync");
238                self.resync_update(key, value).await
239            }
240            Err(err) => Err(StoreError::NATSError(err.to_string())),
241        }
242    }
243
244    /// We have the wrong revision for a key. Fetch it's entry to get the correct revision,
245    /// and try the update again.
246    async fn resync_update(
247        &self,
248        key: &Key,
249        value: bytes::Bytes,
250    ) -> Result<StoreOutcome, StoreError> {
251        match self.nats_store.entry(key).await {
252            Ok(Some(entry)) => {
253                // Re-try the update with new version number
254                let next_rev = entry.revision + 1;
255                match self.nats_store.update(key, value, next_rev).await {
256                    Ok(correct_revision) => Ok(StoreOutcome::Created(correct_revision)),
257                    Err(err) => Err(StoreError::NATSError(format!(
258                        "Error during update of key {key} after resync: {err}"
259                    ))),
260                }
261            }
262            Ok(None) => {
263                tracing::warn!(%key, "Entry does not exist during resync, creating.");
264                self.create(key, value).await
265            }
266            Err(err) => {
267                tracing::error!(%key, %err, "Failed fetching entry during resync");
268                Err(StoreError::NATSError(err.to_string()))
269            }
270        }
271    }
272}
273
274/// async-nats won't let us use a multi-part subject to create KV buckets (and probably many other
275/// things).
276fn single_name(namespace: &str, name: &Slug) -> String {
277    format!("{namespace}_{name}")
278}