convert-invert 0.1.0

Orchestrate spotify playlist downloads using soulseek-rs
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
use crate::internals::database::manager::DatabaseManager;
use futures_util::FutureExt;
use serde::{Deserialize, Serialize};
use soulseek_rs::{Client, ClientSettings};
use std::{
    any::Any,
    collections::HashSet,
    future::Future,
    net::TcpListener,
    panic::{AssertUnwindSafe, catch_unwind},
    path::PathBuf,
    sync::Arc,
};
use tokio::sync::{
    Semaphore,
    mpsc::{self, Sender},
};
use tokio::task::JoinSet;
use tracing::instrument;

use anyhow::Context;

use crate::internals::{
    database::db_pool_snapshot,
    download::download_manager::DownloadManager,
    judge::{judge_manager::JudgeManager, judges::levenshtein::Levenshtein},
    query::query_manager::QueryManager,
    search::search_manager::{
        DownloadableFile, JudgeSubmission, SearchExitReason, SearchItem, SearchManager,
    },
    utils::config::config_manager::Config,
};

/// Metadata for a successfully downloaded file.
#[derive(Debug, Serialize, Deserialize)]
pub struct DownloadedFile {
    pub filename: String,
    pub track: SearchItem,
}

struct ManagedTaskResult {
    pub label: &'static str,
    pub error: Option<String>,
}

struct RunCycleShared<'a> {
    managers: &'a Arc<Managers>,
    sender: &'a Arc<Sender<Track>>,
    state: &'a Arc<tokio::sync::RwLock<HashSet<SearchItem>>>,
    search_semaphore: &'a Arc<Semaphore>,
    download_semaphore: &'a Arc<Semaphore>,
}

/// A request to retry a failed search or download.
#[derive(Debug)]
pub struct RetryRequest {
    pub request: JudgeSubmission,
    pub retry_attempts: u8,
    pub failed_download_result: DownloadableFile,
}

/// The various stages and events in a track's lifecycle.
#[derive(Debug)]
pub enum Track {
    /// A new search query to be performed.
    Query(SearchItem),
    /// A relaxed second-pass search query to be performed.
    SearchRetry(SearchItem),
    /// A candidate submission found for a track.
    Result(JudgeSubmission),
    /// A candidate that has been judged and is ready for download.
    Downloadable(JudgeSubmission),
    /// A file that has been successfully downloaded.
    File(DownloadedFile),
    /// A request to retry a failed operation.
    Retry(RetryRequest),
    /// A track that has been rejected for a specific reason.
    Reject(RejectedTrack),
}

/// Metadata for a track that was rejected.
#[derive(Debug, Serialize, Deserialize)]
pub struct RejectedTrack {
    track: JudgeSubmission,
    reason: RejectReason,
}

/// Reasons why a track candidate might be rejected.
#[derive(Debug, Serialize, Deserialize)]
pub enum RejectReason {
    /// The track has already been downloaded successfully.
    AlreadyDownloaded,
    /// The candidate's score was below the required threshold.
    LowScore(f32),
    /// The candidate was identified as non-music or invalid.
    NotMusic(String),
    /// The peer providing the file is banned.
    Banned(String),
    /// All search attempts were exhausted without finding a suitable candidate.
    AbandonedAttemptingSearch,
}

impl RejectedTrack {
    pub fn new(track: JudgeSubmission, reason: RejectReason) -> Self {
        Self { track, reason }
    }

    pub fn parts(&self) -> (&JudgeSubmission, &RejectReason) {
        (&self.track, &self.reason)
    }
}

/// Sends a `Track` event to the provided channel.
pub async fn send(message: Track, chan: &Sender<Track>) -> anyhow::Result<()> {
    chan.send(message).await.context("Send to channel")?;
    Ok(())
}

pub type RedisPool = diesel::r2d2::Pool<redis::Client>;

