1use std::sync::Arc;
4use std::time::Duration;
5use tokio::sync::Semaphore;
6use tracing::{error, info, warn};
7
8use crate::backend::QueueBackend;
9use crate::job::{Job, JobResult, JobStatus};
10
11#[derive(Clone, Copy)]
12pub struct WorkerConfig {
13 pub max_concurrency: usize,
14 pub poll_interval: Duration,
15}
16
17impl Default for WorkerConfig {
18 fn default() -> Self {
19 Self {
20 max_concurrency: 5,
21 poll_interval: Duration::from_millis(100),
22 }
23 }
24}
25
26use serde::de::DeserializeOwned;
27
28use std::sync::RwLock;
29
30pub struct WorkerPool<B: QueueBackend + ?Sized> {
31 pub backend: Arc<B>,
32 config: WorkerConfig,
33 registry: Arc<JobRegistry>,
34}
35
36type JobFactory = Box<dyn Fn(serde_json::Value) -> Box<dyn Job> + Send + Sync>;
37
38struct JobRegistry {
39 factories: RwLock<std::collections::HashMap<String, JobFactory>>,
40}
41
42impl<B: QueueBackend + 'static> WorkerPool<B> {
43 pub fn new(backend: B, config: WorkerConfig) -> Self {
44 Self::new_with_arc(Arc::new(backend), config)
45 }
46}
47
48impl<B: QueueBackend + ?Sized + 'static> WorkerPool<B> {
49 pub fn new_with_arc(backend: Arc<B>, config: WorkerConfig) -> Self {
51 Self {
52 backend,
53 config,
54 registry: Arc::new(JobRegistry {
55 factories: RwLock::new(std::collections::HashMap::new()),
56 }),
57 }
58 }
59
60 pub fn register_job_type<J: Job + DeserializeOwned + 'static>(&self, name: &str) {
62 let factory = Box::new(|payload: serde_json::Value| {
63 let job: J =
64 serde_json::from_value(payload).expect("Job payload deserialization failed");
65 Box::new(job) as Box<dyn Job>
66 });
67
68 self.registry
69 .factories
70 .write()
71 .expect("Job registry RwLock poisoned")
72 .insert(name.to_string(), factory);
73 }
74
75 pub fn register_job_factory<F>(&self, name: &str, factory: F)
77 where
78 F: Fn(serde_json::Value) -> Box<dyn Job> + Send + Sync + 'static,
79 {
80 self.registry
81 .factories
82 .write()
83 .expect("Job registry RwLock poisoned")
84 .insert(name.to_string(), Box::new(factory));
85 }
86
87 pub async fn start(&self) {
88 let semaphore = Arc::new(Semaphore::new(self.config.max_concurrency));
89
90 info!(
91 "Worker pool started with concurrency {}",
92 self.config.max_concurrency
93 );
94
95 loop {
96 if semaphore.available_permits() > 0 {
97 match self.backend.dequeue().await {
98 Ok(Some(entry)) => {
99 let permit = semaphore
100 .clone()
101 .acquire_owned()
102 .await
103 .expect("Worker semaphore closed unexpectedly");
104 let backend = self.backend.clone();
105 let registry = self.registry.clone();
106
107 tokio::spawn(async move {
108 let job_opt = {
109 let factories = registry
110 .factories
111 .read()
112 .expect("Job registry RwLock poisoned");
113 factories
114 .get(&entry.job_type)
115 .map(|f| f(entry.payload.clone()))
116 };
117
118 if let Some(mut job) = job_opt {
119 info!("Processing job {} ({})", entry.id, entry.job_type);
120
121 let result = job.execute().await;
122
123 match result {
124 JobResult::Success(value) => {
125 if let Some(val) = value {
126 let _ = backend.set_result(entry.id, val).await;
127 }
128 let _ = backend
129 .update_status(
130 entry.id,
131 JobStatus::Completed,
132 None,
133 None,
134 )
135 .await;
136 }
137 JobResult::Retry(e) => {
138 let delay = job.backoff_strategy().delay(entry.attempts);
140 let delay_secs = delay.as_secs();
141
142 info!(
143 job_id = %entry.id,
144 attempt = entry.attempts + 1,
145 delay_secs = delay_secs,
146 "Job failed, scheduling retry with backoff"
147 );
148
149 let _ = backend
150 .update_status(
151 entry.id,
152 JobStatus::Failed(entry.attempts + 1),
153 Some(e),
154 Some(delay_secs),
155 )
156 .await;
157 }
158 JobResult::Fatal(e) => {
159 let _ = backend
160 .update_status(
161 entry.id,
162 JobStatus::DeadLetter,
163 Some(e),
164 None,
165 )
166 .await;
167 }
168 }
169 } else {
170 warn!("No handler registered for job type: {}", entry.job_type);
171 let _ = backend
172 .update_status(
173 entry.id,
174 JobStatus::DeadLetter,
175 Some(format!("No handler for {}", entry.job_type)),
176 None,
177 )
178 .await;
179 }
180
181 drop(permit);
182 });
183 }
184 Ok(None) => {
185 tokio::time::sleep(self.config.poll_interval).await;
187 }
188 Err(e) => {
189 error!("Queue error: {}", e);
190 tokio::time::sleep(Duration::from_secs(1)).await;
191 }
192 }
193 } else {
194 tokio::time::sleep(Duration::from_millis(50)).await;
196 }
197 }
198 }
199}