hibp_downloader/
tasks.rs

1use std::{
2    io,
3    path::PathBuf,
4    sync::{atomic, Arc},
5    time::Duration,
6};
7
8use tokio::{
9    sync::{mpsc::Receiver, mpsc::Sender, Semaphore},
10    task::JoinSet,
11};
12use tracing::Span;
13use tracing_indicatif::span_ext::IndicatifSpanExt;
14
15use crate::{
16    buffered_string_writer::BufferedStringWriter,
17    consts::{BEGIN, END},
18    download::download_prefix,
19    stats::DOWNLOADED,
20};
21
22pub async fn writer_task(
23    mut rx: Receiver<(u32, String)>,
24    output_file: PathBuf,
25) -> Result<(), io::Error> {
26    let mut file = BufferedStringWriter::from_file(&output_file).await?;
27
28    while let Some(rows) = rx.recv().await {
29        file.add_file(rows).await?;
30    }
31
32    file.flush(false).await?;
33    file.inner_flush().await?;
34    Ok::<(), io::Error>(())
35}
36
37pub async fn progress_task() {
38    let span = Span::current();
39    loop {
40        tokio::time::sleep(Duration::from_millis(100)).await;
41        span.pb_set_position(DOWNLOADED.load(atomic::Ordering::Acquire));
42    }
43}
44
45pub async fn download_task(
46    client: reqwest::Client,
47    concurrent_requests: usize,
48    tx: Sender<(u32, String)>,
49    ntlm: bool,
50) -> anyhow::Result<()> {
51    let mut handles = JoinSet::new();
52    let semaphore = Arc::new(Semaphore::new(concurrent_requests));
53    for n in BEGIN..=END {
54        let client = client.clone();
55        let tx = tx.clone();
56        let semaphore = Arc::clone(&semaphore);
57
58        handles.spawn(async move {
59            let _permit = semaphore.acquire().await?;
60            tx.send(download_prefix(&client, n, ntlm).await?).await?;
61            Ok::<(), anyhow::Error>(())
62        });
63    }
64    drop(tx);
65    while let Some(res) = handles.join_next().await {
66        res??;
67    }
68
69    Ok(())
70}