Skip to main content

apalis_nats/
lib.rs

1#![doc = include_str!("../README.md")]
2use apalis_codec::json::JsonCodec;
3use apalis_core::{
4    backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
5    task::{Task, task_id::TaskId},
6    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
7};
8use async_nats::{
9    Client, HeaderMap, StatusCode, Subject, header,
10    jetstream::{
11        self,
12        consumer::{Consumer, FromConsumer, IntoConsumerConfig},
13    },
14};
15use futures::{
16    StreamExt, TryStreamExt,
17    stream::{self, BoxStream, once},
18};
19use std::marker::PhantomData;
20use ulid::Ulid;
21
22pub use crate::{
23    config::Config, consumer::IntoMessageStream, error::JetStreamError, sink::JetStreamSink,
24};
25
26mod ack;
27mod config;
28mod consumer;
29mod error;
30mod sink;
31
32pub type JetStreamTask<Args> = Task<Args, JetStreamContext, ulid::Ulid>;
33
34#[derive(Debug, Default, Clone)]
35pub struct JetStreamContext {
36    pub subject: Option<Subject>,
37    pub reply: Option<Subject>,
38    pub headers: Option<HeaderMap>,
39    pub status: Option<StatusCode>,
40    pub description: Option<String>,
41}
42
43pub struct NatsJetStream<Args, Codec, C> {
44    _args: PhantomData<(Args, Codec)>,
45    client: Client,
46    config: Config<C>,
47    sink: JetStreamSink<Args, Codec, C>,
48}
49
50impl<Args, C: Clone> NatsJetStream<Args, JsonCodec<Vec<u8>>, C> {
51    pub async fn new(client: Client, config: Config<C>) -> Self {
52        let context = jetstream::new(client.clone());
53        context
54            .get_or_create_stream(config.clone().stream)
55            .await
56            .expect("Could not create stream");
57        Self {
58            sink: JetStreamSink::new(context, config.clone()),
59            _args: PhantomData,
60            client,
61            config,
62        }
63    }
64}
65
66impl<Args, Codec, C: Clone> Clone for NatsJetStream<Args, Codec, C> {
67    fn clone(&self) -> Self {
68        Self {
69            _args: PhantomData,
70            client: self.client.clone(),
71            config: self.config.clone(),
72            sink: self.sink.clone(),
73        }
74    }
75}
76
77impl<Args, Decode, C, PollErr> Backend for NatsJetStream<Args, Decode, C>
78where
79    Args: Send + Sync + 'static + Unpin,
80    Decode: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
81    Decode::Error: std::error::Error + Send + Sync + 'static,
82    C: Clone + IntoConsumerConfig + FromConsumer + Send + 'static,
83    Consumer<C>: IntoMessageStream<Error = PollErr>,
84    PollErr: Send + 'static,
85{
86    type Args = Args;
87
88    type Context = JetStreamContext;
89
90    type Beat = BoxStream<'static, Result<(), JetStreamError<PollErr>>>;
91
92    type Error = JetStreamError<PollErr>;
93
94    type IdType = ulid::Ulid;
95
96    type Layer = AcknowledgeLayer<Self>;
97
98    type Stream = TaskStream<JetStreamTask<Args>, JetStreamError<PollErr>>;
99
100    fn heartbeat(&self, _worker: &WorkerContext) -> Self::Beat {
101        let heartbeat = self.config.heartbeat;
102        stream::unfold(self.client.clone(), move |client| async move {
103            apalis_core::timer::sleep(heartbeat).await;
104            let res = client
105                .flush()
106                .await
107                .map_err(|e| JetStreamError::FlushError(e));
108            Some((res, client))
109        })
110        .boxed()
111    }
112
113    fn middleware(&self) -> Self::Layer {
114        AcknowledgeLayer::new(self.clone())
115    }
116
117    fn poll(self, _worker: &WorkerContext) -> Self::Stream {
118        self.poll_general()
119            .map(|t| match t {
120                Ok(Some(task)) => Ok(Some(task.try_map(|task| {
121                    Decode::decode(&task).map_err(|e| JetStreamError::ParseError(e.into()))
122                })?)),
123                Ok(None) => Ok(None),
124                Err(e) => Err(e),
125            })
126            .boxed()
127    }
128}
129
130impl<Args, Decode, C, PollErr> BackendExt for NatsJetStream<Args, Decode, C>
131where
132    Args: Send + Sync + 'static + Unpin,
133    Decode: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
134    Decode::Error: std::error::Error + Send + Sync + 'static,
135    C: Clone + IntoConsumerConfig + FromConsumer + Send + 'static,
136    Consumer<C>: IntoMessageStream<Error = PollErr>,
137    PollErr: Send + 'static,
138{
139    type Codec = Decode;
140    type Compact = Vec<u8>;
141
142    type CompactStream = TaskStream<JetStreamTask<Self::Compact>, JetStreamError<PollErr>>;
143
144    fn get_queue(&self) -> Queue {
145        Queue::from(self.config.stream.name.as_str())
146    }
147
148    fn poll_compact(self, _worker: &WorkerContext) -> Self::CompactStream {
149        self.poll_general()
150    }
151}
152
153impl<Args, Decode, C, PollErr> NatsJetStream<Args, Decode, C>
154where
155    Args: Send + Sync + 'static + Unpin,
156    Decode: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
157    Decode::Error: std::error::Error + Send + Sync + 'static,
158    C: Clone + IntoConsumerConfig + FromConsumer + Send + 'static,
159    Consumer<C>: IntoMessageStream<Error = PollErr>,
160    PollErr: Send + 'static,
161{
162    pub fn poll_general(self) -> TaskStream<JetStreamTask<Vec<u8>>, JetStreamError<PollErr>> {
163        let config = self.config;
164        once(async move {
165            let jetstream = jetstream::new(self.client);
166            let consumer = jetstream
167                .create_stream(config.stream)
168                .await
169                .map_err(|e| JetStreamError::CreateStreamError(e))?
170                // Then, on that `Stream` use method to create Consumer and bind to it too.
171                .create_consumer(config.consumer)
172                .await
173                .map_err(|e| JetStreamError::ConsumerError(e))?;
174            let stream = consumer
175                .into_messages()
176                .await
177                .map_err(|e| JetStreamError::StreamError(e))?
178                .map_err(|e| JetStreamError::PollError(e));
179            Ok::<_, JetStreamError<PollErr>>(stream)
180        })
181        .try_flatten()
182        .map(|message| match message {
183            Ok(msg) => {
184                let args = msg.payload[..].to_vec();
185                let mut task = Task::builder(args);
186                if let Some(headers) = &msg.headers {
187                    let task_id = headers
188                        .get(header::NATS_MESSAGE_ID)
189                        .and_then(|s| Ulid::from_string(s.as_str()).ok());
190                    if let Some(task_id) = task_id {
191                        task = task.with_task_id(TaskId::new(task_id));
192                    }
193                }
194                let ctx = JetStreamContext {
195                    subject: Some(msg.subject.clone()),
196                    reply: msg.reply.clone(),
197                    status: msg.status,
198                    headers: msg.headers.clone(),
199                    description: msg.description.clone(),
200                };
201                task = task.with_ctx(ctx);
202                Ok(Some(task.build()))
203            }
204            Err(e) => Err(e),
205        })
206        .boxed()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use std::{collections::HashMap, env, time::Duration};
213
214    use apalis_core::{error::BoxDynError, worker::builder::WorkerBuilder};
215    use futures::SinkExt;
216
217    use super::*;
218
219    #[tokio::test]
220    async fn basic_worker() {
221        let nats_url = env::var("NATS_URL").unwrap_or_else(|_| "nats://localhost:4222".to_string());
222
223        // Create an unauthenticated connection to NATS.
224        let client = async_nats::connect(nats_url).await.unwrap();
225
226        let config = Config::new("push_messages")
227            .with_pull_consumer()
228            .durable()
229            .with_max_ack_pending(1);
230
231        let mut backend = NatsJetStream::new(client.clone(), config).await;
232
233        backend.send(Task::new(HashMap::new())).await.unwrap();
234
235        async fn send_reminder(
236            _: HashMap<String, String>,
237            wrk: WorkerContext,
238        ) -> Result<(), BoxDynError> {
239            tokio::time::sleep(Duration::from_secs(5)).await;
240            wrk.stop().unwrap();
241            Ok(())
242        }
243
244        let worker = WorkerBuilder::new("rango-tango-1")
245            .backend(backend)
246            .build(send_reminder);
247        worker.run().await.unwrap();
248
249        // This ensures all pending messages are delivered
250        client.flush().await.unwrap();
251    }
252}