1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use anyhow::Context;
8use redis::Commands;
9use serde::Serialize;
10use tokio::sync::watch;
11use tokio::task::JoinHandle;
12
13use crate::internals::context::context_manager::{Managers, RedisPool, Track};
14use crate::internals::database::DbPool;
15use crate::internals::query::query_manager::QueryManager;
16use crate::internals::search::search_manager::SearchItem;
17use crate::internals::utils::config::config_manager::Config;
18
19#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
20pub struct WorkerInfo {
21 pub id: usize,
22 pub username: String,
23 pub port: u16,
24 pub run_id: String,
25 pub started_at_epoch_secs: u64,
26}
27
28#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
29pub struct WorkerStatus {
30 pub workers: Vec<WorkerInfo>,
31 pub queue_len: usize,
32 pub failed_count: usize,
33}
34
35#[derive(Clone, Debug)]
36pub struct WorkerStartOptions {
37 pub worker_count: usize,
38 pub username_prefix: String,
39 pub port_base: u16,
40 pub run_id_prefix: String,
41 pub playlist_id: String,
42 pub chunk_size: usize,
43 pub playlist_range: Option<(usize, usize)>,
44}
45
46pub struct WorkerSupervisor {
47 workers: Mutex<Vec<WorkerHandle>>,
48 next_worker_id: AtomicUsize,
49 queue_key: Mutex<Option<String>>,
50 download_path: PathBuf,
51 db_pool: DbPool,
52 redis_pool: RedisPool,
53 shutdown: watch::Receiver<bool>,
54}
55
56struct WorkerHandle {
57 info: WorkerInfo,
58 handle: JoinHandle<()>,
59}
60
61struct WorkerRunConfig {
62 worker_config: Config,
63 playlist_id: String,
64 chunk_size: usize,
65 download_path: PathBuf,
66 redis_pool: RedisPool,
67 db_pool: DbPool,
68 is_leader: bool,
69 all_items: Vec<SearchItem>,
70 shutdown: watch::Receiver<bool>,
71}
72
73impl WorkerSupervisor {
74 pub fn new(
75 download_path: PathBuf,
76 db_pool: DbPool,
77 redis_pool: RedisPool,
78 shutdown: watch::Receiver<bool>,
79 ) -> Self {
80 Self {
81 workers: Mutex::new(Vec::new()),
82 next_worker_id: AtomicUsize::new(1),
83 queue_key: Mutex::new(None),
84 download_path,
85 db_pool,
86 redis_pool,
87 shutdown,
88 }
89 }
90
91 pub async fn start(
92 &self,
93 options: WorkerStartOptions,
94 base_config: Config,
95 user_password: String,
96 ) -> anyhow::Result<Vec<WorkerInfo>> {
97 let query_manager = QueryManager::new_with_timeout(
98 options.playlist_id.clone(),
99 base_config.client_id.clone(),
100 base_config.client_secret.clone(),
101 base_config.search_timeout_secs,
102 );
103 let playlist_tracks = query_manager
104 .fetch_playlist()
105 .await
106 .context("Fetch worker playlist")?;
107 let playlist_tracks = apply_playlist_range(playlist_tracks, options.playlist_range);
108 let items = playlist_tracks
109 .into_iter()
110 .filter_map(|track| match track {
111 Track::Query(item) => Some(item),
112 _ => None,
113 })
114 .collect::<Vec<_>>();
115
116 self.replace_queue(&options.playlist_id, options.chunk_size, &items)
117 .context("Build worker queue")?;
118
119 let mut spawned = Vec::with_capacity(options.worker_count);
120 let mut guard = self
121 .workers
122 .lock()
123 .map_err(|_| anyhow::anyhow!("worker lock poisoned"))?;
124
125 for index in 0..options.worker_count {
126 let info = worker_info(
127 self.next_worker_id.fetch_add(1, Ordering::Relaxed),
128 &options.username_prefix,
129 options.port_base,
130 &options.run_id_prefix,
131 index,
132 );
133
134 let mut worker_config = base_config.clone();
135 worker_config.user_name = info.username.clone();
136 worker_config.user_password = user_password.clone();
137 worker_config.listen_port = u32::from(info.port);
138 worker_config.run_id = info.run_id.clone();
139
140 let handle = tokio::spawn(run_worker(WorkerRunConfig {
141 worker_config,
142 playlist_id: options.playlist_id.clone(),
143 chunk_size: options.chunk_size,
144 download_path: self.download_path.clone(),
145 redis_pool: self.redis_pool.clone(),
146 db_pool: self.db_pool.clone(),
147 is_leader: index == 0,
148 all_items: items.clone(),
149 shutdown: self.shutdown.clone(),
150 }));
151
152 spawned.push(info.clone());
153 guard.push(WorkerHandle { info, handle });
154 }
155
156 Ok(spawned)
157 }
158
159 pub fn stop(&self, target_ids: Option<&[u32]>) -> anyhow::Result<Vec<u32>> {
160 let mut guard = self
161 .workers
162 .lock()
163 .map_err(|_| anyhow::anyhow!("worker lock poisoned"))?;
164 let mut stopped = Vec::new();
165 let mut remaining = Vec::with_capacity(guard.len());
166
167 for entry in guard.drain(..) {
168 let should_stop = match target_ids {
169 None => true,
170 Some(ids) => ids.contains(&(entry.info.id as u32)),
171 };
172
173 if should_stop {
174 entry.handle.abort();
175 stopped.push(entry.info.id as u32);
176 } else {
177 remaining.push(entry);
178 }
179 }
180
181 *guard = remaining;
182 Ok(stopped)
183 }
184
185 pub fn status(&self) -> anyhow::Result<WorkerStatus> {
186 let mut guard = self
187 .workers
188 .lock()
189 .map_err(|_| anyhow::anyhow!("worker lock poisoned"))?;
190 let mut workers = Vec::with_capacity(guard.len());
191
192 guard.retain_mut(|entry| {
193 if entry.handle.is_finished() {
194 false
195 } else {
196 workers.push(entry.info.clone());
197 true
198 }
199 });
200
201 let queue_key = self
202 .queue_key
203 .lock()
204 .map_err(|_| anyhow::anyhow!("queue key lock poisoned"))?
205 .clone();
206 let (queue_len, failed_count) = if let Some(key) = queue_key {
207 let mut redis_con = self.redis_pool.get().context("Acquire Redis connection")?;
208 let queue_len = redis_con.llen(key).unwrap_or(0);
209 let failed_count = redis_con.scard("dl:failed").unwrap_or(0);
210 (queue_len, failed_count)
211 } else {
212 (0, 0)
213 };
214
215 Ok(WorkerStatus {
216 workers,
217 queue_len,
218 failed_count,
219 })
220 }
221
222 fn replace_queue(
223 &self,
224 playlist_id: &str,
225 chunk_size: usize,
226 items: &[SearchItem],
227 ) -> anyhow::Result<()> {
228 let queue_key = chunk_queue_key(playlist_id, chunk_size);
229 self.queue_key
230 .lock()
231 .map_err(|_| anyhow::anyhow!("queue key lock poisoned"))?
232 .replace(queue_key.clone());
233
234 let chunks = build_chunks(items, chunk_size);
235 let mut redis_con = self.redis_pool.get().context("Acquire Redis connection")?;
236 let _: usize = redis_con.del(queue_key.clone()).unwrap_or(0);
237 for chunk in chunks {
238 let payload = serde_json::to_string(&chunk).context("Serialise worker chunk")?;
239 let _: usize = redis_con.rpush(&queue_key, payload).unwrap_or(0);
240 }
241 Ok(())
242 }
243}
244
245pub fn chunk_queue_key(playlist_id: &str, chunk_size: usize) -> String {
246 format!("dl:chunk_queue:{playlist_id}:{chunk_size}")
247}
248
249pub fn build_chunks(items: &[SearchItem], chunk_size: usize) -> Vec<Vec<SearchItem>> {
250 let chunk_size = chunk_size.max(1);
251 let mut chunks = Vec::new();
252 let mut start = 0usize;
253 while start < items.len() {
254 let end = (start + chunk_size).min(items.len());
255 chunks.push(items[start..end].to_vec());
256 start = end;
257 }
258 chunks
259}
260
261pub fn apply_playlist_range(
262 playlist_tracks: Vec<Track>,
263 playlist_range: Option<(usize, usize)>,
264) -> Vec<Track> {
265 let Some((start, end)) = playlist_range else {
266 return playlist_tracks;
267 };
268 let start = start.min(playlist_tracks.len());
269 let end = end.min(playlist_tracks.len());
270 if start >= end {
271 return playlist_tracks;
272 }
273 playlist_tracks
274 .into_iter()
275 .skip(start)
276 .take(end - start)
277 .collect()
278}
279
280fn worker_info(
281 worker_id: usize,
282 username_prefix: &str,
283 port_base: u16,
284 run_id_prefix: &str,
285 index: usize,
286) -> WorkerInfo {
287 let worker_number = index + 1;
288 WorkerInfo {
289 id: worker_id,
290 username: format!("{username_prefix}{worker_number}"),
291 port: port_base.saturating_add(index as u16),
292 run_id: format!("{run_id_prefix}-{worker_number}"),
293 started_at_epoch_secs: SystemTime::now()
294 .duration_since(UNIX_EPOCH)
295 .unwrap_or_default()
296 .as_secs(),
297 }
298}
299
300async fn run_worker(config: WorkerRunConfig) {
301 let WorkerRunConfig {
302 worker_config,
303 playlist_id,
304 chunk_size,
305 download_path,
306 redis_pool,
307 db_pool,
308 is_leader,
309 all_items,
310 mut shutdown,
311 } = config;
312
313 let queue_key = chunk_queue_key(&playlist_id, chunk_size);
314 let managers = match Managers::new(
315 worker_config.judge_score_levenshtein,
316 download_path.clone(),
317 worker_config.clone(),
318 db_pool.clone(),
319 redis_pool.clone(),
320 ) {
321 Ok(managers) => Arc::new(managers),
322 Err(err) => {
323 tracing::error!(?err, run_id = %worker_config.run_id, "Worker failed to start managers");
324 return;
325 }
326 };
327
328 loop {
329 if *shutdown.borrow() {
330 tracing::info!(run_id = %worker_config.run_id, "Worker exiting due to shutdown");
331 managers.shutdown();
332 return;
333 }
334
335 let chunk_json: Option<String> = {
336 let mut redis_con = match redis_pool.get() {
337 Ok(con) => con,
338 Err(err) => {
339 tracing::error!(?err, "Worker failed to acquire Redis connection; exiting");
340 return;
341 }
342 };
343 redis_con.lpop(&queue_key, None).ok()
344 };
345 let Some(chunk_json) = chunk_json else {
346 break;
347 };
348
349 let chunk_items: Vec<SearchItem> = match serde_json::from_str(&chunk_json) {
350 Ok(value) => value,
351 Err(err) => {
352 let truncated = chunk_json.chars().take(500).collect::<String>();
353 tracing::error!(
354 %err,
355 payload = %truncated,
356 run_id = %worker_config.run_id,
357 "Worker skipped malformed chunk payload",
358 );
359 continue;
360 }
361 };
362
363 let tracks = chunk_items
364 .into_iter()
365 .map(Track::Query)
366 .collect::<Vec<_>>();
367
368 tokio::select! {
369 _ = managers.run_chunk(tracks) => {}
370 _ = shutdown.changed() => {
371 if *shutdown.borrow() {
372 tracing::info!(run_id = %worker_config.run_id, "Worker shutdown mid-cycle");
373 managers.shutdown();
374 return;
375 }
376 }
377 }
378 }
379
380 if is_leader && !*shutdown.borrow() {
381 let failed_ids: Vec<String> = {
382 let mut redis_con = match redis_pool.get() {
383 Ok(con) => con,
384 Err(_) => return,
385 };
386 redis_con.smembers("dl:failed").unwrap_or_default()
387 };
388 if !failed_ids.is_empty() {
389 let failed_items = all_items
390 .into_iter()
391 .filter(|item| failed_ids.contains(&item.track_id))
392 .map(Track::Query)
393 .collect::<Vec<_>>();
394 if !failed_items.is_empty() {
395 let _ = managers.run_chunk(failed_items).await;
396 }
397 }
398 }
399 managers.shutdown();
400}
401
402#[cfg(test)]
403mod tests {
404 use super::{apply_playlist_range, build_chunks, chunk_queue_key, worker_info};
405 use crate::internals::context::context_manager::Track;
406 use crate::internals::search::search_manager::SearchItem;
407
408 fn item(id: &str) -> SearchItem {
409 SearchItem::new(
410 id.to_string(),
411 format!("track-{id}"),
412 "album".to_string(),
413 "artist".to_string(),
414 )
415 }
416
417 #[test]
418 fn queue_key_includes_playlist_and_chunk_size() {
419 assert_eq!(
420 chunk_queue_key("playlist", 15),
421 "dl:chunk_queue:playlist:15"
422 );
423 }
424
425 #[test]
426 fn chunks_items_by_requested_size() {
427 let items = vec![item("1"), item("2"), item("3"), item("4"), item("5")];
428 let chunks = build_chunks(&items, 2);
429 assert_eq!(chunks.len(), 3);
430 assert_eq!(chunks[0].len(), 2);
431 assert_eq!(chunks[1].len(), 2);
432 assert_eq!(chunks[2].len(), 1);
433 }
434
435 #[test]
436 fn chunk_size_zero_is_treated_as_one() {
437 let items = vec![item("1"), item("2")];
438 let chunks = build_chunks(&items, 0);
439 assert_eq!(chunks.len(), 2);
440 }
441
442 #[test]
443 fn applies_playlist_range_when_valid_after_clamping() {
444 let tracks = vec![item("1"), item("2"), item("3")]
445 .into_iter()
446 .map(Track::Query)
447 .collect::<Vec<_>>();
448 let ranged = apply_playlist_range(tracks, Some((1, 10)));
449 assert_eq!(ranged.len(), 2);
450 }
451
452 #[test]
453 fn worker_info_uses_one_based_worker_suffixes() {
454 let info = worker_info(7, "worker", 41000, "run", 2);
455 assert_eq!(info.id, 7);
456 assert_eq!(info.username, "worker3");
457 assert_eq!(info.port, 41002);
458 assert_eq!(info.run_id, "run-3");
459 }
460}