1use std::collections::HashMap;
8use std::fmt;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::CancellationToken;
14use crate::slug::Slug;
15use async_trait::async_trait;
16use futures::StreamExt;
17use serde::{Deserialize, Serialize};
18
19mod mem;
20pub use mem::MemoryStore;
21mod nats;
22pub use nats::NATSStore;
23mod etcd;
24pub use etcd::EtcdStore;
25
26const WATCH_SEND_TIMEOUT: Duration = Duration::from_millis(100);
27
28#[derive(Debug, Clone, PartialEq)]
30pub struct Key(String);
31
32impl Key {
33 pub fn new(s: &str) -> Key {
34 Key(Slug::slugify(s).to_string())
35 }
36
37 pub fn from_raw(s: String) -> Key {
39 Key(s)
40 }
41}
42
43impl From<&str> for Key {
44 fn from(s: &str) -> Key {
45 Key::new(s)
46 }
47}
48
49impl fmt::Display for Key {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 write!(f, "{}", self.0)
52 }
53}
54
55impl AsRef<str> for Key {
56 fn as_ref(&self) -> &str {
57 &self.0
58 }
59}
60
61impl From<&Key> for String {
62 fn from(k: &Key) -> String {
63 k.0.clone()
64 }
65}
66
67#[derive(Debug, Clone, PartialEq)]
68pub struct KeyValue {
69 key: String,
70 value: bytes::Bytes,
71}
72
73impl KeyValue {
74 pub fn new(key: String, value: bytes::Bytes) -> Self {
75 KeyValue { key, value }
76 }
77
78 pub fn key(&self) -> String {
79 self.key.clone()
80 }
81
82 pub fn key_str(&self) -> &str {
83 &self.key
84 }
85
86 pub fn value(&self) -> &[u8] {
87 &self.value
88 }
89
90 pub fn value_str(&self) -> anyhow::Result<&str> {
91 std::str::from_utf8(self.value()).map_err(From::from)
92 }
93}
94
95#[derive(Debug, Clone, PartialEq)]
96pub enum WatchEvent {
97 Put(KeyValue),
98 Delete(KeyValue),
99}
100
101#[async_trait]
102pub trait KeyValueStore: Send + Sync {
103 type Bucket: KeyValueBucket + Send + Sync + 'static;
104
105 async fn get_or_create_bucket(
106 &self,
107 bucket_name: &str,
108 ttl: Option<Duration>,
110 ) -> Result<Self::Bucket, StoreError>;
111
112 async fn get_bucket(&self, bucket_name: &str) -> Result<Option<Self::Bucket>, StoreError>;
113
114 fn connection_id(&self) -> u64;
115}
116
117#[allow(clippy::large_enum_variant)]
118pub enum KeyValueStoreEnum {
119 Memory(MemoryStore),
120 Nats(NATSStore),
121 Etcd(EtcdStore),
122}
123
124impl KeyValueStoreEnum {
125 async fn get_or_create_bucket(
126 &self,
127 bucket_name: &str,
128 ttl: Option<Duration>,
130 ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
131 use KeyValueStoreEnum::*;
132 Ok(match self {
133 Memory(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
134 Nats(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
135 Etcd(x) => Box::new(x.get_or_create_bucket(bucket_name, ttl).await?),
136 })
137 }
138
139 async fn get_bucket(
140 &self,
141 bucket_name: &str,
142 ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
143 use KeyValueStoreEnum::*;
144 let maybe_bucket: Option<Box<dyn KeyValueBucket>> = match self {
145 Memory(x) => x
146 .get_bucket(bucket_name)
147 .await?
148 .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
149 Nats(x) => x
150 .get_bucket(bucket_name)
151 .await?
152 .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
153 Etcd(x) => x
154 .get_bucket(bucket_name)
155 .await?
156 .map(|b| Box::new(b) as Box<dyn KeyValueBucket>),
157 };
158 Ok(maybe_bucket)
159 }
160
161 fn connection_id(&self) -> u64 {
162 use KeyValueStoreEnum::*;
163 match self {
164 Memory(x) => x.connection_id(),
165 Etcd(x) => x.connection_id(),
166 Nats(x) => x.connection_id(),
167 }
168 }
169}
170
171#[derive(Clone)]
172pub struct KeyValueStoreManager(Arc<KeyValueStoreEnum>);
173
174impl Default for KeyValueStoreManager {
175 fn default() -> Self {
176 KeyValueStoreManager::memory()
177 }
178}
179
180impl KeyValueStoreManager {
181 pub fn memory() -> Self {
183 Self::new(KeyValueStoreEnum::Memory(MemoryStore::new()))
184 }
185
186 pub fn etcd(etcd_client: crate::transports::etcd::Client) -> Self {
187 Self::new(KeyValueStoreEnum::Etcd(EtcdStore::new(etcd_client)))
188 }
189
190 fn new(s: KeyValueStoreEnum) -> KeyValueStoreManager {
191 KeyValueStoreManager(Arc::new(s))
192 }
193
194 pub async fn get_or_create_bucket(
195 &self,
196 bucket_name: &str,
197 ttl: Option<Duration>,
199 ) -> Result<Box<dyn KeyValueBucket>, StoreError> {
200 self.0.get_or_create_bucket(bucket_name, ttl).await
201 }
202
203 pub async fn get_bucket(
204 &self,
205 bucket_name: &str,
206 ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError> {
207 self.0.get_bucket(bucket_name).await
208 }
209
210 pub fn connection_id(&self) -> u64 {
211 self.0.connection_id()
212 }
213
214 pub async fn load<T: for<'a> Deserialize<'a>>(
215 &self,
216 bucket: &str,
217 key: &Key,
218 ) -> Result<Option<T>, StoreError> {
219 let Some(bucket) = self.0.get_bucket(bucket).await? else {
220 return Ok(None);
222 };
223 Ok(match bucket.get(key).await? {
224 Some(card_bytes) => {
225 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
226 Some(card)
227 }
228 None => None,
229 })
230 }
231
232 pub fn watch(
236 self: Arc<Self>,
237 bucket_name: &str,
238 bucket_ttl: Option<Duration>,
239 cancel_token: CancellationToken,
240 ) -> (
241 tokio::task::JoinHandle<Result<(), StoreError>>,
242 tokio::sync::mpsc::Receiver<WatchEvent>,
243 ) {
244 let bucket_name = bucket_name.to_string();
245 let (tx, rx) = tokio::sync::mpsc::channel(128);
246 let watch_task = tokio::spawn(async move {
247 let bucket = self
249 .0
250 .get_or_create_bucket(&bucket_name, bucket_ttl)
251 .await?;
252 let mut stream = bucket.watch().await?;
253
254 for (key, bytes) in bucket.entries().await? {
256 if let Err(err) = tx
257 .send_timeout(
258 WatchEvent::Put(KeyValue::new(key, bytes)),
259 WATCH_SEND_TIMEOUT,
260 )
261 .await
262 {
263 tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding existing key to channel");
264 }
265 }
266
267 loop {
269 let event = tokio::select! {
270 _ = cancel_token.cancelled() => break,
271 result = stream.next() => match result {
272 Some(event) => event,
273 None => break,
274 }
275 };
276 if let Err(err) = tx.send_timeout(event, WATCH_SEND_TIMEOUT).await {
277 tracing::error!(bucket_name, %err, "KeyValueStoreManager.watch failed adding new key to channel");
278 }
279 }
280
281 Ok::<(), StoreError>(())
282 });
283 (watch_task, rx)
284 }
285
286 pub async fn publish<T: Serialize + Versioned + Send + Sync>(
287 &self,
288 bucket_name: &str,
289 bucket_ttl: Option<Duration>,
290 key: &Key,
291 obj: &mut T,
292 ) -> anyhow::Result<StoreOutcome> {
293 let obj_json = serde_json::to_string(obj)?;
294 let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
295
296 let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
297
298 match outcome {
299 StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
300 obj.set_revision(revision);
301 }
302 }
303 Ok(outcome)
304 }
305}
306
307#[async_trait]
310pub trait KeyValueBucket: Send + Sync {
311 async fn insert(
314 &self,
315 key: &Key,
316 value: &str,
317 revision: u64,
318 ) -> Result<StoreOutcome, StoreError>;
319
320 async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
322
323 async fn delete(&self, key: &Key) -> Result<(), StoreError>;
325
326 async fn watch(
330 &self,
331 ) -> Result<Pin<Box<dyn futures::Stream<Item = WatchEvent> + Send + '_>>, StoreError>;
332
333 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
334}
335
336#[derive(Debug, Copy, Clone, Eq, PartialEq)]
337pub enum StoreOutcome {
338 Created(u64),
341 Exists(u64),
343}
344impl fmt::Display for StoreOutcome {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 match self {
347 StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
348 StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
349 }
350 }
351}
352
353#[derive(thiserror::Error, Debug)]
354pub enum StoreError {
355 #[error("Could not find bucket '{0}'")]
356 MissingBucket(String),
357
358 #[error("Could not find key '{0}'")]
359 MissingKey(String),
360
361 #[error("Internal storage error: '{0}'")]
362 ProviderError(String),
363
364 #[error("Internal NATS error: {0}")]
365 NATSError(String),
366
367 #[error("Internal etcd error: {0}")]
368 EtcdError(String),
369
370 #[error("Key Value Error: {0} for bucket '{1}'")]
371 KeyValueError(String, String),
372
373 #[error("Error decoding bytes: {0}")]
374 JSONDecodeError(#[from] serde_json::error::Error),
375
376 #[error("Race condition, retry the call")]
377 Retry,
378}
379
380pub trait Versioned {
383 fn revision(&self) -> u64;
384 fn set_revision(&mut self, r: u64);
385}
386
387#[cfg(test)]
388mod tests {
389 use std::sync::Arc;
390
391 use super::*;
392 use futures::{StreamExt, pin_mut};
393
394 const BUCKET_NAME: &str = "v1/mdc";
395
396 #[allow(dead_code)]
399 pub struct TappableStream {
400 tx: tokio::sync::broadcast::Sender<WatchEvent>,
401 }
402
403 #[allow(dead_code)]
404 impl TappableStream {
405 async fn new<T>(stream: T, max_size: usize) -> Self
406 where
407 T: futures::Stream<Item = WatchEvent> + Send + 'static,
408 {
409 let (tx, _) = tokio::sync::broadcast::channel(max_size);
410 let tx2 = tx.clone();
411 tokio::spawn(async move {
412 pin_mut!(stream);
413 while let Some(x) = stream.next().await {
414 let _ = tx2.send(x);
415 }
416 });
417 TappableStream { tx }
418 }
419
420 fn subscribe(&self) -> tokio::sync::broadcast::Receiver<WatchEvent> {
421 self.tx.subscribe()
422 }
423 }
424
425 fn init() {
426 crate::logging::init();
427 }
428
429 #[tokio::test]
430 async fn test_memory_storage() -> anyhow::Result<()> {
431 init();
432
433 let s = Arc::new(MemoryStore::new());
434 let s2 = Arc::clone(&s);
435
436 let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
437 let res = bucket.insert(&"test1".into(), "value1", 0).await?;
438 assert_eq!(res, StoreOutcome::Created(0));
439
440 let mut expected = Vec::with_capacity(3);
441 for i in 1..=3 {
442 let item = WatchEvent::Put(KeyValue::new(
443 format!("test{i}"),
444 bytes::Bytes::from(format!("value{i}").into_bytes()),
445 ));
446 expected.push(item);
447 }
448
449 let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
450 let ingress = tokio::spawn(async move {
451 let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
452 let mut stream = b2.watch().await?;
453
454 let v = stream.next().await.unwrap();
456 assert_eq!(v, expected[0]);
457
458 got_first_tx.send(()).unwrap();
459
460 let v = stream.next().await.unwrap();
462 assert_eq!(v, expected[1]);
463
464 let v = stream.next().await.unwrap();
465 assert_eq!(v, expected[2]);
466
467 Ok::<_, StoreError>(())
468 });
469
470 got_first_rx.await?;
474
475 let res = bucket.insert(&"test2".into(), "value2", 0).await?;
476 assert_eq!(res, StoreOutcome::Created(0));
477
478 let res = bucket.insert(&"test2".into(), "value2", 0).await?;
480 assert_eq!(res, StoreOutcome::Exists(0));
481
482 let res = bucket.insert(&"test2".into(), "value2", 1).await?;
484 assert_eq!(res, StoreOutcome::Created(1));
485
486 let res = bucket.insert(&"test3".into(), "value3", 0).await?;
487 assert_eq!(res, StoreOutcome::Created(0));
488
489 let _ = ingress.await?;
491
492 Ok(())
493 }
494
495 #[tokio::test]
496 async fn test_broadcast_stream() -> anyhow::Result<()> {
497 init();
498
499 let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
500 let bucket: &'static _ =
501 Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
502
503 let res = bucket.insert(&"test1".into(), "value1", 0).await?;
504 assert_eq!(res, StoreOutcome::Created(0));
505
506 let stream = bucket.watch().await?;
507 let tap = TappableStream::new(stream, 10).await;
508
509 let mut rx1 = tap.subscribe();
510 let mut rx2 = tap.subscribe();
511
512 let item = WatchEvent::Put(KeyValue::new(
513 "test1".to_string(),
514 bytes::Bytes::from(b"GK".as_slice()),
515 ));
516 let item_clone = item.clone();
517 let handle1 = tokio::spawn(async move {
518 let b = rx1.recv().await.unwrap();
519 assert_eq!(b, item_clone);
520 });
521 let handle2 = tokio::spawn(async move {
522 let b = rx2.recv().await.unwrap();
523 assert_eq!(b, item);
524 });
525
526 bucket.insert(&"test1".into(), "GK", 1).await?;
527
528 let _ = futures::join!(handle1, handle2);
529 Ok(())
530 }
531}