1use futures::{
3 future::{join_all, BoxFuture},
4 StreamExt, TryStreamExt,
5};
6use lapin::{
7 options::{
8 BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions,
9 BasicQosOptions,
10 },
11 tcp::{OwnedIdentity, OwnedTLSConfig},
12 types::{DeliveryTag, FieldTable, ShortString},
13 BasicProperties, Channel, Connection, ConnectionProperties, Consumer,
14};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use std::{
17 fmt::{Debug, Display},
18 pin::Pin,
19 str::{from_utf8, FromStr},
20 sync::Arc,
21};
22use tokio::task::JoinError;
23
24pub mod client;
26
27mod test;
28
29pub struct Worker {
31 channel: Channel,
32 rpc_consumers: Vec<ListenerConfig>,
34 rpc_handlers: Vec<
35 Arc<
36 dyn Fn(
37 lapin::message::Delivery,
38 ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
39 + Send
40 + Sync,
41 >,
42 >,
43 consumers: Vec<ListenerConfig>,
45 handlers: Vec<
46 Arc<
47 dyn Fn(
48 lapin::message::Delivery,
49 ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
50 + Send
51 + Sync,
52 >,
53 >,
54}
55
56pub struct TlsConfig {
59 cert_chain: String,
60 client_cert_and_key: String,
61 client_cert_and_key_password: String,
62}
63
64impl TlsConfig {
65 pub fn new(
67 cert_chain: String,
68 client_cert_and_key: String,
69 client_cert_and_key_password: String,
70 ) -> Self {
71 Self {
72 cert_chain,
73 client_cert_and_key,
74 client_cert_and_key_password,
75 }
76 }
77}
78
79#[derive(Debug)]
80pub struct WorkerConfig {
82 tls: Option<OwnedTLSConfig>,
83}
84impl WorkerConfig {
85 pub fn default() -> Self {
87 Self { tls: None }
88 }
89
90 pub fn enable_tls(mut self, custom_tls: Option<TlsConfig>) -> Self {
95 match custom_tls {
96 Some(tls) => {
97 let tls = OwnedTLSConfig {
98 identity: Some(OwnedIdentity {
99 der: tls.client_cert_and_key.as_bytes().to_vec(),
100 password: tls.client_cert_and_key_password,
101 }),
102 cert_chain: Some(tls.cert_chain.to_string()),
103 };
104 self.tls = tls.into();
105 }
106 None => self.tls = OwnedTLSConfig::default().into(),
107 }
108 self
109 }
110}
111
112pub struct ListenerConfig {
115 prefetch_count: u16,
116 queue_name: String,
117 consumer_tag: String,
118 message_version: String,
119}
120
121impl ListenerConfig {
122 pub fn default(queue_name: impl Into<String>) -> Self {
126 Self {
127 prefetch_count: 0,
128 queue_name: queue_name.into(),
129 consumer_tag: "".into(),
130 message_version: "v1.0.0".into(),
131 }
132 }
133 pub fn set_prefetch_count(mut self, prefetch_count: u16) -> Self {
136 self.prefetch_count = prefetch_count;
137 self
138 }
139 pub fn set_consumer_tag(mut self, consumer_tag: impl Into<String>) -> Self {
141 self.consumer_tag = consumer_tag.into();
142 self
143 }
144
145 pub fn set_message_version(mut self, version: impl Into<String>) -> Self {
147 self.message_version = version.into();
148 self
149 }
150}
151
152impl Worker {
154 pub async fn new(amqp_server_url: impl Into<String>, config: WorkerConfig) -> Self {
159 let channel = Self::create_channel(amqp_server_url.into(), config).await;
160
161 Worker {
162 channel,
163 handlers: Vec::new(),
164 consumers: Vec::new(),
165
166 rpc_handlers: Vec::new(),
167 rpc_consumers: Vec::new(),
168 }
169 }
170
171 async fn create_channel(amqp_server_url: String, config: WorkerConfig) -> Channel {
172 let channel = match config.tls {
174 None => Connection::connect(&amqp_server_url, ConnectionProperties::default())
175 .await
176 .expect("connection error")
177 .create_channel()
178 .await
179 .unwrap(),
180 Some(tls) => Connection::connect_uri_with_config(
181 lapin::uri::AMQPUri::from_str(&amqp_server_url).unwrap(),
182 ConnectionProperties::default(),
183 tls,
184 )
185 .await
186 .unwrap()
187 .create_channel()
188 .await
189 .unwrap(),
190 };
191 channel
192 }
193
194 pub fn add_non_rpc_consumer<J: Task + 'static + Send>(
206 &mut self,
207 state: Arc<J::State>,
208 listener_config: ListenerConfig,
209 ) where
210 <J as Task>::State: std::marker::Send + Sync,
211 {
212 let handler: Arc<
213 dyn Fn(
214 lapin::message::Delivery,
215 ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
216 + Send
217 + Sync,
218 > = Arc::new(move |delivery: lapin::message::Delivery| {
219 let state = Arc::clone(&state);
220 Box::pin(async move {
221 if let Ok(job) = J::decode(delivery.data.clone()) {
222 tracing::debug!("Running before job");
225 let job = match tokio::task::spawn(async move { job.before_job().await }).await
226 {
227 Err(error) => {
228 tracing::error!(
229 "The before_job function has failed for a job of type: {}, {}",
230 std::any::type_name::<J>(),
231 error
232 );
233 return ();
234 }
235 Ok(j) => j,
236 };
237
238 match tokio::task::spawn(async move { job.run(state).await }).await {
239 Err(error) => {
240 tracing::error!("Failed to run task job: {}", error);
241 let _ = delivery.nack(BasicNackOptions::default()).await;
242 }
243 Ok(_) => {
244 tracing::info!("Non-rpc job has finished.");
245 let _ = delivery.ack(BasicAckOptions::default()).await;
246 }
247 };
248 } else {
250 delivery.nack(BasicNackOptions::default()).await.unwrap();
251 }
252 })
253 });
254
255 self.handlers.push(handler);
256 self.consumers.push(listener_config);
257 }
258 pub fn add_rpc_consumer<J: RPCTask + 'static + Send>(
274 &mut self,
275 state: Arc<J::State>,
276 listener_config: ListenerConfig,
277 ) where
278 <J as RPCTask>::State: std::marker::Send + Sync,
279 <J as RPCTask>::Result: std::marker::Send + Sync,
280 <J as RPCTask>::ErroredResult: std::marker::Send + Sync,
281 {
282 let channel = self.channel.clone();
283 let handler: Arc<
284 dyn Fn(
285 lapin::message::Delivery,
286 ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
287 + Send
288 + Sync,
289 > = Arc::new(move |delivery: lapin::message::Delivery| {
290 let state = Arc::clone(&state);
291 let channel = channel.clone();
292 Box::pin(async move {
293 let routing_key = match delivery.properties.reply_to().as_ref() {
294 Some(key) => key.clone().to_owned(),
295 None => {
296 tracing::warn!("Received a job with no reply_to!");
297 tracing::trace!("No reply_to for job {:?}, skipping loop", delivery);
298 let _ = nack(channel.clone(), delivery.delivery_tag).await;
299 return ();
300 }
301 };
302
303 let correlation_id = match delivery.properties.correlation_id().clone() {
304 None => {
305 tracing::warn!("received a job with no correlation id");
306 tracing::trace!("no correlation id for delivery {:?}", delivery);
307 let _ = nack(channel.clone(), delivery.delivery_tag).await;
308 return ();
309 }
310 Some(id) => id,
311 };
312
313 if let Ok(job) = J::decode(delivery.data.clone()) {
314 let job = tokio::task::spawn(async move { job.before_job().await })
315 .await
316 .unwrap();
317 let outcome = tokio::task::spawn(async move { job.run(state).await }).await;
318
319 match outcome {
320 Err(ref error) => {
321 tracing::error!("Failed to start thread for worker {}", error);
322 let headers = create_header(ResultHeader::Panic);
323 let _ = delivery.ack(BasicAckOptions::default()).await; respond_to_rpc_queue(
325 channel.clone(),
326 routing_key,
327 headers,
328 correlation_id,
329 None::<J::ErroredResult>,
330 )
331 .await
332 }
333 Ok(Ok(ref res)) => {
334 let headers = create_header(ResultHeader::Ok);
335 let _ = delivery.ack(BasicAckOptions::default()).await; respond_to_rpc_queue(
337 channel.clone(),
338 routing_key,
339 headers,
340 correlation_id,
341 Some(res.clone()),
342 )
343 .await
344 }
345 Ok(Err(ref err)) => {
346 let headers = create_header(ResultHeader::Ok);
348 let _ = delivery.ack(BasicAckOptions::default()).await; respond_to_rpc_queue(
350 channel.clone(),
351 routing_key,
352 headers,
353 correlation_id,
354 Some(err.clone()),
355 )
356 .await
357 }
358 };
359 tracing::debug!("Running after job");
360 let _ = tokio::task::spawn(async move { J::after_job(outcome).await }).await;
361 } else {
362 delivery.nack(BasicNackOptions::default()).await.unwrap();
363 }
364 })
365 });
366
367 self.rpc_handlers.push(handler);
368 self.rpc_consumers.push(listener_config);
369 }
370
371 pub async fn start_all_listeners(&self) -> Result<(), String> {
375 let mut listeners = vec![];
376
377 for (handler, consumer_config) in self.handlers.iter().zip(self.consumers.iter()) {
379 let mut channel = self.channel.clone();
381
382 set_consumer_qos(&mut channel, consumer_config.prefetch_count)
384 .await
385 .map_err(|e| {
386 tracing::error!("Failed to set qos: {}", e);
387 "Failed to set qos".to_string()
388 })?;
389
390 let consumer = channel
392 .basic_consume(
393 format!(
394 "{}-{}",
395 consumer_config.queue_name, consumer_config.message_version
396 )
397 .as_str(),
398 &consumer_config.consumer_tag,
399 BasicConsumeOptions::default(),
400 FieldTable::default(),
401 )
402 .await
403 .map_err(|e| {
404 tracing::error!("Failed to start consumer: {}", e);
405 "Failed to start consumer".to_string()
406 })?;
407
408 let handler = Arc::clone(handler);
409
410 tracing::info!(
411 "Started listening for incoming messages on queue: {} | Non-rpc",
412 consumer.queue().as_str()
413 );
414
415 listeners.push(tokio::spawn(async move {
416 consumer
417 .for_each_concurrent(None, move |delivery| {
418 let handler = Arc::clone(&handler);
419 async move {
420 match delivery {
421 Err(error) => {
422 tracing::warn!("Received bad msg: {}", error);
423 }
424 Ok(delivery) => {
425 handler(delivery).await;
426 }
427 }
428 }
429 })
430 .await;
431 }));
432 }
433
434 for (handler, consumer_config) in self.rpc_handlers.iter().zip(self.rpc_consumers.iter()) {
435 let mut channel = self.channel.clone();
436 set_consumer_qos(&mut channel, consumer_config.prefetch_count)
438 .await
439 .map_err(|e| {
440 tracing::error!("Failed to set qos: {}", e);
441 "Failed to set qos".to_string()
442 })?;
443
444 let consumer = channel
446 .basic_consume(
447 format!(
448 "{}-{}",
449 consumer_config.queue_name, consumer_config.message_version
450 )
451 .as_str(),
452 &consumer_config.consumer_tag,
453 BasicConsumeOptions::default(),
454 FieldTable::default(),
455 )
456 .await
457 .map_err(|e| {
458 tracing::error!("Failed to start consumer: {}", e);
459 "Failed to start consumer".to_string()
460 })?;
461 let handler = Arc::clone(handler);
462
463 tracing::debug!(
464 "Started listening for incoming messages on queue: {} | RPC",
465 consumer.queue().as_str()
466 );
467 listeners.push(tokio::spawn(async move {
468 consumer
469 .for_each_concurrent(None, move |delivery| {
470 let handler = Arc::clone(&handler);
471 async move {
472 match delivery {
473 Err(error) => {
474 tracing::warn!("Received bad msg: {}", error);
475 }
476 Ok(delivery) => {
477 handler(delivery).await;
478 }
479 }
480 }
481 })
482 .await;
483 }));
484 }
485
486 join_all(listeners).await;
487 Ok(())
488 }
489}
490
491pub trait RPCTask: Sized + Debug + DeserializeOwned {
519 type Result: Serialize + DeserializeOwned + Debug + Clone;
520 type ErroredResult: Serialize + DeserializeOwned + Debug + Clone;
521 type State: Clone + Debug;
522
523 fn decode(data: Vec<u8>) -> Result<Self, RabbitDecodeError> {
525 let job = match from_utf8(&data) {
526 Err(_) => {
527 return Err(RabbitDecodeError::NotUtf8);
528 }
529 Ok(data) => match serde_json::from_str::<Self>(data) {
530 Err(_) => return Err(RabbitDecodeError::NotJson),
531 Ok(data) => data,
532 },
533 };
534 Ok(job)
535 }
536
537 fn run(
539 self,
540 state: Arc<Self::State>,
541 ) -> BoxFuture<'static, Result<Self::Result, Self::ErroredResult>>;
542
543 fn display(&self) -> String {
545 format!("{:?}", self)
546 }
547 fn before_job(self) -> impl std::future::Future<Output = Self> + Send
549 where
550 Self: Send,
551 {
552 async move { self }
553 }
554 fn after_job(
556 res: Result<Result<Self::Result, Self::ErroredResult>, JoinError>,
557 ) -> impl std::future::Future<Output = ()> + Send
558 where
559 Self: Send,
560 {
561 async move {}
562 }
563}
564
565pub trait Task: Sized + Debug + DeserializeOwned {
591 type State: Clone + Debug;
592
593 fn decode(data: Vec<u8>) -> Result<Self, RabbitDecodeError> {
594 let job = match from_utf8(&data) {
595 Err(_) => {
596 return Err(RabbitDecodeError::NotUtf8);
597 }
598 Ok(data) => match serde_json::from_str::<Self>(data) {
599 Err(e) => {
600 tracing::error!("Failed to decode job: {e} \n {:?}", data);
601 return Err(RabbitDecodeError::NotJson);
602 }
603 Ok(data) => data,
604 },
605 };
606 Ok(job)
607 }
608
609 fn run(self, state: Arc<Self::State>) -> BoxFuture<'static, Result<(), ()>>;
612
613 fn display(&self) -> String {
615 format!("{:?}", self)
616 }
617
618 fn before_job(self) -> impl std::future::Future<Output = Self> + Send
620 where
621 Self: Send,
622 {
623 async move { self }
624 }
625 fn after_job(self) -> impl std::future::Future<Output = ()> + Send
627 where
628 Self: Sync + Send,
629 {
630 async move {}
631 }
632}
633
634#[derive(Debug)]
635pub enum RabbitDecodeError {
637 NotJson,
638 InvalidField,
639 NotUtf8,
640}
641
642async fn respond_to_rpc_queue(
643 channel: Channel,
644 routing_key: ShortString,
645 headers: FieldTable,
646 correlation_id: ShortString,
647 body: Option<impl Serialize + DeserializeOwned>,
648) {
649 match body {
650 None => {
651 let _ = channel
653 .basic_publish(
654 "",
655 &routing_key.to_string().as_str(),
656 BasicPublishOptions::default(),
657 "".as_bytes(),
658 BasicProperties::default()
659 .with_correlation_id(correlation_id)
660 .with_headers(headers),
661 )
662 .await;
663 }
664 Some(body) => {
665 let _ = channel
667 .basic_publish(
668 "",
669 &routing_key.to_string().as_str(),
670 BasicPublishOptions::default(),
671 serde_json::to_string(&body).unwrap().as_bytes(),
672 BasicProperties::default()
673 .with_correlation_id(correlation_id)
674 .with_headers(headers),
675 )
676 .await;
677 }
678 }
679}
680
681async fn nack(channel: Channel, delivery_tag: DeliveryTag) {
682 let asd = BasicNackOptions::default();
683
684 match channel.basic_nack(delivery_tag, asd).await {
685 Err(error) => {
686 tracing::warn!(
687 "Failed to nack to server about delivery tag: {}, {}",
688 delivery_tag,
689 error
690 )
691 }
692 Ok(_) => {
693 tracing::debug!("Sent nack back to server")
694 }
695 }
696}
697fn create_header(header: ResultHeader) -> FieldTable {
698 let mut headers = FieldTable::default();
699 headers.insert(
700 "outcome".into(),
701 lapin::types::AMQPValue::LongString(serde_json::to_string(&header).unwrap().into()),
702 );
703 headers
704}
705#[derive(Debug, Serialize, Deserialize)]
708enum ResultHeader {
709 Ok,
710 Error,
711 Panic,
712}
713impl Display for ResultHeader {
714 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
715 match self {
716 Self::Error => write!(f, "Error"),
717 Self::Ok => write!(f, "Ok"),
718 Self::Panic => write!(f, "Panic"),
719 }
720 }
721}
722
723async fn set_consumer_qos(channel: &mut Channel, prefetch_count: u16) -> Result<(), lapin::Error> {
724 channel
725 .basic_qos(prefetch_count, BasicQosOptions::default())
726 .await
727}