1use std::{
2 collections::HashSet,
3 error::Error,
4 sync::{
5 atomic::{AtomicU64, Ordering},
6 Arc,
7 },
8};
9
10use async_nats::{
11 jetstream::{
12 self,
13 consumer::PullConsumer,
14 context::{
15 CreateKeyValueErrorKind, CreateStreamError, GetStreamErrorKind, KeyValueErrorKind,
16 },
17 kv::{self, Operation, Store},
18 stream::ConsumerError,
19 Context,
20 },
21 subject::ToSubject,
22 Client, ConnectError, ConnectErrorKind,
23};
24use futures::{StreamExt, TryStreamExt};
25use tokio::task::JoinHandle;
26
27pub use async_nats;
28
29pub mod subject;
30
31#[derive(Clone, serde::Deserialize)]
32pub struct Config {
33 app_name: Option<String>,
34 host: Option<String>,
35 port: Option<u16>,
36 #[serde(skip)]
37 address: Option<String>,
38 sys_locks: Option<String>,
39 events_stream_name: Option<String>,
40 events_stream_subject: Option<String>,
41}
42
43impl Config {
44 pub fn new() -> envy::Result<Self> {
45 ConfigBuilder::default().build()
46 }
47
48 pub fn builder<'a>() -> ConfigBuilder<'a> {
49 ConfigBuilder::default()
50 }
51
52 pub fn address(&self) -> &str {
53 self.address.as_deref().unwrap()
54 }
55
56 pub fn port(&self) -> u16 {
57 self.port.unwrap_or(3000)
58 }
59 pub fn sys_locks(&self) -> &str {
60 self.sys_locks.as_deref().unwrap_or("SYS_LOCKS")
61 }
62 pub fn events_stream_name(&self) -> &str {
63 self.events_stream_name.as_deref().unwrap_or("EVENTS")
64 }
65 pub fn events_stream_subject(&self) -> &str {
66 self.events_stream_subject.as_deref().unwrap_or("ev.>")
67 }
68}
69
70#[derive(Default)]
71pub struct ConfigBuilder<'a> {
72 prefix: Option<&'a str>,
73}
74
75impl<'a> ConfigBuilder<'a> {
76 pub fn with_prefix(mut self, prefix: &'a str) -> Self {
77 self.prefix = Some(prefix);
78 self
79 }
80
81 pub fn build(self) -> envy::Result<Config> {
82 let prefix = self.prefix.unwrap_or("NATS_");
83 let mut cfg: Config = envy::prefixed(prefix).from_env()?;
84 if cfg.app_name.is_none() {
85 cfg.app_name = Some("edd-service-rs".into());
86 }
87 let host = cfg.host.as_deref().unwrap_or("127.0.0.1");
88 let port = cfg.port.unwrap_or(4222);
89 cfg.address = Some(format!("nats://{}:{}", host, port));
90 Ok(cfg)
91 }
92}
93
94pub struct Inner {
95 client: Client,
96 config: Config,
97}
98
99#[derive(Clone)]
100pub struct Nats {
101 inner: Arc<Inner>,
102}
103
104impl Nats {
105 pub async fn new(config: Config) -> Result<Self, ConnectError> {
106 let client = async_nats::ConnectOptions::new()
107 .max_reconnects(Some(1))
108 .connect(config.address())
109 .await?;
110 Ok(Self {
111 inner: Arc::new(Inner { client, config }),
112 })
113 }
114
115 pub fn client(&self) -> &Client {
116 &self.inner.client
117 }
118
119 pub fn config(&self) -> &Config {
120 &self.inner.config
121 }
122
123 pub async fn publisher(&self) -> Result<Publisher, CreateStreamError> {
124 let ctx = jetstream::new(self.inner.client.clone());
125 let p = Publisher { ctx };
126 p.init(&self.inner.config).await?;
127 Ok(p)
128 }
129
130 pub async fn sys_consumer(&self, name: String) -> Result<PullConsumer, ConsumerError> {
131 let ctx = jetstream::new(self.inner.client.clone());
132 ctx.create_consumer_on_stream(
133 jetstream::consumer::pull::Config {
134 durable_name: Some(name),
135 ..Default::default()
136 },
137 self.inner.config.events_stream_name(),
138 )
139 .await
140 }
141
142 pub async fn sys_consumer_with_filter(
143 &self,
144 name: String,
145 filter_subject: String,
146 ) -> Result<PullConsumer, ConsumerError> {
147 let ctx = jetstream::new(self.inner.client.clone());
148 ctx.create_consumer_on_stream(
149 jetstream::consumer::pull::Config {
150 durable_name: Some(name),
151 filter_subject,
152 ..Default::default()
153 },
154 self.inner.config.events_stream_name(),
155 )
156 .await
157 }
158
159 pub async fn sys_consumer_with_filters(
160 &self,
161 name: String,
162 filter_subjects: Vec<String>,
163 ) -> Result<PullConsumer, ConsumerError> {
164 let ctx = jetstream::new(self.inner.client.clone());
165 ctx.create_consumer_on_stream(
166 jetstream::consumer::pull::Config {
167 durable_name: Some(name),
168 filter_subjects,
169 ..Default::default()
170 },
171 self.inner.config.events_stream_name(),
172 )
173 .await
174 }
175
176 pub async fn tmp_sys_consumer_with_filter(
177 &self,
178 filter_subject: String,
179 ) -> Result<PullConsumer, ConsumerError> {
180 let ctx = jetstream::new(self.inner.client.clone());
181 ctx.create_consumer_on_stream(
182 jetstream::consumer::pull::Config {
183 filter_subject,
184 deliver_policy: jetstream::consumer::DeliverPolicy::Last,
185 ..Default::default()
186 },
187 self.inner.config.events_stream_name(),
188 )
189 .await
190 }
191
192 pub async fn distributed_locks(&self) -> Result<DistributedLocks, DistributedLocksError> {
193 let ctx = jetstream::new(self.inner.client.clone());
194 DistributedLocks::new(ctx, &self.inner.config).await
195 }
196
197 pub fn sequence_manager(&self) -> SequenceManager {
198 let ctx = jetstream::new(self.inner.client.clone());
199 SequenceManager { ctx }
200 }
201}
202
203pub trait EventToSubject<M> {
204 fn event_to_subject(&self) -> async_nats::Subject;
205}
206
207pub struct Publisher {
208 ctx: Context,
209}
210
211impl Publisher {
212 async fn init(&self, config: &Config) -> Result<(), CreateStreamError> {
213 let names: HashSet<String> = self.ctx.stream_names().try_collect().await?;
214 if !names.contains(config.events_stream_name()) {
215 self.ctx
216 .create_stream(jetstream::stream::Config {
217 name: config.events_stream_name().to_string(),
218 subjects: vec![config.events_stream_subject().into()],
219 allow_direct: true,
220 deny_delete: true,
221 deny_purge: true,
222 ..Default::default()
223 })
224 .await?;
225 }
226 Ok(())
227 }
228
229 pub async fn publish<S: ToSubject, P: ?Sized + serde::Serialize>(
230 &self,
231 subject: S,
232 payload: &P,
233 ) -> anyhow::Result<()> {
234 self.ctx
235 .publish(subject, serde_json::to_vec(payload)?.into())
236 .await?;
237 Ok(())
238 }
239
240 pub async fn publish_event<S, M, P>(&self, subject: &S, payload: &P) -> anyhow::Result<()>
241 where
242 S: ?Sized + EventToSubject<M>,
243 P: ?Sized + serde::Serialize,
244 {
245 self.ctx
246 .publish(
247 subject.event_to_subject(),
248 serde_json::to_vec(payload)?.into(),
249 )
250 .await?;
251 Ok(())
252 }
253}
254
255impl AsRef<Context> for Publisher {
256 fn as_ref(&self) -> &Context {
257 &self.ctx
258 }
259}
260
261#[derive(thiserror::Error, Debug)]
262pub enum DistributedLocksError {
263 #[error(transparent)]
264 Connect(#[from] async_nats::error::Error<ConnectErrorKind>),
265 #[error(transparent)]
266 CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
267 #[error(transparent)]
268 KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
269}
270
271#[derive(Clone)]
272pub struct DistributedLocks {
273 ctx: Context,
274 sys_locks: String,
275}
276
277impl DistributedLocks {
278 async fn new(ctx: Context, config: &Config) -> Result<Self, DistributedLocksError> {
279 let lm = DistributedLocks {
280 ctx,
281 sys_locks: config.sys_locks().to_string(),
282 };
283 if !lm.exists(config.sys_locks()).await? {
284 lm.create(config.sys_locks(), 5).await?;
285 }
286 Ok(lm)
287 }
288
289 async fn create<T: Into<String>>(
290 &self,
291 name: T,
292 max_age: u64,
293 ) -> Result<Store, DistributedLocksError> {
294 Ok(self
295 .ctx
296 .create_key_value(kv::Config {
297 bucket: name.into(),
298 max_age: std::time::Duration::from_secs(max_age),
299 history: 1,
300 ..Default::default()
301 })
302 .await?)
303 }
304
305 async fn exists<T: Into<String>>(&self, bucket: T) -> Result<bool, DistributedLocksError> {
306 if let Err(err) = self.ctx.get_key_value(bucket).await {
307 if err.kind() == KeyValueErrorKind::GetBucket {
308 if let Some(src) = err.source() {
309 let err = src.downcast_ref::<async_nats::error::Error<GetStreamErrorKind>>();
310 if let Some(err) = err {
311 if let GetStreamErrorKind::JetStream(err) = err.kind() {
312 if err.code() == 404 {
313 return Ok(false);
314 }
315 }
316 }
317 }
318 }
319 Err(err)?;
320 }
321 Ok(true)
322 }
323
324 pub async fn sys_locks(&self) -> anyhow::Result<LockManager> {
325 let kv = self.ctx.get_key_value(&self.sys_locks).await?;
326 Ok(LockManager { kv: Arc::new(kv) })
327 }
328}
329
330#[derive(thiserror::Error, Debug)]
331pub enum LockManagerError {
332 #[error(transparent)]
333 CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
334 #[error(transparent)]
335 KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
336 #[error(transparent)]
337 Watch(#[from] async_nats::error::Error<kv::WatchErrorKind>),
338 #[error("unable to lock resource after {0:?}")]
339 OutOfRetries(std::time::Duration),
340}
341
342#[derive(thiserror::Error, Debug)]
343pub enum SequenceManagerError {
344 #[error(transparent)]
345 Connect(#[from] async_nats::error::Error<ConnectErrorKind>),
346 #[error(transparent)]
347 CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
348 #[error(transparent)]
349 KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
350 #[error(transparent)]
351 Put(#[from] async_nats::error::Error<async_nats::jetstream::kv::PutErrorKind>),
352 #[error(transparent)]
353 Entry(#[from] async_nats::error::Error<async_nats::jetstream::kv::EntryErrorKind>),
354}
355
356pub struct SequenceManager {
357 ctx: Context,
358}
359
360impl SequenceManager {
361 async fn create<T: Into<String>>(&self, name: T) -> Result<Store, SequenceManagerError> {
362 Ok(self
363 .ctx
364 .create_key_value(kv::Config {
365 bucket: name.into(),
366 ..Default::default()
367 })
368 .await?)
369 }
370
371 async fn exists<T: Into<String>>(&self, bucket: T) -> Result<bool, SequenceManagerError> {
372 if let Err(err) = self.ctx.get_key_value(bucket).await {
373 if err.kind() == KeyValueErrorKind::GetBucket {
374 if let Some(src) = err.source() {
375 let err = src.downcast_ref::<async_nats::error::Error<GetStreamErrorKind>>();
376 if let Some(err) = err {
377 if let GetStreamErrorKind::JetStream(err) = err.kind() {
378 if err.code() == 404 {
379 return Ok(false);
380 }
381 }
382 }
383 }
384 }
385 Err(err)?;
386 }
387 Ok(true)
388 }
389
390 async fn get<T: Into<String>>(&self, bucket: T) -> Result<Store, SequenceManagerError> {
391 Ok(self.ctx.get_key_value(bucket).await?)
392 }
393
394 pub async fn next(&self, prefix: &str, id: i64) -> Result<i64, SequenceManagerError> {
395 let bucket = format!("sm-{prefix}");
396 if !self.exists(&bucket).await? {
397 let store = self.create(&bucket).await?;
398 let result = store.put("id", id.to_be_bytes().to_vec().into()).await?;
399 Ok(result as i64)
400 } else {
401 let store = self.get(&bucket).await?;
402 let e = store.entry("id").await?;
403 if let Some(e) = e {
404 Ok(e.revision as i64)
405 } else {
406 let result = store.put("id", id.to_be_bytes().to_vec().into()).await?;
407 Ok(result as i64)
408 }
409 }
410 }
411
412 pub async fn increment(&self, prefix: &str, id: i64) -> Result<i64, SequenceManagerError> {
413 let bucket = format!("sm-{prefix}");
414 let store = self.get(&bucket).await?;
415 let e = store.put("id", id.to_be_bytes().to_vec().into()).await?;
416 Ok(e as i64)
417 }
418}
419
420pub struct LockManager {
421 kv: Arc<Store>,
422}
423
424impl LockManager {
425 pub async fn run_locked<N, O, F, E>(&self, name: N, f: F) -> Result<O, E>
426 where
427 N: Into<String>,
428 F: std::future::Future<Output = Result<O, E>>,
429 E: From<LockManagerError>,
430 {
431 let lock = self.try_lock(name.into(), 3, 5).await?;
432 let result = f.await;
433 let w_kv = self.kv.clone();
434 tokio::spawn(async move {
435 if !lock.jh.is_finished() {
436 lock.jh.abort();
437 let result = lock.jh.await;
438 if let Err(err) = result {
439 if !err.is_cancelled() {
440 tracing::error!("{err:#?}");
441 }
442 }
443 }
444 w_kv.delete_expect_revision(lock.name, Some(lock.revision.load(Ordering::SeqCst)))
445 .await
446 .ok();
447 });
448 result
449 }
450
451 async fn try_lock(
452 &self,
453 name: String,
454 timeout: u64,
455 retries: usize,
456 ) -> Result<Lock, LockManagerError> {
457 let now = std::time::Instant::now();
458 let max_retries = retries;
459 let mut tries = 0;
460 let revision = Arc::new(AtomicU64::new(0));
461 let kv = &self.kv;
462 loop {
463 if tries >= max_retries {
464 return Err(LockManagerError::OutOfRetries(now.elapsed()));
465 }
466 let v = kv.create(&name, "r".into()).await;
467 if let Err(err) = v {
468 if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists {
469 tracing::debug!("seems to be locked already, {tries} try to watch for changes");
470 let mut w = kv.watch(&name).await?;
471 let f = async {
472 'inner: while let Some(m) = w.next().await {
473 if let Ok(e) = m {
474 if e.operation == Operation::Delete {
475 tracing::debug!("retry because prev lock was deleted");
476 break 'inner;
477 }
478 }
479 }
480 };
481 let t = async {
482 tokio::time::sleep(std::time::Duration::from_secs(timeout)).await;
483 };
484 let change = tokio::select! {
485 _ = f => true,
486 _ = t => false,
487 };
488 if !change {
489 tries += 1;
490 }
491 }
492 } else {
493 let r = v.unwrap();
494 revision.store(r, Ordering::SeqCst);
495 tracing::debug!("got lock: '{name}'");
496 break;
497 }
498 }
499 let w_kv = self.kv.clone();
500 let w_name = name.clone();
501 let w_revision = revision.clone();
502
503 let jh = tokio::spawn(async move {
504 let mut run = 0;
505 loop {
506 run += 1;
507 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
508 tracing::debug!("refresh lock {w_name}");
509 let result = w_kv
510 .update(&w_name, "u".into(), w_revision.load(Ordering::SeqCst))
511 .await;
512 if let Err(err) = result {
513 tracing::error!("{err:#?}");
514 break;
515 } else {
516 w_revision.store(result.unwrap(), Ordering::SeqCst);
517 }
518 if run >= 5 {
519 tracing::debug!("release lock after timeout");
520 break;
521 }
522 }
523 anyhow::Ok(())
524 });
525
526 Ok(Lock { name, revision, jh })
527 }
528}
529
530#[derive(Debug, PartialEq, Eq)]
531pub enum LockState {
532 Registering,
533 Registered,
534}
535
536#[derive(Debug)]
537pub struct Lock {
538 name: String,
539 revision: Arc<AtomicU64>,
540 jh: JoinHandle<anyhow::Result<()>>,
541}