bunbun_worker/
client.rs

1use std::{
2    fmt::{Debug, Display},
3    str::from_utf8,
4    time::{Duration, Instant},
5};
6
7use futures::StreamExt;
8use lapin::{
9    options::*, types::FieldTable, types::ShortString, BasicProperties, Channel, Connection,
10    ConnectionProperties,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use tokio::time::timeout;
14use uuid::Uuid;
15
16use crate::ResultHeader;
17
18/// A client for the server part of `bunbun-worker`
19#[derive(Debug)]
20pub struct Client {
21    conn: Connection,
22}
23
24// TODO implement tls
25
26impl Client {
27    /// Creates an `bunbun-worker` client.
28    ///
29    /// # Examples
30    ///
31    /// ```
32    /// // Create a client and send a message
33    /// use bunbun_worker::client::Client;
34    /// let client = Client::new("amqp://127.0.0.1:5672");
35    /// ```
36    pub async fn new(address: &str) -> Result<Self, lapin::Error> {
37        let conn = Connection::connect(address, ConnectionProperties::default()).await?;
38        tracing::trace!("Created a listener");
39        Ok(Self { conn })
40    }
41
42    /// A method to call for a RPC
43    ///
44    /// Arguments
45    /// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize
46    /// * `options` BasicCallOptions, used to control the timeout and message version
47    // TODO if the queue is nonexistent return error
48    pub async fn rpc_call<T: RPCClientTask + Send + Debug>(
49        &self,
50        data: T,
51        options: BasicCallOptions,
52    ) -> Result<Result<T::Result, T::ErroredResult>, ClientError>
53    where
54        T: Serialize + DeserializeOwned,
55    {
56        let now = Instant::now();
57        let correlation_id = Uuid::new_v4();
58
59        // TODO handle errors
60        tracing::debug!("Creating channel");
61        let channel = self.conn.create_channel().await.unwrap();
62        tracing::debug!("Creating callback queue for result/error messages from rpc call");
63        let callback_queue = channel
64            .queue_declare(
65                "",
66                QueueDeclareOptions {
67                    exclusive: true,
68                    ..Default::default()
69                },
70                create_timeout_headers(),
71            )
72            .await
73            .unwrap();
74
75        tracing::debug!(
76            "Creating consumer to listen for error/result messages {}",
77            callback_queue.name()
78        );
79
80        let mut consumer = channel
81            .basic_consume(
82                callback_queue.name().as_str(),
83                "",
84                BasicConsumeOptions {
85                    no_ack: true,
86                    ..Default::default()
87                },
88                FieldTable::default(),
89            )
90            .await
91            .unwrap();
92
93        tracing::debug!("Publishing message for {}", callback_queue.name());
94        match channel
95            .basic_publish(
96                "",
97                format!("{}-{}", &options.queue_name, options.message_version).as_str(),
98                BasicPublishOptions::default(),
99                serde_json::to_string(&data).unwrap().as_bytes(),
100                BasicProperties::default()
101                    .with_reply_to(callback_queue.name().clone())
102                    .with_correlation_id(correlation_id.clone().to_string().into()),
103            )
104            .await
105        {
106            Err(error) => {
107                tracing::error!("Failed to send job to AMQP queue: {}", error)
108            }
109            Ok(confirmation) => match confirmation.await {
110                Err(error) => {
111                    tracing::error!("AMQP failed to confirm dispatch of job {error}")
112                }
113                Ok(confirmation) => {
114                    tracing::info!(
115                        "Sent RPC job of type {} to channel {} Ack: {} Ver: {}",
116                        std::any::type_name::<T>(),
117                        options.queue_name,
118                        confirmation.is_ack(),
119                        options.message_version
120                    );
121                }
122            },
123        }
124
125        tracing::debug!("Awaiting response from callback queue");
126        let listen = async move {
127            match consumer.next().await {
128                None => {
129                    tracing::error!("Received empty data after {:?}", now.elapsed());
130                    return Err(ClientError::InvalidResponse);
131                }
132                Some(del) => match del {
133                    Err(error) => {
134                        tracing::error!(
135                            "Received error as response: {} after: {:?}",
136                            error,
137                            now.elapsed()
138                        );
139                        return Err(ClientError::FailedDecode);
140                    }
141                    Ok(del) => {
142                        tracing::debug!("Received response after {:?}", now.elapsed());
143                        return Ok(del);
144                    }
145                },
146            };
147        };
148
149        let del = match options.timeout {
150            None => listen.await?,
151            Some(dur) => match timeout(dur, listen).await {
152                Err(elapsed) => {
153                    tracing::warn!("RPC job has reached timeout after: {}", elapsed);
154                    return Err(ClientError::TimeoutReached);
155                }
156                Ok(r) => match r {
157                    Err(error) => return Err(error),
158                    Ok(r) => r,
159                },
160            },
161        };
162
163        // TODO better implementation of this
164        tracing::debug!("Decoding headers");
165        let result_type = match del.properties.headers().to_owned() {
166            None => {
167                tracing::error!(
168                    "Got a response with no headers, this might be an issue with version mismatch"
169                );
170                return Err(ClientError::InvalidResponse);
171            }
172            Some(headers) => match headers.inner().get("outcome") {
173                None => {
174                    tracing::error!("Got a response with no outcome header");
175                    return Err(ClientError::InvalidResponse);
176                }
177                Some(res) => match res.as_long_string() {
178                    None => {
179                        tracing::error!("Got a response with no headers");
180                        return Err(ClientError::InvalidResponse);
181                    }
182                    Some(outcome) => {
183                        match serde_json::from_str::<ResultHeader>(outcome.to_string().as_str()) {
184                            Ok(result_header) => {
185                                tracing::trace!("Result header: {:?}", result_header);
186                                result_header
187                            }
188                            Err(_) => {
189                                tracing::warn!("Received a result header but it's not a type that can be deserailized ");
190                                return Err(ClientError::InvalidResponse);
191                            }
192                        }
193                    }
194                },
195            },
196        };
197
198        tracing::debug!("Result type is: {result_type}, decoding...");
199        let utf8 = match from_utf8(&del.data) {
200            Ok(r) => r,
201            Err(error) => {
202                tracing::error!("Failed to decode response message to utf8 {error}");
203                return Err(ClientError::FailedDecode);
204            }
205        };
206        let _ = channel.close(0, "byebye").await;
207
208        // ack message
209        let _ = del.ack(BasicAckOptions::default()).await;
210
211        match result_type {
212            ResultHeader::Error => match serde_json::from_str::<T::ErroredResult>(utf8) {
213                // get result header
214                Err(_) => {
215                    tracing::error!("Failed to decode response message to E");
216                    return Err(ClientError::FailedDecode);
217                }
218                Ok(res) => return Ok(Err(res)),
219            },
220            ResultHeader::Panic => return Err(ClientError::ServerPanicked),
221            ResultHeader::Ok =>
222            // get result
223            {
224                match serde_json::from_str::<T::Result>(utf8) {
225                    Err(_) => {
226                        tracing::error!("Failed to decode response message to R");
227                        return Err(ClientError::FailedDecode);
228                    }
229                    Ok(res) => return Ok(Ok(res)),
230                }
231            }
232        }
233    }
234
235    /// Sends a basic Task to the queue
236    ///
237    /// Arguments
238    /// * `data` The job that will be sent to the queue, must implement Deserialize and Serialize
239    /// * `options` BasicCallOptions, used to control the timeout and message version
240    pub async fn call<T>(&self, data: T, options: BasicCallOptions) -> Result<(), ClientError>
241    where
242        T: Serialize + DeserializeOwned,
243    {
244        let channel = self.conn.create_channel().await.unwrap();
245        match channel
246            .basic_publish(
247                "",
248                format!("{}-{}", &options.queue_name, options.message_version).as_str(),
249                BasicPublishOptions::default(),
250                serde_json::to_string(&data).unwrap().as_bytes(),
251                BasicProperties::default(),
252            )
253            .await
254        {
255            Err(error) => {
256                tracing::error!("Failed to send job to AMQP queue: {}", error);
257
258                let _ = channel.close(0, "byebye").await;
259                return Err(ClientError::FailedToSend);
260            }
261            Ok(confirmation) => match confirmation.await {
262                Err(error) => {
263                    let _ = channel.close(0, "byebye").await;
264                    tracing::error!("AMQP failed to confirm dispatch of job {error}")
265                }
266                Ok(confirmation) => {
267                    let _ = channel.close(0, "byebye").await;
268                    tracing::info!(
269                        "Sent nonRPC job of type {} to channel {} Ack: {} Ver: {}",
270                        std::any::type_name::<T>(),
271                        options.queue_name,
272                        confirmation.is_ack(),
273                        options.message_version
274                    );
275                    tracing::debug!(
276                        "AMQP confirmed dispatch of job |  Acknowledged? {}",
277                        confirmation.is_ack()
278                    )
279                }
280            },
281        }
282        Ok(())
283    }
284}
285/// A call option class that is used to control how calls are handled  
286/// You can define the timeout, and the message versions
287pub struct BasicCallOptions {
288    timeout: Option<Duration>,
289    queue_name: String,
290    message_version: String,
291}
292impl BasicCallOptions {
293    /// Create a default BasicCallOptions object by using a queue name.
294    pub fn default(queue_name: impl Into<String>) -> Self {
295        Self {
296            timeout: None,
297            queue_name: queue_name.into(),
298            message_version: "v1.0.0".into(),
299        }
300    }
301    /// Set the version of the message, by appending `message_version` after `queue_name`
302    pub fn message_version(mut self, message_version: impl Into<String>) -> Self {
303        self.message_version = message_version.into();
304        self
305    }
306    /// Set the timeout interval on how long the client shall listen to on the callback queue.
307    pub fn timeout(mut self, timeout: Duration) -> Self {
308        self.timeout = Some(timeout);
309        self
310    }
311}
312
313/// An error that the client returns
314#[derive(Debug)]
315pub enum ClientError {
316    FailedDecode,
317    FailedToSend,
318    InvalidResponse,
319    ServerPanicked,
320    TimeoutReached,
321}
322impl Display for ClientError {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        match self.to_owned() {
325            Self::FailedDecode => write!(f, "Failed to decode response"),
326            Self::FailedToSend => write!(f, "Failed to send job"),
327            Self::InvalidResponse => write!(f, "Received invalid response"),
328            Self::ServerPanicked => write!(f, "The server on the other end has panicked"),
329            Self::TimeoutReached => write!(f, "The client has reached the timeout"),
330        }
331    }
332}
333
334/// A Client-side trait that needs to be implemented for a type in order for the client to know return types.
335///
336/// Examples
337/// ```
338///  #[derive(Deserialize, Serialize, Clone, Debug)]
339///  pub struct EmailJob {
340///      send_to: String,
341///      contents: String,
342///  }
343///  #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
344///  pub struct EmailJobResult {
345///      contents: String,
346///  }
347///  #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
348///  pub enum EmailJobResultError {
349///      Errored,
350///  }
351/// impl RPCClientTask for EmailJob {
352///      type ErroredResult = EmailJobResultError;
353///      type Result = EmailJobResult;
354///  }
355/// ```
356pub trait RPCClientTask: Sized + Debug + DeserializeOwned {
357    type Result: Serialize + DeserializeOwned + Debug;
358    type ErroredResult: Serialize + DeserializeOwned + Debug;
359
360    /// A function to display the task
361    fn display(&self) -> String {
362        format!("{:?}", self)
363    }
364}
365
366fn create_timeout_headers() -> FieldTable {
367    let mut table = FieldTable::default();
368    // 60 second expiry, will not start counting down, only once there are no consumers on the channel
369    table.insert("x-expires".into(), lapin::types::AMQPValue::LongInt(6_000));
370    table
371}