qm_nats/
lib.rs

1use std::{
2    collections::HashSet,
3    error::Error,
4    sync::{
5        atomic::{AtomicU64, Ordering},
6        Arc,
7    },
8};
9
10use async_nats::{
11    jetstream::{
12        self,
13        consumer::PullConsumer,
14        context::{
15            CreateKeyValueErrorKind, CreateStreamError, GetStreamErrorKind, KeyValueErrorKind,
16        },
17        kv::{self, Operation, Store},
18        stream::ConsumerError,
19        Context,
20    },
21    subject::ToSubject,
22    Client, ConnectError, ConnectErrorKind,
23};
24use futures::{StreamExt, TryStreamExt};
25use tokio::task::JoinHandle;
26
27pub use async_nats;
28
29pub mod subject;
30
31#[derive(Clone, serde::Deserialize)]
32pub struct Config {
33    app_name: Option<String>,
34    host: Option<String>,
35    port: Option<u16>,
36    #[serde(skip)]
37    address: Option<String>,
38    sys_locks: Option<String>,
39    events_stream_name: Option<String>,
40    events_stream_subject: Option<String>,
41}
42
43impl Config {
44    pub fn new() -> envy::Result<Self> {
45        ConfigBuilder::default().build()
46    }
47
48    pub fn builder<'a>() -> ConfigBuilder<'a> {
49        ConfigBuilder::default()
50    }
51
52    pub fn address(&self) -> &str {
53        self.address.as_deref().unwrap()
54    }
55
56    pub fn port(&self) -> u16 {
57        self.port.unwrap_or(3000)
58    }
59    pub fn sys_locks(&self) -> &str {
60        self.sys_locks.as_deref().unwrap_or("SYS_LOCKS")
61    }
62    pub fn events_stream_name(&self) -> &str {
63        self.events_stream_name.as_deref().unwrap_or("EVENTS")
64    }
65    pub fn events_stream_subject(&self) -> &str {
66        self.events_stream_subject.as_deref().unwrap_or("ev.>")
67    }
68}
69
70#[derive(Default)]
71pub struct ConfigBuilder<'a> {
72    prefix: Option<&'a str>,
73}
74
75impl<'a> ConfigBuilder<'a> {
76    pub fn with_prefix(mut self, prefix: &'a str) -> Self {
77        self.prefix = Some(prefix);
78        self
79    }
80
81    pub fn build(self) -> envy::Result<Config> {
82        let prefix = self.prefix.unwrap_or("NATS_");
83        let mut cfg: Config = envy::prefixed(prefix).from_env()?;
84        if cfg.app_name.is_none() {
85            cfg.app_name = Some("edd-service-rs".into());
86        }
87        let host = cfg.host.as_deref().unwrap_or("127.0.0.1");
88        let port = cfg.port.unwrap_or(4222);
89        cfg.address = Some(format!("nats://{}:{}", host, port));
90        Ok(cfg)
91    }
92}
93
94pub struct Inner {
95    client: Client,
96    config: Config,
97}
98
99#[derive(Clone)]
100pub struct Nats {
101    inner: Arc<Inner>,
102}
103
104impl Nats {
105    pub async fn new(config: Config) -> Result<Self, ConnectError> {
106        let client = async_nats::ConnectOptions::new()
107            .max_reconnects(Some(1))
108            .connect(config.address())
109            .await?;
110        Ok(Self {
111            inner: Arc::new(Inner { client, config }),
112        })
113    }
114
115    pub fn client(&self) -> &Client {
116        &self.inner.client
117    }
118
119    pub fn config(&self) -> &Config {
120        &self.inner.config
121    }
122
123    pub async fn publisher(&self) -> Result<Publisher, CreateStreamError> {
124        let ctx = jetstream::new(self.inner.client.clone());
125        let p = Publisher { ctx };
126        p.init(&self.inner.config).await?;
127        Ok(p)
128    }
129
130    pub async fn sys_consumer(&self, name: String) -> Result<PullConsumer, ConsumerError> {
131        let ctx = jetstream::new(self.inner.client.clone());
132        ctx.create_consumer_on_stream(
133            jetstream::consumer::pull::Config {
134                durable_name: Some(name),
135                ..Default::default()
136            },
137            self.inner.config.events_stream_name(),
138        )
139        .await
140    }
141
142    pub async fn sys_consumer_with_filter(
143        &self,
144        name: String,
145        filter_subject: String,
146    ) -> Result<PullConsumer, ConsumerError> {
147        let ctx = jetstream::new(self.inner.client.clone());
148        ctx.create_consumer_on_stream(
149            jetstream::consumer::pull::Config {
150                durable_name: Some(name),
151                filter_subject,
152                ..Default::default()
153            },
154            self.inner.config.events_stream_name(),
155        )
156        .await
157    }
158
159    pub async fn sys_consumer_with_filters(
160        &self,
161        name: String,
162        filter_subjects: Vec<String>,
163    ) -> Result<PullConsumer, ConsumerError> {
164        let ctx = jetstream::new(self.inner.client.clone());
165        ctx.create_consumer_on_stream(
166            jetstream::consumer::pull::Config {
167                durable_name: Some(name),
168                filter_subjects,
169                ..Default::default()
170            },
171            self.inner.config.events_stream_name(),
172        )
173        .await
174    }
175
176    pub async fn tmp_sys_consumer_with_filter(
177        &self,
178        filter_subject: String,
179    ) -> Result<PullConsumer, ConsumerError> {
180        let ctx = jetstream::new(self.inner.client.clone());
181        ctx.create_consumer_on_stream(
182            jetstream::consumer::pull::Config {
183                filter_subject,
184                deliver_policy: jetstream::consumer::DeliverPolicy::Last,
185                ..Default::default()
186            },
187            self.inner.config.events_stream_name(),
188        )
189        .await
190    }
191
192    pub async fn distributed_locks(&self) -> Result<DistributedLocks, DistributedLocksError> {
193        let ctx = jetstream::new(self.inner.client.clone());
194        DistributedLocks::new(ctx, &self.inner.config).await
195    }
196
197    pub fn sequence_manager(&self) -> SequenceManager {
198        let ctx = jetstream::new(self.inner.client.clone());
199        SequenceManager { ctx }
200    }
201}
202
203pub trait EventToSubject<M> {
204    fn event_to_subject(&self) -> async_nats::Subject;
205}
206
207pub struct Publisher {
208    ctx: Context,
209}
210
211impl Publisher {
212    async fn init(&self, config: &Config) -> Result<(), CreateStreamError> {
213        let names: HashSet<String> = self.ctx.stream_names().try_collect().await?;
214        if !names.contains(config.events_stream_name()) {
215            self.ctx
216                .create_stream(jetstream::stream::Config {
217                    name: config.events_stream_name().to_string(),
218                    subjects: vec![config.events_stream_subject().into()],
219                    allow_direct: true,
220                    deny_delete: true,
221                    deny_purge: true,
222                    ..Default::default()
223                })
224                .await?;
225        }
226        Ok(())
227    }
228
229    pub async fn publish<S: ToSubject, P: ?Sized + serde::Serialize>(
230        &self,
231        subject: S,
232        payload: &P,
233    ) -> anyhow::Result<()> {
234        self.ctx
235            .publish(subject, serde_json::to_vec(payload)?.into())
236            .await?;
237        Ok(())
238    }
239
240    pub async fn publish_event<S, M, P>(&self, subject: &S, payload: &P) -> anyhow::Result<()>
241    where
242        S: ?Sized + EventToSubject<M>,
243        P: ?Sized + serde::Serialize,
244    {
245        self.ctx
246            .publish(
247                subject.event_to_subject(),
248                serde_json::to_vec(payload)?.into(),
249            )
250            .await?;
251        Ok(())
252    }
253}
254
255impl AsRef<Context> for Publisher {
256    fn as_ref(&self) -> &Context {
257        &self.ctx
258    }
259}
260
261#[derive(thiserror::Error, Debug)]
262pub enum DistributedLocksError {
263    #[error(transparent)]
264    Connect(#[from] async_nats::error::Error<ConnectErrorKind>),
265    #[error(transparent)]
266    CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
267    #[error(transparent)]
268    KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
269}
270
271#[derive(Clone)]
272pub struct DistributedLocks {
273    ctx: Context,
274    sys_locks: String,
275}
276
277impl DistributedLocks {
278    async fn new(ctx: Context, config: &Config) -> Result<Self, DistributedLocksError> {
279        let lm = DistributedLocks {
280            ctx,
281            sys_locks: config.sys_locks().to_string(),
282        };
283        if !lm.exists(config.sys_locks()).await? {
284            lm.create(config.sys_locks(), 5).await?;
285        }
286        Ok(lm)
287    }
288
289    async fn create<T: Into<String>>(
290        &self,
291        name: T,
292        max_age: u64,
293    ) -> Result<Store, DistributedLocksError> {
294        Ok(self
295            .ctx
296            .create_key_value(kv::Config {
297                bucket: name.into(),
298                max_age: std::time::Duration::from_secs(max_age),
299                history: 1,
300                ..Default::default()
301            })
302            .await?)
303    }
304
305    async fn exists<T: Into<String>>(&self, bucket: T) -> Result<bool, DistributedLocksError> {
306        if let Err(err) = self.ctx.get_key_value(bucket).await {
307            if err.kind() == KeyValueErrorKind::GetBucket {
308                if let Some(src) = err.source() {
309                    let err = src.downcast_ref::<async_nats::error::Error<GetStreamErrorKind>>();
310                    if let Some(err) = err {
311                        if let GetStreamErrorKind::JetStream(err) = err.kind() {
312                            if err.code() == 404 {
313                                return Ok(false);
314                            }
315                        }
316                    }
317                }
318            }
319            Err(err)?;
320        }
321        Ok(true)
322    }
323
324    pub async fn sys_locks(&self) -> anyhow::Result<LockManager> {
325        let kv = self.ctx.get_key_value(&self.sys_locks).await?;
326        Ok(LockManager { kv: Arc::new(kv) })
327    }
328}
329
330#[derive(thiserror::Error, Debug)]
331pub enum LockManagerError {
332    #[error(transparent)]
333    CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
334    #[error(transparent)]
335    KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
336    #[error(transparent)]
337    Watch(#[from] async_nats::error::Error<kv::WatchErrorKind>),
338    #[error("unable to lock resource after {0:?}")]
339    OutOfRetries(std::time::Duration),
340}
341
342#[derive(thiserror::Error, Debug)]
343pub enum SequenceManagerError {
344    #[error(transparent)]
345    Connect(#[from] async_nats::error::Error<ConnectErrorKind>),
346    #[error(transparent)]
347    CreateKeyValue(#[from] async_nats::error::Error<CreateKeyValueErrorKind>),
348    #[error(transparent)]
349    KeyValue(#[from] async_nats::error::Error<KeyValueErrorKind>),
350    #[error(transparent)]
351    Put(#[from] async_nats::error::Error<async_nats::jetstream::kv::PutErrorKind>),
352    #[error(transparent)]
353    Entry(#[from] async_nats::error::Error<async_nats::jetstream::kv::EntryErrorKind>),
354}
355
356pub struct SequenceManager {
357    ctx: Context,
358}
359
360impl SequenceManager {
361    async fn create<T: Into<String>>(&self, name: T) -> Result<Store, SequenceManagerError> {
362        Ok(self
363            .ctx
364            .create_key_value(kv::Config {
365                bucket: name.into(),
366                ..Default::default()
367            })
368            .await?)
369    }
370
371    async fn exists<T: Into<String>>(&self, bucket: T) -> Result<bool, SequenceManagerError> {
372        if let Err(err) = self.ctx.get_key_value(bucket).await {
373            if err.kind() == KeyValueErrorKind::GetBucket {
374                if let Some(src) = err.source() {
375                    let err = src.downcast_ref::<async_nats::error::Error<GetStreamErrorKind>>();
376                    if let Some(err) = err {
377                        if let GetStreamErrorKind::JetStream(err) = err.kind() {
378                            if err.code() == 404 {
379                                return Ok(false);
380                            }
381                        }
382                    }
383                }
384            }
385            Err(err)?;
386        }
387        Ok(true)
388    }
389
390    async fn get<T: Into<String>>(&self, bucket: T) -> Result<Store, SequenceManagerError> {
391        Ok(self.ctx.get_key_value(bucket).await?)
392    }
393
394    pub async fn next(&self, prefix: &str, id: i64) -> Result<i64, SequenceManagerError> {
395        let bucket = format!("sm-{prefix}");
396        if !self.exists(&bucket).await? {
397            let store = self.create(&bucket).await?;
398            let result = store.put("id", id.to_be_bytes().to_vec().into()).await?;
399            Ok(result as i64)
400        } else {
401            let store = self.get(&bucket).await?;
402            let e = store.entry("id").await?;
403            if let Some(e) = e {
404                Ok(e.revision as i64)
405            } else {
406                let result = store.put("id", id.to_be_bytes().to_vec().into()).await?;
407                Ok(result as i64)
408            }
409        }
410    }
411
412    pub async fn increment(&self, prefix: &str, id: i64) -> Result<i64, SequenceManagerError> {
413        let bucket = format!("sm-{prefix}");
414        let store = self.get(&bucket).await?;
415        let e = store.put("id", id.to_be_bytes().to_vec().into()).await?;
416        Ok(e as i64)
417    }
418}
419
420pub struct LockManager {
421    kv: Arc<Store>,
422}
423
424impl LockManager {
425    pub async fn run_locked<N, O, F, E>(&self, name: N, f: F) -> Result<O, E>
426    where
427        N: Into<String>,
428        F: std::future::Future<Output = Result<O, E>>,
429        E: From<LockManagerError>,
430    {
431        let lock = self.try_lock(name.into(), 3, 5).await?;
432        let result = f.await;
433        let w_kv = self.kv.clone();
434        tokio::spawn(async move {
435            if !lock.jh.is_finished() {
436                lock.jh.abort();
437                let result = lock.jh.await;
438                if let Err(err) = result {
439                    if !err.is_cancelled() {
440                        tracing::error!("{err:#?}");
441                    }
442                }
443            }
444            w_kv.delete_expect_revision(lock.name, Some(lock.revision.load(Ordering::SeqCst)))
445                .await
446                .ok();
447        });
448        result
449    }
450
451    async fn try_lock(
452        &self,
453        name: String,
454        timeout: u64,
455        retries: usize,
456    ) -> Result<Lock, LockManagerError> {
457        let now = std::time::Instant::now();
458        let max_retries = retries;
459        let mut tries = 0;
460        let revision = Arc::new(AtomicU64::new(0));
461        let kv = &self.kv;
462        loop {
463            if tries >= max_retries {
464                return Err(LockManagerError::OutOfRetries(now.elapsed()));
465            }
466            let v = kv.create(&name, "r".into()).await;
467            if let Err(err) = v {
468                if err.kind() == async_nats::jetstream::kv::CreateErrorKind::AlreadyExists {
469                    tracing::debug!("seems to be locked already, {tries} try to watch for changes");
470                    let mut w = kv.watch(&name).await?;
471                    let f = async {
472                        'inner: while let Some(m) = w.next().await {
473                            if let Ok(e) = m {
474                                if e.operation == Operation::Delete {
475                                    tracing::debug!("retry because prev lock was deleted");
476                                    break 'inner;
477                                }
478                            }
479                        }
480                    };
481                    let t = async {
482                        tokio::time::sleep(std::time::Duration::from_secs(timeout)).await;
483                    };
484                    let change = tokio::select! {
485                        _ = f => true,
486                        _ = t => false,
487                    };
488                    if !change {
489                        tries += 1;
490                    }
491                }
492            } else {
493                let r = v.unwrap();
494                revision.store(r, Ordering::SeqCst);
495                tracing::debug!("got lock: '{name}'");
496                break;
497            }
498        }
499        let w_kv = self.kv.clone();
500        let w_name = name.clone();
501        let w_revision = revision.clone();
502
503        let jh = tokio::spawn(async move {
504            let mut run = 0;
505            loop {
506                run += 1;
507                tokio::time::sleep(std::time::Duration::from_secs(2)).await;
508                tracing::debug!("refresh lock {w_name}");
509                let result = w_kv
510                    .update(&w_name, "u".into(), w_revision.load(Ordering::SeqCst))
511                    .await;
512                if let Err(err) = result {
513                    tracing::error!("{err:#?}");
514                    break;
515                } else {
516                    w_revision.store(result.unwrap(), Ordering::SeqCst);
517                }
518                if run >= 5 {
519                    tracing::debug!("release lock after timeout");
520                    break;
521                }
522            }
523            anyhow::Ok(())
524        });
525
526        Ok(Lock { name, revision, jh })
527    }
528}
529
530#[derive(Debug, PartialEq, Eq)]
531pub enum LockState {
532    Registering,
533    Registered,
534}
535
536#[derive(Debug)]
537pub struct Lock {
538    name: String,
539    revision: Arc<AtomicU64>,
540    jh: JoinHandle<anyhow::Result<()>>,
541}