1use base64::{
8 alphabet,
9 engine::{general_purpose::PAD, GeneralPurpose},
10 Engine,
11};
12use chrono::{DateTime, Duration, Utc};
13use log::{debug, warn};
14use once_cell::sync::Lazy;
15use serde::{Deserialize, Serialize};
16use serde_json::{from_slice, from_value, json, Value};
17use std::convert::TryFrom;
18use std::process;
19use std::time::SystemTime;
20use uuid::Uuid;
21
22use crate::error::{ContentTypeError, ProtocolError};
23use crate::task::{Signature, Task};
24
25pub(crate) const ENGINE: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, PAD);
26
27static ORIGIN: Lazy<Option<String>> = Lazy::new(|| {
28 hostname::get()
29 .ok()
30 .and_then(|sys_hostname| sys_hostname.into_string().ok())
31 .map(|sys_hostname| format!("gen{}@{}", process::id(), sys_hostname))
32});
33
34#[derive(Default, Copy, Clone)]
36pub enum MessageContentType {
37 #[default]
38 Json,
39 Yaml,
40 Pickle,
41 MsgPack,
42}
43
44pub struct MessageBuilder<T>
46where
47 T: Task,
48{
49 message: Message,
50 params: Option<T::Params>,
51}
52
53impl<T> MessageBuilder<T>
54where
55 T: Task,
56{
57 pub fn new(id: String) -> Self {
59 Self {
60 message: Message {
61 properties: MessageProperties {
62 correlation_id: id.clone(),
63 content_type: "application/json".into(),
64 content_encoding: "utf-8".into(),
65 reply_to: None,
66 delivery_info: None,
67 },
68 headers: MessageHeaders {
69 id,
70 task: T::NAME.into(),
71 origin: ORIGIN.to_owned(),
72 retries: Some(0),
73 ..Default::default()
74 },
75 raw_body: Vec::new(),
76 },
77 params: None,
78 }
79 }
80 #[cfg(any(test, feature = "extra_content_types"))]
84 pub fn content_type(mut self, content_type: MessageContentType) -> Self {
85 use MessageContentType::*;
86 let content_type_name = match content_type {
87 Json => "application/json",
88 Yaml => "application/x-yaml",
89 Pickle => "application/x-python-serialize",
90 MsgPack => "application/x-msgpack",
91 };
92 self.message.properties.content_type = content_type_name.into();
93 self
94 }
95
96 pub fn content_encoding(mut self, content_encoding: String) -> Self {
97 self.message.properties.content_encoding = content_encoding;
98 self
99 }
100
101 pub fn correlation_id(mut self, correlation_id: String) -> Self {
102 self.message.properties.correlation_id = correlation_id;
103 self
104 }
105
106 pub fn reply_to(mut self, reply_to: String) -> Self {
107 self.message.properties.reply_to = Some(reply_to);
108 self
109 }
110
111 pub fn delivery_info(mut self, delivery_info: DeliveryInfo) -> Self {
112 self.message.properties.delivery_info = Some(delivery_info);
113 self
114 }
115
116 pub fn id(mut self, id: String) -> Self {
117 self.message.headers.id = id;
118 self
119 }
120
121 pub fn task(mut self, task: String) -> Self {
122 self.message.headers.task = task;
123 self
124 }
125
126 pub fn lang(mut self, lang: String) -> Self {
127 self.message.headers.lang = Some(lang);
128 self
129 }
130
131 pub fn root_id(mut self, root_id: String) -> Self {
132 self.message.headers.root_id = Some(root_id);
133 self
134 }
135
136 pub fn parent_id(mut self, parent_id: String) -> Self {
137 self.message.headers.parent_id = Some(parent_id);
138 self
139 }
140
141 pub fn group(mut self, group: String) -> Self {
142 self.message.headers.group = Some(group);
143 self
144 }
145
146 pub fn meth(mut self, meth: String) -> Self {
147 self.message.headers.meth = Some(meth);
148 self
149 }
150
151 pub fn shadow(mut self, shadow: String) -> Self {
152 self.message.headers.shadow = Some(shadow);
153 self
154 }
155
156 pub fn retries(mut self, retries: u32) -> Self {
157 self.message.headers.retries = Some(retries);
158 self
159 }
160
161 pub fn argsrepr(mut self, argsrepr: String) -> Self {
162 self.message.headers.argsrepr = Some(argsrepr);
163 self
164 }
165
166 pub fn kwargsrepr(mut self, kwargsrepr: String) -> Self {
167 self.message.headers.kwargsrepr = Some(kwargsrepr);
168 self
169 }
170
171 pub fn origin(mut self, origin: String) -> Self {
172 self.message.headers.origin = Some(origin);
173 self
174 }
175
176 pub fn time_limit(mut self, time_limit: u32) -> Self {
177 self.message.headers.timelimit.1 = Some(time_limit);
178 self
179 }
180
181 pub fn hard_time_limit(mut self, time_limit: u32) -> Self {
182 self.message.headers.timelimit.0 = Some(time_limit);
183 self
184 }
185
186 pub fn eta(mut self, eta: DateTime<Utc>) -> Self {
187 self.message.headers.eta = Some(eta);
188 self
189 }
190
191 pub fn countdown(self, countdown: u32) -> Self {
192 let now = DateTime::<Utc>::from(SystemTime::now());
193 let eta = now + Duration::seconds(countdown as i64);
194 self.eta(eta)
195 }
196
197 pub fn expires(mut self, expires: DateTime<Utc>) -> Self {
198 self.message.headers.expires = Some(expires);
199 self
200 }
201
202 pub fn expires_in(self, expires_in: Duration) -> Self {
203 let now = DateTime::<Utc>::from(SystemTime::now());
204 let expires = now + expires_in;
205 self.expires(expires)
206 }
207
208 pub fn params(mut self, params: T::Params) -> Self {
209 self.params = Some(params);
210 self
211 }
212
213 pub fn build(mut self) -> Result<Message, ProtocolError> {
215 if let Some(params) = self.params.take() {
216 let body = MessageBody::<T>::new(params);
217
218 let raw_body = match self.message.properties.content_type.as_str() {
219 "application/json" => serde_json::to_vec(&body)?,
220 #[cfg(any(test, feature = "extra_content_types"))]
221 "application/x-yaml" => {
222 let mut vec = Vec::with_capacity(128);
223 serde_yaml::to_writer(&mut vec, &body)?;
224 vec
225 }
226 #[cfg(any(test, feature = "extra_content_types"))]
227 "application/x-python-serialize" => {
228 serde_pickle::to_vec(&body, serde_pickle::SerOptions::new())?
229 }
230 #[cfg(any(test, feature = "extra_content_types"))]
231 "application/x-msgpack" => rmp_serde::to_vec(&body)?,
232 _ => {
233 return Err(ProtocolError::BodySerializationError(
234 ContentTypeError::Unknown,
235 ));
236 }
237 };
238 self.message.raw_body = raw_body;
239 };
240 Ok(self.message)
241 }
242}
243
244#[derive(Eq, PartialEq, Debug, Clone)]
251pub struct Message {
252 pub properties: MessageProperties,
254
255 pub headers: MessageHeaders,
257
258 pub raw_body: Vec<u8>,
260}
261
262impl Message {
263 pub fn body<T: Task>(&self) -> Result<MessageBody<T>, ProtocolError> {
265 match self.properties.content_type.as_str() {
266 "application/json" => {
267 let value: Value = from_slice(&self.raw_body)?;
268 debug!("Deserialized message body: {:?}", value);
269 if let Value::Array(ref vec) = value {
270 if let [Value::Array(ref args), Value::Object(ref kwargs), Value::Object(ref embed)] =
271 vec[..]
272 {
273 if !args.is_empty() {
274 let mut kwargs = kwargs.clone();
276 let embed = embed.clone();
277 let arg_names = T::ARGS;
278 for (i, arg) in args.iter().enumerate() {
279 if let Some(arg_name) = arg_names.get(i) {
280 kwargs.insert((*arg_name).into(), arg.clone());
281 } else {
282 break;
283 }
284 }
285 return Ok(MessageBody(
286 vec![],
287 from_value::<T::Params>(Value::Object(kwargs))?,
288 from_value::<MessageBodyEmbed>(Value::Object(embed))?,
289 ));
290 }
291 }
292 }
293 Ok(from_value::<MessageBody<T>>(value)?)
294 }
295 #[cfg(any(test, feature = "extra_content_types"))]
296 "application/x-yaml" => {
297 use serde_yaml::{from_slice, from_value, Value};
298 let value: Value = from_slice(&self.raw_body)?;
299 debug!("Deserialized message body: {:?}", value);
300 if let Value::Sequence(ref vec) = value {
301 if let [Value::Sequence(ref args), Value::Mapping(ref kwargs), Value::Mapping(ref embed)] =
302 vec[..]
303 {
304 if !args.is_empty() {
305 let mut kwargs = kwargs.clone();
307 let embed = embed.clone();
308 let arg_names = T::ARGS;
309 for (i, arg) in args.iter().enumerate() {
310 if let Some(arg_name) = arg_names.get(i) {
311 kwargs.insert((*arg_name).into(), arg.clone());
312 } else {
313 break;
314 }
315 }
316 return Ok(MessageBody(
317 vec![],
318 from_value::<T::Params>(Value::Mapping(kwargs))?,
319 from_value::<MessageBodyEmbed>(Value::Mapping(embed))?,
320 ));
321 }
322 }
323 }
324 Ok(from_value(value)?)
325 }
326 #[cfg(any(test, feature = "extra_content_types"))]
327 "application/x-python-serialize" => {
328 use serde_pickle::{from_slice, from_value, DeOptions, HashableValue, Value};
329 let value: Value = from_slice(&self.raw_body, DeOptions::new())?;
330 if let Value::List(ref vec) = value {
332 if let [Value::List(ref args), Value::Dict(ref kwargs), Value::Dict(ref embed)] =
333 vec[..]
334 {
335 if !args.is_empty() {
336 let mut kwargs = kwargs.clone();
338 let embed = embed.clone();
339 let arg_names = T::ARGS;
340 for (i, arg) in args.iter().enumerate() {
341 if let Some(arg_name) = arg_names.get(i) {
342 let key = HashableValue::String((*arg_name).into());
343 kwargs.insert(key, arg.clone());
344 } else {
345 break;
346 }
347 }
348 return Ok(MessageBody(
349 vec![],
350 from_value::<T::Params>(Value::Dict(kwargs))?,
351 from_value::<MessageBodyEmbed>(Value::Dict(embed))?,
352 ));
353 }
354 }
355 }
356 Ok(from_value(value)?)
357 }
358 #[cfg(any(test, feature = "extra_content_types"))]
359 "application/x-msgpack" => {
360 use rmp_serde::from_slice;
361 use rmpv::{ext::from_value, Value};
362 let value: Value = from_slice(&self.raw_body)?;
363 debug!("Deserialized message body: {:?}", value);
364 if let Value::Array(ref vec) = value {
365 if let [Value::Array(ref args), Value::Map(ref kwargs), Value::Map(ref embed)] =
366 vec[..]
367 {
368 if !args.is_empty() {
369 let mut kwargs = kwargs.clone();
371 let embed = embed.clone();
372 let arg_names = T::ARGS;
373 for (i, arg) in args.iter().enumerate() {
374 if let Some(arg_name) = arg_names.get(i) {
375 let existing_entry = kwargs
380 .iter()
381 .enumerate()
382 .filter(|(_, (key, _))| {
383 if let Value::String(key) = key {
384 if let Some(key) = key.as_str() {
385 key == *arg_name
386 } else {
387 false
388 }
389 } else {
390 false
391 }
392 })
393 .map(|(i, _)| i)
394 .next();
395 if let Some(index) = existing_entry {
396 kwargs[index] = ((*arg_name).into(), arg.clone());
397 } else {
398 kwargs.push(((*arg_name).into(), arg.clone()));
399 }
400 } else {
401 break;
402 }
403 }
404 return Ok(MessageBody(
405 vec![],
406 from_value::<T::Params>(Value::Map(kwargs))?,
407 from_value::<MessageBodyEmbed>(Value::Map(embed))?,
408 ));
409 }
410 }
411 }
412 Ok(from_value(value)?)
413 }
414 _ => Err(ProtocolError::BodySerializationError(
415 ContentTypeError::Unknown,
416 )),
417 }
418 }
419
420 pub fn task_id(&self) -> &str {
422 &self.headers.id
423 }
424
425 pub fn json_serialized(
426 &self,
427 delivery_info: Option<DeliveryInfo>,
428 ) -> Result<Vec<u8>, ProtocolError> {
429 let root_id = match &self.headers.root_id {
430 Some(root_id) => json!(root_id.clone()),
431 None => Value::Null,
432 };
433 let reply_to = match &self.properties.reply_to {
434 Some(reply_to) => json!(reply_to.clone()),
435 None => Value::Null,
436 };
437 let eta = match self.headers.eta {
438 Some(time) => json!(time.to_rfc3339()),
439 None => Value::Null,
440 };
441 let expires = match self.headers.expires {
442 Some(time) => json!(time.to_rfc3339()),
443 None => Value::Null,
444 };
445 let mut buffer = Uuid::encode_buffer();
446 let uuid = Uuid::new_v4().hyphenated().encode_lower(&mut buffer);
447 let delivery_tag = uuid.to_owned();
448 let msg_json_value = json!({
449 "body": ENGINE.encode(self.raw_body.clone()),
450 "content-encoding": self.properties.content_encoding.clone(),
451 "content-type": self.properties.content_type.clone(),
452 "headers": {
453 "id": self.headers.id.clone(),
454 "task": self.headers.task.clone(),
455 "lang": self.headers.lang.clone(),
456 "root_id": root_id,
457 "parent_id": self.headers.parent_id.clone(),
458 "group": self.headers.group.clone(),
459 "meth": self.headers.meth.clone(),
460 "shadow": self.headers.shadow.clone(),
461 "eta": eta,
462 "expires": expires,
463 "retries": self.headers.retries.clone(),
464 "timelimit": self.headers.timelimit.clone(),
465 "argsrepr": self.headers.argsrepr.clone(),
466 "kwargsrepr": self.headers.kwargsrepr.clone(),
467 "origin": self.headers.origin.clone()
468 },
469 "properties": json!({
470 "correlation_id": self.properties.correlation_id.clone(),
471 "reply_to": reply_to,
472 "delivery_tag": delivery_tag,
473 "body_encoding": "base64",
474 "delivery_info": self.properties.delivery_info.clone().or(delivery_info).and_then(|i| serde_json::to_value(i).ok()).unwrap_or(Value::Null)
475 })
476 });
477 let res = serde_json::to_string(&msg_json_value)?;
478 Ok(res.into_bytes())
479 }
480}
481
482impl<T> TryFrom<Signature<T>> for Message
483where
484 T: Task,
485{
486 type Error = ProtocolError;
487
488 fn try_from(mut task_sig: Signature<T>) -> Result<Self, Self::Error> {
490 let mut buffer = Uuid::encode_buffer();
492 let uuid = Uuid::new_v4().hyphenated().encode_lower(&mut buffer);
493 let id = uuid.to_owned();
494
495 let mut builder = MessageBuilder::<T>::new(id);
496
497 if let Some(countdown) = task_sig.countdown.take() {
499 builder = builder.countdown(countdown);
500 if task_sig.eta.is_some() {
501 warn!(
502 "Task {} specified both a 'countdown' and an 'eta'. Ignoring 'eta'.",
503 T::NAME
504 )
505 }
506 } else if let Some(eta) = task_sig.eta.take() {
507 builder = builder.eta(eta);
508 }
509
510 match (
513 task_sig.expires_in.take(),
514 task_sig.expires.take(),
515 task_sig.options.expires,
516 ) {
517 (Some(expires_in), None, None) => {
518 builder = builder.expires_in(Duration::seconds(expires_in as i64));
519 }
520 (Some(_), Some(expires), None) => {
521 warn!(
522 "Task {} specified both 'expires_in' and 'expires'. Ignoring 'expires'.",
523 T::NAME
524 );
525 builder = builder.expires(expires);
526 }
527 (None, Some(expires), None) => {
528 builder = builder.expires(expires);
529 }
530 (None, None, Some(expires)) => {
531 builder =
532 builder.expires_in(Duration::from_std(expires).unwrap_or(Duration::zero()))
533 }
534 _ => {}
535 };
536
537 #[cfg(any(test, feature = "extra_content_types"))]
538 if let Some(content_type) = task_sig.options.content_type {
539 builder = builder.content_type(content_type);
540 }
541
542 if let Some(time_limit) = task_sig.options.time_limit.take() {
543 builder = builder.time_limit(time_limit);
544 }
545
546 if let Some(time_limit) = task_sig.options.hard_time_limit.take() {
547 builder = builder.hard_time_limit(time_limit);
548 }
549
550 builder.params(task_sig.params).build()
551 }
552}
553
554pub trait TryCreateMessage {
558 fn try_create_message(&self) -> Result<Message, ProtocolError>;
559}
560
561impl<T> TryCreateMessage for Signature<T>
562where
563 T: Task + Clone,
564{
565 fn try_create_message(&self) -> Result<Message, ProtocolError> {
568 Message::try_from(self.clone())
569 }
570}
571
572pub trait TryDeserializeMessage {
575 fn try_deserialize_message(&self) -> Result<Message, ProtocolError>;
576}
577
578#[derive(Eq, PartialEq, Debug, Clone)]
580pub struct MessageProperties {
581 pub correlation_id: String,
583
584 pub content_type: String,
586
587 pub content_encoding: String,
589
590 pub reply_to: Option<String>,
592
593 pub delivery_info: Option<DeliveryInfo>,
595}
596
597#[derive(Eq, PartialEq, Debug, Clone, Serialize)]
601pub struct DeliveryInfo {
602 pub exchange: String,
603 pub routing_key: String,
604}
605
606impl DeliveryInfo {
607 pub fn for_redis_default() -> Self {
608 Self {
609 exchange: String::new(),
610 routing_key: "celery".to_string(),
611 }
612 }
613}
614
615#[derive(Eq, PartialEq, Debug, Default, Deserialize, Clone)]
617pub struct MessageHeaders {
618 pub id: String,
620
621 pub task: String,
623
624 pub lang: Option<String>,
626
627 pub root_id: Option<String>,
629
630 pub parent_id: Option<String>,
632
633 pub group: Option<String>,
635
636 pub meth: Option<String>,
638
639 pub shadow: Option<String>,
641
642 pub eta: Option<DateTime<Utc>>,
644
645 pub expires: Option<DateTime<Utc>>,
648
649 pub retries: Option<u32>,
651
652 pub timelimit: (Option<u32>, Option<u32>),
657
658 pub argsrepr: Option<String>,
660
661 pub kwargsrepr: Option<String>,
663
664 pub origin: Option<String>,
666}
667
668#[derive(Eq, PartialEq, Debug, Serialize, Deserialize)]
671pub struct MessageBody<T: Task>(Vec<u8>, pub(crate) T::Params, pub(crate) MessageBodyEmbed);
672
673impl<T> MessageBody<T>
674where
675 T: Task,
676{
677 pub fn new(params: T::Params) -> Self {
678 Self(vec![], params, MessageBodyEmbed::default())
679 }
680
681 pub fn parts(self) -> (T::Params, MessageBodyEmbed) {
682 (self.1, self.2)
683 }
684}
685
686#[derive(Eq, PartialEq, Debug, Default, Serialize, Deserialize)]
688pub struct MessageBodyEmbed {
689 #[serde(default)]
691 pub callbacks: Option<Vec<String>>,
692
693 #[serde(default)]
699 pub errbacks: Option<Vec<String>>,
700
701 #[serde(default)]
703 pub chain: Option<Vec<String>>,
704
705 #[serde(default)]
707 pub chord: Option<String>,
708}
709
710#[derive(Debug, Clone, Deserialize)]
711#[serde(rename_all = "snake_case")]
712pub enum BodyEncoding {
713 Base64,
714}
715#[derive(Debug, Clone, Deserialize)]
716pub struct DeliveryProperties {
717 pub correlation_id: String,
718 pub reply_to: Option<String>,
719 pub delivery_tag: String,
720 pub body_encoding: BodyEncoding,
721}
722
723#[derive(Debug, Deserialize, Clone)]
724pub struct Delivery {
725 pub body: String,
726 #[serde(rename = "content-encoding")]
727 pub content_encoding: String,
728 #[serde(rename = "content-type")]
729 pub content_type: String,
730 pub headers: MessageHeaders,
731 pub properties: DeliveryProperties,
732}
733
734impl TryDeserializeMessage for Delivery {
735 fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
736 let raw_body = match self.properties.body_encoding {
737 BodyEncoding::Base64 => ENGINE
738 .decode(self.body.clone())
739 .map_err(|e| ProtocolError::InvalidProperty(format!("body error: {e}")))?,
740 };
741 Ok(Message {
742 properties: MessageProperties {
743 correlation_id: self.properties.correlation_id.clone(),
744 content_type: self.content_type.clone(),
745 content_encoding: self.content_encoding.clone(),
746 reply_to: self.properties.reply_to.clone(),
747 delivery_info: None,
748 },
749 headers: MessageHeaders {
750 id: self.headers.id.clone(),
751 task: self.headers.task.clone(),
752 lang: self.headers.lang.clone(),
753 root_id: self.headers.root_id.clone(),
754 parent_id: self.headers.parent_id.clone(),
755 group: self.headers.group.clone(),
756 meth: self.headers.meth.clone(),
757 shadow: self.headers.shadow.clone(),
758 eta: self.headers.eta,
759 expires: self.headers.expires,
760 retries: self.headers.retries,
761 timelimit: self.headers.timelimit,
762 argsrepr: self.headers.argsrepr.clone(),
763 kwargsrepr: self.headers.kwargsrepr.clone(),
764 origin: self.headers.origin.clone(),
765 },
766 raw_body,
767 })
768 }
769}
770
771#[cfg(test)]
772mod tests;