1#![doc = include_str!("../README.md")]
2use apalis_codec::json::JsonCodec;
3use apalis_core::{
4 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
5 task::{Task, task_id::TaskId},
6 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
7};
8use async_nats::{
9 Client, HeaderMap, StatusCode, Subject, header,
10 jetstream::{
11 self,
12 consumer::{Consumer, FromConsumer, IntoConsumerConfig},
13 },
14};
15use futures::{
16 StreamExt, TryStreamExt,
17 stream::{self, BoxStream, once},
18};
19use std::marker::PhantomData;
20use ulid::Ulid;
21
22pub use crate::{
23 config::Config, consumer::IntoMessageStream, error::JetStreamError, sink::JetStreamSink,
24};
25
26mod ack;
27mod config;
28mod consumer;
29mod error;
30mod sink;
31
32pub type JetStreamTask<Args> = Task<Args, JetStreamContext, ulid::Ulid>;
33
34#[derive(Debug, Default, Clone)]
35pub struct JetStreamContext {
36 pub subject: Option<Subject>,
37 pub reply: Option<Subject>,
38 pub headers: Option<HeaderMap>,
39 pub status: Option<StatusCode>,
40 pub description: Option<String>,
41}
42
43pub struct NatsJetStream<Args, Codec, C> {
44 _args: PhantomData<(Args, Codec)>,
45 client: Client,
46 config: Config<C>,
47 sink: JetStreamSink<Args, Codec, C>,
48}
49
50impl<Args, C: Clone> NatsJetStream<Args, JsonCodec<Vec<u8>>, C> {
51 pub async fn new(client: Client, config: Config<C>) -> Self {
52 let context = jetstream::new(client.clone());
53 context
54 .get_or_create_stream(config.clone().stream)
55 .await
56 .expect("Could not create stream");
57 Self {
58 sink: JetStreamSink::new(context, config.clone()),
59 _args: PhantomData,
60 client,
61 config,
62 }
63 }
64}
65
66impl<Args, Codec, C: Clone> Clone for NatsJetStream<Args, Codec, C> {
67 fn clone(&self) -> Self {
68 Self {
69 _args: PhantomData,
70 client: self.client.clone(),
71 config: self.config.clone(),
72 sink: self.sink.clone(),
73 }
74 }
75}
76
77impl<Args, Decode, C, PollErr> Backend for NatsJetStream<Args, Decode, C>
78where
79 Args: Send + Sync + 'static + Unpin,
80 Decode: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
81 Decode::Error: std::error::Error + Send + Sync + 'static,
82 C: Clone + IntoConsumerConfig + FromConsumer + Send + 'static,
83 Consumer<C>: IntoMessageStream<Error = PollErr>,
84 PollErr: Send + 'static,
85{
86 type Args = Args;
87
88 type Context = JetStreamContext;
89
90 type Beat = BoxStream<'static, Result<(), JetStreamError<PollErr>>>;
91
92 type Error = JetStreamError<PollErr>;
93
94 type IdType = ulid::Ulid;
95
96 type Layer = AcknowledgeLayer<Self>;
97
98 type Stream = TaskStream<JetStreamTask<Args>, JetStreamError<PollErr>>;
99
100 fn heartbeat(&self, _worker: &WorkerContext) -> Self::Beat {
101 let heartbeat = self.config.heartbeat;
102 stream::unfold(self.client.clone(), move |client| async move {
103 apalis_core::timer::sleep(heartbeat).await;
104 let res = client
105 .flush()
106 .await
107 .map_err(|e| JetStreamError::FlushError(e));
108 Some((res, client))
109 })
110 .boxed()
111 }
112
113 fn middleware(&self) -> Self::Layer {
114 AcknowledgeLayer::new(self.clone())
115 }
116
117 fn poll(self, _worker: &WorkerContext) -> Self::Stream {
118 self.poll_general()
119 .map(|t| match t {
120 Ok(Some(task)) => Ok(Some(task.try_map(|task| {
121 Decode::decode(&task).map_err(|e| JetStreamError::ParseError(e.into()))
122 })?)),
123 Ok(None) => Ok(None),
124 Err(e) => Err(e),
125 })
126 .boxed()
127 }
128}
129
130impl<Args, Decode, C, PollErr> BackendExt for NatsJetStream<Args, Decode, C>
131where
132 Args: Send + Sync + 'static + Unpin,
133 Decode: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
134 Decode::Error: std::error::Error + Send + Sync + 'static,
135 C: Clone + IntoConsumerConfig + FromConsumer + Send + 'static,
136 Consumer<C>: IntoMessageStream<Error = PollErr>,
137 PollErr: Send + 'static,
138{
139 type Codec = Decode;
140 type Compact = Vec<u8>;
141
142 type CompactStream = TaskStream<JetStreamTask<Self::Compact>, JetStreamError<PollErr>>;
143
144 fn get_queue(&self) -> Queue {
145 Queue::from(self.config.stream.name.as_str())
146 }
147
148 fn poll_compact(self, _worker: &WorkerContext) -> Self::CompactStream {
149 self.poll_general()
150 }
151}
152
153impl<Args, Decode, C, PollErr> NatsJetStream<Args, Decode, C>
154where
155 Args: Send + Sync + 'static + Unpin,
156 Decode: Codec<Args, Compact = Vec<u8>> + Send + Sync + 'static,
157 Decode::Error: std::error::Error + Send + Sync + 'static,
158 C: Clone + IntoConsumerConfig + FromConsumer + Send + 'static,
159 Consumer<C>: IntoMessageStream<Error = PollErr>,
160 PollErr: Send + 'static,
161{
162 pub fn poll_general(self) -> TaskStream<JetStreamTask<Vec<u8>>, JetStreamError<PollErr>> {
163 let config = self.config;
164 once(async move {
165 let jetstream = jetstream::new(self.client);
166 let consumer = jetstream
167 .create_stream(config.stream)
168 .await
169 .map_err(|e| JetStreamError::CreateStreamError(e))?
170 .create_consumer(config.consumer)
172 .await
173 .map_err(|e| JetStreamError::ConsumerError(e))?;
174 let stream = consumer
175 .into_messages()
176 .await
177 .map_err(|e| JetStreamError::StreamError(e))?
178 .map_err(|e| JetStreamError::PollError(e));
179 Ok::<_, JetStreamError<PollErr>>(stream)
180 })
181 .try_flatten()
182 .map(|message| match message {
183 Ok(msg) => {
184 let args = msg.payload[..].to_vec();
185 let mut task = Task::builder(args);
186 if let Some(headers) = &msg.headers {
187 let task_id = headers
188 .get(header::NATS_MESSAGE_ID)
189 .and_then(|s| Ulid::from_string(s.as_str()).ok());
190 if let Some(task_id) = task_id {
191 task = task.with_task_id(TaskId::new(task_id));
192 }
193 }
194 let ctx = JetStreamContext {
195 subject: Some(msg.subject.clone()),
196 reply: msg.reply.clone(),
197 status: msg.status,
198 headers: msg.headers.clone(),
199 description: msg.description.clone(),
200 };
201 task = task.with_ctx(ctx);
202 Ok(Some(task.build()))
203 }
204 Err(e) => Err(e),
205 })
206 .boxed()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use std::{collections::HashMap, env, time::Duration};
213
214 use apalis_core::{error::BoxDynError, worker::builder::WorkerBuilder};
215 use futures::SinkExt;
216
217 use super::*;
218
219 #[tokio::test]
220 async fn basic_worker() {
221 let nats_url = env::var("NATS_URL").unwrap_or_else(|_| "nats://localhost:4222".to_string());
222
223 let client = async_nats::connect(nats_url).await.unwrap();
225
226 let config = Config::new("push_messages")
227 .with_pull_consumer()
228 .durable()
229 .with_max_ack_pending(1);
230
231 let mut backend = NatsJetStream::new(client.clone(), config).await;
232
233 backend.send(Task::new(HashMap::new())).await.unwrap();
234
235 async fn send_reminder(
236 _: HashMap<String, String>,
237 wrk: WorkerContext,
238 ) -> Result<(), BoxDynError> {
239 tokio::time::sleep(Duration::from_secs(5)).await;
240 wrk.stop().unwrap();
241 Ok(())
242 }
243
244 let worker = WorkerBuilder::new("rango-tango-1")
245 .backend(backend)
246 .build(send_reminder);
247 worker.run().await.unwrap();
248
249 client.flush().await.unwrap();
251 }
252}