/// Tuning parameters for a worker's execution.
#[derive(Debug, Clone, Copy)]
pub struct WorkerTuning {
    /// Max in-flight search requests against Soulseek. Soulseek is rate-sensitive;
    /// raise carefully.
    pub search_concurrency: usize,
    /// Max in-flight downloads. Keep this below the host's network and file
    /// descriptor budget.
    pub download_concurrency: usize,
    /// Capacity of the work-distribution channel.
    pub queue_capacity: usize,
}

impl WorkerTuning {
    /// Loads tuning parameters from environment variables with sensible defaults.
    pub fn from_env() -> Self {
        Self {
            search_concurrency: env_usize("SEARCH_CONCURRENCY", 4),
            download_concurrency: env_usize("DOWNLOAD_CONCURRENCY", 7),
            queue_capacity: env_usize("QUEUE_CAPACITY", 20000),
        }
    }
}

fn env_usize(key: &str, default: usize) -> usize {
    std::env::var(key)
        .ok()
        .and_then(|value| value.parse().ok())
        .unwrap_or(default)
}

/// The central container for all service managers and shared state.
pub struct Managers {
    /// The shared Soulseek client.
    pub client: Arc<Client>,
    /// Manager for file downloads.
    pub download_manager: DownloadManager,
    /// Manager for Soulseek searches.
    pub search_manager: SearchManager,
    /// Manager for Spotify playlist queries.
    pub query_manager: QueryManager,
    /// Manager for candidate judging.
    pub judge_manager: JudgeManager,
    /// The PostgreSQL connection pool.
    pub db_pool: crate::internals::database::DbPool,
    /// The Redis connection pool.
    pub redis_pool: RedisPool,
    pub search_empty_result_cutoff: usize,
}

impl Managers {
    /// Creates a new `Managers` instance with the provided configuration.
    pub fn new(
        score: Option<f32>,
        path: PathBuf,
        config: Config,
        db_pool: crate::internals::database::DbPool,
        redis_pool: RedisPool,
    ) -> anyhow::Result<Self> {
        let listen_port = config.listen_port;
        TcpListener::bind(format!("0.0.0.0:{listen_port}"))
            .with_context(|| format!("listener bind preflight failed on port {listen_port}"))?;
        let client_settings = ClientSettings {
            username: config.user_name,
            password: config.user_password,
            listen_port,
            ..Default::default()
        };
        let search_empty_result_cutoff = config.search_empty_result_cutoff;
        let mut client = Client::with_settings(client_settings);
        catch_unwind(AssertUnwindSafe(|| client.connect())).map_err(|payload| {
            anyhow::anyhow!("listener bind panic: {}", panic_payload(payload))
        })?;
        client.login().context("Could not connect")?;
        let client = Arc::new(client);
        let download_manager = DownloadManager::new(client.clone(), path);
        let search_manager = SearchManager::new(client.clone());
        let judge_threshold =
            score.unwrap_or(crate::internals::judge::judge_manager::JUDGE_THRESHOLD);
        let lev_judge = Levenshtein::new(judge_threshold);
        let judge_manager = JudgeManager::new(Box::new(lev_judge), judge_threshold);
        let query_manager = QueryManager::new_with_timeout(
            config.playlist_id,
            config.client_id,
            config.client_secret,
            config.search_timeout_secs,
        );
        Ok(Managers {
            client,
            download_manager,
            search_manager,
            judge_manager,
            query_manager,
            db_pool,
            redis_pool,
            search_empty_result_cutoff,
        })
    }

