ibdl_core/async_queue/
mod.rs

1//! Queue used specifically to download, filter and save posts found by an [`Extractor`](ibdl-extractors::websites).
2//!
3//! # Example usage
4//!
5//! Conveniently using the same example from [here](ibdl-extractors::websites)
6//!
7//! ```rust
8//! use imageboard_downloader::*;
9//! use std::path::PathBuf;
10//!
11//! async fn download_posts() {
12//!     let tags = ["umbreon", "espeon"]; // The tags to search
13//!     
14//!     let safe_mode = false; // Setting this to true, will ignore searching NSFW posts
15//!
16//!     let disable_blacklist = false; // Will filter all items according to what's set in GBL
17//!
18//!     let mut unit = DanbooruExtractor::new(&tags, safe_mode, disable_blacklist); // Initialize
19//!
20//!     let prompt = true; // If true, will ask the user to input thei username and API key.
21//!
22//!     unit.auth(prompt).await.unwrap(); // Try to authenticate
23//!
24//!     let start_page = Some(1); // Start searching from the first page
25//!
26//!     let limit = Some(50); // Max number of posts to download
27//!
28//!     let posts = unit.full_search(start_page, limit).await.unwrap(); // and then, finally search
29//!
30//!     let sd = 10; // Number of simultaneous downloads.
31//!
32//!     let limit = Some(1000); // Max number of posts to download
33//!
34//!     let cbz = false; // Set to true to download everything into a .cbz file
35//!
36//!     let mut qw = Queue::new( // Initialize the queue
37//!         ImageBoards::Danbooru,
38//!         posts,
39//!         sd,
40//!         Some(unit.client()), // Re-use the client from the extractor
41//!         limit,
42//!         cbz,
43//!     );
44//!
45//!     let output = Some(PathBuf::from("./")); // Where to save the downloaded files or .cbz file
46//!
47//!     let id = true; // Save file with their ID as the filename instead of MD5
48//!
49//!     qw.download(output, id).await.unwrap(); // Start downloading
50//! }
51//! ```
52
53mod cbz;
54mod folder;
55
56use crate::error::QueueError;
57use crate::progress_bars::ProgressCounter;
58use ibdl_common::log::debug;
59use ibdl_common::post::error::PostError;
60use ibdl_common::post::{NameType, Post};
61use ibdl_common::reqwest::Client;
62use ibdl_common::tokio::spawn;
63use ibdl_common::tokio::sync::mpsc::{channel, Receiver, UnboundedReceiver};
64use ibdl_common::tokio::task::JoinHandle;
65use ibdl_common::{client, tokio};
66use ibdl_extractors::extractor_config::ServerConfig;
67use once_cell::sync::OnceCell;
68use std::path::{Path, PathBuf};
69use std::sync::atomic::{AtomicU64, Ordering};
70use std::sync::Arc;
71use tokio::fs::{create_dir_all, OpenOptions};
72use tokio::io::AsyncWriteExt;
73use tokio_stream::wrappers::UnboundedReceiverStream;
74
75static PROGRESS_COUNTERS: OnceCell<ProgressCounter> = OnceCell::new();
76
77pub(crate) fn get_counters() -> &'static ProgressCounter {
78    PROGRESS_COUNTERS.get().unwrap()
79}
80
81#[derive(Debug, Copy, Clone)]
82enum DownloadFormat {
83    Cbz,
84    CbzPool,
85    Folder,
86    FolderPool,
87}
88
89impl DownloadFormat {
90    #[inline]
91    pub const fn download_cbz(&self) -> bool {
92        match self {
93            Self::Cbz => true,
94            Self::CbzPool => true,
95            Self::Folder => false,
96            Self::FolderPool => false,
97        }
98    }
99
100    #[inline]
101    pub const fn download_pool(&self) -> bool {
102        match self {
103            Self::Cbz => false,
104            Self::CbzPool => true,
105            Self::Folder => false,
106            Self::FolderPool => true,
107        }
108    }
109}
110
111/// Struct where all the downloading will take place
112pub struct Queue {
113    imageboard: ServerConfig,
114    sim_downloads: u8,
115    client: Client,
116    download_fmt: DownloadFormat,
117    name_type: NameType,
118    annotate: bool,
119}
120
121impl Queue {
122    /// Set up the queue for download
123    pub fn new(
124        imageboard: ServerConfig,
125        sim_downloads: u8,
126        custom_client: Option<Client>,
127        save_as_cbz: bool,
128        pool_download: bool,
129        name_type: NameType,
130        annotate: bool,
131    ) -> Self {
132        let client = if let Some(cli) = custom_client {
133            cli
134        } else {
135            client!(imageboard)
136        };
137
138        let download_fmt = if save_as_cbz && pool_download {
139            DownloadFormat::CbzPool
140        } else if save_as_cbz {
141            DownloadFormat::Cbz
142        } else if pool_download {
143            DownloadFormat::FolderPool
144        } else {
145            DownloadFormat::Folder
146        };
147
148        Self {
149            download_fmt,
150            imageboard,
151            sim_downloads,
152            annotate,
153            client,
154            name_type,
155        }
156    }
157
158    pub fn setup_async_downloader(
159        self,
160        output_dir: PathBuf,
161        post_counter: Arc<AtomicU64>,
162        channel_rx: UnboundedReceiver<Post>,
163        length_rx: Receiver<u64>,
164    ) -> JoinHandle<Result<u64, QueueError>> {
165        spawn(async move {
166            debug!("Async Downloader thread initialized");
167
168            let counters = PROGRESS_COUNTERS.get_or_init(|| {
169                ProgressCounter::initialize(
170                    post_counter.load(Ordering::Relaxed),
171                    self.imageboard.server,
172                )
173            });
174
175            self.create_out(&output_dir).await?;
176
177            let post_channel = UnboundedReceiverStream::new(channel_rx);
178            let (progress_sender, progress_channel) = channel(self.sim_downloads as usize);
179
180            counters.init_length_updater(length_rx).await;
181            counters.init_download_counter(progress_channel).await;
182
183            if self.download_fmt.download_cbz() {
184                self.cbz_path(
185                    output_dir,
186                    progress_sender,
187                    post_channel,
188                    self.download_fmt.download_pool(),
189                )
190                .await?;
191            } else {
192                self.download_channel(
193                    post_channel,
194                    progress_sender,
195                    output_dir,
196                    self.download_fmt.download_pool(),
197                )
198                .await;
199            }
200
201            counters.main.finish_and_clear();
202
203            let tot = counters.downloaded_mtx.load(Ordering::SeqCst);
204
205            Ok(tot)
206        })
207    }
208
209    async fn create_out(&self, dir: &Path) -> Result<(), QueueError> {
210        if self.download_fmt.download_cbz() {
211            let output_file = dir.parent().unwrap().to_path_buf();
212
213            match create_dir_all(&output_file).await {
214                Ok(_) => (),
215                Err(error) => {
216                    return Err(QueueError::DirCreationError {
217                        message: error.to_string(),
218                    })
219                }
220            };
221            return Ok(());
222        }
223
224        debug!("Target dir: {}", dir.display());
225        match create_dir_all(&dir).await {
226            Ok(_) => (),
227            Err(error) => {
228                return Err(QueueError::DirCreationError {
229                    message: error.to_string(),
230                })
231            }
232        };
233
234        Ok(())
235    }
236
237    async fn write_caption(
238        post: &Post,
239        name_type: NameType,
240        output: &Path,
241    ) -> Result<(), PostError> {
242        let outpath = output.join(format!("{}.txt", post.name(name_type)));
243        let mut prompt_file = OpenOptions::new()
244            .create(true)
245            .write(true)
246            .open(outpath)
247            .await?;
248
249        let tag_list = Vec::from_iter(
250            post.tags
251                .iter()
252                .filter(|t| t.is_prompt_tag())
253                .map(|tag| tag.tag()),
254        );
255
256        let prompt = tag_list.join(", ");
257
258        let f1 = prompt.replace('_', " ");
259
260        prompt_file.write_all(f1.as_bytes()).await?;
261        debug!("Wrote caption file for {}", post.file_name(name_type));
262        Ok(())
263    }
264}