1#![deny(missing_docs)]
2
3use std::{
54 collections::HashSet,
55 error::Error,
56 sync::{
57 atomic::{AtomicU64, Ordering},
58 Arc,
59 },
60};
61
62use async_nats::{
63 jetstream::{
64 self,
65 consumer::PullConsumer,
66 context::{
67 CreateKeyValueErrorKind, CreateStreamError, GetStreamErrorKind, KeyValueErrorKind,
68 },
69 kv::{self, Operation, Store},
70 stream::ConsumerError,
71 Context,
72 },
73 subject::ToSubject,
74 Client, ConnectError, ConnectErrorKind,
75};
76use futures::{StreamExt, TryStreamExt};
77use tokio::task::JoinHandle;
78
79pub use async_nats;
80
81pub mod subject;
83
84#[derive(Clone, serde::Deserialize)]
89pub struct Config {
90 app_name: Option<String>,
91 host: Option<String>,
92 port: Option<u16>,
93 #[serde(skip)]
94 address: Option<String>,
95 sys_locks: Option<String>,
96 events_stream_name: Option<String>,
97 events_stream_subject: Option<String>,
98}
99
100impl Config {
101 pub fn new() -> envy::Result<Self> {
103 ConfigBuilder::default().build()
104 }
105
106 pub fn builder<'a>() -> ConfigBuilder<'a> {
108 ConfigBuilder::default()
109 }
110
111 pub fn address(&self) -> &str {
113 self.address.as_deref().unwrap()
114 }
115
116 pub fn port(&self) -> u16 {
118 self.port.unwrap_or(3000)
119 }
120
121 pub fn sys_locks(&self) -> &str {
123 self.sys_locks.as_deref().unwrap_or("SYS_LOCKS")
124 }
125
126 pub fn events_stream_name(&self) -> &str {
128 self.events_stream_name.as_deref().unwrap_or("EVENTS")
129 }
130
131 pub fn events_stream_subject(&self) -> &str {
133 self.events_stream_subject.as_deref().unwrap_or("ev.>")
134 }
135}
136
137#[derive(Default)]
139pub struct ConfigBuilder<'a> {
140 prefix: Option<&'a str>,
141}
142
143impl<'a> ConfigBuilder<'a> {
144 pub fn with_prefix(mut self, prefix: &'a str) -> Self {
146 self.prefix = Some(prefix);
147 self
148 }
149
150 pub fn build(self) -> envy::Result<Config> {
152 let prefix = self.prefix.unwrap_or("NATS_");
153 let mut cfg: Config = envy::prefixed(prefix).from_env()?;
154 if cfg.app_name.is_none() {
155 cfg.app_name = Some("edd-service-rs".into());
156 }
157 let host = cfg.host.as_deref().unwrap_or("127.0.0.1");
158 let port = cfg.port.unwrap_or(4222);
159 cfg.address = Some(format!("nats://{}:{}", host, port));
160 Ok(cfg)
161 }
162}
163
164pub struct Inner {
166 client: Client,
167 config: Config,
168}
169
170#[derive(Clone)]
175pub struct Nats {
176 inner: Arc<Inner>,
177}
178
179impl Nats {
180 pub async fn new(config: Config) -> Result<Self, ConnectError> {
182 let client = async_nats::ConnectOptions::new()
183 .max_reconnects(Some(1))
184 .connect(config.address())
185 .await?;
186 Ok(Self {
187 inner: Arc::new(Inner { client, config }),
188 })
189 }
190
191 pub fn client(&self) -> &Client {
193 &self.inner.client
194 }
195
196 pub fn config(&self) -> &Config {
198 &self.inner.config
199 }
200
201 pub async fn publisher(&self) -> Result<Publisher, CreateStreamError> {
203 let ctx = jetstream::new(self.inner.client.clone());
204 let p = Publisher { ctx };
205 p.init(&self.inner.config).await?;
206 Ok(p)
207 }
208
209 pub async fn sys_consumer(&self, name: String) -> Result<PullConsumer, ConsumerError> {
211 let ctx = jetstream::new(self.inner.client.clone());
212 ctx.create_consumer_on_stream(
213 jetstream::consumer::pull::Config {
214 durable_name: Some(name),
215 ..Default::default()
216 },
217 self.inner.config.events_stream_name(),
218 )
219 .await
220 }
221
222 pub async fn sys_consumer_with_filter(
224 &self,
225 name: String,
226 filter_subject: String,
227 ) -> Result<PullConsumer, ConsumerError> {
228 let ctx = jetstream::new(self.inner.client.clone());
229 ctx.create_consumer_on_stream(
230 jetstream::consumer::pull::Config {
231 durable_name: Some(name),
232 filter_subject,
233 ..Default::default()
234 },
235 self.inner.config.events_stream_name(),
236 )
237 .await
238 }
239
240 pub async fn sys_consumer_with_filters(
242 &self,
243 name: String,
244 filter_subjects: Vec<String>,
245 ) -> Result<PullConsumer, ConsumerError> {
246 let ctx = jetstream::new(self.inner.client.clone());
247 ctx.create_consumer_on_stream(
248 jetstream::consumer::pull::Config {
249 durable_name: Some(name),
250 filter_subjects,
251 ..Default::default()
252 },
253 self.inner.config.events_stream_name(),
254 )
255 .await
256 }
257
258 pub async fn tmp_sys_consumer_with_filter(
260 &self,
261 filter_subject: String,
262 ) -> Result<PullConsumer, ConsumerError> {
263 let ctx = jetstream::new(self.inner.client.clone());
264 ctx.create_consumer_on_stream(
265 jetstream::consumer::pull::Config {
266 filter_subject,
267 deliver_policy: jetstream::consumer::DeliverPolicy::Last,
268 ..Default::default()
269 },
270 self.inner.config.events_stream_name(),
271 )
272 .await
273 }
274
275 pub async fn distributed_locks(&self) -> Result<DistributedLocks, DistributedLocksError> {
277 let ctx = jetstream::new(self.inner.client.clone());
278 DistributedLocks::new(ctx, &self.inner.config).await
279 }
280
281 pub fn sequence_manager(&self) -> SequenceManager {
283 let ctx = jetstream::new(self.inner.client.clone());
284 SequenceManager { ctx }
285 }
286}
287
288pub trait EventToSubject<M> {
293 fn event_to_subject(&self) -> async_nats::Subject;
295}
296
297pub struct Publisher {
302 ctx: Context,
303}
304
305impl Publisher {
306 async fn init(&self, config: &Config) -> Result<(), CreateStreamError> {
307 let names: HashSet<String> = self.ctx.stream_names().try_collect().await?;
308 if !names.contains(config.events_stream_name()) {
309 self.ctx
310 .create_stream(jetstream::stream::Config {
311 name: config.events_stream_name().to_string(),
312 subjects: vec![config.events_stream_subject().into()],
313 allow_direct: true,
314 deny_delete: true,
315 deny_purge: true,
316 ..Default::default()
317 })
318 .await?;
319 }
320 Ok(())
321 }
322
323 pub async fn publish<S: ToSubject, P: ?Sized + serde::Serialize>(
325 &self,
326 subject: S,
327 payload: &P,
328 ) -> anyhow::Result<()> {
329 self.ctx
330 .publish(subject, serde_json::to_vec(payload)?.into())
331 .await?;
332 Ok(())
333 }
334
335 pub async fn publish_event<S, M, P>(&self, subject: &S, payload: &P) -> anyhow::Result<()>
337 where
338 S: ?Sized + EventToSubject<M>,
339 P: ?Sized + serde::Serialize,
340 {
341 self.ctx
342 .publish(
343 subject.event_to_subject(),
344 serde_json::to_vec(payload)?.into(),
345 )
346 .await?;
347 Ok(())
348 }
349}
350
351impl AsRef<Context> for Publisher {
352 fn as_ref(&self) -> &Context {
353 &self.ctx
354 }
355}
356
357#[derive(thiserror::Error, Debug)]
359pub enum DistributedLocksError {
360 #[error(transparent)]
362 Connect(#[from] async_nats::error::Error<ConnectErrorKind>),
363 #[error(transparent)]
365 CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
366 #[error(transparent)]
368 KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
369}
370
371#[derive(Clone)]
375pub struct DistributedLocks {
376 ctx: Context,
377 sys_locks: String,
378}
379
380impl DistributedLocks {
381 async fn new(ctx: Context, config: &Config) -> Result<Self, DistributedLocksError> {
382 let lm = DistributedLocks {
383 ctx,
384 sys_locks: config.sys_locks().to_string(),
385 };
386 if !lm.exists(config.sys_locks()).await? {
387 lm.create(config.sys_locks(), 5).await?;
388 }
389 Ok(lm)
390 }
391
392 async fn create<T: Into<String>>(
393 &self,
394 name: T,
395 max_age: u64,
396 ) -> Result<Store, DistributedLocksError> {
397 Ok(self
398 .ctx
399 .create_key_value(kv::Config {
400 bucket: name.into(),
401 max_age: std::time::Duration::from_secs(max_age),
402 history: 1,
403 ..Default::default()
404 })
405 .await?)
406 }
407
408 async fn exists<T: Into<String>>(&self, bucket: T) -> Result<bool, DistributedLocksError> {
409 if let Err(err) = self.ctx.get_key_value(bucket).await {
410 if err.kind() == KeyValueErrorKind::GetBucket {
411 if let Some(src) = err.source() {
412 let err = src.downcast_ref::<async_nats::error::Error<GetStreamErrorKind>>();
413 if let Some(err) = err {
414 if let GetStreamErrorKind::JetStream(err) = err.kind() {
415 if err.code() == 404 {
416 return Ok(false);
417 }
418 }
419 }
420 }
421 }
422 Err(err)?;
423 }
424 Ok(true)
425 }
426
427 pub async fn sys_locks(&self) -> anyhow::Result<LockManager> {
429 let kv = self.ctx.get_key_value(&self.sys_locks).await?;
430 Ok(LockManager { kv: Arc::new(kv) })
431 }
432}
433
434#[derive(thiserror::Error, Debug)]
436pub enum LockManagerError {
437 #[error(transparent)]
439 CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
440 #[error(transparent)]
442 KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
443 #[error(transparent)]
445 Watch(#[from] async_nats::error::Error<kv::WatchErrorKind>),
446 #[error("unable to lock resource after {0:?}")]
448 OutOfRetries(std::time::Duration),
449}
450
451#[derive(thiserror::Error, Debug)]
453pub enum SequenceManagerError {
454 #[error(transparent)]
456 Connect(#[from] async_nats::error::Error<ConnectErrorKind>),
457 #[error(transparent)]
459 CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
460 #[error(transparent)]
462 KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
463 #[error(transparent)]
465 Put(#[from] async_nats::error::Error<async_nats::jetstream::kv::PutErrorKind>),
466 #[error(transparent)]
468 Entry(#[from] async_nats::error::Error<async_nats::jetstream::kv::EntryErrorKind>),
469}
470
471pub struct SequenceManager {
476 ctx: Context,
477}
478
479impl SequenceManager {
480 async fn create<T: Into<String>>(&self, name: T) -> Result<Store, SequenceManagerError> {
481 Ok(self
482 .ctx
483 .create_key_value(kv::Config {
484 bucket: name.into(),
485 ..Default::default()
486 })
487 .await?)
488 }
489
490 async fn exists<T: Into<String>>(&self, bucket: T) -> Result<bool, SequenceManagerError> {
491 if let Err(err) = self.ctx.get_key_value(bucket).await {
492 if err.kind() == KeyValueErrorKind::GetBucket {
493 if let Some(src) = err.source() {
494 let err = src.downcast_ref::<async_nats::error::Error<GetStreamErrorKind>>();
495 if let Some(err) = err {
496 if let GetStreamErrorKind::JetStream(err) = err.kind() {
497 if err.code() == 404 {
498 return Ok(false);
499 }
500 }
501 }
502 }
503 }
504 Err(err)?;
505 }
506 Ok(true)
507 }
508
509 async fn get<T: Into<String>>(&self, bucket: T) -> Result<Store, SequenceManagerError> {
510 Ok(self.ctx.get_key_value(bucket).await?)
511 }
512
513 pub async fn next(&self, prefix: &str, id: i64) -> Result<i64, SequenceManagerError> {
515 let bucket = format!("sm-{prefix}");
516 if !self.exists(&bucket).await? {
517 let store = self.create(&bucket).await?;
518 let result = store.put("id", id.to_be_bytes().to_vec().into()).await?;
519 Ok(result as i64)
520 } else {
521 let store = self.get(&bucket).await?;
522 let e = store.entry("id").await?;
523 if let Some(e) = e {
524 Ok(e.revision as i64)
525 } else {
526 let result = store.put("id", id.to_be_bytes().to_vec().into()).await?;
527 Ok(result as i64)
528 }
529 }
530 }
531
532 pub async fn increment(&self, prefix: &str, id: i64) -> Result<i64, SequenceManagerError> {
534 let bucket = format!("sm-{prefix}");
535 let store = self.get(&bucket).await?;
536 let e = store.put("id", id.to_be_bytes().to_vec().into()).await?;
537 Ok(e as i64)
538 }
539}
540
541pub struct LockManager {
547 kv: Arc<Store>,
548}
549
550impl LockManager {
551 pub async fn run_locked<N, O, F, E>(&self, name: N, f: F) -> Result<O, E>
556 where
557 N: Into<String>,
558 F: std::future::Future<Output = Result<O, E>>,
559 E: From<LockManagerError>,
560 {
561 let lock = self.try_lock(name.into(), 3, 5).await?;
562 let result = f.await;
563 let w_kv = self.kv.clone();
564 tokio::spawn(async move {
565 if !lock.jh.is_finished() {
566 lock.jh.abort();
567 let result = lock.jh.await;
568 if let Err(err) = result {
569 if !err.is_cancelled() {
570 tracing::error!("{err:#?}");
571 }
572 }
573 }
574 w_kv.delete_expect_revision(lock.name, Some(lock.revision.load(Ordering::SeqCst)))
575 .await
576 .ok();
577 });
578 result
579 }
580
581 async fn try_lock(
582 &self,
583 name: String,
584 timeout: u64,
585 retries: usize,
586 ) -> Result<Lock, LockManagerError> {
587 let now = std::time::Instant::now();
588 let max_retries = retries;
589 let mut tries = 0;
590 let revision = Arc::new(AtomicU64::new(0));
591 let kv = &self.kv;
592 loop {
593 if tries >= max_retries {
594 return Err(LockManagerError::OutOfRetries(now.elapsed()));
595 }
596 let v = kv.create(&name, "r".into()).await;
597 if let Err(err) = v {
598 if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists {
599 tracing::debug!("seems to be locked already, {tries} try to watch for changes");
600 let mut w = kv.watch(&name).await?;
601 let f = async {
602 'inner: while let Some(m) = w.next().await {
603 if let Ok(e) = m {
604 if e.operation == Operation::Delete {
605 tracing::debug!("retry because prev lock was deleted");
606 break 'inner;
607 }
608 }
609 }
610 };
611 let t = async {
612 tokio::time::sleep(std::time::Duration::from_secs(timeout)).await;
613 };
614 let change = tokio::select! {
615 _ = f => true,
616 _ = t => false,
617 };
618 if !change {
619 tries += 1;
620 }
621 }
622 } else {
623 let r = v.unwrap();
624 revision.store(r, Ordering::SeqCst);
625 tracing::debug!("got lock: '{name}'");
626 break;
627 }
628 }
629 let w_kv = self.kv.clone();
630 let w_name = name.clone();
631 let w_revision = revision.clone();
632
633 let jh = tokio::spawn(async move {
634 let mut run = 0;
635 loop {
636 run += 1;
637 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
638 tracing::debug!("refresh lock {w_name}");
639 let result = w_kv
640 .update(&w_name, "u".into(), w_revision.load(Ordering::SeqCst))
641 .await;
642 if let Err(err) = result {
643 tracing::error!("{err:#?}");
644 break;
645 } else {
646 w_revision.store(result.unwrap(), Ordering::SeqCst);
647 }
648 if run >= 5 {
649 tracing::debug!("release lock after timeout");
650 break;
651 }
652 }
653 anyhow::Ok(())
654 });
655
656 Ok(Lock { name, revision, jh })
657 }
658}
659
660#[derive(Debug, PartialEq, Eq)]
662pub enum LockState {
663 Registering,
665 Registered,
667}
668
669#[derive(Debug)]
674pub struct Lock {
675 name: String,
676 revision: Arc<AtomicU64>,
677 jh: JoinHandle<anyhow::Result<()>>,
678}