    /// Runs a single chunk of tracks through the search-judge-download pipeline.
    ///
    /// This method orchestrates the task lifecycle using a `JoinSet` and an internal
    /// message channel. It ensures that all spawned tasks are completed before returning.
    #[instrument(name = "run-chunk", skip(self, tracks))]
    pub async fn run_chunk(
        self: &Arc<Self>,
        tracks: impl IntoIterator<Item = Track>,
    ) -> anyhow::Result<()> {
        let managers = Arc::clone(self);
        let tuning = WorkerTuning::from_env();
        let (sender, mut receiver) = mpsc::channel(tuning.queue_capacity);

        let sender = Arc::new(sender);
        let state = Arc::new(tokio::sync::RwLock::new(HashSet::new()));
        let search_semaphore = Arc::new(Semaphore::new(tuning.search_concurrency));
        let download_semaphore = Arc::new(Semaphore::new(tuning.download_concurrency));
        let snapshot = db_pool_snapshot(&managers.db_pool);
        tracing::info!(
            search_permits = search_semaphore.available_permits(),
            download_permits = download_semaphore.available_permits(),
            db_pool_connections = snapshot.connections,
            db_pool_idle_connections = snapshot.idle_connections,
            db_pool_in_use_connections = snapshot.in_use_connections(),
            "Started run chunk",
        );
        for track in tracks {
            sender.send(track).await.context("injecting tracks")?;
        }
        let mut tasks = JoinSet::new();
        let mut first_task_error: Option<String> = None;

        while !receiver.is_empty() || !tasks.is_empty() {
            tokio::select! {
                maybe_track = receiver.recv(), if !receiver.is_empty() => {
                    let Some(track) = maybe_track else {
                        break;
                    };
                    tracing::debug!(?track, "Incoming package");
                    if first_task_error.is_some() {
                        tracing::debug!(?track, "Dropping queued work after task failure");
                        continue;
                    }
                    let shared = RunCycleShared {
                        managers: &managers,
                        sender: &sender,
                        state: &state,
                        search_semaphore: &search_semaphore,
                        download_semaphore: &download_semaphore,
                    };
                    process_track(track, shared, &mut tasks).await?;
                }
                maybe_result = tasks.join_next(), if !tasks.is_empty() => {
                    let task_result = match maybe_result {
                        Some(Ok(result)) => result,
                        Some(Err(err)) => ManagedTaskResult {
                            label: "unknown",
                            error: Some(format!("managed task join failed: {err}")),
                        },
                        None => continue,
                    };
                    if let Some(error) = task_result.error {
                        let message = format!("{} task failed: {error}", task_result.label);
                        tracing::error!(task_label = task_result.label, %error, "Managed task failed");
                        first_task_error.get_or_insert(message);
                    }
                }
            }
        }
        drop(sender);
        if let Some(error) = first_task_error {
            anyhow::bail!(error);
        }
        let snapshot = db_pool_snapshot(&managers.db_pool);
        tracing::info!(
            db_pool_connections = snapshot.connections,
            db_pool_idle_connections = snapshot.idle_connections,
            db_pool_in_use_connections = snapshot.in_use_connections(),
            "Run chunk finished",
        );
        Ok(())
    }

    pub fn shutdown(self: Arc<Self>) {
        tracing::info!(
            "Managers shutdown requested; soulseek-rs-lib 0.3.0 exposes no public disconnect/logout API",
        );
    }
}

