bunbun_worker/
lib.rs

1// TODO clean up code...
2use futures::{
3    future::{join_all, BoxFuture},
4    StreamExt, TryStreamExt,
5};
6use lapin::{
7    options::{
8        BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions,
9        BasicQosOptions,
10    },
11    tcp::{OwnedIdentity, OwnedTLSConfig},
12    types::{DeliveryTag, FieldTable, ShortString},
13    BasicProperties, Channel, Connection, ConnectionProperties, Consumer,
14};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use std::{
17    fmt::{Debug, Display},
18    pin::Pin,
19    str::{from_utf8, FromStr},
20    sync::Arc,
21};
22use tokio::task::JoinError;
23
24/// The client module that interacts with the server part of `bunbun-worker`
25pub mod client;
26
27mod test;
28
29/// The worker object that contains all the threads and runners.
30pub struct Worker {
31    channel: Channel,
32    /// A consumer for each rpc handler
33    rpc_consumers: Vec<ListenerConfig>,
34    rpc_handlers: Vec<
35        Arc<
36            dyn Fn(
37                    lapin::message::Delivery,
38                ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
39                + Send
40                + Sync,
41        >,
42    >,
43    /// A consumer for each non-rpc handler
44    consumers: Vec<ListenerConfig>,
45    handlers: Vec<
46        Arc<
47            dyn Fn(
48                    lapin::message::Delivery,
49                ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
50                + Send
51                + Sync,
52        >,
53    >,
54}
55
56// Example taken from https://github.com/amqp-rs/lapin/blob/main/examples/client-certificate.rs
57/// Custom certificate configuration
58pub struct TlsConfig {
59    cert_chain: String,
60    client_cert_and_key: String,
61    client_cert_and_key_password: String,
62}
63
64impl TlsConfig {
65    /// Create a custom TLS config
66    pub fn new(
67        cert_chain: String,
68        client_cert_and_key: String,
69        client_cert_and_key_password: String,
70    ) -> Self {
71        Self {
72            cert_chain,
73            client_cert_and_key,
74            client_cert_and_key_password,
75        }
76    }
77}
78
79#[derive(Debug)]
80/// A worker configuration  
81pub struct WorkerConfig {
82    tls: Option<OwnedTLSConfig>,
83}
84impl WorkerConfig {
85    /// Creates a new worker config
86    pub fn default() -> Self {
87        Self { tls: None }
88    }
89
90    /// Enable secure connection to amqp.
91    ///
92    /// # Arguments
93    /// * `custom_tls` - Optional [`TlsConfig`] (if none defaults to lapins choice)
94    pub fn enable_tls(mut self, custom_tls: Option<TlsConfig>) -> Self {
95        match custom_tls {
96            Some(tls) => {
97                let tls = OwnedTLSConfig {
98                    identity: Some(OwnedIdentity {
99                        der: tls.client_cert_and_key.as_bytes().to_vec(),
100                        password: tls.client_cert_and_key_password,
101                    }),
102                    cert_chain: Some(tls.cert_chain.to_string()),
103                };
104                self.tls = tls.into();
105            }
106            None => self.tls = OwnedTLSConfig::default().into(),
107        }
108        self
109    }
110}
111
112/// A listener's configuration.
113
114pub struct ListenerConfig {
115    prefetch_count: u16,
116    queue_name: String,
117    consumer_tag: String,
118    message_version: String,
119}
120
121impl ListenerConfig {
122    /// Create a new listener config
123    /// # Arguments
124    /// * `queue_name` - The name of the queue to listen to (e.g. service-serviceJobName-v1.0.0)
125    pub fn default(queue_name: impl Into<String>) -> Self {
126        Self {
127            prefetch_count: 0,
128            queue_name: queue_name.into(),
129            consumer_tag: "".into(),
130            message_version: "v1.0.0".into(),
131        }
132    }
133    /// Set the prefetch count for the listener  
134    ///  This serves as a maximum job count that can be processed at a time. (0 is unlimited)
135    pub fn set_prefetch_count(mut self, prefetch_count: u16) -> Self {
136        self.prefetch_count = prefetch_count;
137        self
138    }
139    /// Set the consumer tag for the listener  
140    pub fn set_consumer_tag(mut self, consumer_tag: impl Into<String>) -> Self {
141        self.consumer_tag = consumer_tag.into();
142        self
143    }
144
145    /// Set the message version (eg queue_name-v1.0.0)
146    pub fn set_message_version(mut self, version: impl Into<String>) -> Self {
147        self.message_version = version.into();
148        self
149    }
150}
151
152// TODO implement reconnect
153impl Worker {
154    /// Create a new instance of [`Worker`]
155    /// # Arguments
156    /// * `amqp_server_url` - A string slice that holds the url of the amqp server (e.g. amqp://localhost:5672)
157    /// * `config` - A [`WorkerConfig`], containing the TLS config for now.
158    pub async fn new(amqp_server_url: impl Into<String>, config: WorkerConfig) -> Self {
159        let channel = Self::create_channel(amqp_server_url.into(), config).await;
160
161        Worker {
162            channel,
163            handlers: Vec::new(),
164            consumers: Vec::new(),
165
166            rpc_handlers: Vec::new(),
167            rpc_consumers: Vec::new(),
168        }
169    }
170
171    async fn create_channel(amqp_server_url: String, config: WorkerConfig) -> Channel {
172        // TODO handle unwraps
173        let channel = match config.tls {
174            None => Connection::connect(&amqp_server_url, ConnectionProperties::default())
175                .await
176                .expect("connection error")
177                .create_channel()
178                .await
179                .unwrap(),
180            Some(tls) => Connection::connect_uri_with_config(
181                lapin::uri::AMQPUri::from_str(&amqp_server_url).unwrap(),
182                ConnectionProperties::default(),
183                tls,
184            )
185            .await
186            .unwrap()
187            .create_channel()
188            .await
189            .unwrap(),
190        };
191        channel
192    }
193
194    /// Add a [`Task`] listener to the worker object
195    ///
196    /// # Arguments
197    /// * `state` - An Arc of the state object that will be passed to the listener
198    /// * `listener_config` -  [`ListenerConfig`] that holds the configuration for the listener
199    /// ```
200    ///    use bunbun_worker::{Worker, ListenerConfig, WorkerConfig};
201    ///    let server = Worker::new("amqp://localhost:5672", Workerconfig::default()).await;
202    ///    server.add_non_rpc_consumer::<MyTask>(ListenerConfig::default("service-jobname").set_message_version("v2.0.0") ));
203    ///    server.start_all_listeners().await;
204    /// ```
205    pub fn add_non_rpc_consumer<J: Task + 'static + Send>(
206        &mut self,
207        state: Arc<J::State>,
208        listener_config: ListenerConfig,
209    ) where
210        <J as Task>::State: std::marker::Send + Sync,
211    {
212        let handler: Arc<
213            dyn Fn(
214                    lapin::message::Delivery,
215                ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
216                + Send
217                + Sync,
218        > = Arc::new(move |delivery: lapin::message::Delivery| {
219            let state = Arc::clone(&state);
220            Box::pin(async move {
221                if let Ok(job) = J::decode(delivery.data.clone()) {
222                    // Running job
223
224                    tracing::debug!("Running before job");
225                    let job = match tokio::task::spawn(async move { job.before_job().await }).await
226                    {
227                        Err(error) => {
228                            tracing::error!(
229                                "The before_job function has failed for a job of type: {}, {}",
230                                std::any::type_name::<J>(),
231                                error
232                            );
233                            return ();
234                        }
235                        Ok(j) => j,
236                    };
237
238                    match tokio::task::spawn(async move { job.run(state).await }).await {
239                        Err(error) => {
240                            tracing::error!("Failed to run task job: {}", error);
241                            let _ = delivery.nack(BasicNackOptions::default()).await;
242                        }
243                        Ok(_) => {
244                            tracing::info!("Non-rpc job has finished.");
245                            let _ = delivery.ack(BasicAckOptions::default()).await;
246                        }
247                    };
248                    // TODO run afterjob
249                } else {
250                    delivery.nack(BasicNackOptions::default()).await.unwrap();
251                }
252            })
253        });
254
255        self.handlers.push(handler);
256        self.consumers.push(listener_config);
257    }
258    /// Add an rpc job listener to the worker object
259    /// Make sure the type you pass in implements RPCTask
260    /// # Arguments
261    /// * `queue_name` - A string slice that holds the name of the queue to listen to (e.g. service-serviceJobName-v1.0.0)
262    /// * `state` - An Arc of the state object that will be passed to the listener
263    ///
264    ///
265    /// # Examples
266    ///
267    /// ```
268    ///    use bunbun_worker::{Worker, ListenerConfig, WorkerConfig};
269    ///    let server = Worker::new("amqp://localhost:5672", Workerconfig::default()).await;
270    ///    server.add_rpc_consumer::<MyRPCTask>(ListenerConfig::default("service-jobname").set_message_version("v2.0.0") ));
271    ///    server.start_all_listeners().await;
272    /// ```
273    pub fn add_rpc_consumer<J: RPCTask + 'static + Send>(
274        &mut self,
275        state: Arc<J::State>,
276        listener_config: ListenerConfig,
277    ) where
278        <J as RPCTask>::State: std::marker::Send + Sync,
279        <J as RPCTask>::Result: std::marker::Send + Sync,
280        <J as RPCTask>::ErroredResult: std::marker::Send + Sync,
281    {
282        let channel = self.channel.clone();
283        let handler: Arc<
284            dyn Fn(
285                    lapin::message::Delivery,
286                ) -> Pin<Box<dyn std::future::Future<Output = ()> + Send>>
287                + Send
288                + Sync,
289        > = Arc::new(move |delivery: lapin::message::Delivery| {
290            let state = Arc::clone(&state);
291            let channel = channel.clone();
292            Box::pin(async move {
293                let routing_key = match delivery.properties.reply_to().as_ref() {
294                    Some(key) => key.clone().to_owned(),
295                    None => {
296                        tracing::warn!("Received a job with no reply_to!");
297                        tracing::trace!("No reply_to for job {:?}, skipping loop", delivery);
298                        let _ = nack(channel.clone(), delivery.delivery_tag).await;
299                        return ();
300                    }
301                };
302
303                let correlation_id = match delivery.properties.correlation_id().clone() {
304                    None => {
305                        tracing::warn!("received a job with no correlation id");
306                        tracing::trace!("no correlation id for delivery {:?}", delivery);
307                        let _ = nack(channel.clone(), delivery.delivery_tag).await;
308                        return ();
309                    }
310                    Some(id) => id,
311                };
312
313                if let Ok(job) = J::decode(delivery.data.clone()) {
314                    let job = tokio::task::spawn(async move { job.before_job().await })
315                        .await
316                        .unwrap();
317                    let outcome = tokio::task::spawn(async move { job.run(state).await }).await;
318
319                    match outcome {
320                        Err(ref error) => {
321                            tracing::error!("Failed to start thread for worker {}", error);
322                            let headers = create_header(ResultHeader::Panic);
323                            let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery
324                            respond_to_rpc_queue(
325                                channel.clone(),
326                                routing_key,
327                                headers,
328                                correlation_id,
329                                None::<J::ErroredResult>,
330                            )
331                            .await
332                        }
333                        Ok(Ok(ref res)) => {
334                            let headers = create_header(ResultHeader::Ok);
335                            let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery
336                            respond_to_rpc_queue(
337                                channel.clone(),
338                                routing_key,
339                                headers,
340                                correlation_id,
341                                Some(res.clone()),
342                            )
343                            .await
344                        }
345                        Ok(Err(ref err)) => {
346                            //
347                            let headers = create_header(ResultHeader::Ok);
348                            let _ = delivery.ack(BasicAckOptions::default()).await; // acking the delivery
349                            respond_to_rpc_queue(
350                                channel.clone(),
351                                routing_key,
352                                headers,
353                                correlation_id,
354                                Some(err.clone()),
355                            )
356                            .await
357                        }
358                    };
359                    tracing::debug!("Running after job");
360                    let _ = tokio::task::spawn(async move { J::after_job(outcome).await }).await;
361                } else {
362                    delivery.nack(BasicNackOptions::default()).await.unwrap();
363                }
364            })
365        });
366
367        self.rpc_handlers.push(handler);
368        self.rpc_consumers.push(listener_config);
369    }
370
371    /// Start all the listeners added to the worker object
372    // TODO implement reconnect
373    // TODO better error handling
374    pub async fn start_all_listeners(&self) -> Result<(), String> {
375        let mut listeners = vec![];
376
377        // Start all the non-rpc listeners
378        for (handler, consumer_config) in self.handlers.iter().zip(self.consumers.iter()) {
379            // Clone channel
380            let mut channel = self.channel.clone();
381
382            // Set prefetch count
383            set_consumer_qos(&mut channel, consumer_config.prefetch_count)
384                .await
385                .map_err(|e| {
386                    tracing::error!("Failed to set qos: {}", e);
387                    "Failed to set qos".to_string()
388                })?;
389
390            // Create a consumer with modified channel
391            let consumer = channel
392                .basic_consume(
393                    format!(
394                        "{}-{}",
395                        consumer_config.queue_name, consumer_config.message_version
396                    )
397                    .as_str(),
398                    &consumer_config.consumer_tag,
399                    BasicConsumeOptions::default(),
400                    FieldTable::default(),
401                )
402                .await
403                .map_err(|e| {
404                    tracing::error!("Failed to start consumer: {}", e);
405                    "Failed to start consumer".to_string()
406                })?;
407
408            let handler = Arc::clone(handler);
409
410            tracing::info!(
411                "Started listening for incoming messages on queue: {} | Non-rpc",
412                consumer.queue().as_str()
413            );
414
415            listeners.push(tokio::spawn(async move {
416                consumer
417                    .for_each_concurrent(None, move |delivery| {
418                        let handler = Arc::clone(&handler);
419                        async move {
420                            match delivery {
421                                Err(error) => {
422                                    tracing::warn!("Received bad msg: {}", error);
423                                }
424                                Ok(delivery) => {
425                                    handler(delivery).await;
426                                }
427                            }
428                        }
429                    })
430                    .await;
431            }));
432        }
433
434        for (handler, consumer_config) in self.rpc_handlers.iter().zip(self.rpc_consumers.iter()) {
435            let mut channel = self.channel.clone();
436            // Set prefetch count
437            set_consumer_qos(&mut channel, consumer_config.prefetch_count)
438                .await
439                .map_err(|e| {
440                    tracing::error!("Failed to set qos: {}", e);
441                    "Failed to set qos".to_string()
442                })?;
443
444            // Create a consumer with modified channel
445            let consumer = channel
446                .basic_consume(
447                    format!(
448                        "{}-{}",
449                        consumer_config.queue_name, consumer_config.message_version
450                    )
451                    .as_str(),
452                    &consumer_config.consumer_tag,
453                    BasicConsumeOptions::default(),
454                    FieldTable::default(),
455                )
456                .await
457                .map_err(|e| {
458                    tracing::error!("Failed to start consumer: {}", e);
459                    "Failed to start consumer".to_string()
460                })?;
461            let handler = Arc::clone(handler);
462
463            tracing::debug!(
464                "Started listening for incoming messages on queue: {} | RPC",
465                consumer.queue().as_str()
466            );
467            listeners.push(tokio::spawn(async move {
468                consumer
469                    .for_each_concurrent(None, move |delivery| {
470                        let handler = Arc::clone(&handler);
471                        async move {
472                            match delivery {
473                                Err(error) => {
474                                    tracing::warn!("Received bad msg: {}", error);
475                                }
476                                Ok(delivery) => {
477                                    handler(delivery).await;
478                                }
479                            }
480                        }
481                    })
482                    .await;
483            }));
484        }
485
486        join_all(listeners).await;
487        Ok(())
488    }
489}
490
491/// A trait that defines the structure of a task that can be run by the worker  
492/// BoxFuture is from a crate called `futures`  
493///
494/// # Examples
495/// ```
496/// #[derive(Debug, Serialize, Deserialize)]
497/// struct MyRPCTask {
498///    pub name: String,
499/// }
500/// #[derive(Debug, Serialize, Deserialize)]
501/// struct MyRPCTaskResult {
502///    pub something: String,
503/// }
504/// #[derive(Debug, Serialize, Deserialize)]
505/// struct MyRPCTaskErroredResult {
506///    pub what_failed: String,
507/// }
508/// impl RPCTask for MyRPCTask {
509///   type Result = MyRPCTaskResult;
510///   type Error = MyRPCTaskErroredResult;
511///   type State = SomeState;
512///   
513///   fn run(self, state: Arc<Self::State>) -> BoxFuture<'static, Result<Self::Result, Self::Error>> {
514///    async move {
515///       Ok(MyRPCTaskResult{ something: "Hello I ran ok!".into() })
516///   }.boxed()
517/// }
518pub trait RPCTask: Sized + Debug + DeserializeOwned {
519    type Result: Serialize + DeserializeOwned + Debug + Clone;
520    type ErroredResult: Serialize + DeserializeOwned + Debug + Clone;
521    type State: Clone + Debug;
522
523    /// Decoding for the message. Overriding is possible.
524    fn decode(data: Vec<u8>) -> Result<Self, RabbitDecodeError> {
525        let job = match from_utf8(&data) {
526            Err(_) => {
527                return Err(RabbitDecodeError::NotUtf8);
528            }
529            Ok(data) => match serde_json::from_str::<Self>(data) {
530                Err(_) => return Err(RabbitDecodeError::NotJson),
531                Ok(data) => data,
532            },
533        };
534        Ok(job)
535    }
536
537    /// The function that will run once a message is received
538    fn run(
539        self,
540        state: Arc<Self::State>,
541    ) -> BoxFuture<'static, Result<Self::Result, Self::ErroredResult>>;
542
543    /// A function to display the task
544    fn display(&self) -> String {
545        format!("{:?}", self)
546    }
547    /// A function that runs before the job is ran by the worker. This allows you to modify any values inside it, add tracing ect.
548    fn before_job(self) -> impl std::future::Future<Output = Self> + Send
549    where
550        Self: Send,
551    {
552        async move { self }
553    }
554    /// A function that runs after the job has ran and the worker has responded to the callback queue. Any modifications made will not be reflected on the client side.
555    fn after_job(
556        res: Result<Result<Self::Result, Self::ErroredResult>, JoinError>,
557    ) -> impl std::future::Future<Output = ()> + Send
558    where
559        Self: Send,
560    {
561        async move {}
562    }
563}
564
565/// A regular task  
566/// Implement this trait to any struct to make it a runnable `non-rpc` [`Task`] job.
567///
568/// Examples
569///
570/// ```
571///
572/// #[derive(Deserialize, Serialize, Clone, Debug)]
573/// pub struct EmailJob {
574///     send_to: String,
575///     contents: String,
576/// }
577/// impl Task for EmailJob {
578///     type State = State;
579///     fn run(
580///         self,
581///         state: Self::State,
582///     ) -> futures::prelude::future::BoxFuture<'static, Result<(), ()>>
583///     {
584///         Box::pin(async move {
585///             todo!();
586///         })
587///     }
588/// }
589/// ```
590pub trait Task: Sized + Debug + DeserializeOwned {
591    type State: Clone + Debug;
592
593    fn decode(data: Vec<u8>) -> Result<Self, RabbitDecodeError> {
594        let job = match from_utf8(&data) {
595            Err(_) => {
596                return Err(RabbitDecodeError::NotUtf8);
597            }
598            Ok(data) => match serde_json::from_str::<Self>(data) {
599                Err(e) => {
600                    tracing::error!("Failed to decode job: {e} \n {:?}", data);
601                    return Err(RabbitDecodeError::NotJson);
602                }
603                Ok(data) => data,
604            },
605        };
606        Ok(job)
607    }
608
609    // TODO Attribute-Based Extraction
610    /// The method that will be run by the worker
611    fn run(self, state: Arc<Self::State>) -> BoxFuture<'static, Result<(), ()>>;
612
613    /// A function to display the task
614    fn display(&self) -> String {
615        format!("{:?}", self)
616    }
617
618    /// A function that runs before the job is ran by the worker. This allows you to modify any values inside it, add tracing ect.
619    fn before_job(self) -> impl std::future::Future<Output = Self> + Send
620    where
621        Self: Send,
622    {
623        async move { self }
624    }
625    /// A function that runs after the job has finished and the worker has acked the request.
626    fn after_job(self) -> impl std::future::Future<Output = ()> + Send
627    where
628        Self: Sync + Send,
629    {
630        async move {}
631    }
632}
633
634#[derive(Debug)]
635/// A decode error
636pub enum RabbitDecodeError {
637    NotJson,
638    InvalidField,
639    NotUtf8,
640}
641
642async fn respond_to_rpc_queue(
643    channel: Channel,
644    routing_key: ShortString,
645    headers: FieldTable,
646    correlation_id: ShortString,
647    body: Option<impl Serialize + DeserializeOwned>,
648) {
649    match body {
650        None => {
651            //
652            let _ = channel
653                .basic_publish(
654                    "",
655                    &routing_key.to_string().as_str(),
656                    BasicPublishOptions::default(),
657                    "".as_bytes(),
658                    BasicProperties::default()
659                        .with_correlation_id(correlation_id)
660                        .with_headers(headers),
661                )
662                .await;
663        }
664        Some(body) => {
665            // TODO handle errors? idk
666            let _ = channel
667                .basic_publish(
668                    "",
669                    &routing_key.to_string().as_str(),
670                    BasicPublishOptions::default(),
671                    serde_json::to_string(&body).unwrap().as_bytes(),
672                    BasicProperties::default()
673                        .with_correlation_id(correlation_id)
674                        .with_headers(headers),
675                )
676                .await;
677        }
678    }
679}
680
681async fn nack(channel: Channel, delivery_tag: DeliveryTag) {
682    let asd = BasicNackOptions::default();
683
684    match channel.basic_nack(delivery_tag, asd).await {
685        Err(error) => {
686            tracing::warn!(
687                "Failed to nack to server about delivery tag: {}, {}",
688                delivery_tag,
689                error
690            )
691        }
692        Ok(_) => {
693            tracing::debug!("Sent nack back to server")
694        }
695    }
696}
697fn create_header(header: ResultHeader) -> FieldTable {
698    let mut headers = FieldTable::default();
699    headers.insert(
700        "outcome".into(),
701        lapin::types::AMQPValue::LongString(serde_json::to_string(&header).unwrap().into()),
702    );
703    headers
704}
705/// A result header that is included in the header of the AMQP message.  
706/// It indicates the status of the returned message
707#[derive(Debug, Serialize, Deserialize)]
708enum ResultHeader {
709    Ok,
710    Error,
711    Panic,
712}
713impl Display for ResultHeader {
714    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
715        match self {
716            Self::Error => write!(f, "Error"),
717            Self::Ok => write!(f, "Ok"),
718            Self::Panic => write!(f, "Panic"),
719        }
720    }
721}
722
723async fn set_consumer_qos(channel: &mut Channel, prefetch_count: u16) -> Result<(), lapin::Error> {
724    channel
725        .basic_qos(prefetch_count, BasicQosOptions::default())
726        .await
727}