1use std::pin::Pin;
8use std::str::FromStr;
9use std::sync::Arc;
10use std::time::Duration;
11use std::{collections::HashMap, path::PathBuf};
12use std::{env, fmt};
13
14use crate::CancellationToken;
15use crate::slug::Slug;
16use crate::transports::etcd as etcd_transport;
17use async_trait::async_trait;
18use futures::StreamExt;
19use serde::{Deserialize, Serialize};
20
21mod mem;
22pub use mem::MemoryStore;
23mod nats;
24pub use nats::NATSStore;
25mod etcd;
26pub use etcd::EtcdStore;
27mod file;
28pub use file::FileStore;
29
30const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
31
32#[derive(Debug, Clone, PartialEq)]
37pub struct Key(String);
38
39impl Key {
40 pub fn new(s: &str) -> Key {
41 Key(Slug::slugify(s).to_string())
42 }
43
44 pub fn from_raw(s: String) -> Key {
46 Key(s)
47 }
48}
49
50impl From<&str> for Key {
51 fn from(s: &str) -> Key {
52 Key::new(s)
53 }
54}
55
56impl fmt::Display for Key {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl AsRef<str> for Key {
63 fn as_ref(&self) -> &str {
64 &self.0
65 }
66}
67
68impl From<&Key> for String {
69 fn from(k: &Key) -> String {
70 k.0.clone()
71 }
72}
73
74#[derive(Debug, Clone, PartialEq)]
75pub struct KeyValue {
76 key: String,
77 value: bytes::Bytes,
78}
79
80impl KeyValue {
81 pub fn new(key: String, value: bytes::Bytes) -> Self {
82 KeyValue { key, value }
83 }
84
85 pub fn key(&self) -> String {
86 self.key.clone()
87 }
88
89 pub fn key_str(&self) -> &str {
90 &self.key
91 }
92
93 pub fn value(&self) -> &[u8] {
94 &self.value
95 }
96
97 pub fn value_str(&self) -> anyhow::Result<&str> {
98 std::str::from_utf8(self.value()).map_err(From::from)
99 }
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub enum WatchEvent {
104 Put(KeyValue),
105 Delete(Key),
106}
107
108#[async_trait]
109pub trait KeyValueStore: Send + Sync {
110 type Bucket: KeyValueBucket + Send + Sync + 'static;
111
112 async fn get_or_create_bucket(
113 &self,
114 bucket_name: &str,
115 ttl: Option<Duration>,
117 ) -> Result<Self::Bucket, StoreError>;
118
119 async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
120
121 fn connection_id(&self) -> u64;
122
123 fn shutdown(&self);
124}
125
126#[derive(Clone, Debug, Default)]
127pub enum KeyValueStoreSelect {
128 Etcd(Box<etcd_transport::ClientOptions>),
130 File(PathBuf),
131 #[default]
132 Memory,
133 }
135
136impl fmt::Display for KeyValueStoreSelect {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 match self {
139 KeyValueStoreSelect::Etcd(opts) => {
140 let urls = opts.etcd_url.join(",");
141 write!(f, "Etcd({urls})")
142 }
143 KeyValueStoreSelect::File(path) => write!(f, "File({})", path.display()),
144 KeyValueStoreSelect::Memory => write!(f, "Memory"),
145 }
146 }
147}
148
149impl FromStr for KeyValueStoreSelect {
150 type Err = anyhow::Error;
151
152 fn from_str(s: &str) -> anyhow::Result<KeyValueStoreSelect> {
153 match s {
154 "etcd" => Ok(Self::Etcd(Box::default())),
155 "file" => {
156 let root = env::var("DYN_FILE_KV")
157 .map(PathBuf::from)
158 .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
159 Ok(Self::File(root))
160 }
161 "mem" => Ok(Self::Memory),
162 x => anyhow::bail!("Unknown key-value store type '{x}'"),
163 }
164 }
165}
166
167impl TryFrom<String> for KeyValueStoreSelect {
168 type Error = anyhow::Error;
169
170 fn try_from(s: String) -> anyhow::Result<KeyValueStoreSelect> {
171 s.parse()
172 }
173}
174
175#[allow(clippy::large_enum_variant)]
176pub enum KeyValueStoreEnum {
177 Memory(MemoryStore),
178 Nats(NATSStore),
179 Etcd(EtcdStore),
180 File(FileStore),
181}
182
183impl KeyValueStoreEnum {
184 async fn get_or_create_bucket(
185 &self,
186 bucket_name: &str,
187 ttl: Option<Duration>,
189 ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
190 use KeyValueStoreEnum::*;
191 Ok(match self {
192 Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
193 Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
194 Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
195 File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
196 })
197 }
198
199 async fn get_bucket(
200 &self,
201 bucket_name: &str,
202 ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
203 use KeyValueStoreEnum::*;
204 let maybe_bucket: Option<Box<dyn KeyValueBucket>> = match self {
205 Memory(x) => x
206 .get_bucket(bucket_name)
207 .await?
208 .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
209 Nats(x) => x
210 .get_bucket(bucket_name)
211 .await?
212 .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
213 Etcd(x) => x
214 .get_bucket(bucket_name)
215 .await?
216 .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
217 File(x) => x
218 .get_bucket(bucket_name)
219 .await?
220 .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
221 };
222 Ok(maybe_bucket)
223 }
224
225 fn connection_id(&self) -> u64 {
226 use KeyValueStoreEnum::*;
227 match self {
228 Memory(x) => x.connection_id(),
229 Etcd(x) => x.connection_id(),
230 Nats(x) => x.connection_id(),
231 File(x) => x.connection_id(),
232 }
233 }
234
235 fn shutdown(&self) {
236 use KeyValueStoreEnum::*;
237 match self {
238 Memory(x) => x.shutdown(),
239 Etcd(x) => x.shutdown(),
240 Nats(x) => x.shutdown(),
241 File(x) => x.shutdown(),
242 }
243 }
244}
245
246#[derive(Clone)]
247pub struct KeyValueStoreManager(pub Arc<KeyValueStoreEnum>);
248
249impl Default for KeyValueStoreManager {
250 fn default() -> Self {
251 KeyValueStoreManager::memory()
252 }
253}
254
255impl KeyValueStoreManager {
256 pub fn memory() -> Self {
258 Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
259 }
260
261 pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
262 Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
263 }
264
265 pub fn file<P: Into<PathBuf>>(root: P) -> Self {
266 Self::new(KeyValueStoreEnum::File(FileStore::new(root)))
267 }
268
269 fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager {
270 KeyValueStoreManager(Arc::new(s))
271 }
272
273 pub async fn get_or_create_bucket(
274 &self,
275 bucket_name: &str,
276 ttl: Option<Duration>,
278 ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
279 self.0.get_or_create_bucket(bucket_name, ttl).await
280 }
281
282 pub async fn get_bucket(
283 &self,
284 bucket_name: &str,
285 ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
286 self.0.get_bucket(bucket_name).await
287 }
288
289 pub fn connection_id(&self) -> u64 {
290 self.0.connection_id()
291 }
292
293 pub async fn load<T: for<'a> Deserialize<'a>>(
294 &self,
295 bucket: &str,
296 key: &Key,
297 ) -> Result<Option<T>, StoreError> {
298 let Some(bucket) = self.0.get_bucket(bucket).await? else {
299 return Ok(None);
301 };
302 Ok(match bucket.get(key).await? {
303 Some(card_bytes) => {
304 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
305 Some(card)
306 }
307 None => None,
308 })
309 }
310
311 pub fn watch(
315 self: Arc<Self>,
316 bucket_name: &str,
317 bucket_ttl: Option<Duration>,
318 cancel_token: CancellationToken,
319 ) -> (
320 tokio::task::JoinHandle<Result<(), StoreError>>,
321 tokio::sync::mpsc::Receiver<WatchEvent>,
322 ) {
323 let bucket_name = bucket_name.to_string();
324 let (tx, rx) = tokio::sync::mpsc::channel(128);
325 let watch_task = tokio::spawn(async move {
326 let bucket = self
328 .0
329 .get_or_create_bucket(&bucket_name, bucket_ttl)
330 .await?;
331 let mut stream = bucket.watch().await?;
332
333 for (key, bytes) in bucket.entries().await? {
335 if let Err(err) = tx
336 .send_timeout(
337 WatchEvent::Put(KeyValue::new(key, bytes)),
338 WATCH_SEND_TIMEOUT,
339 )
340 .await
341 {
342 tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
343 }
344 }
345
346 loop {
348 let event = tokio::select! {
349 _ = cancel_token.cancelled() => break,
350 result = stream.next() => match result {
351 Some(event) => event,
352 None => break,
353 }
354 };
355 if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
356 tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
357 }
358 }
359
360 Ok::<(), StoreError>(())
361 });
362 (watch_task, rx)
363 }
364
365 pub async fn publish<T: Serialize + Versioned + Send + Sync>(
366 &self,
367 bucket_name: &str,
368 bucket_ttl: Option<Duration>,
369 key: &Key,
370 obj: &mut T,
371 ) -> anyhow::Result<StoreOutcome> {
372 let obj_json = serde_json::to_vec(obj)?;
373 let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
374
375 let outcome = bucket.insert(key, obj_json.into(), obj.revision()).await?;
376
377 match outcome {
378 StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
379 obj.set_revision(revision);
380 }
381 }
382 Ok(outcome)
383 }
384
385 pub fn shutdown(&self) {
388 self.0.shutdown()
389 }
390}
391
392#[async_trait]
394pub trait KeyValueBucket: Send + Sync {
395 async fn insert(
398 &self,
399 key: &Key,
400 value: bytes::Bytes,
401 revision: u64,
402 ) -> Result<StoreOutcome, StoreError>;
403
404 async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
406
407 async fn delete(&self, key: &Key) -> Result<(), StoreError>;
409
410 async fn watch(
414 &self,
415 ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
416
417 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
418}
419
420#[derive(Debug, Copy, Clone, Eq, PartialEq)]
421pub enum StoreOutcome {
422 Created(u64),
425 Exists(u64),
427}
428impl fmt::Display for StoreOutcome {
429 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
430 match self {
431 StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
432 StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
433 }
434 }
435}
436
437#[derive(thiserror::Error, Debug)]
438pub enum StoreError {
439 #[error("Could not find bucket '{0}'")]
440 MissingBucket(String),
441
442 #[error("Could not find key '{0}'")]
443 MissingKey(String),
444
445 #[error("Internal storage error: '{0}'")]
446 ProviderError(String),
447
448 #[error("Internal NATS error: {0}")]
449 NATSError(String),
450
451 #[error("Internal etcd error: {0}")]
452 EtcdError(String),
453
454 #[error("Internal filesystem error: {0}")]
455 FilesystemError(String),
456
457 #[error("Key Value Error: {0} for bucket '{1}'")]
458 KeyValueError(String, String),
459
460 #[error("Error decoding bytes: {0}")]
461 JSONDecodeError(#[from] serde_json::error::Error),
462
463 #[error("Race condition, retry the call")]
464 Retry,
465}
466
467pub trait Versioned {
470 fn revision(&self) -> u64;
471 fn set_revision(&mut self, r: u64);
472}
473
474#[cfg(test)]
475mod tests {
476 use std::sync::Arc;
477
478 use super::*;
479 use futures::{StreamExt, pin_mut};
480
481 const BUCKET_NAME: &str = "v1/mdc";
482
483 #[allow(dead_code)]
486 pub struct TappableStream {
487 tx: tokio::sync::broadcast::Sender<WatchEvent>,
488 }
489
490 #[allow(dead_code)]
491 impl TappableStream {
492 async fn new<T>(stream: T, max_size: usize) -> Self
493 where
494 T: futures::Stream<Item = WatchEvent> + Send + 'static,
495 {
496 let (tx, _) = tokio::sync::broadcast::channel(max_size);
497 let tx2 = tx.clone();
498 tokio::spawn(async move {
499 pin_mut!(stream);
500 while let Some(x) = stream.next().await {
501 let _ = tx2.send(x);
502 }
503 });
504 TappableStream { tx }
505 }
506
507 fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
508 self.tx.subscribe()
509 }
510 }
511
512 fn init() {
513 crate::logging::init();
514 }
515
516 #[tokio::test]
517 async fn test_memory_storage() -> anyhow::Result<()> {
518 init();
519
520 let s = Arc::new(MemoryStore::new());
521 let s2 = Arc::clone(&s);
522
523 let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
524 let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
525 assert_eq!(res, StoreOutcome::Created(0));
526
527 let mut expected = Vec::with_capacity(3);
528 for i in 1..=3 {
529 let item = WatchEvent::Put(KeyValue::new(
530 format!("test{i}"),
531 format!("value{i}").into(),
532 ));
533 expected.push(item);
534 }
535
536 let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
537 let ingress = tokio::spawn(async move {
538 let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
539 let mut stream = b2.watch().await?;
540
541 let v = stream.next().await.unwrap();
543 assert_eq!(v, expected[0]);
544
545 got_first_tx.send(()).unwrap();
546
547 let v = stream.next().await.unwrap();
549 assert_eq!(v, expected[1]);
550
551 let v = stream.next().await.unwrap();
552 assert_eq!(v, expected[2]);
553
554 Ok::<_, StoreError>(())
555 });
556
557 got_first_rx.await?;
561
562 let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
563 assert_eq!(res, StoreOutcome::Created(0));
564
565 let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
567 assert_eq!(res, StoreOutcome::Exists(0));
568
569 let res = bucket.insert(&"test2".into(), "value2".into(), 1).await?;
571 assert_eq!(res, StoreOutcome::Created(1));
572
573 let res = bucket.insert(&"test3".into(), "value3".into(), 0).await?;
574 assert_eq!(res, StoreOutcome::Created(0));
575
576 let _ = ingress.await?;
578
579 Ok(())
580 }
581
582 #[tokio::test]
583 async fn test_broadcast_stream() -> anyhow::Result<()> {
584 init();
585
586 let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
587 let bucket: &'static _ =
588 Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
589
590 let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
591 assert_eq!(res, StoreOutcome::Created(0));
592
593 let stream = bucket.watch().await?;
594 let tap = TappableStream::new(stream, 10).await;
595
596 let mut rx1 = tap.subscribe();
597 let mut rx2 = tap.subscribe();
598
599 let item = WatchEvent::Put(KeyValue::new("test1".to_string(), "GK".into()));
600 let item_clone = item.clone();
601 let handle1 = tokio::spawn(async move {
602 let b = rx1.recv().await.unwrap();
603 assert_eq!(b, item_clone);
604 });
605 let handle2 = tokio::spawn(async move {
606 let b = rx2.recv().await.unwrap();
607 assert_eq!(b, item);
608 });
609
610 bucket.insert(&"test1".into(), "GK".into(), 1).await?;
611
612 let _ = futures::join!(handle1, handle2);
613 Ok(())
614 }
615}