1use std::collections::HashMap;
8use std::fmt;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::time::Duration;
12
13use crate::CancellationToken;
14use crate::slug::Slug;
15use async_trait::async_trait;
16use futures::StreamExt;
17use serde::{Deserialize, Serialize};
18
19mod mem;
20pub use mem::MemoryStore;
21mod nats;
22pub use nats::NATSStore;
23mod etcd;
24pub use etcd::EtcdStore;
25
26#[derive(Debug, Clone, PartialEq)]
28pub struct Key(String);
29
30impl Key {
31 pub fn new(s: &str) -> Key {
32 Key(Slug::slugify(s).to_string())
33 }
34
35 pub fn from_raw(s: String) -> Key {
37 Key(s)
38 }
39}
40
41impl From<&str> for Key {
42 fn from(s: &str) -> Key {
43 Key::new(s)
44 }
45}
46
47impl fmt::Display for Key {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 write!(f, "{}", self.0)
50 }
51}
52
53impl AsRef<str> for Key {
54 fn as_ref(&self) -> &str {
55 &self.0
56 }
57}
58
59impl From<&Key> for String {
60 fn from(k: &Key) -> String {
61 k.0.clone()
62 }
63}
64
65#[async_trait]
66pub trait KeyValueStore: Send + Sync {
67 async fn get_or_create_bucket(
68 &self,
69 bucket_name: &str,
70 ttl: Option<Duration>,
72 ) -> Result<Box<dyn KeyValueBucket>, StoreError>;
73
74 async fn get_bucket(
75 &self,
76 bucket_name: &str,
77 ) -> Result<Option<Box<dyn KeyValueBucket>>, StoreError>;
78
79 fn connection_id(&self) -> u64;
80}
81
82pub struct KeyValueStoreManager(Box<dyn KeyValueStore>);
83
84impl KeyValueStoreManager {
85 pub fn new(s: Box<dyn KeyValueStore>) -> KeyValueStoreManager {
86 KeyValueStoreManager(s)
87 }
88
89 pub async fn load<T: for<'a> Deserialize<'a>>(
90 &self,
91 bucket: &str,
92 key: &Key,
93 ) -> Result<Option<T>, StoreError> {
94 let Some(bucket) = self.0.get_bucket(bucket).await? else {
95 return Ok(None);
97 };
98 match bucket.get(key).await {
99 Ok(Some(card_bytes)) => {
100 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
101 Ok(Some(card))
102 }
103 Ok(None) => Ok(None),
104 Err(err) => {
105 Err(StoreError::NATSError(err.to_string()))
107 }
108 }
109 }
110
111 pub fn watch<T: for<'a> Deserialize<'a> + Send + 'static>(
115 self: Arc<Self>,
116 bucket_name: &str,
117 bucket_ttl: Option<Duration>,
118 ) -> (
119 tokio::task::JoinHandle<Result<(), StoreError>>,
120 tokio::sync::mpsc::UnboundedReceiver<T>,
121 ) {
122 let bucket_name = bucket_name.to_string();
123 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
124 let watch_task = tokio::spawn(async move {
125 let bucket = self
127 .0
128 .get_or_create_bucket(&bucket_name, bucket_ttl)
129 .await?;
130 let mut stream = bucket.watch().await?;
131
132 for (_, card_bytes) in bucket.entries().await? {
134 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
135 let _ = tx.send(card);
136 }
137
138 while let Some(card_bytes) = stream.next().await {
140 let card: T = serde_json::from_slice(card_bytes.as_ref())?;
141 let _ = tx.send(card);
142 }
143
144 Ok::<(), StoreError>(())
145 });
146 (watch_task, rx)
147 }
148
149 pub async fn publish<T: Serialize + Versioned + Send + Sync>(
150 &self,
151 bucket_name: &str,
152 bucket_ttl: Option<Duration>,
153 key: &Key,
154 obj: &mut T,
155 ) -> anyhow::Result<StoreOutcome> {
156 let obj_json = serde_json::to_string(obj)?;
157 let bucket = self.0.get_or_create_bucket(bucket_name, bucket_ttl).await?;
158
159 let outcome = bucket.insert(key, &obj_json, obj.revision()).await?;
160
161 match outcome {
162 StoreOutcome::Created(revision) | StoreOutcome::Exists(revision) => {
163 obj.set_revision(revision);
164 }
165 }
166 Ok(outcome)
167 }
168}
169
170#[async_trait]
173pub trait KeyValueBucket: Send {
174 async fn insert(
177 &self,
178 key: &Key,
179 value: &str,
180 revision: u64,
181 ) -> Result<StoreOutcome, StoreError>;
182
183 async fn get(&self, key: &Key) -> Result<Option<bytes::Bytes>, StoreError>;
185
186 async fn delete(&self, key: &Key) -> Result<(), StoreError>;
188
189 async fn watch(
193 &self,
194 ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StoreError>;
195
196 async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StoreError>;
197}
198
199#[derive(Debug, Copy, Clone, Eq, PartialEq)]
200pub enum StoreOutcome {
201 Created(u64),
204 Exists(u64),
206}
207impl fmt::Display for StoreOutcome {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 match self {
210 StoreOutcome::Created(revision) => write!(f, "Created at {revision}"),
211 StoreOutcome::Exists(revision) => write!(f, "Exists at {revision}"),
212 }
213 }
214}
215
216#[derive(thiserror::Error, Debug)]
217pub enum StoreError {
218 #[error("Could not find bucket '{0}'")]
219 MissingBucket(String),
220
221 #[error("Could not find key '{0}'")]
222 MissingKey(String),
223
224 #[error("Internal storage error: '{0}'")]
225 ProviderError(String),
226
227 #[error("Internal NATS error: {0}")]
228 NATSError(String),
229
230 #[error("Internal etcd error: {0}")]
231 EtcdError(String),
232
233 #[error("Key Value Error: {0} for bucket '{1}")]
234 KeyValueError(String, String),
235
236 #[error("Error decoding bytes: {0}")]
237 JSONDecodeError(#[from] serde_json::error::Error),
238
239 #[error("Race condition, retry the call")]
240 Retry,
241}
242
243pub trait Versioned {
246 fn revision(&self) -> u64;
247 fn set_revision(&mut self, r: u64);
248}
249
250#[cfg(test)]
251mod tests {
252 use std::sync::Arc;
253
254 use super::*;
255 use futures::{StreamExt, pin_mut};
256
257 const BUCKET_NAME: &str = "v1/mdc";
258
259 #[allow(dead_code)]
262 pub struct TappableStream {
263 tx: tokio::sync::broadcast::Sender<bytes::Bytes>,
264 }
265
266 #[allow(dead_code)]
267 impl TappableStream {
268 async fn new<T>(stream: T, max_size: usize) -> Self
269 where
270 T: futures::Stream<Item = bytes::Bytes> + Send + 'static,
271 {
272 let (tx, _) = tokio::sync::broadcast::channel(max_size);
273 let tx2 = tx.clone();
274 tokio::spawn(async move {
275 pin_mut!(stream);
276 while let Some(x) = stream.next().await {
277 let _ = tx2.send(x);
278 }
279 });
280 TappableStream { tx }
281 }
282
283 fn subscribe(&self) -> tokio::sync::broadcast::Receiver<bytes::Bytes> {
284 self.tx.subscribe()
285 }
286 }
287
288 fn init() {
289 crate::logging::init();
290 }
291
292 #[tokio::test]
293 async fn test_memory_storage() -> anyhow::Result<()> {
294 init();
295
296 let s = Arc::new(MemoryStore::new());
297 let s2 = Arc::clone(&s);
298
299 let bucket = s.get_or_create_bucket(BUCKET_NAME, None).await?;
300 let res = bucket.insert(&"test1".into(), "value1", 0).await?;
301 assert_eq!(res, StoreOutcome::Created(0));
302
303 let (got_first_tx, got_first_rx) = tokio::sync::oneshot::channel();
304 let ingress = tokio::spawn(async move {
305 let b2 = s2.get_or_create_bucket(BUCKET_NAME, None).await?;
306 let mut stream = b2.watch().await?;
307
308 let v = stream.next().await.unwrap();
310 assert_eq!(v, "value1".as_bytes());
311
312 got_first_tx.send(()).unwrap();
313
314 let v = stream.next().await.unwrap();
316 assert_eq!(v, "value2".as_bytes());
317 let v = stream.next().await.unwrap();
318 assert_eq!(v, "value3".as_bytes());
319
320 Ok::<_, StoreError>(())
321 });
322
323 got_first_rx.await?;
327
328 let res = bucket.insert(&"test2".into(), "value2", 0).await?;
329 assert_eq!(res, StoreOutcome::Created(0));
330
331 let res = bucket.insert(&"test2".into(), "value2", 0).await?;
333 assert_eq!(res, StoreOutcome::Exists(0));
334
335 let res = bucket.insert(&"test2".into(), "value2", 1).await?;
337 assert_eq!(res, StoreOutcome::Created(1));
338
339 let res = bucket.insert(&"test3".into(), "value3", 0).await?;
340 assert_eq!(res, StoreOutcome::Created(0));
341
342 let _ = ingress.await?;
344
345 Ok(())
346 }
347
348 #[tokio::test]
349 async fn test_broadcast_stream() -> anyhow::Result<()> {
350 init();
351
352 let s: &'static _ = Box::leak(Box::new(MemoryStore::new()));
353 let bucket: &'static _ =
354 Box::leak(Box::new(s.get_or_create_bucket(BUCKET_NAME, None).await?));
355
356 let res = bucket.insert(&"test1".into(), "value1", 0).await?;
357 assert_eq!(res, StoreOutcome::Created(0));
358
359 let stream = bucket.watch().await?;
360 let tap = TappableStream::new(stream, 10).await;
361
362 let mut rx1 = tap.subscribe();
363 let mut rx2 = tap.subscribe();
364
365 let handle1 = tokio::spawn(async move {
366 let b = rx1.recv().await.unwrap();
367 assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
368 });
369 let handle2 = tokio::spawn(async move {
370 let b = rx2.recv().await.unwrap();
371 assert_eq!(b, bytes::Bytes::from(vec![b'G', b'K']));
372 });
373
374 bucket.insert(&"test1".into(), "GK", 1).await?;
375
376 let _ = futures::join!(handle1, handle2);
377 Ok(())
378 }
379}