async fn process_track(
    track: Track,
    shared: RunCycleShared<'_>,
    tasks: &mut JoinSet<ManagedTaskResult>,
) -> anyhow::Result<()> {
    let RunCycleShared {
        managers,
        sender,
        state,
        search_semaphore,
        download_semaphore,
    } = shared;

    {
        let mut conn = managers.db_pool.get().map_err(|err| {
            let snapshot = db_pool_snapshot(&managers.db_pool);
            tracing::error!(
                ?err,
                db_pool_connections = snapshot.connections,
                db_pool_idle_connections = snapshot.idle_connections,
                db_pool_in_use_connections = snapshot.in_use_connections(),
                "DB pool in process_track"
            );
            err
        })?;
        let mut database_manager = DatabaseManager::new(&mut conn);
        database_manager
            .load_item_to_database(&track)
            .context("Load into database")?;
    }
    match track {
        Track::Query(search_item) => {
            let managers = Arc::clone(managers);
            let sender = Arc::clone(sender);
            let semaphore = search_semaphore.clone();
            tracing::debug!(?search_item, "Scheduling search");
            spawn_managed(tasks, "search", async move {
                let outcome = managers
                    .search_manager
                    .run(
                        search_item.clone(),
                        managers.search_empty_result_cutoff(),
                        managers.query_manager_search_timeout(),
                        false,
                        semaphore,
                        Arc::clone(&sender),
                    )
                    .await
                    .context("returning track")?;
                if matches!(outcome.exit_reason, SearchExitReason::NoCandidatesFound) {
                    send(Track::SearchRetry(search_item), &sender)
                        .await
                        .context("queue relaxed search")?;
                }
                Ok(())
            });
        }
        Track::SearchRetry(search_item) => {
            let managers = Arc::clone(managers);
            let sender = Arc::clone(sender);
            let semaphore = search_semaphore.clone();
            tracing::debug!(?search_item, "Scheduling relaxed search retry");
            spawn_managed(tasks, "search_retry", async move {
                let outcome = managers
                    .search_manager
                    .run(
                        search_item.clone(),
                        managers.search_empty_result_cutoff(),
                        managers.query_manager_search_timeout(),
                        true,
                        semaphore,
                        sender,
                    )
                    .await
                    .context("returning relaxed track")?;
                if matches!(outcome.exit_reason, SearchExitReason::NoCandidatesFound) {
                    tracing::info!(?search_item, "Relaxed search returned no candidates");
                }
                Ok(())
            });
        }
        Track::Result(judge_submission) => {
            let managers = Arc::clone(managers);
            let sender = Arc::clone(sender);
            spawn_managed(tasks, "judge", async move {
                tracing::debug!(?judge_submission, "Scheduling judge");
                managers
                    .judge_manager
                    .run(judge_submission, sender)
                    .await
                    .context("running judge")
            });
        }
        Track::Downloadable(judge_submission) => {
            let is_downloaded = {
                let mut conn = managers.db_pool.get().map_err(|err| {
                    let snapshot = db_pool_snapshot(&managers.db_pool);
                    tracing::error!(
                        ?err,
                        db_pool_connections = snapshot.connections,
                        db_pool_idle_connections = snapshot.idle_connections,
                        db_pool_in_use_connections = snapshot.in_use_connections(),
                        "DB pool in downloadable check"
                    );
                    err
                })?;
                let mut database_manager = DatabaseManager::new(&mut conn);
                database_manager
                    .is_search_item_downloaded(&judge_submission.track)
                    .context("Check existing downloaded track")?
            };
            if is_downloaded {
                let reject =
                    RejectedTrack::new(judge_submission.clone(), RejectReason::AlreadyDownloaded);
                send(Track::Reject(reject), sender)
                    .await
                    .context("sending already downloaded rejection")?;
                return Ok(());
            }
            let semaphore = download_semaphore.clone();
            let managers = Arc::clone(managers);
            tracing::debug!(?judge_submission, "Scheduling download");
            let judge_sub = judge_submission.clone();
            let sender = Arc::clone(sender);

            let mut state_guard = state.write().await;
            if state_guard.insert(judge_submission.track.clone()) {
                drop(state_guard);

                let managers_clone = Arc::clone(&managers);
                spawn_managed(tasks, "download", async move {
                    managers_clone
                        .download_manager
                        .run(
                            judge_sub,
                            semaphore,
                            sender,
                            managers_clone.redis_pool.clone(),
                            managers_clone.db_pool.clone(),
                        )
                        .await
                        .context("Downloading")?;
                    Ok(())
                });
            } else {
                drop(state_guard);
                let reject =
                    RejectedTrack::new(judge_submission.clone(), RejectReason::AlreadyDownloaded);
                send(Track::Reject(reject), &sender)
                    .await
                    .context("sending rejected_tracks")?;
            }
        }
        Track::File(downloaded_file) => {
            tracing::info!(?downloaded_file, "Downloaded file");
        }
        Track::Retry(mut retry_request) => {
            state.write().await.remove(&retry_request.request.track);
            if retry_request.retry_attempts >= 1 {
                let reject = RejectedTrack::new(
                    retry_request.request,
                    RejectReason::AbandonedAttemptingSearch,
                );
                send(Track::Reject(reject), sender)
                    .await
                    .context("rejecting")?;
                return Ok(());
            }
            retry_request.retry_attempts += 1;
            let managers = Arc::clone(managers);
            let semaphore = search_semaphore.clone();
            let sender = Arc::clone(sender);
            tracing::info!(?retry_request.request, "Retry requested");
            let search_item = retry_request.request.clone();
            spawn_managed(tasks, "retry_search", async move {
                managers
                    .search_manager
                    .run(
                        search_item.track,
                        managers.search_empty_result_cutoff(),
                        managers.query_manager_search_timeout(),
                        true,
                        semaphore,
                        sender,
                    )
                    .await
                    .context("returning track")?;
                Ok(())
            });
            tracing::debug!(?retry_request, "Retry queued")
        }
        Track::Reject(rejected_track) => {
            state.write().await.remove(&rejected_track.track.track);
        }
    }
    Ok(())
}

