1use crate::internals::database::manager::DatabaseManager;
2use futures_util::FutureExt;
3use serde::{Deserialize, Serialize};
4use soulseek_rs::{Client, ClientSettings};
5use std::{
6 any::Any,
7 collections::HashSet,
8 future::Future,
9 net::TcpListener,
10 panic::{AssertUnwindSafe, catch_unwind},
11 path::PathBuf,
12 sync::Arc,
13};
14use tokio::sync::{
15 Semaphore,
16 mpsc::{self, Sender},
17};
18use tokio::task::JoinSet;
19use tracing::instrument;
20
21use anyhow::Context;
22
23use crate::internals::{
24 database::db_pool_snapshot,
25 download::download_manager::DownloadManager,
26 judge::{judge_manager::JudgeManager, judges::levenshtein::Levenshtein},
27 query::query_manager::QueryManager,
28 search::search_manager::{
29 DownloadableFile, JudgeSubmission, SearchExitReason, SearchItem, SearchManager,
30 },
31 utils::config::config_manager::Config,
32};
33
34#[derive(Debug, Serialize, Deserialize)]
36pub struct DownloadedFile {
37 pub filename: String,
38 pub track: SearchItem,
39}
40
41struct ManagedTaskResult {
42 pub label: &'static str,
43 pub error: Option<String>,
44}
45
46struct RunCycleShared<'a> {
47 managers: &'a Arc<Managers>,
48 sender: &'a Arc<Sender<Track>>,
49 state: &'a Arc<tokio::sync::RwLock<HashSet<SearchItem>>>,
50 search_semaphore: &'a Arc<Semaphore>,
51 download_semaphore: &'a Arc<Semaphore>,
52}
53
54#[derive(Debug)]
56pub struct RetryRequest {
57 pub request: JudgeSubmission,
58 pub retry_attempts: u8,
59 pub failed_download_result: DownloadableFile,
60}
61
62#[derive(Debug)]
64pub enum Track {
65 Query(SearchItem),
67 SearchRetry(SearchItem),
69 Result(JudgeSubmission),
71 Downloadable(JudgeSubmission),
73 File(DownloadedFile),
75 Retry(RetryRequest),
77 Reject(RejectedTrack),
79}
80
81#[derive(Debug, Serialize, Deserialize)]
83pub struct RejectedTrack {
84 track: JudgeSubmission,
85 reason: RejectReason,
86}
87
88#[derive(Debug, Serialize, Deserialize)]
90pub enum RejectReason {
91 AlreadyDownloaded,
93 LowScore(f32),
95 NotMusic(String),
97 Banned(String),
99 AbandonedAttemptingSearch,
101}
102
103impl RejectedTrack {
104 pub fn new(track: JudgeSubmission, reason: RejectReason) -> Self {
105 Self { track, reason }
106 }
107
108 pub fn parts(&self) -> (&JudgeSubmission, &RejectReason) {
109 (&self.track, &self.reason)
110 }
111}
112
113pub async fn send(message: Track, chan: &Sender<Track>) -> anyhow::Result<()> {
115 chan.send(message).await.context("Send to channel")?;
116 Ok(())
117}
118
119pub type RedisPool = diesel::r2d2::Pool<redis::Client>;
120
121#[derive(Debug, Clone, Copy)]
123pub struct WorkerTuning {
124 pub search_concurrency: usize,
127 pub download_concurrency: usize,
130 pub queue_capacity: usize,
132}
133
134impl WorkerTuning {
135 pub fn from_env() -> Self {
137 Self {
138 search_concurrency: env_usize("SEARCH_CONCURRENCY", 4),
139 download_concurrency: env_usize("DOWNLOAD_CONCURRENCY", 7),
140 queue_capacity: env_usize("QUEUE_CAPACITY", 20000),
141 }
142 }
143}
144
145fn env_usize(key: &str, default: usize) -> usize {
146 std::env::var(key)
147 .ok()
148 .and_then(|value| value.parse().ok())
149 .unwrap_or(default)
150}
151
152pub struct Managers {
154 pub client: Arc<Client>,
156 pub download_manager: DownloadManager,
158 pub search_manager: SearchManager,
160 pub query_manager: QueryManager,
162 pub judge_manager: JudgeManager,
164 pub db_pool: crate::internals::database::DbPool,
166 pub redis_pool: RedisPool,
168 pub search_empty_result_cutoff: usize,
169}
170
171impl Managers {
172 pub fn new(
174 score: Option<f32>,
175 path: PathBuf,
176 config: Config,
177 db_pool: crate::internals::database::DbPool,
178 redis_pool: RedisPool,
179 ) -> anyhow::Result<Self> {
180 let listen_port = config.listen_port;
181 TcpListener::bind(format!("0.0.0.0:{listen_port}"))
182 .with_context(|| format!("listener bind preflight failed on port {listen_port}"))?;
183 let client_settings = ClientSettings {
184 username: config.user_name,
185 password: config.user_password,
186 listen_port,
187 ..Default::default()
188 };
189 let search_empty_result_cutoff = config.search_empty_result_cutoff;
190 let mut client = Client::with_settings(client_settings);
191 catch_unwind(AssertUnwindSafe(|| client.connect())).map_err(|payload| {
192 anyhow::anyhow!("listener bind panic: {}", panic_payload(payload))
193 })?;
194 client.login().context("Could not connect")?;
195 let client = Arc::new(client);
196 let download_manager = DownloadManager::new(client.clone(), path);
197 let search_manager = SearchManager::new(client.clone());
198 let judge_threshold =
199 score.unwrap_or(crate::internals::judge::judge_manager::JUDGE_THRESHOLD);
200 let lev_judge = Levenshtein::new(judge_threshold);
201 let judge_manager = JudgeManager::new(Box::new(lev_judge), judge_threshold);
202 let query_manager = QueryManager::new_with_timeout(
203 config.playlist_id,
204 config.client_id,
205 config.client_secret,
206 config.search_timeout_secs,
207 );
208 Ok(Managers {
209 client,
210 download_manager,
211 search_manager,
212 judge_manager,
213 query_manager,
214 db_pool,
215 redis_pool,
216 search_empty_result_cutoff,
217 })
218 }
219
220 #[instrument(name = "run-chunk", skip(self, tracks))]
225 pub async fn run_chunk(
226 self: &Arc<Self>,
227 tracks: impl IntoIterator<Item = Track>,
228 ) -> anyhow::Result<()> {
229 let managers = Arc::clone(self);
230 let tuning = WorkerTuning::from_env();
231 let (sender, mut receiver) = mpsc::channel(tuning.queue_capacity);
232
233 let sender = Arc::new(sender);
234 let state = Arc::new(tokio::sync::RwLock::new(HashSet::new()));
235 let search_semaphore = Arc::new(Semaphore::new(tuning.search_concurrency));
236 let download_semaphore = Arc::new(Semaphore::new(tuning.download_concurrency));
237 let snapshot = db_pool_snapshot(&managers.db_pool);
238 tracing::info!(
239 search_permits = search_semaphore.available_permits(),
240 download_permits = download_semaphore.available_permits(),
241 db_pool_connections = snapshot.connections,
242 db_pool_idle_connections = snapshot.idle_connections,
243 db_pool_in_use_connections = snapshot.in_use_connections(),
244 "Started run chunk",
245 );
246 for track in tracks {
247 sender.send(track).await.context("injecting tracks")?;
248 }
249 let mut tasks = JoinSet::new();
250 let mut first_task_error: Option<String> = None;
251
252 while !receiver.is_empty() || !tasks.is_empty() {
253 tokio::select! {
254 maybe_track = receiver.recv(), if !receiver.is_empty() => {
255 let Some(track) = maybe_track else {
256 break;
257 };
258 tracing::debug!(?track, "Incoming package");
259 if first_task_error.is_some() {
260 tracing::debug!(?track, "Dropping queued work after task failure");
261 continue;
262 }
263 let shared = RunCycleShared {
264 managers: &managers,
265 sender: &sender,
266 state: &state,
267 search_semaphore: &search_semaphore,
268 download_semaphore: &download_semaphore,
269 };
270 process_track(track, shared, &mut tasks).await?;
271 }
272 maybe_result = tasks.join_next(), if !tasks.is_empty() => {
273 let task_result = match maybe_result {
274 Some(Ok(result)) => result,
275 Some(Err(err)) => ManagedTaskResult {
276 label: "unknown",
277 error: Some(format!("managed task join failed: {err}")),
278 },
279 None => continue,
280 };
281 if let Some(error) = task_result.error {
282 let message = format!("{} task failed: {error}", task_result.label);
283 tracing::error!(task_label = task_result.label, %error, "Managed task failed");
284 first_task_error.get_or_insert(message);
285 }
286 }
287 }
288 }
289 drop(sender);
290 if let Some(error) = first_task_error {
291 anyhow::bail!(error);
292 }
293 let snapshot = db_pool_snapshot(&managers.db_pool);
294 tracing::info!(
295 db_pool_connections = snapshot.connections,
296 db_pool_idle_connections = snapshot.idle_connections,
297 db_pool_in_use_connections = snapshot.in_use_connections(),
298 "Run chunk finished",
299 );
300 Ok(())
301 }
302
303 pub fn shutdown(self: Arc<Self>) {
304 tracing::info!(
305 "Managers shutdown requested; soulseek-rs-lib 0.3.0 exposes no public disconnect/logout API",
306 );
307 }
308}
309
310async fn process_track(
311 track: Track,
312 shared: RunCycleShared<'_>,
313 tasks: &mut JoinSet<ManagedTaskResult>,
314) -> anyhow::Result<()> {
315 let RunCycleShared {
316 managers,
317 sender,
318 state,
319 search_semaphore,
320 download_semaphore,
321 } = shared;
322
323 {
324 let mut conn = managers.db_pool.get().map_err(|err| {
325 let snapshot = db_pool_snapshot(&managers.db_pool);
326 tracing::error!(
327 ?err,
328 db_pool_connections = snapshot.connections,
329 db_pool_idle_connections = snapshot.idle_connections,
330 db_pool_in_use_connections = snapshot.in_use_connections(),
331 "DB pool in process_track"
332 );
333 err
334 })?;
335 let mut database_manager = DatabaseManager::new(&mut conn);
336 database_manager
337 .load_item_to_database(&track)
338 .context("Load into database")?;
339 }
340 match track {
341 Track::Query(search_item) => {
342 let managers = Arc::clone(managers);
343 let sender = Arc::clone(sender);
344 let semaphore = search_semaphore.clone();
345 tracing::debug!(?search_item, "Scheduling search");
346 spawn_managed(tasks, "search", async move {
347 let outcome = managers
348 .search_manager
349 .run(
350 search_item.clone(),
351 managers.search_empty_result_cutoff(),
352 managers.query_manager_search_timeout(),
353 false,
354 semaphore,
355 Arc::clone(&sender),
356 )
357 .await
358 .context("returning track")?;
359 if matches!(outcome.exit_reason, SearchExitReason::NoCandidatesFound) {
360 send(Track::SearchRetry(search_item), &sender)
361 .await
362 .context("queue relaxed search")?;
363 }
364 Ok(())
365 });
366 }
367 Track::SearchRetry(search_item) => {
368 let managers = Arc::clone(managers);
369 let sender = Arc::clone(sender);
370 let semaphore = search_semaphore.clone();
371 tracing::debug!(?search_item, "Scheduling relaxed search retry");
372 spawn_managed(tasks, "search_retry", async move {
373 let outcome = managers
374 .search_manager
375 .run(
376 search_item.clone(),
377 managers.search_empty_result_cutoff(),
378 managers.query_manager_search_timeout(),
379 true,
380 semaphore,
381 sender,
382 )
383 .await
384 .context("returning relaxed track")?;
385 if matches!(outcome.exit_reason, SearchExitReason::NoCandidatesFound) {
386 tracing::info!(?search_item, "Relaxed search returned no candidates");
387 }
388 Ok(())
389 });
390 }
391 Track::Result(judge_submission) => {
392 let managers = Arc::clone(managers);
393 let sender = Arc::clone(sender);
394 spawn_managed(tasks, "judge", async move {
395 tracing::debug!(?judge_submission, "Scheduling judge");
396 managers
397 .judge_manager
398 .run(judge_submission, sender)
399 .await
400 .context("running judge")
401 });
402 }
403 Track::Downloadable(judge_submission) => {
404 let is_downloaded = {
405 let mut conn = managers.db_pool.get().map_err(|err| {
406 let snapshot = db_pool_snapshot(&managers.db_pool);
407 tracing::error!(
408 ?err,
409 db_pool_connections = snapshot.connections,
410 db_pool_idle_connections = snapshot.idle_connections,
411 db_pool_in_use_connections = snapshot.in_use_connections(),
412 "DB pool in downloadable check"
413 );
414 err
415 })?;
416 let mut database_manager = DatabaseManager::new(&mut conn);
417 database_manager
418 .is_search_item_downloaded(&judge_submission.track)
419 .context("Check existing downloaded track")?
420 };
421 if is_downloaded {
422 let reject =
423 RejectedTrack::new(judge_submission.clone(), RejectReason::AlreadyDownloaded);
424 send(Track::Reject(reject), sender)
425 .await
426 .context("sending already downloaded rejection")?;
427 return Ok(());
428 }
429 let semaphore = download_semaphore.clone();
430 let managers = Arc::clone(managers);
431 tracing::debug!(?judge_submission, "Scheduling download");
432 let judge_sub = judge_submission.clone();
433 let sender = Arc::clone(sender);
434
435 let mut state_guard = state.write().await;
436 if state_guard.insert(judge_submission.track.clone()) {
437 drop(state_guard);
438
439 let managers_clone = Arc::clone(&managers);
440 spawn_managed(tasks, "download", async move {
441 managers_clone
442 .download_manager
443 .run(
444 judge_sub,
445 semaphore,
446 sender,
447 managers_clone.redis_pool.clone(),
448 managers_clone.db_pool.clone(),
449 )
450 .await
451 .context("Downloading")?;
452 Ok(())
453 });
454 } else {
455 drop(state_guard);
456 let reject =
457 RejectedTrack::new(judge_submission.clone(), RejectReason::AlreadyDownloaded);
458 send(Track::Reject(reject), &sender)
459 .await
460 .context("sending rejected_tracks")?;
461 }
462 }
463 Track::File(downloaded_file) => {
464 tracing::info!(?downloaded_file, "Downloaded file");
465 }
466 Track::Retry(mut retry_request) => {
467 state.write().await.remove(&retry_request.request.track);
468 if retry_request.retry_attempts >= 1 {
469 let reject = RejectedTrack::new(
470 retry_request.request,
471 RejectReason::AbandonedAttemptingSearch,
472 );
473 send(Track::Reject(reject), sender)
474 .await
475 .context("rejecting")?;
476 return Ok(());
477 }
478 retry_request.retry_attempts += 1;
479 let managers = Arc::clone(managers);
480 let semaphore = search_semaphore.clone();
481 let sender = Arc::clone(sender);
482 tracing::info!(?retry_request.request, "Retry requested");
483 let search_item = retry_request.request.clone();
484 spawn_managed(tasks, "retry_search", async move {
485 managers
486 .search_manager
487 .run(
488 search_item.track,
489 managers.search_empty_result_cutoff(),
490 managers.query_manager_search_timeout(),
491 true,
492 semaphore,
493 sender,
494 )
495 .await
496 .context("returning track")?;
497 Ok(())
498 });
499 tracing::debug!(?retry_request, "Retry queued")
500 }
501 Track::Reject(rejected_track) => {
502 state.write().await.remove(&rejected_track.track.track);
503 }
504 }
505 Ok(())
506}
507
508fn panic_payload(payload: Box<dyn Any + Send>) -> String {
509 if let Some(message) = payload.downcast_ref::<&str>() {
510 (*message).to_string()
511 } else if let Some(message) = payload.downcast_ref::<String>() {
512 message.clone()
513 } else {
514 "unknown panic payload".to_string()
515 }
516}
517
518impl Managers {
519 fn query_manager_search_timeout(&self) -> u8 {
520 self.query_manager.search_timeout_secs
521 }
522
523 fn search_empty_result_cutoff(&self) -> usize {
524 self.search_empty_result_cutoff
525 }
526}
527
528fn spawn_managed<F>(tasks: &mut JoinSet<ManagedTaskResult>, label: &'static str, future: F)
529where
530 F: Future<Output = anyhow::Result<()>> + Send + 'static,
531{
532 tasks.spawn(async move {
533 let result = AssertUnwindSafe(future).catch_unwind().await;
534 let error = match result {
535 Ok(Ok(())) => None,
536 Ok(Err(err)) => Some(format!("{err:?}")),
537 Err(_) => Some("task panicked".to_string()),
538 };
539 ManagedTaskResult { label, error }
540 });
541}
542
543#[cfg(test)]
544mod tests {
545 use super::{ManagedTaskResult, spawn_managed};
546 use tokio::task::JoinSet;
547
548 async fn next_result(tasks: &mut JoinSet<ManagedTaskResult>) -> ManagedTaskResult {
549 tasks
550 .join_next()
551 .await
552 .expect("task result")
553 .expect("task joined")
554 }
555
556 #[tokio::test]
557 async fn managed_task_reports_success_by_label() {
558 let mut tasks = JoinSet::new();
559
560 spawn_managed(&mut tasks, "search", async { Ok(()) });
561
562 let result = next_result(&mut tasks).await;
563 assert_eq!(result.label, "search");
564 assert!(result.error.is_none());
565 }
566
567 #[tokio::test]
568 async fn managed_task_captures_errors() {
569 let mut tasks = JoinSet::new();
570
571 spawn_managed(&mut tasks, "download", async {
572 anyhow::bail!("download failed")
573 });
574
575 let result = next_result(&mut tasks).await;
576 assert_eq!(result.label, "download");
577 assert!(
578 result
579 .error
580 .expect("managed task error")
581 .contains("download failed")
582 );
583 }
584
585 #[tokio::test]
586 async fn managed_task_converts_panics_to_errors() {
587 let mut tasks = JoinSet::new();
588
589 spawn_managed(&mut tasks, "judge", async {
590 panic!("judge panic");
591 #[allow(unreachable_code)]
592 Ok(())
593 });
594
595 let result = next_result(&mut tasks).await;
596 assert_eq!(result.label, "judge");
597 assert_eq!(result.error.as_deref(), Some("task panicked"));
598 }
599}