1use std::collections::HashMap;
20use std::fmt;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use dynamo_runtime::slug::Slug;
27use dynamo_runtime::CancellationToken;
28use futures::StreamExt;
29use serde::{Deserialize, Serialize};
30
31mod mem;
32pub use mem::MemoryStorage;
33mod nats;
34pub use nats::NATSStorage;
35
36#[async_trait]
37pub trait KeyValueStore: Send + Sync {
38 async fn get_or_create_bucket(
39 &self,
40 bucket_name: &str,
41 ttl: Option<Duration>,
43 ) -> Result<Box<dyn KeyValueBucket>, StorageError>;
44
45 async fn get_bucket(
46 &self,
47 bucket_name: &str,
48 ) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError>;
49}
50
51pub struct KeyValueStoreManager(Box<dyn KeyValueStore>);
52
53impl KeyValueStoreManager {
54 pub fn new(s: Box<dyn KeyValueStore>) -> KeyValueStoreManager {
55 KeyValueStoreManager(s)
56 }
57
58 pub async fn load<T: for<'a> Deserialize<'a>>(
59 &self,
60 bucket: &str,
61 key: &Slug,
62 ) -> Result<Option<T>, StorageError> {
63 let Some(bucket) = self.0.get_bucket(bucket).await? else {
64 return Ok(None);
66 };
67 match bucket.get(key.as_ref()).await {
68 Ok(Some(card_bytes)) => {
69 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
70 Ok(Some(card))
71 }
72 Ok(None) => Ok(None),
73 Err(err) => {
74 Err(StorageError::NATSError(err.to_string()))
76 }
77 }
78 }
79
80 pub fn watch<T: for<'a> Deserialize<'a> + Send + 'static>(
84 self: Arc<Self>,
85 bucket_name: &str,
86 bucket_ttl: Option<Duration>,
87 ) -> (
88 tokio::task::JoinHandle<Result<(), StorageError>>,
89 tokio::sync::mpsc::UnboundedReceiver<T>,
90 ) {
91 let bucket_name = bucket_name.to_string();
92 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
93 let watch_task = tokio::spawn(async move {
94 let bucket = self
96 .0
97 .get_or_create_bucket(&bucket_name, bucket_ttl)
98 .await?;
99 let mut stream = bucket.watch().await?;
100
101 for (_, card_bytes) in bucket.entries().await? {
103 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
104 let _ = tx.send(card);
105 }
106
107 while let Some(card_bytes) = stream.next().await {
109 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
110 let _ = tx.send(card);
111 }
112
113 Ok::<(), StorageError>(())
114 });
115 (watch_task, rx)
116 }
117
118 pub async fn publish<T: Serialize + Versioned + Send + Sync>(
119 &self,
120 bucket_name: &str,
121 bucket_ttl: Option<Duration>,
122 key: &str,
123 obj: &mut T,
124 ) -> anyhow::Result<StorageOutcome> {
125 let obj_json = serde_json::to_string(obj)?;
126 let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
127
128 let outcome = bucket
129 .insert(key.to_string(), obj_json, obj.revision())
130 .await?;
131
132 match outcome {
133 StorageOutcome::Created(revision) | StorageOutcome::Exists(revision) => {
134 obj.set_revision(revision);
135 }
136 }
137 Ok(outcome)
138 }
139
140 pub fn publish_until_cancelled<T: Serialize + Versioned + Send + Sync + 'static>(
144 self: Arc<Self>,
145 cancel_token: CancellationToken,
146 bucket_name: String,
147 bucket_ttl: Option<Duration>,
148 publish_interval: Duration,
149 key: String,
150 mut obj: T,
151 ) {
152 tokio::spawn(async move {
153 loop {
154 let publish_result = self
155 .clone()
156 .publish(&bucket_name, bucket_ttl, &key, &mut obj)
157 .await;
158 if let Err(err) = publish_result {
159 tracing::error!(
160 model = key,
161 error = %err,
162 "Failed publishing to KV storage. Ending publish task.",
163 );
164 }
165 tokio::select! {
166 _ = tokio::time::sleep(publish_interval) => {},
167 _ = cancel_token.cancelled() => {
168 tracing::trace!(model_service_name = key, "Publish loop cancelled");
169 match self.0.get_bucket(&bucket_name).await {
170 Ok(Some(bucket)) => {
171 if let Err(err) = bucket.delete(&key).await {
172 tracing::trace!(bucket_name, key, %err, "Error delete published card from NATS on publish stop");
174 }
175
176 tracing::trace!(bucket_name, key, "Deleted Model Deployment Card from NATS");
177 }
178 Ok(None) => {
179 tracing::trace!(bucket_name, key, "Bucket does not exist");
180 }
181 Err(err) => {
182 tracing::trace!(bucket_name, %err, "publish_until_cancelled shutdown error");
183 }
184 }
185 break;
187 }
188 }
189 }
190 });
191 }
192}
193
194#[async_trait]
197pub trait KeyValueBucket: Send {
198 async fn insert(
201 &self,
202 key: String,
203 value: String,
204 revision: u64,
205 ) -> Result<StorageOutcome, StorageError>;
206
207 async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError>;
209
210 async fn delete(&self, key: &str) -> Result<(), StorageError>;
212
213 async fn watch(
217 &self,
218 ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>;
219
220 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError>;
221}
222
223#[derive(Debug, Copy, Clone, Eq, PartialEq)]
224pub enum StorageOutcome {
225 Created(u64),
228 Exists(u64),
230}
231impl fmt::Display for StorageOutcome {
232 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
233 match self {
234 StorageOutcome::Created(revision) => write!(f, "Created at {revision}"),
235 StorageOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
236 }
237 }
238}
239
240#[derive(thiserror::Error, Debug)]
241pub enum StorageError {
242 #[error("Could not find bucket '{0}'")]
243 MissingBucket(String),
244
245 #[error("Could not find key '{0}'")]
246 MissingKey(String),
247
248 #[error("Internal storage error: '{0}'")]
249 ProviderError(String),
250
251 #[error("Internal NATS error: {0}")]
252 NATSError(String),
253
254 #[error("Internal etcd error: {0}")]
255 EtcdError(String),
256
257 #[error("Key Value Error: {0} for bucket '{1}")]
258 KeyValueError(String, String),
259
260 #[error("Error decoding bytes: {0}")]
261 JSONDecodeError(#[from] serde_json::error::Error),
262
263 #[error("Race condition, retry the call")]
264 Retry,
265}
266
267pub trait Versioned {
270 fn revision(&self) -> u64;
271 fn set_revision(&mut self, r: u64);
272}
273
274#[cfg(test)]
275mod tests {
276 use std::sync::Arc;
277
278 use super::*;
279 use futures::{pin_mut, StreamExt};
280
281 const BUCKET_NAME: &str = "mdc";
282
283 #[allow(dead_code)]
286 pub struct TappableStream {
287 tx: tokio::sync::broadcast::Sender<bytes::Bytes>,
288 }
289
290 #[allow(dead_code)]
291 impl TappableStream {
292 async fn new<T>(stream: T, max_size: usize) -> Self
293 where
294 T: futures::Stream<Item = bytes::Bytes> + Send + 'static,
295 {
296 let (tx, _) = tokio::sync::broadcast::channel(max_size);
297 let tx2 = tx.clone();
298 tokio::spawn(async move {
299 pin_mut!(stream);
300 while let Some(x) = stream.next().await {
301 let _ = tx2.send(x);
302 }
303 });
304 TappableStream { tx }
305 }
306
307 fn subscribe(&self) -> tokio::sync::broadcast::Receiver<bytes::Bytes> {
308 self.tx.subscribe()
309 }
310 }
311
312 fn init() {
313 dynamo_runtime::logging::init();
314 }
315
316 #[tokio::test]
317 async fn test_memory_storage() -> anyhow::Result<()> {
318 init();
319
320 let s = Arc::new(MemoryStorage::new());
321 let s2 = Arc::clone(&s);
322
323 let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
324 let res = bucket
325 .insert("test1".to_string(), "value1".to_string(), 0)
326 .await?;
327 assert_eq!(res, StorageOutcome::Created(0));
328
329 let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
330 let ingress = tokio::spawn(async move {
331 let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
332 let mut stream = b2.watch().await?;
333
334 let v = stream.next().await.unwrap();
336 assert_eq!(v, "value1".as_bytes());
337
338 got_first_tx.send(()).unwrap();
339
340 let v = stream.next().await.unwrap();
342 assert_eq!(v, "value2".as_bytes());
343 let v = stream.next().await.unwrap();
344 assert_eq!(v, "value3".as_bytes());
345
346 Ok::<_, StorageError>(())
347 });
348
349 got_first_rx.await?;
353
354 let res = bucket
355 .insert("test2".to_string(), "value2".to_string(), 0)
356 .await?;
357 assert_eq!(res, StorageOutcome::Created(0));
358
359 let res = bucket
361 .insert("test2".to_string(), "value2".to_string(), 0)
362 .await?;
363 assert_eq!(res, StorageOutcome::Exists(0));
364
365 let res = bucket
367 .insert("test2".to_string(), "value2".to_string(), 1)
368 .await?;
369 assert_eq!(res, StorageOutcome::Created(1));
370
371 let res = bucket
372 .insert("test3".to_string(), "value3".to_string(), 0)
373 .await?;
374 assert_eq!(res, StorageOutcome::Created(0));
375
376 let _ = ingress.await?;
378
379 Ok(())
380 }
381
382 #[tokio::test]
383 async fn test_broadcast_stream() -> anyhow::Result<()> {
384 init();
385
386 let s: &'static _ = Box::leak(Box::new(MemoryStorage::new()));
387 let bucket: &'static _ =
388 Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
389
390 let res = bucket
391 .insert("test1".to_string(), "value1".to_string(), 0)
392 .await?;
393 assert_eq!(res, StorageOutcome::Created(0));
394
395 let stream = bucket.watch().await?;
396 let tap = TappableStream::new(stream, 10).await;
397
398 let mut rx1 = tap.subscribe();
399 let mut rx2 = tap.subscribe();
400
401 let handle1 = tokio::spawn(async move {
402 let b = rx1.recv().await.unwrap();
403 assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
404 });
405 let handle2 = tokio::spawn(async move {
406 let b = rx2.recv().await.unwrap();
407 assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
408 });
409
410 bucket
411 .insert("test1".to_string(), "GK".to_string(), 1)
412 .await?;
413
414 let _ = futures::join!(handle1, handle2);
415 Ok(())
416 }
417}