Skip to main content

convert_invert/internals/worker/
worker_manager.rs

1use std::path::PathBuf;
2use std::sync::Arc;
3use std::sync::Mutex;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use anyhow::Context;
8use redis::Commands;
9use serde::Serialize;
10use tokio::sync::watch;
11use tokio::task::JoinHandle;
12
13use crate::internals::context::context_manager::{Managers, RedisPool, Track};
14use crate::internals::database::DbPool;
15use crate::internals::query::query_manager::QueryManager;
16use crate::internals::search::search_manager::SearchItem;
17use crate::internals::utils::config::config_manager::Config;
18
19#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
20pub struct WorkerInfo {
21    pub id: usize,
22    pub username: String,
23    pub port: u16,
24    pub run_id: String,
25    pub started_at_epoch_secs: u64,
26}
27
28#[derive(Serialize, Clone, Debug, PartialEq, Eq)]
29pub struct WorkerStatus {
30    pub workers: Vec<WorkerInfo>,
31    pub queue_len: usize,
32    pub failed_count: usize,
33}
34
35#[derive(Clone, Debug)]
36pub struct WorkerStartOptions {
37    pub worker_count: usize,
38    pub username_prefix: String,
39    pub port_base: u16,
40    pub run_id_prefix: String,
41    pub playlist_id: String,
42    pub chunk_size: usize,
43    pub playlist_range: Option<(usize, usize)>,
44}
45
46pub struct WorkerSupervisor {
47    workers: Mutex<Vec<WorkerHandle>>,
48    next_worker_id: AtomicUsize,
49    queue_key: Mutex<Option<String>>,
50    download_path: PathBuf,
51    db_pool: DbPool,
52    redis_pool: RedisPool,
53    shutdown: watch::Receiver<bool>,
54}
55
56struct WorkerHandle {
57    info: WorkerInfo,
58    handle: JoinHandle<()>,
59}
60
61struct WorkerRunConfig {
62    worker_config: Config,
63    playlist_id: String,
64    chunk_size: usize,
65    download_path: PathBuf,
66    redis_pool: RedisPool,
67    db_pool: DbPool,
68    is_leader: bool,
69    all_items: Vec<SearchItem>,
70    shutdown: watch::Receiver<bool>,
71}
72
73impl WorkerSupervisor {
74    pub fn new(
75        download_path: PathBuf,
76        db_pool: DbPool,
77        redis_pool: RedisPool,
78        shutdown: watch::Receiver<bool>,
79    ) -> Self {
80        Self {
81            workers: Mutex::new(Vec::new()),
82            next_worker_id: AtomicUsize::new(1),
83            queue_key: Mutex::new(None),
84            download_path,
85            db_pool,
86            redis_pool,
87            shutdown,
88        }
89    }
90
91    pub async fn start(
92        &self,
93        options: WorkerStartOptions,
94        base_config: Config,
95        user_password: String,
96    ) -> anyhow::Result<Vec<WorkerInfo>> {
97        let query_manager = QueryManager::new_with_timeout(
98            options.playlist_id.clone(),
99            base_config.client_id.clone(),
100            base_config.client_secret.clone(),
101            base_config.search_timeout_secs,
102        );
103        let playlist_tracks = query_manager
104            .fetch_playlist()
105            .await
106            .context("Fetch worker playlist")?;
107        let playlist_tracks = apply_playlist_range(playlist_tracks, options.playlist_range);
108        let items = playlist_tracks
109            .into_iter()
110            .filter_map(|track| match track {
111                Track::Query(item) => Some(item),
112                _ => None,
113            })
114            .collect::<Vec<_>>();
115
116        self.replace_queue(&options.playlist_id, options.chunk_size, &items)
117            .context("Build worker queue")?;
118
119        let mut spawned = Vec::with_capacity(options.worker_count);
120        let mut guard = self
121            .workers
122            .lock()
123            .map_err(|_| anyhow::anyhow!("worker lock poisoned"))?;
124
125        for index in 0..options.worker_count {
126            let info = worker_info(
127                self.next_worker_id.fetch_add(1, Ordering::Relaxed),
128                &options.username_prefix,
129                options.port_base,
130                &options.run_id_prefix,
131                index,
132            );
133
134            let mut worker_config = base_config.clone();
135            worker_config.user_name = info.username.clone();
136            worker_config.user_password = user_password.clone();
137            worker_config.listen_port = u32::from(info.port);
138            worker_config.run_id = info.run_id.clone();
139
140            let handle = tokio::spawn(run_worker(WorkerRunConfig {
141                worker_config,
142                playlist_id: options.playlist_id.clone(),
143                chunk_size: options.chunk_size,
144                download_path: self.download_path.clone(),
145                redis_pool: self.redis_pool.clone(),
146                db_pool: self.db_pool.clone(),
147                is_leader: index == 0,
148                all_items: items.clone(),
149                shutdown: self.shutdown.clone(),
150            }));
151
152            spawned.push(info.clone());
153            guard.push(WorkerHandle { info, handle });
154        }
155
156        Ok(spawned)
157    }
158
159    pub fn stop(&self, target_ids: Option<&[u32]>) -> anyhow::Result<Vec<u32>> {
160        let mut guard = self
161            .workers
162            .lock()
163            .map_err(|_| anyhow::anyhow!("worker lock poisoned"))?;
164        let mut stopped = Vec::new();
165        let mut remaining = Vec::with_capacity(guard.len());
166
167        for entry in guard.drain(..) {
168            let should_stop = match target_ids {
169                None => true,
170                Some(ids) => ids.contains(&(entry.info.id as u32)),
171            };
172
173            if should_stop {
174                entry.handle.abort();
175                stopped.push(entry.info.id as u32);
176            } else {
177                remaining.push(entry);
178            }
179        }
180
181        *guard = remaining;
182        Ok(stopped)
183    }
184
185    pub fn status(&self) -> anyhow::Result<WorkerStatus> {
186        let mut guard = self
187            .workers
188            .lock()
189            .map_err(|_| anyhow::anyhow!("worker lock poisoned"))?;
190        let mut workers = Vec::with_capacity(guard.len());
191
192        guard.retain_mut(|entry| {
193            if entry.handle.is_finished() {
194                false
195            } else {
196                workers.push(entry.info.clone());
197                true
198            }
199        });
200
201        let queue_key = self
202            .queue_key
203            .lock()
204            .map_err(|_| anyhow::anyhow!("queue key lock poisoned"))?
205            .clone();
206        let (queue_len, failed_count) = if let Some(key) = queue_key {
207            let mut redis_con = self.redis_pool.get().context("Acquire Redis connection")?;
208            let queue_len = redis_con.llen(key).unwrap_or(0);
209            let failed_count = redis_con.scard("dl:failed").unwrap_or(0);
210            (queue_len, failed_count)
211        } else {
212            (0, 0)
213        };
214
215        Ok(WorkerStatus {
216            workers,
217            queue_len,
218            failed_count,
219        })
220    }
221
222    fn replace_queue(
223        &self,
224        playlist_id: &str,
225        chunk_size: usize,
226        items: &[SearchItem],
227    ) -> anyhow::Result<()> {
228        let queue_key = chunk_queue_key(playlist_id, chunk_size);
229        self.queue_key
230            .lock()
231            .map_err(|_| anyhow::anyhow!("queue key lock poisoned"))?
232            .replace(queue_key.clone());
233
234        let chunks = build_chunks(items, chunk_size);
235        let mut redis_con = self.redis_pool.get().context("Acquire Redis connection")?;
236        let _: usize = redis_con.del(queue_key.clone()).unwrap_or(0);
237        for chunk in chunks {
238            let payload = serde_json::to_string(&chunk).context("Serialise worker chunk")?;
239            let _: usize = redis_con.rpush(&queue_key, payload).unwrap_or(0);
240        }
241        Ok(())
242    }
243}
244
245pub fn chunk_queue_key(playlist_id: &str, chunk_size: usize) -> String {
246    format!("dl:chunk_queue:{playlist_id}:{chunk_size}")
247}
248
249pub fn build_chunks(items: &[SearchItem], chunk_size: usize) -> Vec<Vec<SearchItem>> {
250    let chunk_size = chunk_size.max(1);
251    let mut chunks = Vec::new();
252    let mut start = 0usize;
253    while start < items.len() {
254        let end = (start + chunk_size).min(items.len());
255        chunks.push(items[start..end].to_vec());
256        start = end;
257    }
258    chunks
259}
260
261pub fn apply_playlist_range(
262    playlist_tracks: Vec<Track>,
263    playlist_range: Option<(usize, usize)>,
264) -> Vec<Track> {
265    let Some((start, end)) = playlist_range else {
266        return playlist_tracks;
267    };
268    let start = start.min(playlist_tracks.len());
269    let end = end.min(playlist_tracks.len());
270    if start >= end {
271        return playlist_tracks;
272    }
273    playlist_tracks
274        .into_iter()
275        .skip(start)
276        .take(end - start)
277        .collect()
278}
279
280fn worker_info(
281    worker_id: usize,
282    username_prefix: &str,
283    port_base: u16,
284    run_id_prefix: &str,
285    index: usize,
286) -> WorkerInfo {
287    let worker_number = index + 1;
288    WorkerInfo {
289        id: worker_id,
290        username: format!("{username_prefix}{worker_number}"),
291        port: port_base.saturating_add(index as u16),
292        run_id: format!("{run_id_prefix}-{worker_number}"),
293        started_at_epoch_secs: SystemTime::now()
294            .duration_since(UNIX_EPOCH)
295            .unwrap_or_default()
296            .as_secs(),
297    }
298}
299
300async fn run_worker(config: WorkerRunConfig) {
301    let WorkerRunConfig {
302        worker_config,
303        playlist_id,
304        chunk_size,
305        download_path,
306        redis_pool,
307        db_pool,
308        is_leader,
309        all_items,
310        mut shutdown,
311    } = config;
312
313    let queue_key = chunk_queue_key(&playlist_id, chunk_size);
314    let managers = match Managers::new(
315        worker_config.judge_score_levenshtein,
316        download_path.clone(),
317        worker_config.clone(),
318        db_pool.clone(),
319        redis_pool.clone(),
320    ) {
321        Ok(managers) => Arc::new(managers),
322        Err(err) => {
323            tracing::error!(?err, run_id = %worker_config.run_id, "Worker failed to start managers");
324            return;
325        }
326    };
327
328    loop {
329        if *shutdown.borrow() {
330            tracing::info!(run_id = %worker_config.run_id, "Worker exiting due to shutdown");
331            managers.shutdown();
332            return;
333        }
334
335        let chunk_json: Option<String> = {
336            let mut redis_con = match redis_pool.get() {
337                Ok(con) => con,
338                Err(err) => {
339                    tracing::error!(?err, "Worker failed to acquire Redis connection; exiting");
340                    return;
341                }
342            };
343            redis_con.lpop(&queue_key, None).ok()
344        };
345        let Some(chunk_json) = chunk_json else {
346            break;
347        };
348
349        let chunk_items: Vec<SearchItem> = match serde_json::from_str(&chunk_json) {
350            Ok(value) => value,
351            Err(err) => {
352                let truncated = chunk_json.chars().take(500).collect::<String>();
353                tracing::error!(
354                    %err,
355                    payload = %truncated,
356                    run_id = %worker_config.run_id,
357                    "Worker skipped malformed chunk payload",
358                );
359                continue;
360            }
361        };
362
363        let tracks = chunk_items
364            .into_iter()
365            .map(Track::Query)
366            .collect::<Vec<_>>();
367
368        tokio::select! {
369            _ = managers.run_chunk(tracks) => {}
370            _ = shutdown.changed() => {
371                if *shutdown.borrow() {
372                    tracing::info!(run_id = %worker_config.run_id, "Worker shutdown mid-cycle");
373                    managers.shutdown();
374                    return;
375                }
376            }
377        }
378    }
379
380    if is_leader && !*shutdown.borrow() {
381        let failed_ids: Vec<String> = {
382            let mut redis_con = match redis_pool.get() {
383                Ok(con) => con,
384                Err(_) => return,
385            };
386            redis_con.smembers("dl:failed").unwrap_or_default()
387        };
388        if !failed_ids.is_empty() {
389            let failed_items = all_items
390                .into_iter()
391                .filter(|item| failed_ids.contains(&item.track_id))
392                .map(Track::Query)
393                .collect::<Vec<_>>();
394            if !failed_items.is_empty() {
395                let _ = managers.run_chunk(failed_items).await;
396            }
397        }
398    }
399    managers.shutdown();
400}
401
402#[cfg(test)]
403mod tests {
404    use super::{apply_playlist_range, build_chunks, chunk_queue_key, worker_info};
405    use crate::internals::context::context_manager::Track;
406    use crate::internals::search::search_manager::SearchItem;
407
408    fn item(id: &str) -> SearchItem {
409        SearchItem::new(
410            id.to_string(),
411            format!("track-{id}"),
412            "album".to_string(),
413            "artist".to_string(),
414        )
415    }
416
417    #[test]
418    fn queue_key_includes_playlist_and_chunk_size() {
419        assert_eq!(
420            chunk_queue_key("playlist", 15),
421            "dl:chunk_queue:playlist:15"
422        );
423    }
424
425    #[test]
426    fn chunks_items_by_requested_size() {
427        let items = vec![item("1"), item("2"), item("3"), item("4"), item("5")];
428        let chunks = build_chunks(&items, 2);
429        assert_eq!(chunks.len(), 3);
430        assert_eq!(chunks[0].len(), 2);
431        assert_eq!(chunks[1].len(), 2);
432        assert_eq!(chunks[2].len(), 1);
433    }
434
435    #[test]
436    fn chunk_size_zero_is_treated_as_one() {
437        let items = vec![item("1"), item("2")];
438        let chunks = build_chunks(&items, 0);
439        assert_eq!(chunks.len(), 2);
440    }
441
442    #[test]
443    fn applies_playlist_range_when_valid_after_clamping() {
444        let tracks = vec![item("1"), item("2"), item("3")]
445            .into_iter()
446            .map(Track::Query)
447            .collect::<Vec<_>>();
448        let ranged = apply_playlist_range(tracks, Some((1, 10)));
449        assert_eq!(ranged.len(), 2);
450    }
451
452    #[test]
453    fn worker_info_uses_one_based_worker_suffixes() {
454        let info = worker_info(7, "worker", 41000, "run", 2);
455        assert_eq!(info.id, 7);
456        assert_eq!(info.username, "worker3");
457        assert_eq!(info.port, 41002);
458        assert_eq!(info.run_id, "run-3");
459    }
460}