1use std::{
2 fmt::{Debug, Display},
3 str::from_utf8,
4 time::{Duration, Instant},
5};
6
7use futures::StreamExt;
8use lapin::{
9 options::*, types::FieldTable, types::ShortString, BasicProperties, Channel, Connection,
10 ConnectionProperties,
11};
12use serde::{de::DeserializeOwned, Serialize};
13use tokio::time::timeout;
14use uuid::Uuid;
15
16use crate::ResultHeader;
17
18#[derive(Debug)]
20pub struct Client {
21 conn: Connection,
22}
23
24impl Client {
27 pub async fn new(address: &str) -> Result<Self, lapin::Error> {
37 let conn = Connection::connect(address, ConnectionProperties::default()).await?;
38 tracing::trace!("Created a listener");
39 Ok(Self { conn })
40 }
41
42 pub async fn rpc_call<T: RPCClientTask + Send + Debug>(
49 &self,
50 data: T,
51 options: BasicCallOptions,
52 ) -> Result<Result<T::Result, T::ErroredResult>, ClientError>
53 where
54 T: Serialize + DeserializeOwned,
55 {
56 let now = Instant::now();
57 let correlation_id = Uuid::new_v4();
58
59 tracing::debug!("Creating channel");
61 let channel = self.conn.create_channel().await.unwrap();
62 tracing::debug!("Creating callback queue for result/error messages from rpc call");
63 let callback_queue = channel
64 .queue_declare(
65 "",
66 QueueDeclareOptions {
67 exclusive: true,
68 ..Default::default()
69 },
70 create_timeout_headers(),
71 )
72 .await
73 .unwrap();
74
75 tracing::debug!(
76 "Creating consumer to listen for error/result messages {}",
77 callback_queue.name()
78 );
79
80 let mut consumer = channel
81 .basic_consume(
82 callback_queue.name().as_str(),
83 "",
84 BasicConsumeOptions {
85 no_ack: true,
86 ..Default::default()
87 },
88 FieldTable::default(),
89 )
90 .await
91 .unwrap();
92
93 tracing::debug!("Publishing message for {}", callback_queue.name());
94 match channel
95 .basic_publish(
96 "",
97 format!("{}-{}", &options.queue_name, options.message_version).as_str(),
98 BasicPublishOptions::default(),
99 serde_json::to_string(&data).unwrap().as_bytes(),
100 BasicProperties::default()
101 .with_reply_to(callback_queue.name().clone())
102 .with_correlation_id(correlation_id.clone().to_string().into()),
103 )
104 .await
105 {
106 Err(error) => {
107 tracing::error!("Failed to send job to AMQP queue: {}", error)
108 }
109 Ok(confirmation) => match confirmation.await {
110 Err(error) => {
111 tracing::error!("AMQP failed to confirm dispatch of job {error}")
112 }
113 Ok(confirmation) => {
114 tracing::info!(
115 "Sent RPC job of type {} to channel {} Ack: {} Ver: {}",
116 std::any::type_name::<T>(),
117 options.queue_name,
118 confirmation.is_ack(),
119 options.message_version
120 );
121 }
122 },
123 }
124
125 tracing::debug!("Awaiting response from callback queue");
126 let listen = async move {
127 match consumer.next().await {
128 None => {
129 tracing::error!("Received empty data after {:?}", now.elapsed());
130 return Err(ClientError::InvalidResponse);
131 }
132 Some(del) => match del {
133 Err(error) => {
134 tracing::error!(
135 "Received error as response: {} after: {:?}",
136 error,
137 now.elapsed()
138 );
139 return Err(ClientError::FailedDecode);
140 }
141 Ok(del) => {
142 tracing::debug!("Received response after {:?}", now.elapsed());
143 return Ok(del);
144 }
145 },
146 };
147 };
148
149 let del = match options.timeout {
150 None => listen.await?,
151 Some(dur) => match timeout(dur, listen).await {
152 Err(elapsed) => {
153 tracing::warn!("RPC job has reached timeout after: {}", elapsed);
154 return Err(ClientError::TimeoutReached);
155 }
156 Ok(r) => match r {
157 Err(error) => return Err(error),
158 Ok(r) => r,
159 },
160 },
161 };
162
163 tracing::debug!("Decoding headers");
165 let result_type = match del.properties.headers().to_owned() {
166 None => {
167 tracing::error!(
168 "Got a response with no headers, this might be an issue with version mismatch"
169 );
170 return Err(ClientError::InvalidResponse);
171 }
172 Some(headers) => match headers.inner().get("outcome") {
173 None => {
174 tracing::error!("Got a response with no outcome header");
175 return Err(ClientError::InvalidResponse);
176 }
177 Some(res) => match res.as_long_string() {
178 None => {
179 tracing::error!("Got a response with no headers");
180 return Err(ClientError::InvalidResponse);
181 }
182 Some(outcome) => {
183 match serde_json::from_str::<ResultHeader>(outcome.to_string().as_str()) {
184 Ok(result_header) => {
185 tracing::trace!("Result header: {:?}", result_header);
186 result_header
187 }
188 Err(_) => {
189 tracing::warn!("Received a result header but it's not a type that can be deserailized ");
190 return Err(ClientError::InvalidResponse);
191 }
192 }
193 }
194 },
195 },
196 };
197
198 tracing::debug!("Result type is: {result_type}, decoding...");
199 let utf8 = match from_utf8(&del.data) {
200 Ok(r) => r,
201 Err(error) => {
202 tracing::error!("Failed to decode response message to utf8 {error}");
203 return Err(ClientError::FailedDecode);
204 }
205 };
206 let _ = channel.close(0, "byebye").await;
207
208 let _ = del.ack(BasicAckOptions::default()).await;
210
211 match result_type {
212 ResultHeader::Error => match serde_json::from_str::<T::ErroredResult>(utf8) {
213 Err(_) => {
215 tracing::error!("Failed to decode response message to E");
216 return Err(ClientError::FailedDecode);
217 }
218 Ok(res) => return Ok(Err(res)),
219 },
220 ResultHeader::Panic => return Err(ClientError::ServerPanicked),
221 ResultHeader::Ok =>
222 {
224 match serde_json::from_str::<T::Result>(utf8) {
225 Err(_) => {
226 tracing::error!("Failed to decode response message to R");
227 return Err(ClientError::FailedDecode);
228 }
229 Ok(res) => return Ok(Ok(res)),
230 }
231 }
232 }
233 }
234
235 pub async fn call<T>(&self, data: T, options: BasicCallOptions) -> Result<(), ClientError>
241 where
242 T: Serialize + DeserializeOwned,
243 {
244 let channel = self.conn.create_channel().await.unwrap();
245 match channel
246 .basic_publish(
247 "",
248 format!("{}-{}", &options.queue_name, options.message_version).as_str(),
249 BasicPublishOptions::default(),
250 serde_json::to_string(&data).unwrap().as_bytes(),
251 BasicProperties::default(),
252 )
253 .await
254 {
255 Err(error) => {
256 tracing::error!("Failed to send job to AMQP queue: {}", error);
257
258 let _ = channel.close(0, "byebye").await;
259 return Err(ClientError::FailedToSend);
260 }
261 Ok(confirmation) => match confirmation.await {
262 Err(error) => {
263 let _ = channel.close(0, "byebye").await;
264 tracing::error!("AMQP failed to confirm dispatch of job {error}")
265 }
266 Ok(confirmation) => {
267 let _ = channel.close(0, "byebye").await;
268 tracing::info!(
269 "Sent nonRPC job of type {} to channel {} Ack: {} Ver: {}",
270 std::any::type_name::<T>(),
271 options.queue_name,
272 confirmation.is_ack(),
273 options.message_version
274 );
275 tracing::debug!(
276 "AMQP confirmed dispatch of job | Acknowledged? {}",
277 confirmation.is_ack()
278 )
279 }
280 },
281 }
282 Ok(())
283 }
284}
285pub struct BasicCallOptions {
288 timeout: Option<Duration>,
289 queue_name: String,
290 message_version: String,
291}
292impl BasicCallOptions {
293 pub fn default(queue_name: impl Into<String>) -> Self {
295 Self {
296 timeout: None,
297 queue_name: queue_name.into(),
298 message_version: "v1.0.0".into(),
299 }
300 }
301 pub fn message_version(mut self, message_version: impl Into<String>) -> Self {
303 self.message_version = message_version.into();
304 self
305 }
306 pub fn timeout(mut self, timeout: Duration) -> Self {
308 self.timeout = Some(timeout);
309 self
310 }
311}
312
313#[derive(Debug)]
315pub enum ClientError {
316 FailedDecode,
317 FailedToSend,
318 InvalidResponse,
319 ServerPanicked,
320 TimeoutReached,
321}
322impl Display for ClientError {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 match self.to_owned() {
325 Self::FailedDecode => write!(f, "Failed to decode response"),
326 Self::FailedToSend => write!(f, "Failed to send job"),
327 Self::InvalidResponse => write!(f, "Received invalid response"),
328 Self::ServerPanicked => write!(f, "The server on the other end has panicked"),
329 Self::TimeoutReached => write!(f, "The client has reached the timeout"),
330 }
331 }
332}
333
334pub trait RPCClientTask: Sized + Debug + DeserializeOwned {
357 type Result: Serialize + DeserializeOwned + Debug;
358 type ErroredResult: Serialize + DeserializeOwned + Debug;
359
360 fn display(&self) -> String {
362 format!("{:?}", self)
363 }
364}
365
366fn create_timeout_headers() -> FieldTable {
367 let mut table = FieldTable::default();
368 table.insert("x-expires".into(), lapin::types::AMQPValue::LongInt(6_000));
370 table
371}