1use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::time::Duration;
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct CeleryConfig {
31 pub broker_url: String,
33
34 #[serde(skip_serializing_if = "Option::is_none")]
36 pub result_backend: Option<String>,
37
38 #[serde(default = "default_serializer")]
40 pub task_serializer: String,
41
42 #[serde(default = "default_serializer")]
44 pub result_serializer: String,
45
46 #[serde(default = "default_accept_content")]
48 pub accept_content: Vec<String>,
49
50 #[serde(default = "default_timezone")]
52 pub timezone: String,
53
54 #[serde(default = "default_true")]
56 pub enable_utc: bool,
57
58 #[serde(default)]
60 pub task_track_started: bool,
61
62 #[serde(default)]
64 pub task_send_sent_event: bool,
65
66 #[serde(default)]
68 pub task_acks_late: bool,
69
70 #[serde(default)]
72 pub task_reject_on_worker_lost: bool,
73
74 #[serde(default = "default_concurrency")]
76 pub worker_concurrency: usize,
77
78 #[serde(default = "default_prefetch_multiplier")]
80 pub worker_prefetch_multiplier: usize,
81
82 #[serde(skip_serializing_if = "Option::is_none")]
84 pub worker_max_tasks_per_child: Option<usize>,
85
86 #[serde(skip_serializing_if = "Option::is_none")]
88 pub worker_max_memory_per_child: Option<usize>,
89
90 #[serde(default = "default_heartbeat_interval")]
92 pub worker_heartbeat: u64,
93
94 #[serde(default = "default_queue_name")]
96 pub task_default_queue: String,
97
98 #[serde(default = "default_queue_name")]
100 pub task_default_exchange: String,
101
102 #[serde(default = "default_exchange_type")]
104 pub task_default_exchange_type: String,
105
106 #[serde(default = "default_queue_name")]
108 pub task_default_routing_key: String,
109
110 #[serde(default)]
112 pub task_routes: HashMap<String, TaskRoute>,
113
114 #[serde(skip_serializing_if = "Option::is_none")]
116 pub task_time_limit: Option<u64>,
117
118 #[serde(skip_serializing_if = "Option::is_none")]
120 pub task_soft_time_limit: Option<u64>,
121
122 #[serde(default = "default_retry_delay")]
124 pub task_default_retry_delay: u64,
125
126 #[serde(default = "default_max_retries")]
128 pub task_max_retries: u32,
129
130 #[serde(default = "default_result_expires")]
132 pub result_expires: u64,
133
134 #[serde(skip_serializing_if = "Option::is_none")]
136 pub result_compression: Option<String>,
137
138 #[serde(default = "default_compression_threshold")]
140 pub result_compression_threshold: usize,
141
142 #[serde(default)]
144 pub task_annotations: HashMap<String, TaskConfig>,
145
146 #[serde(default)]
148 pub broker_transport_options: BrokerTransport,
149
150 #[serde(default)]
152 pub result_backend_transport_options: BackendTransport,
153
154 #[serde(default)]
156 pub beat_schedule: HashMap<String, BeatSchedule>,
157
158 #[serde(flatten)]
160 pub custom: HashMap<String, serde_json::Value>,
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct TaskRoute {
166 pub queue: String,
168
169 #[serde(skip_serializing_if = "Option::is_none")]
171 pub exchange: Option<String>,
172
173 #[serde(skip_serializing_if = "Option::is_none")]
175 pub routing_key: Option<String>,
176
177 #[serde(skip_serializing_if = "Option::is_none")]
179 pub priority: Option<u8>,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct TaskConfig {
185 #[serde(skip_serializing_if = "Option::is_none")]
187 pub time_limit: Option<u64>,
188
189 #[serde(skip_serializing_if = "Option::is_none")]
191 pub soft_time_limit: Option<u64>,
192
193 #[serde(skip_serializing_if = "Option::is_none")]
195 pub max_retries: Option<u32>,
196
197 #[serde(skip_serializing_if = "Option::is_none")]
199 pub default_retry_delay: Option<u64>,
200
201 #[serde(skip_serializing_if = "Option::is_none")]
203 pub priority: Option<u8>,
204
205 #[serde(skip_serializing_if = "Option::is_none")]
207 pub queue: Option<String>,
208
209 #[serde(skip_serializing_if = "Option::is_none")]
211 pub acks_late: Option<bool>,
212
213 #[serde(skip_serializing_if = "Option::is_none")]
215 pub track_started: Option<bool>,
216
217 #[serde(skip_serializing_if = "Option::is_none")]
219 pub rate_limit: Option<String>,
220}
221
222#[derive(Debug, Clone, Default, Serialize, Deserialize)]
224pub struct BrokerTransport {
225 #[serde(skip_serializing_if = "Option::is_none")]
227 pub visibility_timeout: Option<u64>,
228
229 #[serde(skip_serializing_if = "Option::is_none")]
231 pub max_connections: Option<usize>,
232
233 #[serde(skip_serializing_if = "Option::is_none")]
235 pub max_retries: Option<u32>,
236
237 #[serde(skip_serializing_if = "Option::is_none")]
239 pub interval_start: Option<u64>,
240
241 #[serde(skip_serializing_if = "Option::is_none")]
243 pub interval_max: Option<u64>,
244
245 #[serde(flatten)]
247 pub custom: HashMap<String, serde_json::Value>,
248}
249
250#[derive(Debug, Clone, Default, Serialize, Deserialize)]
252pub struct BackendTransport {
253 #[serde(skip_serializing_if = "Option::is_none")]
255 pub result_expires: Option<u64>,
256
257 #[serde(skip_serializing_if = "Option::is_none")]
259 pub max_connections: Option<usize>,
260
261 #[serde(flatten)]
263 pub custom: HashMap<String, serde_json::Value>,
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct BeatSchedule {
269 pub task: String,
271
272 pub schedule: ScheduleDefinition,
274
275 #[serde(skip_serializing_if = "Option::is_none")]
277 pub args: Option<Vec<serde_json::Value>>,
278
279 #[serde(skip_serializing_if = "Option::is_none")]
281 pub kwargs: Option<HashMap<String, serde_json::Value>>,
282
283 #[serde(skip_serializing_if = "Option::is_none")]
285 pub options: Option<TaskConfig>,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290#[serde(untagged)]
291pub enum ScheduleDefinition {
292 Crontab(String),
294
295 Interval(u64),
297
298 Complex {
300 #[serde(rename = "type")]
302 schedule_type: String,
303
304 value: serde_json::Value,
306 },
307}
308
309impl Default for CeleryConfig {
310 fn default() -> Self {
311 Self {
312 broker_url: "redis://localhost:6379/0".to_string(),
313 result_backend: Some("redis://localhost:6379/1".to_string()),
314 task_serializer: default_serializer(),
315 result_serializer: default_serializer(),
316 accept_content: default_accept_content(),
317 timezone: default_timezone(),
318 enable_utc: true,
319 task_track_started: false,
320 task_send_sent_event: false,
321 task_acks_late: false,
322 task_reject_on_worker_lost: false,
323 worker_concurrency: default_concurrency(),
324 worker_prefetch_multiplier: default_prefetch_multiplier(),
325 worker_max_tasks_per_child: None,
326 worker_max_memory_per_child: None,
327 worker_heartbeat: default_heartbeat_interval(),
328 task_default_queue: default_queue_name(),
329 task_default_exchange: default_queue_name(),
330 task_default_exchange_type: default_exchange_type(),
331 task_default_routing_key: default_queue_name(),
332 task_routes: HashMap::new(),
333 task_time_limit: None,
334 task_soft_time_limit: None,
335 task_default_retry_delay: default_retry_delay(),
336 task_max_retries: default_max_retries(),
337 result_expires: default_result_expires(),
338 result_compression: None,
339 result_compression_threshold: default_compression_threshold(),
340 task_annotations: HashMap::new(),
341 broker_transport_options: BrokerTransport::default(),
342 result_backend_transport_options: BackendTransport::default(),
343 beat_schedule: HashMap::new(),
344 custom: HashMap::new(),
345 }
346 }
347}
348
349impl CeleryConfig {
350 #[inline]
352 pub fn new(broker_url: impl Into<String>) -> Self {
353 Self {
354 broker_url: broker_url.into(),
355 ..Default::default()
356 }
357 }
358
359 #[inline]
361 #[must_use]
362 pub fn with_broker_url(mut self, url: impl Into<String>) -> Self {
363 self.broker_url = url.into();
364 self
365 }
366
367 #[inline]
369 #[must_use]
370 pub fn with_result_backend(mut self, url: impl Into<String>) -> Self {
371 self.result_backend = Some(url.into());
372 self
373 }
374
375 #[inline]
377 #[must_use]
378 pub fn with_task_serializer(mut self, serializer: impl Into<String>) -> Self {
379 self.task_serializer = serializer.into();
380 self
381 }
382
383 #[inline]
385 #[must_use]
386 pub fn with_result_serializer(mut self, serializer: impl Into<String>) -> Self {
387 self.result_serializer = serializer.into();
388 self
389 }
390
391 #[inline]
393 #[must_use]
394 pub fn with_accept_content(mut self, content: Vec<String>) -> Self {
395 self.accept_content = content;
396 self
397 }
398
399 #[inline]
401 #[must_use]
402 pub fn with_timezone(mut self, tz: impl Into<String>) -> Self {
403 self.timezone = tz.into();
404 self
405 }
406
407 #[must_use]
409 pub const fn with_enable_utc(mut self, enabled: bool) -> Self {
410 self.enable_utc = enabled;
411 self
412 }
413
414 #[must_use]
416 pub const fn with_worker_concurrency(mut self, concurrency: usize) -> Self {
417 self.worker_concurrency = concurrency;
418 self
419 }
420
421 #[must_use]
423 pub const fn with_prefetch_multiplier(mut self, multiplier: usize) -> Self {
424 self.worker_prefetch_multiplier = multiplier;
425 self
426 }
427
428 #[inline]
430 #[must_use]
431 pub fn with_default_queue(mut self, queue: impl Into<String>) -> Self {
432 self.task_default_queue = queue.into();
433 self
434 }
435
436 #[inline]
438 #[must_use]
439 pub fn with_task_route(mut self, task: impl Into<String>, route: TaskRoute) -> Self {
440 self.task_routes.insert(task.into(), route);
441 self
442 }
443
444 #[inline]
446 #[must_use]
447 pub fn with_task_annotation(mut self, task: impl Into<String>, config: TaskConfig) -> Self {
448 self.task_annotations.insert(task.into(), config);
449 self
450 }
451
452 #[must_use]
454 pub const fn with_result_expires(mut self, expires: u64) -> Self {
455 self.result_expires = expires;
456 self
457 }
458
459 #[inline]
461 #[must_use]
462 pub fn with_result_compression(mut self, algorithm: impl Into<String>) -> Self {
463 self.result_compression = Some(algorithm.into());
464 self
465 }
466
467 #[must_use]
469 pub const fn with_compression_threshold(mut self, threshold: usize) -> Self {
470 self.result_compression_threshold = threshold;
471 self
472 }
473
474 #[inline]
476 #[must_use]
477 pub fn with_beat_schedule(mut self, name: impl Into<String>, schedule: BeatSchedule) -> Self {
478 self.beat_schedule.insert(name.into(), schedule);
479 self
480 }
481
482 #[inline]
484 #[must_use]
485 pub fn get_task_config(&self, task_name: &str) -> Option<&TaskConfig> {
486 self.task_annotations.get(task_name)
487 }
488
489 #[inline]
491 #[must_use]
492 pub fn get_task_route(&self, task_name: &str) -> Option<&TaskRoute> {
493 self.task_routes.get(task_name)
494 }
495
496 #[inline]
498 #[must_use]
499 pub const fn result_expires_duration(&self) -> Duration {
500 Duration::from_secs(self.result_expires)
501 }
502
503 #[inline]
505 #[must_use]
506 pub fn task_time_limit_duration(&self) -> Option<Duration> {
507 self.task_time_limit.map(Duration::from_secs)
508 }
509
510 #[inline]
512 #[must_use]
513 pub fn task_soft_time_limit_duration(&self) -> Option<Duration> {
514 self.task_soft_time_limit.map(Duration::from_secs)
515 }
516
517 #[must_use]
519 pub fn from_env() -> Self {
520 let mut config = Self::default();
521
522 if let Ok(url) = std::env::var("CELERY_BROKER_URL") {
523 config.broker_url = url;
524 }
525 if let Ok(backend) = std::env::var("CELERY_RESULT_BACKEND") {
526 config.result_backend = Some(backend);
527 }
528 if let Ok(serializer) = std::env::var("CELERY_TASK_SERIALIZER") {
529 config.task_serializer = serializer;
530 }
531 if let Ok(tz) = std::env::var("CELERY_TIMEZONE") {
532 config.timezone = tz;
533 }
534 if let Ok(concurrency) = std::env::var("CELERYD_CONCURRENCY") {
535 if let Ok(val) = concurrency.parse() {
536 config.worker_concurrency = val;
537 }
538 }
539
540 config
541 }
542
543 pub fn validate(&self) -> Result<(), String> {
549 if self.broker_url.is_empty() {
550 return Err("broker_url is required".to_string());
551 }
552
553 if self.worker_concurrency == 0 {
554 return Err("worker_concurrency must be greater than 0".to_string());
555 }
556
557 if !["json", "msgpack", "yaml", "pickle"].contains(&self.task_serializer.as_str()) {
558 return Err(format!(
559 "Unsupported task_serializer: {}",
560 self.task_serializer
561 ));
562 }
563
564 Ok(())
565 }
566}
567
568fn default_serializer() -> String {
570 "json".to_string()
571}
572
573fn default_accept_content() -> Vec<String> {
574 vec!["json".to_string(), "msgpack".to_string()]
575}
576
577fn default_timezone() -> String {
578 "UTC".to_string()
579}
580
581fn default_true() -> bool {
582 true
583}
584
585fn default_concurrency() -> usize {
586 num_cpus::get()
587}
588
589fn default_prefetch_multiplier() -> usize {
590 4
591}
592
593fn default_heartbeat_interval() -> u64 {
594 10
595}
596
597fn default_queue_name() -> String {
598 "celery".to_string()
599}
600
601fn default_exchange_type() -> String {
602 "direct".to_string()
603}
604
605fn default_retry_delay() -> u64 {
606 180 }
608
609fn default_max_retries() -> u32 {
610 3
611}
612
613fn default_result_expires() -> u64 {
614 86400 }
616
617fn default_compression_threshold() -> usize {
618 1024 * 1024 }
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624
625 #[test]
626 fn test_default_config() {
627 let config = CeleryConfig::default();
628 assert_eq!(config.broker_url, "redis://localhost:6379/0");
629 assert_eq!(config.task_serializer, "json");
630 assert_eq!(config.timezone, "UTC");
631 assert!(config.enable_utc);
632 }
633
634 #[test]
635 fn test_config_builder() {
636 let config = CeleryConfig::new("redis://localhost:6379/0")
637 .with_result_backend("redis://localhost:6379/1")
638 .with_worker_concurrency(8)
639 .with_default_queue("my_queue");
640
641 assert_eq!(config.worker_concurrency, 8);
642 assert_eq!(config.task_default_queue, "my_queue");
643 }
644
645 #[test]
646 fn test_config_validation() {
647 let config = CeleryConfig::default();
648 assert!(config.validate().is_ok());
649
650 let invalid = CeleryConfig {
651 broker_url: String::new(),
652 ..Default::default()
653 };
654 assert!(invalid.validate().is_err());
655 }
656
657 #[test]
658 fn test_task_route() {
659 let route = TaskRoute {
660 queue: "high_priority".to_string(),
661 exchange: Some("tasks".to_string()),
662 routing_key: Some("task.high".to_string()),
663 priority: Some(9),
664 };
665
666 let config = CeleryConfig::default().with_task_route("important_task", route);
667
668 assert!(config.get_task_route("important_task").is_some());
669 }
670
671 #[test]
672 fn test_duration_conversions() {
673 let config = CeleryConfig::default();
674 assert_eq!(config.result_expires_duration(), Duration::from_secs(86400));
675 }
676}