Skip to main content

convert_invert/internals/context/
context_manager.rs

1use crate::internals::database::manager::DatabaseManager;
2use futures_util::FutureExt;
3use serde::{Deserialize, Serialize};
4use soulseek_rs::{Client, ClientSettings};
5use std::{
6    any::Any,
7    collections::HashSet,
8    future::Future,
9    net::TcpListener,
10    panic::{AssertUnwindSafe, catch_unwind},
11    path::PathBuf,
12    sync::Arc,
13};
14use tokio::sync::{
15    Semaphore,
16    mpsc::{self, Sender},
17};
18use tokio::task::JoinSet;
19use tracing::instrument;
20
21use anyhow::Context;
22
23use crate::internals::{
24    database::db_pool_snapshot,
25    download::download_manager::DownloadManager,
26    judge::{judge_manager::JudgeManager, judges::levenshtein::Levenshtein},
27    query::query_manager::QueryManager,
28    search::search_manager::{
29        DownloadableFile, JudgeSubmission, SearchExitReason, SearchItem, SearchManager,
30    },
31    utils::config::config_manager::Config,
32};
33
34/// Metadata for a successfully downloaded file.
35#[derive(Debug, Serialize, Deserialize)]
36pub struct DownloadedFile {
37    pub filename: String,
38    pub track: SearchItem,
39}
40
41struct ManagedTaskResult {
42    pub label: &'static str,
43    pub error: Option<String>,
44}
45
46struct RunCycleShared<'a> {
47    managers: &'a Arc<Managers>,
48    sender: &'a Arc<Sender<Track>>,
49    state: &'a Arc<tokio::sync::RwLock<HashSet<SearchItem>>>,
50    search_semaphore: &'a Arc<Semaphore>,
51    download_semaphore: &'a Arc<Semaphore>,
52}
53
54/// A request to retry a failed search or download.
55#[derive(Debug)]
56pub struct RetryRequest {
57    pub request: JudgeSubmission,
58    pub retry_attempts: u8,
59    pub failed_download_result: DownloadableFile,
60}
61
62/// The various stages and events in a track's lifecycle.
63#[derive(Debug)]
64pub enum Track {
65    /// A new search query to be performed.
66    Query(SearchItem),
67    /// A relaxed second-pass search query to be performed.
68    SearchRetry(SearchItem),
69    /// A candidate submission found for a track.
70    Result(JudgeSubmission),
71    /// A candidate that has been judged and is ready for download.
72    Downloadable(JudgeSubmission),
73    /// A file that has been successfully downloaded.
74    File(DownloadedFile),
75    /// A request to retry a failed operation.
76    Retry(RetryRequest),
77    /// A track that has been rejected for a specific reason.
78    Reject(RejectedTrack),
79}
80
81/// Metadata for a track that was rejected.
82#[derive(Debug, Serialize, Deserialize)]
83pub struct RejectedTrack {
84    track: JudgeSubmission,
85    reason: RejectReason,
86}
87
88/// Reasons why a track candidate might be rejected.
89#[derive(Debug, Serialize, Deserialize)]
90pub enum RejectReason {
91    /// The track has already been downloaded successfully.
92    AlreadyDownloaded,
93    /// The candidate's score was below the required threshold.
94    LowScore(f32),
95    /// The candidate was identified as non-music or invalid.
96    NotMusic(String),
97    /// The peer providing the file is banned.
98    Banned(String),
99    /// All search attempts were exhausted without finding a suitable candidate.
100    AbandonedAttemptingSearch,
101}
102
103impl RejectedTrack {
104    pub fn new(track: JudgeSubmission, reason: RejectReason) -> Self {
105        Self { track, reason }
106    }
107
108    pub fn parts(&self) -> (&JudgeSubmission, &RejectReason) {
109        (&self.track, &self.reason)
110    }
111}
112
113/// Sends a `Track` event to the provided channel.
114pub async fn send(message: Track, chan: &Sender<Track>) -> anyhow::Result<()> {
115    chan.send(message).await.context("Send to channel")?;
116    Ok(())
117}
118
119pub type RedisPool = diesel::r2d2::Pool<redis::Client>;
120
121/// Tuning parameters for a worker's execution.
122#[derive(Debug, Clone, Copy)]
123pub struct WorkerTuning {
124    /// Max in-flight search requests against Soulseek. Soulseek is rate-sensitive;
125    /// raise carefully.
126    pub search_concurrency: usize,
127    /// Max in-flight downloads. Keep this below the host's network and file
128    /// descriptor budget.
129    pub download_concurrency: usize,
130    /// Capacity of the work-distribution channel.
131    pub queue_capacity: usize,
132}
133
134impl WorkerTuning {
135    /// Loads tuning parameters from environment variables with sensible defaults.
136    pub fn from_env() -> Self {
137        Self {
138            search_concurrency: env_usize("SEARCH_CONCURRENCY", 4),
139            download_concurrency: env_usize("DOWNLOAD_CONCURRENCY", 7),
140            queue_capacity: env_usize("QUEUE_CAPACITY", 20000),
141        }
142    }
143}
144
145fn env_usize(key: &str, default: usize) -> usize {
146    std::env::var(key)
147        .ok()
148        .and_then(|value| value.parse().ok())
149        .unwrap_or(default)
150}
151
152/// The central container for all service managers and shared state.
153pub struct Managers {
154    /// The shared Soulseek client.
155    pub client: Arc<Client>,
156    /// Manager for file downloads.
157    pub download_manager: DownloadManager,
158    /// Manager for Soulseek searches.
159    pub search_manager: SearchManager,
160    /// Manager for Spotify playlist queries.
161    pub query_manager: QueryManager,
162    /// Manager for candidate judging.
163    pub judge_manager: JudgeManager,
164    /// The PostgreSQL connection pool.
165    pub db_pool: crate::internals::database::DbPool,
166    /// The Redis connection pool.
167    pub redis_pool: RedisPool,
168    pub search_empty_result_cutoff: usize,
169}
170
171impl Managers {
172    /// Creates a new `Managers` instance with the provided configuration.
173    pub fn new(
174        score: Option<f32>,
175        path: PathBuf,
176        config: Config,
177        db_pool: crate::internals::database::DbPool,
178        redis_pool: RedisPool,
179    ) -> anyhow::Result<Self> {
180        let listen_port = config.listen_port;
181        TcpListener::bind(format!("0.0.0.0:{listen_port}"))
182            .with_context(|| format!("listener bind preflight failed on port {listen_port}"))?;
183        let client_settings = ClientSettings {
184            username: config.user_name,
185            password: config.user_password,
186            listen_port,
187            ..Default::default()
188        };
189        let search_empty_result_cutoff = config.search_empty_result_cutoff;
190        let mut client = Client::with_settings(client_settings);
191        catch_unwind(AssertUnwindSafe(|| client.connect())).map_err(|payload| {
192            anyhow::anyhow!("listener bind panic: {}", panic_payload(payload))
193        })?;
194        client.login().context("Could not connect")?;
195        let client = Arc::new(client);
196        let download_manager = DownloadManager::new(client.clone(), path);
197        let search_manager = SearchManager::new(client.clone());
198        let judge_threshold =
199            score.unwrap_or(crate::internals::judge::judge_manager::JUDGE_THRESHOLD);
200        let lev_judge = Levenshtein::new(judge_threshold);
201        let judge_manager = JudgeManager::new(Box::new(lev_judge), judge_threshold);
202        let query_manager = QueryManager::new_with_timeout(
203            config.playlist_id,
204            config.client_id,
205            config.client_secret,
206            config.search_timeout_secs,
207        );
208        Ok(Managers {
209            client,
210            download_manager,
211            search_manager,
212            judge_manager,
213            query_manager,
214            db_pool,
215            redis_pool,
216            search_empty_result_cutoff,
217        })
218    }
219
220    /// Runs a single chunk of tracks through the search-judge-download pipeline.
221    ///
222    /// This method orchestrates the task lifecycle using a `JoinSet` and an internal
223    /// message channel. It ensures that all spawned tasks are completed before returning.
224    #[instrument(name = "run-chunk", skip(self, tracks))]
225    pub async fn run_chunk(
226        self: &Arc<Self>,
227        tracks: impl IntoIterator<Item = Track>,
228    ) -> anyhow::Result<()> {
229        let managers = Arc::clone(self);
230        let tuning = WorkerTuning::from_env();
231        let (sender, mut receiver) = mpsc::channel(tuning.queue_capacity);
232
233        let sender = Arc::new(sender);
234        let state = Arc::new(tokio::sync::RwLock::new(HashSet::new()));
235        let search_semaphore = Arc::new(Semaphore::new(tuning.search_concurrency));
236        let download_semaphore = Arc::new(Semaphore::new(tuning.download_concurrency));
237        let snapshot = db_pool_snapshot(&managers.db_pool);
238        tracing::info!(
239            search_permits = search_semaphore.available_permits(),
240            download_permits = download_semaphore.available_permits(),
241            db_pool_connections = snapshot.connections,
242            db_pool_idle_connections = snapshot.idle_connections,
243            db_pool_in_use_connections = snapshot.in_use_connections(),
244            "Started run chunk",
245        );
246        for track in tracks {
247            sender.send(track).await.context("injecting tracks")?;
248        }
249        let mut tasks = JoinSet::new();
250        let mut first_task_error: Option<String> = None;
251
252        while !receiver.is_empty() || !tasks.is_empty() {
253            tokio::select! {
254                maybe_track = receiver.recv(), if !receiver.is_empty() => {
255                    let Some(track) = maybe_track else {
256                        break;
257                    };
258                    tracing::debug!(?track, "Incoming package");
259                    if first_task_error.is_some() {
260                        tracing::debug!(?track, "Dropping queued work after task failure");
261                        continue;
262                    }
263                    let shared = RunCycleShared {
264                        managers: &managers,
265                        sender: &sender,
266                        state: &state,
267                        search_semaphore: &search_semaphore,
268                        download_semaphore: &download_semaphore,
269                    };
270                    process_track(track, shared, &mut tasks).await?;
271                }
272                maybe_result = tasks.join_next(), if !tasks.is_empty() => {
273                    let task_result = match maybe_result {
274                        Some(Ok(result)) => result,
275                        Some(Err(err)) => ManagedTaskResult {
276                            label: "unknown",
277                            error: Some(format!("managed task join failed: {err}")),
278                        },
279                        None => continue,
280                    };
281                    if let Some(error) = task_result.error {
282                        let message = format!("{} task failed: {error}", task_result.label);
283                        tracing::error!(task_label = task_result.label, %error, "Managed task failed");
284                        first_task_error.get_or_insert(message);
285                    }
286                }
287            }
288        }
289        drop(sender);
290        if let Some(error) = first_task_error {
291            anyhow::bail!(error);
292        }
293        let snapshot = db_pool_snapshot(&managers.db_pool);
294        tracing::info!(
295            db_pool_connections = snapshot.connections,
296            db_pool_idle_connections = snapshot.idle_connections,
297            db_pool_in_use_connections = snapshot.in_use_connections(),
298            "Run chunk finished",
299        );
300        Ok(())
301    }
302
303    pub fn shutdown(self: Arc<Self>) {
304        tracing::info!(
305            "Managers shutdown requested; soulseek-rs-lib 0.3.0 exposes no public disconnect/logout API",
306        );
307    }
308}
309
310async fn process_track(
311    track: Track,
312    shared: RunCycleShared<'_>,
313    tasks: &mut JoinSet<ManagedTaskResult>,
314) -> anyhow::Result<()> {
315    let RunCycleShared {
316        managers,
317        sender,
318        state,
319        search_semaphore,
320        download_semaphore,
321    } = shared;
322
323    {
324        let mut conn = managers.db_pool.get().map_err(|err| {
325            let snapshot = db_pool_snapshot(&managers.db_pool);
326            tracing::error!(
327                ?err,
328                db_pool_connections = snapshot.connections,
329                db_pool_idle_connections = snapshot.idle_connections,
330                db_pool_in_use_connections = snapshot.in_use_connections(),
331                "DB pool in process_track"
332            );
333            err
334        })?;
335        let mut database_manager = DatabaseManager::new(&mut conn);
336        database_manager
337            .load_item_to_database(&track)
338            .context("Load into database")?;
339    }
340    match track {
341        Track::Query(search_item) => {
342            let managers = Arc::clone(managers);
343            let sender = Arc::clone(sender);
344            let semaphore = search_semaphore.clone();
345            tracing::debug!(?search_item, "Scheduling search");
346            spawn_managed(tasks, "search", async move {
347                let outcome = managers
348                    .search_manager
349                    .run(
350                        search_item.clone(),
351                        managers.search_empty_result_cutoff(),
352                        managers.query_manager_search_timeout(),
353                        false,
354                        semaphore,
355                        Arc::clone(&sender),
356                    )
357                    .await
358                    .context("returning track")?;
359                if matches!(outcome.exit_reason, SearchExitReason::NoCandidatesFound) {
360                    send(Track::SearchRetry(search_item), &sender)
361                        .await
362                        .context("queue relaxed search")?;
363                }
364                Ok(())
365            });
366        }
367        Track::SearchRetry(search_item) => {
368            let managers = Arc::clone(managers);
369            let sender = Arc::clone(sender);
370            let semaphore = search_semaphore.clone();
371            tracing::debug!(?search_item, "Scheduling relaxed search retry");
372            spawn_managed(tasks, "search_retry", async move {
373                let outcome = managers
374                    .search_manager
375                    .run(
376                        search_item.clone(),
377                        managers.search_empty_result_cutoff(),
378                        managers.query_manager_search_timeout(),
379                        true,
380                        semaphore,
381                        sender,
382                    )
383                    .await
384                    .context("returning relaxed track")?;
385                if matches!(outcome.exit_reason, SearchExitReason::NoCandidatesFound) {
386                    tracing::info!(?search_item, "Relaxed search returned no candidates");
387                }
388                Ok(())
389            });
390        }
391        Track::Result(judge_submission) => {
392            let managers = Arc::clone(managers);
393            let sender = Arc::clone(sender);
394            spawn_managed(tasks, "judge", async move {
395                tracing::debug!(?judge_submission, "Scheduling judge");
396                managers
397                    .judge_manager
398                    .run(judge_submission, sender)
399                    .await
400                    .context("running judge")
401            });
402        }
403        Track::Downloadable(judge_submission) => {
404            let is_downloaded = {
405                let mut conn = managers.db_pool.get().map_err(|err| {
406                    let snapshot = db_pool_snapshot(&managers.db_pool);
407                    tracing::error!(
408                        ?err,
409                        db_pool_connections = snapshot.connections,
410                        db_pool_idle_connections = snapshot.idle_connections,
411                        db_pool_in_use_connections = snapshot.in_use_connections(),
412                        "DB pool in downloadable check"
413                    );
414                    err
415                })?;
416                let mut database_manager = DatabaseManager::new(&mut conn);
417                database_manager
418                    .is_search_item_downloaded(&judge_submission.track)
419                    .context("Check existing downloaded track")?
420            };
421            if is_downloaded {
422                let reject =
423                    RejectedTrack::new(judge_submission.clone(), RejectReason::AlreadyDownloaded);
424                send(Track::Reject(reject), sender)
425                    .await
426                    .context("sending already downloaded rejection")?;
427                return Ok(());
428            }
429            let semaphore = download_semaphore.clone();
430            let managers = Arc::clone(managers);
431            tracing::debug!(?judge_submission, "Scheduling download");
432            let judge_sub = judge_submission.clone();
433            let sender = Arc::clone(sender);
434
435            let mut state_guard = state.write().await;
436            if state_guard.insert(judge_submission.track.clone()) {
437                drop(state_guard);
438
439                let managers_clone = Arc::clone(&managers);
440                spawn_managed(tasks, "download", async move {
441                    managers_clone
442                        .download_manager
443                        .run(
444                            judge_sub,
445                            semaphore,
446                            sender,
447                            managers_clone.redis_pool.clone(),
448                            managers_clone.db_pool.clone(),
449                        )
450                        .await
451                        .context("Downloading")?;
452                    Ok(())
453                });
454            } else {
455                drop(state_guard);
456                let reject =
457                    RejectedTrack::new(judge_submission.clone(), RejectReason::AlreadyDownloaded);
458                send(Track::Reject(reject), &sender)
459                    .await
460                    .context("sending rejected_tracks")?;
461            }
462        }
463        Track::File(downloaded_file) => {
464            tracing::info!(?downloaded_file, "Downloaded file");
465        }
466        Track::Retry(mut retry_request) => {
467            state.write().await.remove(&retry_request.request.track);
468            if retry_request.retry_attempts >= 1 {
469                let reject = RejectedTrack::new(
470                    retry_request.request,
471                    RejectReason::AbandonedAttemptingSearch,
472                );
473                send(Track::Reject(reject), sender)
474                    .await
475                    .context("rejecting")?;
476                return Ok(());
477            }
478            retry_request.retry_attempts += 1;
479            let managers = Arc::clone(managers);
480            let semaphore = search_semaphore.clone();
481            let sender = Arc::clone(sender);
482            tracing::info!(?retry_request.request, "Retry requested");
483            let search_item = retry_request.request.clone();
484            spawn_managed(tasks, "retry_search", async move {
485                managers
486                    .search_manager
487                    .run(
488                        search_item.track,
489                        managers.search_empty_result_cutoff(),
490                        managers.query_manager_search_timeout(),
491                        true,
492                        semaphore,
493                        sender,
494                    )
495                    .await
496                    .context("returning track")?;
497                Ok(())
498            });
499            tracing::debug!(?retry_request, "Retry queued")
500        }
501        Track::Reject(rejected_track) => {
502            state.write().await.remove(&rejected_track.track.track);
503        }
504    }
505    Ok(())
506}
507
508fn panic_payload(payload: Box<dyn Any + Send>) -> String {
509    if let Some(message) = payload.downcast_ref::<&str>() {
510        (*message).to_string()
511    } else if let Some(message) = payload.downcast_ref::<String>() {
512        message.clone()
513    } else {
514        "unknown panic payload".to_string()
515    }
516}
517
518impl Managers {
519    fn query_manager_search_timeout(&self) -> u8 {
520        self.query_manager.search_timeout_secs
521    }
522
523    fn search_empty_result_cutoff(&self) -> usize {
524        self.search_empty_result_cutoff
525    }
526}
527
528fn spawn_managed<F>(tasks: &mut JoinSet<ManagedTaskResult>, label: &'static str, future: F)
529where
530    F: Future<Output = anyhow::Result<()>> + Send + 'static,
531{
532    tasks.spawn(async move {
533        let result = AssertUnwindSafe(future).catch_unwind().await;
534        let error = match result {
535            Ok(Ok(())) => None,
536            Ok(Err(err)) => Some(format!("{err:?}")),
537            Err(_) => Some("task panicked".to_string()),
538        };
539        ManagedTaskResult { label, error }
540    });
541}
542
543#[cfg(test)]
544mod tests {
545    use super::{ManagedTaskResult, spawn_managed};
546    use tokio::task::JoinSet;
547
548    async fn next_result(tasks: &mut JoinSet<ManagedTaskResult>) -> ManagedTaskResult {
549        tasks
550            .join_next()
551            .await
552            .expect("task result")
553            .expect("task joined")
554    }
555
556    #[tokio::test]
557    async fn managed_task_reports_success_by_label() {
558        let mut tasks = JoinSet::new();
559
560        spawn_managed(&mut tasks, "search", async { Ok(()) });
561
562        let result = next_result(&mut tasks).await;
563        assert_eq!(result.label, "search");
564        assert!(result.error.is_none());
565    }
566
567    #[tokio::test]
568    async fn managed_task_captures_errors() {
569        let mut tasks = JoinSet::new();
570
571        spawn_managed(&mut tasks, "download", async {
572            anyhow::bail!("download failed")
573        });
574
575        let result = next_result(&mut tasks).await;
576        assert_eq!(result.label, "download");
577        assert!(
578            result
579                .error
580                .expect("managed task error")
581                .contains("download failed")
582        );
583    }
584
585    #[tokio::test]
586    async fn managed_task_converts_panics_to_errors() {
587        let mut tasks = JoinSet::new();
588
589        spawn_managed(&mut tasks, "judge", async {
590            panic!("judge panic");
591            #[allow(unreachable_code)]
592            Ok(())
593        });
594
595        let result = next_result(&mut tasks).await;
596        assert_eq!(result.label, "judge");
597        assert_eq!(result.error.as_deref(), Some("task panicked"));
598    }
599}