fn panic_payload(payload: Box<dyn Any + Send>) -> String {
    if let Some(message) = payload.downcast_ref::<&str>() {
        (*message).to_string()
    } else if let Some(message) = payload.downcast_ref::<String>() {
        message.clone()
    } else {
        "unknown panic payload".to_string()
    }
}

impl Managers {
    fn query_manager_search_timeout(&self) -> u8 {
        self.query_manager.search_timeout_secs
    }

    fn search_empty_result_cutoff(&self) -> usize {
        self.search_empty_result_cutoff
    }
}

fn spawn_managed<F>(tasks: &mut JoinSet<ManagedTaskResult>, label: &'static str, future: F)
where
    F: Future<Output = anyhow::Result<()>> + Send + 'static,
{
    tasks.spawn(async move {
        let result = AssertUnwindSafe(future).catch_unwind().await;
        let error = match result {
            Ok(Ok(())) => None,
            Ok(Err(err)) => Some(format!("{err:?}")),
            Err(_) => Some("task panicked".to_string()),
        };
        ManagedTaskResult { label, error }
    });
}

#[cfg(test)]
mod tests {
    use super::{ManagedTaskResult, spawn_managed};
    use tokio::task::JoinSet;

    async fn next_result(tasks: &mut JoinSet<ManagedTaskResult>) -> ManagedTaskResult {
        tasks
            .join_next()
            .await
            .expect("task result")
            .expect("task joined")
    }

    #[tokio::test]
    async fn managed_task_reports_success_by_label() {
        let mut tasks = JoinSet::new();

        spawn_managed(&mut tasks, "search", async { Ok(()) });

        let result = next_result(&mut tasks).await;
        assert_eq!(result.label, "search");
        assert!(result.error.is_none());
    }

    #[tokio::test]
    async fn managed_task_captures_errors() {
        let mut tasks = JoinSet::new();

        spawn_managed(&mut tasks, "download", async {
            anyhow::bail!("download failed")
        });

        let result = next_result(&mut tasks).await;
        assert_eq!(result.label, "download");
        assert!(
            result
                .error
                .expect("managed task error")
                .contains("download failed")
        );
    }

    #[tokio::test]
    async fn managed_task_converts_panics_to_errors() {
        let mut tasks = JoinSet::new();

        spawn_managed(&mut tasks, "judge", async {
            panic!("judge panic");
            #[allow(unreachable_code)]
            Ok(())
        });

        let result = next_result(&mut tasks).await;
        assert_eq!(result.label, "judge");
        assert_eq!(result.error.as_deref(), Some("task panicked"));
    }
}