celers_core/
config.rs

1//! Celery-compatible configuration for `CeleRS`
2//!
3//! This module provides configuration structures that are compatible with
4//! Python Celery's configuration format, making it easy to migrate from
5//! Celery to `CeleRS` or run them side-by-side.
6//!
7//! # Example
8//!
9//! ```rust
10//! use celers_core::config::{CeleryConfig, TaskConfig, BrokerTransport};
11//! use std::time::Duration;
12//!
13//! let config = CeleryConfig::default()
14//!     .with_broker_url("redis://localhost:6379/0")
15//!     .with_result_backend("redis://localhost:6379/1")
16//!     .with_task_serializer("json")
17//!     .with_timezone("UTC")
18//!     .with_worker_concurrency(4);
19//!
20//! assert_eq!(config.broker_url, "redis://localhost:6379/0");
21//! assert_eq!(config.worker_concurrency, 4);
22//! ```
23
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::time::Duration;
27
28/// Celery-compatible main configuration
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct CeleryConfig {
31    /// Broker connection URL (`CELERY_BROKER_URL`)
32    pub broker_url: String,
33
34    /// Result backend URL (`CELERY_RESULT_BACKEND`)
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub result_backend: Option<String>,
37
38    /// Task serializer format (`CELERY_TASK_SERIALIZER`)
39    #[serde(default = "default_serializer")]
40    pub task_serializer: String,
41
42    /// Result serializer format (`CELERY_RESULT_SERIALIZER`)
43    #[serde(default = "default_serializer")]
44    pub result_serializer: String,
45
46    /// Accepted content types (`CELERY_ACCEPT_CONTENT`)
47    #[serde(default = "default_accept_content")]
48    pub accept_content: Vec<String>,
49
50    /// Timezone for scheduling (`CELERY_TIMEZONE`)
51    #[serde(default = "default_timezone")]
52    pub timezone: String,
53
54    /// Use UTC timestamps (`CELERY_ENABLE_UTC`)
55    #[serde(default = "default_true")]
56    pub enable_utc: bool,
57
58    /// Track task started events (`CELERY_TASK_TRACK_STARTED`)
59    #[serde(default)]
60    pub task_track_started: bool,
61
62    /// Send task sent events (`CELERY_TASK_SEND_SENT_EVENT`)
63    #[serde(default)]
64    pub task_send_sent_event: bool,
65
66    /// Acknowledge tasks late (`CELERY_TASK_ACKS_LATE`)
67    #[serde(default)]
68    pub task_acks_late: bool,
69
70    /// Reject on worker lost (`CELERY_TASK_REJECT_ON_WORKER_LOST`)
71    #[serde(default)]
72    pub task_reject_on_worker_lost: bool,
73
74    /// Worker concurrency (`CELERYD_CONCURRENCY`)
75    #[serde(default = "default_concurrency")]
76    pub worker_concurrency: usize,
77
78    /// Worker prefetch multiplier (`CELERYD_PREFETCH_MULTIPLIER`)
79    #[serde(default = "default_prefetch_multiplier")]
80    pub worker_prefetch_multiplier: usize,
81
82    /// Maximum tasks per child before restart (`CELERYD_MAX_TASKS_PER_CHILD`)
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub worker_max_tasks_per_child: Option<usize>,
85
86    /// Maximum memory per child in KB (`CELERYD_MAX_MEMORY_PER_CHILD`)
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub worker_max_memory_per_child: Option<usize>,
89
90    /// Worker heartbeat interval in seconds (`CELERY_WORKER_HEARTBEAT`)
91    #[serde(default = "default_heartbeat_interval")]
92    pub worker_heartbeat: u64,
93
94    /// Task default queue (`CELERY_DEFAULT_QUEUE`)
95    #[serde(default = "default_queue_name")]
96    pub task_default_queue: String,
97
98    /// Task default exchange (`CELERY_DEFAULT_EXCHANGE`)
99    #[serde(default = "default_queue_name")]
100    pub task_default_exchange: String,
101
102    /// Task default exchange type (`CELERY_DEFAULT_EXCHANGE_TYPE`)
103    #[serde(default = "default_exchange_type")]
104    pub task_default_exchange_type: String,
105
106    /// Task default routing key (`CELERY_DEFAULT_ROUTING_KEY`)
107    #[serde(default = "default_queue_name")]
108    pub task_default_routing_key: String,
109
110    /// Task routes (`CELERY_TASK_ROUTES`)
111    #[serde(default)]
112    pub task_routes: HashMap<String, TaskRoute>,
113
114    /// Task time limit in seconds (`CELERY_TASK_TIME_LIMIT`)
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub task_time_limit: Option<u64>,
117
118    /// Task soft time limit in seconds (`CELERY_TASK_SOFT_TIME_LIMIT`)
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub task_soft_time_limit: Option<u64>,
121
122    /// Task default retry delay in seconds (`CELERY_TASK_DEFAULT_RETRY_DELAY`)
123    #[serde(default = "default_retry_delay")]
124    pub task_default_retry_delay: u64,
125
126    /// Task max retries (`CELERY_TASK_MAX_RETRIES`)
127    #[serde(default = "default_max_retries")]
128    pub task_max_retries: u32,
129
130    /// Result expires in seconds (`CELERY_RESULT_EXPIRES`)
131    #[serde(default = "default_result_expires")]
132    pub result_expires: u64,
133
134    /// Result compression (`CELERY_RESULT_COMPRESSION`)
135    #[serde(skip_serializing_if = "Option::is_none")]
136    pub result_compression: Option<String>,
137
138    /// Result compression threshold in bytes
139    #[serde(default = "default_compression_threshold")]
140    pub result_compression_threshold: usize,
141
142    /// Task-specific configurations
143    #[serde(default)]
144    pub task_annotations: HashMap<String, TaskConfig>,
145
146    /// Broker transport options (`CELERY_BROKER_TRANSPORT_OPTIONS`)
147    #[serde(default)]
148    pub broker_transport_options: BrokerTransport,
149
150    /// Result backend transport options
151    #[serde(default)]
152    pub result_backend_transport_options: BackendTransport,
153
154    /// Beat schedule configuration (`CELERYBEAT_SCHEDULE`)
155    #[serde(default)]
156    pub beat_schedule: HashMap<String, BeatSchedule>,
157
158    /// Custom configuration extensions
159    #[serde(flatten)]
160    pub custom: HashMap<String, serde_json::Value>,
161}
162
163/// Task routing configuration
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct TaskRoute {
166    /// Target queue name
167    pub queue: String,
168
169    /// Exchange name
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub exchange: Option<String>,
172
173    /// Routing key
174    #[serde(skip_serializing_if = "Option::is_none")]
175    pub routing_key: Option<String>,
176
177    /// Priority (0-255)
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub priority: Option<u8>,
180}
181
182/// Per-task configuration
183#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct TaskConfig {
185    /// Task time limit in seconds
186    #[serde(skip_serializing_if = "Option::is_none")]
187    pub time_limit: Option<u64>,
188
189    /// Task soft time limit in seconds
190    #[serde(skip_serializing_if = "Option::is_none")]
191    pub soft_time_limit: Option<u64>,
192
193    /// Max retries for this task
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub max_retries: Option<u32>,
196
197    /// Default retry delay
198    #[serde(skip_serializing_if = "Option::is_none")]
199    pub default_retry_delay: Option<u64>,
200
201    /// Task priority
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub priority: Option<u8>,
204
205    /// Target queue
206    #[serde(skip_serializing_if = "Option::is_none")]
207    pub queue: Option<String>,
208
209    /// Acknowledge late
210    #[serde(skip_serializing_if = "Option::is_none")]
211    pub acks_late: Option<bool>,
212
213    /// Track started
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub track_started: Option<bool>,
216
217    /// Rate limit (e.g., "10/s", "100/m", "1000/h")
218    #[serde(skip_serializing_if = "Option::is_none")]
219    pub rate_limit: Option<String>,
220}
221
222/// Broker transport options
223#[derive(Debug, Clone, Default, Serialize, Deserialize)]
224pub struct BrokerTransport {
225    /// Visibility timeout in seconds
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub visibility_timeout: Option<u64>,
228
229    /// Connection pool size
230    #[serde(skip_serializing_if = "Option::is_none")]
231    pub max_connections: Option<usize>,
232
233    /// Connection retry settings
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub max_retries: Option<u32>,
236
237    /// Retry interval in seconds
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub interval_start: Option<u64>,
240
241    /// Retry interval max in seconds
242    #[serde(skip_serializing_if = "Option::is_none")]
243    pub interval_max: Option<u64>,
244
245    /// Additional transport-specific options
246    #[serde(flatten)]
247    pub custom: HashMap<String, serde_json::Value>,
248}
249
250/// Result backend transport options
251#[derive(Debug, Clone, Default, Serialize, Deserialize)]
252pub struct BackendTransport {
253    /// Result expiration in seconds
254    #[serde(skip_serializing_if = "Option::is_none")]
255    pub result_expires: Option<u64>,
256
257    /// Connection pool size
258    #[serde(skip_serializing_if = "Option::is_none")]
259    pub max_connections: Option<usize>,
260
261    /// Additional backend-specific options
262    #[serde(flatten)]
263    pub custom: HashMap<String, serde_json::Value>,
264}
265
266/// Beat scheduler configuration
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct BeatSchedule {
269    /// Task name to execute
270    pub task: String,
271
272    /// Schedule definition
273    pub schedule: ScheduleDefinition,
274
275    /// Task arguments
276    #[serde(skip_serializing_if = "Option::is_none")]
277    pub args: Option<Vec<serde_json::Value>>,
278
279    /// Task keyword arguments
280    #[serde(skip_serializing_if = "Option::is_none")]
281    pub kwargs: Option<HashMap<String, serde_json::Value>>,
282
283    /// Task options
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub options: Option<TaskConfig>,
286}
287
288/// Schedule definition for beat tasks
289#[derive(Debug, Clone, Serialize, Deserialize)]
290#[serde(untagged)]
291pub enum ScheduleDefinition {
292    /// Crontab schedule (e.g., "0 0 * * *")
293    Crontab(String),
294
295    /// Interval in seconds
296    Interval(u64),
297
298    /// Complex schedule
299    Complex {
300        /// Schedule type (crontab, interval, solar)
301        #[serde(rename = "type")]
302        schedule_type: String,
303
304        /// Schedule value
305        value: serde_json::Value,
306    },
307}
308
309impl Default for CeleryConfig {
310    fn default() -> Self {
311        Self {
312            broker_url: "redis://localhost:6379/0".to_string(),
313            result_backend: Some("redis://localhost:6379/1".to_string()),
314            task_serializer: default_serializer(),
315            result_serializer: default_serializer(),
316            accept_content: default_accept_content(),
317            timezone: default_timezone(),
318            enable_utc: true,
319            task_track_started: false,
320            task_send_sent_event: false,
321            task_acks_late: false,
322            task_reject_on_worker_lost: false,
323            worker_concurrency: default_concurrency(),
324            worker_prefetch_multiplier: default_prefetch_multiplier(),
325            worker_max_tasks_per_child: None,
326            worker_max_memory_per_child: None,
327            worker_heartbeat: default_heartbeat_interval(),
328            task_default_queue: default_queue_name(),
329            task_default_exchange: default_queue_name(),
330            task_default_exchange_type: default_exchange_type(),
331            task_default_routing_key: default_queue_name(),
332            task_routes: HashMap::new(),
333            task_time_limit: None,
334            task_soft_time_limit: None,
335            task_default_retry_delay: default_retry_delay(),
336            task_max_retries: default_max_retries(),
337            result_expires: default_result_expires(),
338            result_compression: None,
339            result_compression_threshold: default_compression_threshold(),
340            task_annotations: HashMap::new(),
341            broker_transport_options: BrokerTransport::default(),
342            result_backend_transport_options: BackendTransport::default(),
343            beat_schedule: HashMap::new(),
344            custom: HashMap::new(),
345        }
346    }
347}
348
349impl CeleryConfig {
350    /// Create a new configuration with broker URL
351    #[inline]
352    pub fn new(broker_url: impl Into<String>) -> Self {
353        Self {
354            broker_url: broker_url.into(),
355            ..Default::default()
356        }
357    }
358
359    /// Set broker URL
360    #[inline]
361    #[must_use]
362    pub fn with_broker_url(mut self, url: impl Into<String>) -> Self {
363        self.broker_url = url.into();
364        self
365    }
366
367    /// Set result backend URL
368    #[inline]
369    #[must_use]
370    pub fn with_result_backend(mut self, url: impl Into<String>) -> Self {
371        self.result_backend = Some(url.into());
372        self
373    }
374
375    /// Set task serializer
376    #[inline]
377    #[must_use]
378    pub fn with_task_serializer(mut self, serializer: impl Into<String>) -> Self {
379        self.task_serializer = serializer.into();
380        self
381    }
382
383    /// Set result serializer
384    #[inline]
385    #[must_use]
386    pub fn with_result_serializer(mut self, serializer: impl Into<String>) -> Self {
387        self.result_serializer = serializer.into();
388        self
389    }
390
391    /// Set accepted content types
392    #[inline]
393    #[must_use]
394    pub fn with_accept_content(mut self, content: Vec<String>) -> Self {
395        self.accept_content = content;
396        self
397    }
398
399    /// Set timezone
400    #[inline]
401    #[must_use]
402    pub fn with_timezone(mut self, tz: impl Into<String>) -> Self {
403        self.timezone = tz.into();
404        self
405    }
406
407    /// Enable/disable UTC
408    #[must_use]
409    pub const fn with_enable_utc(mut self, enabled: bool) -> Self {
410        self.enable_utc = enabled;
411        self
412    }
413
414    /// Set worker concurrency
415    #[must_use]
416    pub const fn with_worker_concurrency(mut self, concurrency: usize) -> Self {
417        self.worker_concurrency = concurrency;
418        self
419    }
420
421    /// Set worker prefetch multiplier
422    #[must_use]
423    pub const fn with_prefetch_multiplier(mut self, multiplier: usize) -> Self {
424        self.worker_prefetch_multiplier = multiplier;
425        self
426    }
427
428    /// Set default queue name
429    #[inline]
430    #[must_use]
431    pub fn with_default_queue(mut self, queue: impl Into<String>) -> Self {
432        self.task_default_queue = queue.into();
433        self
434    }
435
436    /// Add task route
437    #[inline]
438    #[must_use]
439    pub fn with_task_route(mut self, task: impl Into<String>, route: TaskRoute) -> Self {
440        self.task_routes.insert(task.into(), route);
441        self
442    }
443
444    /// Add task annotation
445    #[inline]
446    #[must_use]
447    pub fn with_task_annotation(mut self, task: impl Into<String>, config: TaskConfig) -> Self {
448        self.task_annotations.insert(task.into(), config);
449        self
450    }
451
452    /// Set result expiration
453    #[must_use]
454    pub const fn with_result_expires(mut self, expires: u64) -> Self {
455        self.result_expires = expires;
456        self
457    }
458
459    /// Enable result compression
460    #[inline]
461    #[must_use]
462    pub fn with_result_compression(mut self, algorithm: impl Into<String>) -> Self {
463        self.result_compression = Some(algorithm.into());
464        self
465    }
466
467    /// Set compression threshold
468    #[must_use]
469    pub const fn with_compression_threshold(mut self, threshold: usize) -> Self {
470        self.result_compression_threshold = threshold;
471        self
472    }
473
474    /// Add beat schedule
475    #[inline]
476    #[must_use]
477    pub fn with_beat_schedule(mut self, name: impl Into<String>, schedule: BeatSchedule) -> Self {
478        self.beat_schedule.insert(name.into(), schedule);
479        self
480    }
481
482    /// Get task configuration for a specific task
483    #[inline]
484    #[must_use]
485    pub fn get_task_config(&self, task_name: &str) -> Option<&TaskConfig> {
486        self.task_annotations.get(task_name)
487    }
488
489    /// Get task route for a specific task
490    #[inline]
491    #[must_use]
492    pub fn get_task_route(&self, task_name: &str) -> Option<&TaskRoute> {
493        self.task_routes.get(task_name)
494    }
495
496    /// Get result expiration duration
497    #[inline]
498    #[must_use]
499    pub const fn result_expires_duration(&self) -> Duration {
500        Duration::from_secs(self.result_expires)
501    }
502
503    /// Get task time limit duration
504    #[inline]
505    #[must_use]
506    pub fn task_time_limit_duration(&self) -> Option<Duration> {
507        self.task_time_limit.map(Duration::from_secs)
508    }
509
510    /// Get task soft time limit duration
511    #[inline]
512    #[must_use]
513    pub fn task_soft_time_limit_duration(&self) -> Option<Duration> {
514        self.task_soft_time_limit.map(Duration::from_secs)
515    }
516
517    /// Load configuration from environment variables
518    #[must_use]
519    pub fn from_env() -> Self {
520        let mut config = Self::default();
521
522        if let Ok(url) = std::env::var("CELERY_BROKER_URL") {
523            config.broker_url = url;
524        }
525        if let Ok(backend) = std::env::var("CELERY_RESULT_BACKEND") {
526            config.result_backend = Some(backend);
527        }
528        if let Ok(serializer) = std::env::var("CELERY_TASK_SERIALIZER") {
529            config.task_serializer = serializer;
530        }
531        if let Ok(tz) = std::env::var("CELERY_TIMEZONE") {
532            config.timezone = tz;
533        }
534        if let Ok(concurrency) = std::env::var("CELERYD_CONCURRENCY") {
535            if let Ok(val) = concurrency.parse() {
536                config.worker_concurrency = val;
537            }
538        }
539
540        config
541    }
542
543    /// Validate configuration
544    ///
545    /// # Errors
546    ///
547    /// Returns an error if the configuration is invalid (e.g., empty broker URL, invalid concurrency, unsupported serializer).
548    pub fn validate(&self) -> Result<(), String> {
549        if self.broker_url.is_empty() {
550            return Err("broker_url is required".to_string());
551        }
552
553        if self.worker_concurrency == 0 {
554            return Err("worker_concurrency must be greater than 0".to_string());
555        }
556
557        if !["json", "msgpack", "yaml", "pickle"].contains(&self.task_serializer.as_str()) {
558            return Err(format!(
559                "Unsupported task_serializer: {}",
560                self.task_serializer
561            ));
562        }
563
564        Ok(())
565    }
566}
567
568// Default value functions
569fn default_serializer() -> String {
570    "json".to_string()
571}
572
573fn default_accept_content() -> Vec<String> {
574    vec!["json".to_string(), "msgpack".to_string()]
575}
576
577fn default_timezone() -> String {
578    "UTC".to_string()
579}
580
581fn default_true() -> bool {
582    true
583}
584
585fn default_concurrency() -> usize {
586    num_cpus::get()
587}
588
589fn default_prefetch_multiplier() -> usize {
590    4
591}
592
593fn default_heartbeat_interval() -> u64 {
594    10
595}
596
597fn default_queue_name() -> String {
598    "celery".to_string()
599}
600
601fn default_exchange_type() -> String {
602    "direct".to_string()
603}
604
605fn default_retry_delay() -> u64 {
606    180 // 3 minutes
607}
608
609fn default_max_retries() -> u32 {
610    3
611}
612
613fn default_result_expires() -> u64 {
614    86400 // 24 hours
615}
616
617fn default_compression_threshold() -> usize {
618    1024 * 1024 // 1MB
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624
625    #[test]
626    fn test_default_config() {
627        let config = CeleryConfig::default();
628        assert_eq!(config.broker_url, "redis://localhost:6379/0");
629        assert_eq!(config.task_serializer, "json");
630        assert_eq!(config.timezone, "UTC");
631        assert!(config.enable_utc);
632    }
633
634    #[test]
635    fn test_config_builder() {
636        let config = CeleryConfig::new("redis://localhost:6379/0")
637            .with_result_backend("redis://localhost:6379/1")
638            .with_worker_concurrency(8)
639            .with_default_queue("my_queue");
640
641        assert_eq!(config.worker_concurrency, 8);
642        assert_eq!(config.task_default_queue, "my_queue");
643    }
644
645    #[test]
646    fn test_config_validation() {
647        let config = CeleryConfig::default();
648        assert!(config.validate().is_ok());
649
650        let invalid = CeleryConfig {
651            broker_url: String::new(),
652            ..Default::default()
653        };
654        assert!(invalid.validate().is_err());
655    }
656
657    #[test]
658    fn test_task_route() {
659        let route = TaskRoute {
660            queue: "high_priority".to_string(),
661            exchange: Some("tasks".to_string()),
662            routing_key: Some("task.high".to_string()),
663            priority: Some(9),
664        };
665
666        let config = CeleryConfig::default().with_task_route("important_task", route);
667
668        assert!(config.get_task_route("important_task").is_some());
669    }
670
671    #[test]
672    fn test_duration_conversions() {
673        let config = CeleryConfig::default();
674        assert_eq!(config.result_expires_duration(), Duration::from_secs(86400));
675    }
676}