1use base64::{
8 alphabet,
9 engine::{general_purpose::PAD, GeneralPurpose},
10 Engine,
11};
12use chrono::{DateTime, Duration, Utc};
13use log::{debug, warn};
14use once_cell::sync::Lazy;
15use serde::{Deserialize, Serialize};
16use serde_json::{from_slice, from_value, json, Value};
17use std::convert::TryFrom;
18use std::process;
19use std::time::SystemTime;
20use uuid::Uuid;
21
22use crate::error::{ContentTypeError, ProtocolError};
23use crate::task::{Signature, Task};
24
25pub(crate) const ENGINE: GeneralPurpose = GeneralPurpose::new(&alphabet::STANDARD, PAD);
26
27static ORIGIN: Lazy<Option<String>> = Lazy::new(|| {
28 hostname::get()
29 .ok()
30 .and_then(|sys_hostname| sys_hostname.into_string().ok())
31 .map(|sys_hostname| format!("gen{}@{}", process::id(), sys_hostname))
32});
33
34#[derive(Default, Copy, Clone)]
36pub enum MessageContentType {
37 #[default]
38 Json,
39 Yaml,
40 Pickle,
41 MsgPack,
42}
43
44pub struct MessageBuilder<T>
46where
47 T: Task,
48{
49 message: Message,
50 params: Option<T::Params>,
51}
52
53impl<T> MessageBuilder<T>
54where
55 T: Task,
56{
57 pub fn new(id: String) -> Self {
59 Self {
60 message: Message {
61 properties: MessageProperties {
62 correlation_id: id.clone(),
63 content_type: "application/json".into(),
64 content_encoding: "utf-8".into(),
65 reply_to: None,
66 delivery_info: None,
67 },
68 headers: MessageHeaders {
69 id,
70 task: T::NAME.into(),
71 origin: ORIGIN.to_owned(),
72 ..Default::default()
73 },
74 raw_body: Vec::new(),
75 },
76 params: None,
77 }
78 }
79 #[cfg(any(test, feature = "extra_content_types"))]
83 pub fn content_type(mut self, content_type: MessageContentType) -> Self {
84 use MessageContentType::*;
85 let content_type_name = match content_type {
86 Json => "application/json",
87 Yaml => "application/x-yaml",
88 Pickle => "application/x-python-serialize",
89 MsgPack => "application/x-msgpack",
90 };
91 self.message.properties.content_type = content_type_name.into();
92 self
93 }
94
95 pub fn content_encoding(mut self, content_encoding: String) -> Self {
96 self.message.properties.content_encoding = content_encoding;
97 self
98 }
99
100 pub fn correlation_id(mut self, correlation_id: String) -> Self {
101 self.message.properties.correlation_id = correlation_id;
102 self
103 }
104
105 pub fn reply_to(mut self, reply_to: String) -> Self {
106 self.message.properties.reply_to = Some(reply_to);
107 self
108 }
109
110 pub fn delivery_info(mut self, delivery_info: DeliveryInfo) -> Self {
111 self.message.properties.delivery_info = Some(delivery_info);
112 self
113 }
114
115 pub fn id(mut self, id: String) -> Self {
116 self.message.headers.id = id;
117 self
118 }
119
120 pub fn task(mut self, task: String) -> Self {
121 self.message.headers.task = task;
122 self
123 }
124
125 pub fn lang(mut self, lang: String) -> Self {
126 self.message.headers.lang = Some(lang);
127 self
128 }
129
130 pub fn root_id(mut self, root_id: String) -> Self {
131 self.message.headers.root_id = Some(root_id);
132 self
133 }
134
135 pub fn parent_id(mut self, parent_id: String) -> Self {
136 self.message.headers.parent_id = Some(parent_id);
137 self
138 }
139
140 pub fn group(mut self, group: String) -> Self {
141 self.message.headers.group = Some(group);
142 self
143 }
144
145 pub fn meth(mut self, meth: String) -> Self {
146 self.message.headers.meth = Some(meth);
147 self
148 }
149
150 pub fn shadow(mut self, shadow: String) -> Self {
151 self.message.headers.shadow = Some(shadow);
152 self
153 }
154
155 pub fn retries(mut self, retries: u32) -> Self {
156 self.message.headers.retries = Some(retries);
157 self
158 }
159
160 pub fn argsrepr(mut self, argsrepr: String) -> Self {
161 self.message.headers.argsrepr = Some(argsrepr);
162 self
163 }
164
165 pub fn kwargsrepr(mut self, kwargsrepr: String) -> Self {
166 self.message.headers.kwargsrepr = Some(kwargsrepr);
167 self
168 }
169
170 pub fn origin(mut self, origin: String) -> Self {
171 self.message.headers.origin = Some(origin);
172 self
173 }
174
175 pub fn time_limit(mut self, time_limit: u32) -> Self {
176 self.message.headers.timelimit.1 = Some(time_limit);
177 self
178 }
179
180 pub fn hard_time_limit(mut self, time_limit: u32) -> Self {
181 self.message.headers.timelimit.0 = Some(time_limit);
182 self
183 }
184
185 pub fn eta(mut self, eta: DateTime<Utc>) -> Self {
186 self.message.headers.eta = Some(eta);
187 self
188 }
189
190 pub fn countdown(self, countdown: u32) -> Self {
191 let now = DateTime::<Utc>::from(SystemTime::now());
192 let eta = now + Duration::seconds(countdown as i64);
193 self.eta(eta)
194 }
195
196 pub fn expires(mut self, expires: DateTime<Utc>) -> Self {
197 self.message.headers.expires = Some(expires);
198 self
199 }
200
201 pub fn expires_in(self, expires_in: Duration) -> Self {
202 let now = DateTime::<Utc>::from(SystemTime::now());
203 let expires = now + expires_in;
204 self.expires(expires)
205 }
206
207 pub fn params(mut self, params: T::Params) -> Self {
208 self.params = Some(params);
209 self
210 }
211
212 pub fn build(mut self) -> Result<Message, ProtocolError> {
214 if let Some(params) = self.params.take() {
215 let body = MessageBody::<T>::new(params);
216
217 let raw_body = match self.message.properties.content_type.as_str() {
218 "application/json" => serde_json::to_vec(&body)?,
219 #[cfg(any(test, feature = "extra_content_types"))]
220 "application/x-yaml" => {
221 let mut vec = Vec::with_capacity(128);
222 serde_yaml::to_writer(&mut vec, &body)?;
223 vec
224 }
225 #[cfg(any(test, feature = "extra_content_types"))]
226 "application/x-python-serialize" => {
227 serde_pickle::to_vec(&body, serde_pickle::SerOptions::new())?
228 }
229 #[cfg(any(test, feature = "extra_content_types"))]
230 "application/x-msgpack" => rmp_serde::to_vec(&body)?,
231 _ => {
232 return Err(ProtocolError::BodySerializationError(
233 ContentTypeError::Unknown,
234 ));
235 }
236 };
237 self.message.raw_body = raw_body;
238 };
239 Ok(self.message)
240 }
241}
242
243#[derive(Eq, PartialEq, Debug, Clone)]
250pub struct Message {
251 pub properties: MessageProperties,
253
254 pub headers: MessageHeaders,
256
257 pub raw_body: Vec<u8>,
259}
260
261impl Message {
262 pub fn body<T: Task>(&self) -> Result<MessageBody<T>, ProtocolError> {
264 match self.properties.content_type.as_str() {
265 "application/json" => {
266 let value: Value = from_slice(&self.raw_body)?;
267 debug!("Deserialized message body: {:?}", value);
268 if let Value::Array(ref vec) = value {
269 if let [Value::Array(ref args), Value::Object(ref kwargs), Value::Object(ref embed)] =
270 vec[..]
271 {
272 if !args.is_empty() {
273 let mut kwargs = kwargs.clone();
275 let embed = embed.clone();
276 let arg_names = T::ARGS;
277 for (i, arg) in args.iter().enumerate() {
278 if let Some(arg_name) = arg_names.get(i) {
279 kwargs.insert((*arg_name).into(), arg.clone());
280 } else {
281 break;
282 }
283 }
284 return Ok(MessageBody(
285 vec![],
286 from_value::<T::Params>(Value::Object(kwargs))?,
287 from_value::<MessageBodyEmbed>(Value::Object(embed))?,
288 ));
289 }
290 }
291 }
292 Ok(from_value::<MessageBody<T>>(value)?)
293 }
294 #[cfg(any(test, feature = "extra_content_types"))]
295 "application/x-yaml" => {
296 use serde_yaml::{from_slice, from_value, Value};
297 let value: Value = from_slice(&self.raw_body)?;
298 debug!("Deserialized message body: {:?}", value);
299 if let Value::Sequence(ref vec) = value {
300 if let [Value::Sequence(ref args), Value::Mapping(ref kwargs), Value::Mapping(ref embed)] =
301 vec[..]
302 {
303 if !args.is_empty() {
304 let mut kwargs = kwargs.clone();
306 let embed = embed.clone();
307 let arg_names = T::ARGS;
308 for (i, arg) in args.iter().enumerate() {
309 if let Some(arg_name) = arg_names.get(i) {
310 kwargs.insert((*arg_name).into(), arg.clone());
311 } else {
312 break;
313 }
314 }
315 return Ok(MessageBody(
316 vec![],
317 from_value::<T::Params>(Value::Mapping(kwargs))?,
318 from_value::<MessageBodyEmbed>(Value::Mapping(embed))?,
319 ));
320 }
321 }
322 }
323 Ok(from_value(value)?)
324 }
325 #[cfg(any(test, feature = "extra_content_types"))]
326 "application/x-python-serialize" => {
327 use serde_pickle::{from_slice, from_value, DeOptions, HashableValue, Value};
328 let value: Value = from_slice(&self.raw_body, DeOptions::new())?;
329 if let Value::List(ref vec) = value {
331 if let [Value::List(ref args), Value::Dict(ref kwargs), Value::Dict(ref embed)] =
332 vec[..]
333 {
334 if !args.is_empty() {
335 let mut kwargs = kwargs.clone();
337 let embed = embed.clone();
338 let arg_names = T::ARGS;
339 for (i, arg) in args.iter().enumerate() {
340 if let Some(arg_name) = arg_names.get(i) {
341 let key = HashableValue::String((*arg_name).into());
342 kwargs.insert(key, arg.clone());
343 } else {
344 break;
345 }
346 }
347 return Ok(MessageBody(
348 vec![],
349 from_value::<T::Params>(Value::Dict(kwargs))?,
350 from_value::<MessageBodyEmbed>(Value::Dict(embed))?,
351 ));
352 }
353 }
354 }
355 Ok(from_value(value)?)
356 }
357 #[cfg(any(test, feature = "extra_content_types"))]
358 "application/x-msgpack" => {
359 use rmp_serde::from_slice;
360 use rmpv::{ext::from_value, Value};
361 let value: Value = from_slice(&self.raw_body)?;
362 debug!("Deserialized message body: {:?}", value);
363 if let Value::Array(ref vec) = value {
364 if let [Value::Array(ref args), Value::Map(ref kwargs), Value::Map(ref embed)] =
365 vec[..]
366 {
367 if !args.is_empty() {
368 let mut kwargs = kwargs.clone();
370 let embed = embed.clone();
371 let arg_names = T::ARGS;
372 for (i, arg) in args.iter().enumerate() {
373 if let Some(arg_name) = arg_names.get(i) {
374 let existing_entry = kwargs
379 .iter()
380 .enumerate()
381 .filter(|(_, (key, _))| {
382 if let Value::String(key) = key {
383 if let Some(key) = key.as_str() {
384 key == *arg_name
385 } else {
386 false
387 }
388 } else {
389 false
390 }
391 })
392 .map(|(i, _)| i)
393 .next();
394 if let Some(index) = existing_entry {
395 kwargs[index] = ((*arg_name).into(), arg.clone());
396 } else {
397 kwargs.push(((*arg_name).into(), arg.clone()));
398 }
399 } else {
400 break;
401 }
402 }
403 return Ok(MessageBody(
404 vec![],
405 from_value::<T::Params>(Value::Map(kwargs))?,
406 from_value::<MessageBodyEmbed>(Value::Map(embed))?,
407 ));
408 }
409 }
410 }
411 Ok(from_value(value)?)
412 }
413 _ => Err(ProtocolError::BodySerializationError(
414 ContentTypeError::Unknown,
415 )),
416 }
417 }
418
419 pub fn task_id(&self) -> &str {
421 &self.headers.id
422 }
423
424 pub fn json_serialized(
425 &self,
426 delivery_info: Option<DeliveryInfo>,
427 ) -> Result<Vec<u8>, ProtocolError> {
428 let root_id = match &self.headers.root_id {
429 Some(root_id) => json!(root_id.clone()),
430 None => Value::Null,
431 };
432 let reply_to = match &self.properties.reply_to {
433 Some(reply_to) => json!(reply_to.clone()),
434 None => Value::Null,
435 };
436 let eta = match self.headers.eta {
437 Some(time) => json!(time.to_rfc3339()),
438 None => Value::Null,
439 };
440 let expires = match self.headers.expires {
441 Some(time) => json!(time.to_rfc3339()),
442 None => Value::Null,
443 };
444 let mut buffer = Uuid::encode_buffer();
445 let uuid = Uuid::new_v4().hyphenated().encode_lower(&mut buffer);
446 let delivery_tag = uuid.to_owned();
447 let msg_json_value = json!({
448 "body": ENGINE.encode(self.raw_body.clone()),
449 "content-encoding": self.properties.content_encoding.clone(),
450 "content-type": self.properties.content_type.clone(),
451 "headers": {
452 "id": self.headers.id.clone(),
453 "task": self.headers.task.clone(),
454 "lang": self.headers.lang.clone(),
455 "root_id": root_id,
456 "parent_id": self.headers.parent_id.clone(),
457 "group": self.headers.group.clone(),
458 "meth": self.headers.meth.clone(),
459 "shadow": self.headers.shadow.clone(),
460 "eta": eta,
461 "expires": expires,
462 "retries": self.headers.retries.clone(),
463 "timelimit": self.headers.timelimit.clone(),
464 "argsrepr": self.headers.argsrepr.clone(),
465 "kwargsrepr": self.headers.kwargsrepr.clone(),
466 "origin": self.headers.origin.clone()
467 },
468 "properties": json!({
469 "correlation_id": self.properties.correlation_id.clone(),
470 "reply_to": reply_to,
471 "delivery_tag": delivery_tag,
472 "body_encoding": "base64",
473 "delivery_info": self.properties.delivery_info.clone().or(delivery_info).and_then(|i| serde_json::to_value(i).ok()).unwrap_or(Value::Null)
474 })
475 });
476 let res = serde_json::to_string(&msg_json_value)?;
477 Ok(res.into_bytes())
478 }
479}
480
481impl<T> TryFrom<Signature<T>> for Message
482where
483 T: Task,
484{
485 type Error = ProtocolError;
486
487 fn try_from(mut task_sig: Signature<T>) -> Result<Self, Self::Error> {
489 let mut buffer = Uuid::encode_buffer();
491 let uuid = Uuid::new_v4().hyphenated().encode_lower(&mut buffer);
492 let id = uuid.to_owned();
493
494 let mut builder = MessageBuilder::<T>::new(id);
495
496 if let Some(countdown) = task_sig.countdown.take() {
498 builder = builder.countdown(countdown);
499 if task_sig.eta.is_some() {
500 warn!(
501 "Task {} specified both a 'countdown' and an 'eta'. Ignoring 'eta'.",
502 T::NAME
503 )
504 }
505 } else if let Some(eta) = task_sig.eta.take() {
506 builder = builder.eta(eta);
507 }
508
509 match (
512 task_sig.expires_in.take(),
513 task_sig.expires.take(),
514 task_sig.options.expires,
515 ) {
516 (Some(expires_in), None, None) => {
517 builder = builder.expires_in(Duration::seconds(expires_in as i64));
518 }
519 (Some(_), Some(expires), None) => {
520 warn!(
521 "Task {} specified both 'expires_in' and 'expires'. Ignoring 'expires'.",
522 T::NAME
523 );
524 builder = builder.expires(expires);
525 }
526 (None, Some(expires), None) => {
527 builder = builder.expires(expires);
528 }
529 (None, None, Some(expires)) => {
530 builder =
531 builder.expires_in(Duration::from_std(expires).unwrap_or(Duration::zero()))
532 }
533 _ => {}
534 };
535
536 #[cfg(any(test, feature = "extra_content_types"))]
537 if let Some(content_type) = task_sig.options.content_type {
538 builder = builder.content_type(content_type);
539 }
540
541 if let Some(time_limit) = task_sig.options.time_limit.take() {
542 builder = builder.time_limit(time_limit);
543 }
544
545 if let Some(time_limit) = task_sig.options.hard_time_limit.take() {
546 builder = builder.hard_time_limit(time_limit);
547 }
548
549 builder.params(task_sig.params).build()
550 }
551}
552
553pub trait TryCreateMessage {
557 fn try_create_message(&self) -> Result<Message, ProtocolError>;
558}
559
560impl<T> TryCreateMessage for Signature<T>
561where
562 T: Task + Clone,
563{
564 fn try_create_message(&self) -> Result<Message, ProtocolError> {
567 Message::try_from(self.clone())
568 }
569}
570
571pub trait TryDeserializeMessage {
574 fn try_deserialize_message(&self) -> Result<Message, ProtocolError>;
575}
576
577#[derive(Eq, PartialEq, Debug, Clone)]
579pub struct MessageProperties {
580 pub correlation_id: String,
582
583 pub content_type: String,
585
586 pub content_encoding: String,
588
589 pub reply_to: Option<String>,
591
592 pub delivery_info: Option<DeliveryInfo>,
594}
595
596#[derive(Eq, PartialEq, Debug, Clone, Serialize)]
600pub struct DeliveryInfo {
601 pub exchange: String,
602 pub routing_key: String,
603}
604
605impl DeliveryInfo {
606 pub fn for_redis_default() -> Self {
607 Self {
608 exchange: String::new(),
609 routing_key: "celery".to_string(),
610 }
611 }
612}
613
614#[derive(Eq, PartialEq, Debug, Default, Deserialize, Clone)]
616pub struct MessageHeaders {
617 pub id: String,
619
620 pub task: String,
622
623 pub lang: Option<String>,
625
626 pub root_id: Option<String>,
628
629 pub parent_id: Option<String>,
631
632 pub group: Option<String>,
634
635 pub meth: Option<String>,
637
638 pub shadow: Option<String>,
640
641 pub eta: Option<DateTime<Utc>>,
643
644 pub expires: Option<DateTime<Utc>>,
647
648 pub retries: Option<u32>,
650
651 pub timelimit: (Option<u32>, Option<u32>),
656
657 pub argsrepr: Option<String>,
659
660 pub kwargsrepr: Option<String>,
662
663 pub origin: Option<String>,
665}
666
667#[derive(Eq, PartialEq, Debug, Serialize, Deserialize)]
670pub struct MessageBody<T: Task>(Vec<u8>, pub(crate) T::Params, pub(crate) MessageBodyEmbed);
671
672impl<T> MessageBody<T>
673where
674 T: Task,
675{
676 pub fn new(params: T::Params) -> Self {
677 Self(vec![], params, MessageBodyEmbed::default())
678 }
679
680 pub fn parts(self) -> (T::Params, MessageBodyEmbed) {
681 (self.1, self.2)
682 }
683}
684
685#[derive(Eq, PartialEq, Debug, Default, Serialize, Deserialize)]
687pub struct MessageBodyEmbed {
688 #[serde(default)]
690 pub callbacks: Option<Vec<String>>,
691
692 #[serde(default)]
698 pub errbacks: Option<Vec<String>>,
699
700 #[serde(default)]
702 pub chain: Option<Vec<String>>,
703
704 #[serde(default)]
706 pub chord: Option<String>,
707}
708
709#[derive(Debug, Clone, Deserialize)]
710#[serde(rename_all = "snake_case")]
711pub enum BodyEncoding {
712 Base64,
713}
714#[derive(Debug, Clone, Deserialize)]
715pub struct DeliveryProperties {
716 pub correlation_id: String,
717 pub reply_to: Option<String>,
718 pub delivery_tag: String,
719 pub body_encoding: BodyEncoding,
720}
721
722#[derive(Debug, Deserialize, Clone)]
723pub struct Delivery {
724 pub body: String,
725 #[serde(rename = "content-encoding")]
726 pub content_encoding: String,
727 #[serde(rename = "content-type")]
728 pub content_type: String,
729 pub headers: MessageHeaders,
730 pub properties: DeliveryProperties,
731}
732
733impl TryDeserializeMessage for Delivery {
734 fn try_deserialize_message(&self) -> Result<Message, ProtocolError> {
735 let raw_body = match self.properties.body_encoding {
736 BodyEncoding::Base64 => ENGINE
737 .decode(self.body.clone())
738 .map_err(|e| ProtocolError::InvalidProperty(format!("body error: {e}")))?,
739 };
740 Ok(Message {
741 properties: MessageProperties {
742 correlation_id: self.properties.correlation_id.clone(),
743 content_type: self.content_type.clone(),
744 content_encoding: self.content_encoding.clone(),
745 reply_to: self.properties.reply_to.clone(),
746 delivery_info: None,
747 },
748 headers: MessageHeaders {
749 id: self.headers.id.clone(),
750 task: self.headers.task.clone(),
751 lang: self.headers.lang.clone(),
752 root_id: self.headers.root_id.clone(),
753 parent_id: self.headers.parent_id.clone(),
754 group: self.headers.group.clone(),
755 meth: self.headers.meth.clone(),
756 shadow: self.headers.shadow.clone(),
757 eta: self.headers.eta,
758 expires: self.headers.expires,
759 retries: self.headers.retries,
760 timelimit: self.headers.timelimit,
761 argsrepr: self.headers.argsrepr.clone(),
762 kwargsrepr: self.headers.kwargsrepr.clone(),
763 origin: self.headers.origin.clone(),
764 },
765 raw_body,
766 })
767 }
768}
769
770#[cfg(test)]
771mod tests;