1use std::borrow::Cow;
8use std::pin::Pin;
9use std::str::FromStr;
10use std::sync::Arc;
11use std::time::Duration;
12use std::{collections::HashMap, path::PathBuf};
13use std::{env, fmt};
14
15use crate::CancellationToken;
16use crate::transports::etcd as etcd_transport;
17use async_trait::async_trait;
18use futures::StreamExt;
19use percent_encoding::{NON_ALPHANUMERIC, percent_decode_str, percent_encode};
20use serde::{Deserialize, Serialize};
21
22mod mem;
23pub use mem::MemoryStore;
24mod nats;
25pub use nats::NATSStore;
26mod etcd;
27pub use etcd::EtcdStore;
28mod file;
29pub use file::FileStore;
30
31const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(1000);
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
36pub struct Key(String);
37
38impl Key {
39 pub fn new(s: String) -> Key {
40 Key(s)
41 }
42
43 pub fn from_url_safe(s: &str) -> Key {
46 Key(percent_decode_str(s).decode_utf8_lossy().to_string())
47 }
48
49 pub fn url_safe(&self) -> Cow<'_, str> {
52 percent_encode(self.0.as_bytes(), NON_ALPHANUMERIC).into()
53 }
54}
55
56impl From<&str> for Key {
57 fn from(s: &str) -> Key {
58 Key::new(s.to_string())
59 }
60}
61
62impl fmt::Display for Key {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(f, "{}", self.0)
65 }
66}
67
68impl AsRef<str> for Key {
69 fn as_ref(&self) -> &str {
70 &self.0
71 }
72}
73
74impl From<&Key> for String {
75 fn from(k: &Key) -> String {
76 k.0.clone()
77 }
78}
79
80#[derive(Debug, Clone, PartialEq)]
81pub struct KeyValue {
82 key: Key,
83 value: bytes::Bytes,
84}
85
86impl KeyValue {
87 pub fn new(key: Key, value: bytes::Bytes) -> Self {
88 KeyValue { key, value }
89 }
90
91 pub fn key(&self) -> String {
92 self.key.clone().to_string()
93 }
94
95 pub fn key_str(&self) -> &str {
96 self.key.as_ref()
97 }
98
99 pub fn value(&self) -> &[u8] {
100 &self.value
101 }
102
103 pub fn value_str(&self) -> anyhow::Result<&str> {
104 std::str::from_utf8(self.value()).map_err(From::from)
105 }
106}
107
108#[derive(Debug, Clone, PartialEq)]
109pub enum WatchEvent {
110 Put(KeyValue),
111 Delete(Key),
112}
113
114#[async_trait]
115pub trait Store: Send + Sync {
116 type Bucket: Bucket + Send + Sync + 'static;
117
118 async fn get_or_create_bucket(
119 &self,
120 bucket_name: &str,
121 ttl: Option<Duration>,
123 ) -> Result<Self::Bucket, StoreError>;
124
125 async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
126
127 fn connection_id(&self) -> u64;
128
129 fn shutdown(&self);
130}
131
132#[derive(Clone, Debug, Default)]
133pub enum Selector {
134 Etcd(Box<etcd_transport::ClientOptions>),
136 File(PathBuf),
137 #[default]
138 Memory,
139 }
141
142impl fmt::Display for Selector {
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 match self {
145 Selector::Etcd(opts) => {
146 let urls = opts.etcd_url.join(",");
147 write!(f, "Etcd({urls})")
148 }
149 Selector::File(path) => write!(f, "File({})", path.display()),
150 Selector::Memory => write!(f, "Memory"),
151 }
152 }
153}
154
155impl FromStr for Selector {
156 type Err = anyhow::Error;
157
158 fn from_str(s: &str) -> anyhow::Result<Selector> {
159 match s {
160 "etcd" => Ok(Self::Etcd(Box::default())),
161 "file" => {
162 let root = env::var("DYN_FILE_KV")
163 .map(PathBuf::from)
164 .unwrap_or_else(|_| env::temp_dir().join("dynamo_store_kv"));
165 Ok(Self::File(root))
166 }
167 "mem" => Ok(Self::Memory),
168 x => anyhow::bail!("Unknown key-value store type '{x}'"),
169 }
170 }
171}
172
173impl TryFrom<String> for Selector {
174 type Error = anyhow::Error;
175
176 fn try_from(s: String) -> anyhow::Result<Selector> {
177 s.parse()
178 }
179}
180
181#[allow(clippy::large_enum_variant)]
182enum KeyValueStoreEnum {
183 Memory(MemoryStore),
184 Nats(NATSStore),
185 Etcd(EtcdStore),
186 File(FileStore),
187}
188
189impl KeyValueStoreEnum {
190 async fn get_or_create_bucket(
191 &self,
192 bucket_name: &str,
193 ttl: Option<Duration>,
195 ) -> Result<Box<dyn Bucket>, StoreError> {
196 use KeyValueStoreEnum::*;
197 Ok(match self {
198 Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
199 Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
200 Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
201 File(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
202 })
203 }
204
205 async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Box<dyn Bucket>>, StoreError> {
206 use KeyValueStoreEnum::*;
207 let maybe_bucket: Option<Box<dyn Bucket>> = match self {
208 Memory(x) => x
209 .get_bucket(bucket_name)
210 .await?
211 .map(|b| Box::new(b) as Box<dyn Bucket>),
212 Nats(x) => x
213 .get_bucket(bucket_name)
214 .await?
215 .map(|b| Box::new(b) as Box<dyn Bucket>),
216 Etcd(x) => x
217 .get_bucket(bucket_name)
218 .await?
219 .map(|b| Box::new(b) as Box<dyn Bucket>),
220 File(x) => x
221 .get_bucket(bucket_name)
222 .await?
223 .map(|b| Box::new(b) as Box<dyn Bucket>),
224 };
225 Ok(maybe_bucket)
226 }
227
228 fn connection_id(&self) -> u64 {
229 use KeyValueStoreEnum::*;
230 match self {
231 Memory(x) => x.connection_id(),
232 Etcd(x) => x.connection_id(),
233 Nats(x) => x.connection_id(),
234 File(x) => x.connection_id(),
235 }
236 }
237
238 fn shutdown(&self) {
239 use KeyValueStoreEnum::*;
240 match self {
241 Memory(x) => x.shutdown(),
242 Etcd(x) => x.shutdown(),
243 Nats(x) => x.shutdown(),
244 File(x) => x.shutdown(),
245 }
246 }
247}
248
249#[derive(Clone)]
250pub struct Manager(Arc<KeyValueStoreEnum>);
251
252impl Default for Manager {
253 fn default() -> Self {
254 Manager::memory()
255 }
256}
257
258impl Manager {
259 pub fn memory() -> Self {
261 Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
262 }
263
264 pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
265 Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
266 }
267
268 pub fn file<P: Into<PathBuf>>(cancel_token: CancellationToken, root: P) -> Self {
269 Self::new(KeyValueStoreEnum::File(FileStore::new(cancel_token, root)))
270 }
271
272 fn new(s: KeyValueStoreEnum) -> Manager {
273 Manager(Arc::new(s))
274 }
275
276 pub async fn get_or_create_bucket(
277 &self,
278 bucket_name: &str,
279 ttl: Option<Duration>,
281 ) -> Result<Box<dyn Bucket>, StoreError> {
282 self.0.get_or_create_bucket(bucket_name, ttl).await
283 }
284
285 pub async fn get_bucket(
286 &self,
287 bucket_name: &str,
288 ) -> Result<Option<Box<dyn Bucket>>, StoreError> {
289 self.0.get_bucket(bucket_name).await
290 }
291
292 pub fn connection_id(&self) -> u64 {
293 self.0.connection_id()
294 }
295
296 pub async fn load<T: for<'a> Deserialize<'a>>(
297 &self,
298 bucket: &str,
299 key: &Key,
300 ) -> Result<Option<T>, StoreError> {
301 let Some(bucket) = self.0.get_bucket(bucket).await? else {
302 return Ok(None);
304 };
305 Ok(match bucket.get(key).await? {
306 Some(card_bytes) => {
307 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
308 Some(card)
309 }
310 None => None,
311 })
312 }
313
314 pub fn watch(
318 self: Arc<Self>,
319 bucket_name: &str,
320 bucket_ttl: Option<Duration>,
321 cancel_token: CancellationToken,
322 ) -> (
323 tokio::task::JoinHandle<Result<(), StoreError>>,
324 tokio::sync::mpsc::Receiver<WatchEvent>,
325 ) {
326 let bucket_name = bucket_name.to_string();
327 let (tx, rx) = tokio::sync::mpsc::channel(16384);
331 let watch_task = tokio::spawn(async move {
332 let bucket = self
334 .0
335 .get_or_create_bucket(&bucket_name, bucket_ttl)
336 .await?;
337 let mut stream = bucket.watch().await?;
338
339 for (key, bytes) in bucket.entries().await? {
341 if let Err(err) = tx
342 .send_timeout(
343 WatchEvent::Put(KeyValue::new(key, bytes)),
344 WATCH_SEND_TIMEOUT,
345 )
346 .await
347 {
348 tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
349 }
350 }
351
352 loop {
354 let event = tokio::select! {
355 _ = cancel_token.cancelled() => break,
356 result = stream.next() => match result {
357 Some(event) => event,
358 None => break,
359 }
360 };
361 if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
362 tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
363 }
364 }
365
366 Ok::<(), StoreError>(())
367 });
368 (watch_task, rx)
369 }
370
371 pub async fn publish<T: Serialize + Versioned + Send + Sync>(
372 &self,
373 bucket_name: &str,
374 bucket_ttl: Option<Duration>,
375 key: &Key,
376 obj: &mut T,
377 ) -> anyhow::Result<StoreOutcome> {
378 let obj_json = serde_json::to_vec(obj)?;
379 let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
380
381 let outcome = bucket.insert(key, obj_json.into(), obj.revision()).await?;
382
383 match outcome {
384 StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
385 obj.set_revision(revision);
386 }
387 }
388 Ok(outcome)
389 }
390
391 pub fn shutdown(&self) {
394 self.0.shutdown()
395 }
396}
397
398#[async_trait]
400pub trait Bucket: Send + Sync {
401 async fn insert(
405 &self,
406 key: &Key,
407 value: bytes::Bytes,
408 revision: u64,
409 ) -> Result<StoreOutcome, StoreError>;
410
411 async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
414
415 async fn delete(&self, key: &Key) -> Result<(), StoreError>;
418
419 async fn watch(
423 &self,
424 ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
425
426 async fn entries(&self) -> Result<HashMap<Key, bytes::Bytes>, StoreError>;
430}
431
432#[derive(Debug, Copy, Clone, Eq, PartialEq)]
433pub enum StoreOutcome {
434 Created(u64),
437 Exists(u64),
439}
440impl fmt::Display for StoreOutcome {
441 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
442 match self {
443 StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
444 StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
445 }
446 }
447}
448
449#[derive(thiserror::Error, Debug)]
450pub enum StoreError {
451 #[error("Could not find bucket '{0}'")]
452 MissingBucket(String),
453
454 #[error("Could not find key '{0}'")]
455 MissingKey(String),
456
457 #[error("Internal storage error: '{0}'")]
458 ProviderError(String),
459
460 #[error("Internal NATS error: {0}")]
461 NATSError(String),
462
463 #[error("Internal etcd error: {0}")]
464 EtcdError(String),
465
466 #[error("Internal filesystem error: {0}")]
467 FilesystemError(String),
468
469 #[error("Key Value Error: {0} for bucket '{1}'")]
470 KeyValueError(String, String),
471
472 #[error("Error decoding bytes: {0}")]
473 JSONDecodeError(#[from] serde_json::error::Error),
474
475 #[error("Race condition, retry the call")]
476 Retry,
477}
478
479pub trait Versioned {
482 fn revision(&self) -> u64;
483 fn set_revision(&mut self, r: u64);
484}
485
486#[cfg(test)]
487mod tests {
488 use std::sync::Arc;
489
490 use super::*;
491 use futures::{StreamExt, pin_mut};
492
493 const BUCKET_NAME: &str = "v1/mdc";
494
495 #[allow(dead_code)]
498 pub struct TappableStream {
499 tx: tokio::sync::broadcast::Sender<WatchEvent>,
500 }
501
502 #[allow(dead_code)]
503 impl TappableStream {
504 async fn new<T>(stream: T, max_size: usize) -> Self
505 where
506 T: futures::Stream<Item = WatchEvent> + Send + 'static,
507 {
508 let (tx, _) = tokio::sync::broadcast::channel(max_size);
509 let tx2 = tx.clone();
510 tokio::spawn(async move {
511 pin_mut!(stream);
512 while let Some(x) = stream.next().await {
513 let _ = tx2.send(x);
514 }
515 });
516 TappableStream { tx }
517 }
518
519 fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
520 self.tx.subscribe()
521 }
522 }
523
524 fn init() {
525 crate::logging::init();
526 }
527
528 #[tokio::test]
529 async fn test_memory_storage() -> anyhow::Result<()> {
530 init();
531
532 let s = Arc::new(MemoryStore::new());
533 let s2 = Arc::clone(&s);
534
535 let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
536 let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
537 assert_eq!(res, StoreOutcome::Created(0));
538
539 let mut expected = Vec::with_capacity(3);
540 for i in 1..=3 {
541 let item = WatchEvent::Put(KeyValue::new(
542 Key::new(format!("test{i}")),
543 format!("value{i}").into(),
544 ));
545 expected.push(item);
546 }
547
548 let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
549 let ingress = tokio::spawn(async move {
550 let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
551 let mut stream = b2.watch().await?;
552
553 let v = stream.next().await.unwrap();
555 assert_eq!(v, expected[0]);
556
557 got_first_tx.send(()).unwrap();
558
559 let v = stream.next().await.unwrap();
561 assert_eq!(v, expected[1]);
562
563 let v = stream.next().await.unwrap();
564 assert_eq!(v, expected[2]);
565
566 Ok::<_, StoreError>(())
567 });
568
569 got_first_rx.await?;
573
574 let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
575 assert_eq!(res, StoreOutcome::Created(0));
576
577 let res = bucket.insert(&"test2".into(), "value2".into(), 0).await?;
579 assert_eq!(res, StoreOutcome::Exists(0));
580
581 let res = bucket.insert(&"test2".into(), "value2".into(), 1).await?;
583 assert_eq!(res, StoreOutcome::Created(1));
584
585 let res = bucket.insert(&"test3".into(), "value3".into(), 0).await?;
586 assert_eq!(res, StoreOutcome::Created(0));
587
588 let _ = ingress.await?;
590
591 Ok(())
592 }
593
594 #[tokio::test]
595 async fn test_broadcast_stream() -> anyhow::Result<()> {
596 init();
597
598 let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
599 let bucket: &'static _ =
600 Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
601
602 let res = bucket.insert(&"test1".into(), "value1".into(), 0).await?;
603 assert_eq!(res, StoreOutcome::Created(0));
604
605 let stream = bucket.watch().await?;
606 let tap = TappableStream::new(stream, 10).await;
607
608 let mut rx1 = tap.subscribe();
609 let mut rx2 = tap.subscribe();
610
611 let item = WatchEvent::Put(KeyValue::new(Key::new("test1".to_string()), "GK".into()));
612 let item_clone = item.clone();
613 let handle1 = tokio::spawn(async move {
614 let b = rx1.recv().await.unwrap();
615 assert_eq!(b, item_clone);
616 });
617 let handle2 = tokio::spawn(async move {
618 let b = rx2.recv().await.unwrap();
619 assert_eq!(b, item);
620 });
621
622 bucket.insert(&"test1".into(), "GK".into(), 1).await?;
623
624 let _ = futures::join!(handle1, handle2);
625 Ok(())
626 }
627}