dynamo_llm/key_value_store/
nats.rs1use std::{collections::HashMap, pin::Pin, time::Duration};
17
18use async_trait::async_trait;
19use dynamo_runtime::{protocols::Endpoint, slug::Slug, transports::nats::Client};
20use futures::StreamExt;
21
22use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
23
24#[derive(Clone)]
25pub struct NATSStorage {
26 client: Client,
27 endpoint: Endpoint,
28}
29
30pub struct NATSBucket {
31 nats_store: async_nats::jetstream::kv::Store,
32}
33
34#[async_trait]
35impl KeyValueStore for NATSStorage {
36 async fn get_or_create_bucket(
37 &self,
38 bucket_name: &str,
39 ttl: Option<Duration>,
40 ) -> Result<Box<dyn KeyValueBucket>, StorageError> {
41 let name = Slug::slugify(bucket_name);
42 let nats_store = self
43 .get_or_create_key_value(&self.endpoint.namespace, &name, ttl)
44 .await?;
45 Ok(Box::new(NATSBucket { nats_store }))
46 }
47
48 async fn get_bucket(
49 &self,
50 bucket_name: &str,
51 ) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
52 let name = Slug::slugify(bucket_name);
53 match self.get_key_value(&self.endpoint.namespace, &name).await? {
54 Some(nats_store) => Ok(Some(Box::new(NATSBucket { nats_store }))),
55 None => Ok(None),
56 }
57 }
58}
59
60impl NATSStorage {
61 pub fn new(client: Client, endpoint: Endpoint) -> Self {
62 NATSStorage { client, endpoint }
63 }
64
65 async fn get_or_create_key_value(
70 &self,
71 namespace: &str,
72 bucket_name: &Slug,
73 ttl: Option<Duration>,
75 ) -> Result<async_nats::jetstream::kv::Store, StorageError> {
76 if let Ok(Some(kv)) = self.get_key_value(namespace, bucket_name).await {
77 return Ok(kv);
78 }
79
80 let bucket_name = single_name(namespace, bucket_name);
83 let js = self.client.jetstream();
84 let create_result = js
85 .create_key_value(
86 async_nats::jetstream::kv::Config {
88 bucket: bucket_name.clone(),
89 max_age: ttl.unwrap_or_default(),
90 ..Default::default()
91 },
92 )
93 .await;
94 tracing::debug!("Created bucket {bucket_name}");
95 create_result.map_err(|err| StorageError::KeyValueError(err.to_string(), bucket_name))
96 }
97
98 async fn get_key_value(
99 &self,
100 namespace: &str,
101 bucket_name: &Slug,
102 ) -> Result<Option<async_nats::jetstream::kv::Store>, StorageError> {
103 let bucket_name = single_name(namespace, bucket_name);
104 let js = self.client.jetstream();
105
106 use async_nats::jetstream::context::KeyValueErrorKind;
107 match js.get_key_value(&bucket_name).await {
108 Ok(store) => Ok(Some(store)),
109 Err(err) if err.kind() == KeyValueErrorKind::GetBucket => {
110 Ok(None)
112 }
113 Err(err) => Err(StorageError::KeyValueError(err.to_string(), bucket_name)),
114 }
115 }
116}
117
118#[async_trait]
119impl KeyValueBucket for NATSBucket {
120 async fn insert(
121 &self,
122 key: String,
123 value: String,
124 revision: u64,
125 ) -> Result<StorageOutcome, StorageError> {
126 if revision == 0 {
127 self.create(key, value).await
128 } else {
129 self.update(key, value, revision).await
130 }
131 }
132
133 async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
134 self.nats_store
135 .get(key)
136 .await
137 .map_err(|e| StorageError::NATSError(e.to_string()))
138 }
139
140 async fn delete(&self, key: &str) -> Result<(), StorageError> {
141 self.nats_store
142 .delete(key)
143 .await
144 .map_err(|e| StorageError::NATSError(e.to_string()))
145 }
146
147 async fn watch(
148 &self,
149 ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
150 {
151 let watch_stream = self
152 .nats_store
153 .watch_all()
154 .await
155 .map_err(|e| StorageError::NATSError(e.to_string()))?;
156 Ok(Box::pin(
158 watch_stream.filter_map(
159 |maybe_entry: Result<
160 async_nats::jetstream::kv::Entry,
161 async_nats::error::Error<_>,
162 >| async move {
163 match maybe_entry {
164 Ok(entry) => Some(entry.value),
165 Err(e) => {
166 tracing::error!(error=%e, "watch fatal err");
167 None
168 }
169 }
170 },
171 ),
172 ))
173 }
174
175 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
176 let mut key_stream = self
177 .nats_store
178 .keys()
179 .await
180 .map_err(|e| StorageError::NATSError(e.to_string()))?;
181 let mut out = HashMap::new();
182 while let Some(Ok(key)) = key_stream.next().await {
183 if let Ok(Some(entry)) = self.nats_store.entry(&key).await {
184 out.insert(key, entry.value);
185 }
186 }
187 Ok(out)
188 }
189}
190
191impl NATSBucket {
192 async fn create(&self, key: String, value: String) -> Result<StorageOutcome, StorageError> {
193 match self.nats_store.create(&key, value.into()).await {
194 Ok(revision) => Ok(StorageOutcome::Created(revision)),
195 Err(err) if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists => {
196 match self.nats_store.entry(&key).await {
198 Ok(Some(entry)) => Ok(StorageOutcome::Exists(entry.revision)),
199 Ok(None) => {
200 tracing::error!(
201 key,
202 "Race condition, key deleted between create and fetch. Retry."
203 );
204 Err(StorageError::Retry)
205 }
206 Err(err) => Err(StorageError::NATSError(err.to_string())),
207 }
208 }
209 Err(err) => Err(StorageError::NATSError(err.to_string())),
210 }
211 }
212
213 async fn update(
214 &self,
215 key: String,
216 value: String,
217 revision: u64,
218 ) -> Result<StorageOutcome, StorageError> {
219 match self
220 .nats_store
221 .update(key.clone(), value.clone().into(), revision)
222 .await
223 {
224 Ok(revision) => Ok(StorageOutcome::Created(revision)),
225 Err(err)
226 if err.kind() == async_nats::jetstream::kv::UpdateErrorKind::WrongLastRevision =>
227 {
228 tracing::warn!(revision, key, "Update WrongLastRevision, resync");
229 self.resync_update(key, value).await
230 }
231 Err(err) => Err(StorageError::NATSError(err.to_string())),
232 }
233 }
234
235 async fn resync_update(
238 &self,
239 key: String,
240 value: String,
241 ) -> Result<StorageOutcome, StorageError> {
242 match self.nats_store.entry(&key).await {
243 Ok(Some(entry)) => {
244 let next_rev = entry.revision + 1;
246 match self
247 .nats_store
248 .update(key.clone(), value.into(), next_rev)
249 .await
250 {
251 Ok(correct_revision) => Ok(StorageOutcome::Created(correct_revision)),
252 Err(err) => Err(StorageError::NATSError(format!(
253 "Error during update of key {key} after resync: {err}"
254 ))),
255 }
256 }
257 Ok(None) => {
258 tracing::warn!(key, "Entry does not exist during resync, creating.");
259 self.create(key, value).await
260 }
261 Err(err) => {
262 tracing::error!(key, %err, "Failed fetching entry during resync");
263 Err(StorageError::NATSError(err.to_string()))
264 }
265 }
266 }
267}
268
269fn single_name(namespace: &str, name: &Slug) -> String {
272 format!("{namespace}_{name}")
273}