quickfetch/
lib.rs

1#![doc = include_str!("../README.md")]
2#[macro_use]
3extern crate log;
4use anyhow::Result;
5pub use bincode;
6use bytes::Bytes;
7use futures::future::join_all;
8use futures::StreamExt;
9use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
10use notify::{Config as NConfig, Event, EventKind, RecommendedWatcher, RecursiveMode, Watcher};
11use package::{Config, Mode};
12pub use pretty_env_logger;
13pub use quickfetch_traits as traits;
14use quickfetch_traits::{Entry, EntryKey, EntryValue};
15use reqwest::{Client, Response};
16use serde::Deserialize;
17use sled::Db;
18use std::path::{Path, PathBuf};
19use std::sync::Arc;
20use tokio::fs::create_dir;
21use tokio::sync::mpsc::{channel, Receiver};
22use tokio::sync::Mutex;
23use url::Url;
24/// Provides different types of packages that can be used
25pub mod package;
26/// Provides structures that can be used as a Key and Value for Fetcher
27pub mod val;
28
29/// Provides all the common types to use with Fetcher
30pub mod prelude {
31    pub use crate::package::{Config, GHPackage, Mode, SimplePackage};
32    pub use crate::traits::{Entry, EntryKey, EntryValue};
33    pub use crate::val::{GHValue, SimpleValue};
34    pub use crate::Fetcher;
35}
36
37/// Returns the path to the home directory with the sub directory appended
38pub fn home_plus<P: AsRef<Path>>(sub_dir: P) -> PathBuf {
39    dirs::home_dir().unwrap().join(sub_dir)
40}
41
42/// Returns the path to the config directory with the sub directory appended
43pub fn config_plus<P: AsRef<Path>>(sub_dir: P) -> PathBuf {
44    dirs::config_dir().unwrap().join(sub_dir)
45}
46
47/// Returns the path to the cache directory with the sub directory appended
48pub fn cache_plus<P: AsRef<Path>>(sub_dir: P) -> PathBuf {
49    dirs::cache_dir().unwrap().join(sub_dir)
50}
51
52/// `ResponseMethod` enum to specify the method of fetching the response
53///
54/// - `Bytes`: Fetch the full response using the `bytes` method
55/// - `Chunk`: Fetch the response in chunks using the `chunk` method
56/// - `BytesStream`: Fetch the response in a stream of bytes using the `bytes_stream` method
57#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
58pub enum ResponseMethod {
59    Bytes,
60    Chunk,
61    BytesStream,
62}
63
64#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
65pub enum NotifyMethod {
66    Log,
67    ProgressBar,
68    Silent,
69}
70
71impl Default for NotifyMethod {
72    fn default() -> Self {
73        Self::Log
74    }
75}
76
77/// `FetchMethod` enum to specify the method of fetching the response
78///
79/// - `Async`: Fetch the response asynchronously using `tokio::spawn`
80/// - `Channel`: Fetch the response using a bounded multi-producer single-consumer channel
81#[derive(Debug, Copy, Clone)]
82pub enum FetchMethod {
83    Async,
84    Watch,
85    #[cfg(feature = "unstable")]
86    Sync,
87}
88
89impl Default for FetchMethod {
90    fn default() -> Self {
91        Self::Async
92    }
93}
94
95impl Default for ResponseMethod {
96    fn default() -> Self {
97        Self::Bytes
98    }
99}
100
101/// Fetcher struct that will be used to fetch and cache data
102///
103/// - `entries`: List of entries to fetch
104/// - `db`: sled db to cache the fetched data
105/// - `client`: reqwest client to fetch the data
106/// - `response_method`: Method of fetching the response
107/// - `encryption_method`: Method of encrypting and decrypting the response
108#[derive(Debug, Clone)]
109pub struct Fetcher<E: Entry> {
110    /// List of entries to fetch
111    entries: Arc<Vec<E>>,
112    /// Path to the config file
113    config_path: PathBuf,
114    /// Config struct to hold the configuration
115    config: Config<E>,
116    /// Type of config file (json or toml)
117    config_type: Mode,
118    /// sled db to cache the fetched data
119    db: Db,
120    /// Path to the db file
121    db_path: PathBuf,
122    /// reqwest client to fetch the data
123    client: Client,
124    /// Method of fetching the response
125    response_method: ResponseMethod,
126    /// Method of notifying the user
127    notify_method: NotifyMethod,
128    /// Multi progress bar to show multiple progress bars
129    multi_pb: Arc<MultiProgress>,
130}
131
132// Constructor and Setup Methods
133impl<E: Entry + Clone + Send + Sync + 'static + for<'de> Deserialize<'de>> Fetcher<E> {
134    /// Create a new `Fetcher` instance with list of urls and db path
135    pub async fn new<P: AsRef<Path> + Send + Sync>(
136        config_path: P,
137        config_type: Mode,
138        db_path: P,
139    ) -> Result<Self> {
140        let client = Client::builder()
141            .brotli(true) // by default enable brotli decompression
142            .build()?;
143
144        let config = Config::from_file(&config_path, config_type).await?;
145        let entries = config.packages_owned();
146
147        Ok(Self {
148            entries: Arc::new(entries),
149            db: sled::open(&db_path)?,
150            db_path: PathBuf::from(db_path.as_ref()),
151            config,
152            config_path: config_path.as_ref().to_path_buf(),
153            config_type,
154            client,
155            response_method: ResponseMethod::default(),
156            notify_method: NotifyMethod::Log,
157            multi_pb: Arc::new(MultiProgress::new()),
158        })
159    }
160
161    #[cfg(feature = "unstable")]
162    /// Create a new `Fetcher` instance with list of urls and db path synchronously
163    pub fn new_sync<P: AsRef<Path> + Send + Sync>(
164        config_path: P,
165        config_type: Mode,
166        db_path: P,
167    ) -> Result<Self> {
168        futures::executor::block_on(Self::new(config_path, config_type, db_path))
169    }
170
171    /// Set the client to be used for fetching the data
172    ///
173    /// This is useful when you want to use a custom client with custom settings
174    /// when using `Client::builder()`
175    pub fn set_client(&mut self, client: Client) {
176        self.client = client;
177    }
178
179    /// Set the response method to be used for fetching the response
180    ///
181    /// By default `self.response_method = ResponseMethod::Bytes`
182    ///
183    /// - `Bytes`: Fetch the full response using the `bytes` method
184    /// - `Chunk`: Fetch the response in chunks using the `chunk` method
185    /// - `BytesStream`: Fetch the response in a stream of bytes using the `bytes_stream` method
186    ///
187    /// If `response_method` is `Chunk` or `BytesStream`, then `notify_method` is set to `ProgressBar`
188    /// and a `MultiProgress` instance is created
189    pub fn set_response_method(&mut self, response_method: ResponseMethod) {
190        self.response_method = response_method;
191    }
192
193    /// Set the notify method to be used for notifying the user
194    /// By default `self.notify_method = NotifyMethod::Log`
195    pub fn set_notify_method(&mut self, notify_method: NotifyMethod) {
196        self.notify_method = notify_method;
197        if notify_method == NotifyMethod::ProgressBar {
198            assert!(
199                (self.response_method == ResponseMethod::BytesStream)
200                    || (self.response_method == ResponseMethod::Chunk)
201            )
202        }
203    }
204}
205
206// Database Operations
207impl<E: Entry + Clone + Send + Sync + 'static> Fetcher<E> {
208    /// Removes the `db` directory (use with caution as this uses `tokio::fs::remove_dir_all`)
209    pub async fn remove_db_dir(&self) -> Result<()> {
210        tokio::fs::remove_dir_all(&self.db_path).await?;
211        Ok(())
212    }
213
214    /// Remove the db and all its trees
215    pub fn remove_db_trees(&self) -> Result<()> {
216        let trees = self.db.tree_names();
217        for tree in trees {
218            self.db.drop_tree(tree)?;
219        }
220        Ok(())
221    }
222
223    /// Remove current tree
224    pub fn remove_tree(&self) -> Result<()> {
225        let tree = self.db.name();
226        self.db.drop_tree(tree)?;
227        Ok(())
228    }
229
230    /// Export the db to a vector of key value pairs and an iterator of values
231    ///
232    /// > Useful when needing to migrate the db from an older version to a newer version
233    pub fn export(&self) -> Vec<(Vec<u8>, Vec<u8>, impl Iterator<Item = Vec<Vec<u8>>> + Sized)> {
234        self.db.export()
235    }
236
237    /// Import the db from a vector of key value pairs and an iterator of values
238    ///
239    /// > Useful when needing to migrate the db from an older version to a newer version
240    pub fn import(
241        &self,
242        export: Vec<(Vec<u8>, Vec<u8>, impl Iterator<Item = Vec<Vec<u8>>> + Sized)>,
243    ) {
244        self.db.import(export)
245    }
246
247    pub fn clear(&self) -> Result<()> {
248        self.db.clear()?;
249        Ok(())
250    }
251}
252
253// Handles and Fetching Entries
254impl<E: Entry + Clone + Send + Sync + 'static + for<'de> Deserialize<'de>> Fetcher<E> {
255    async fn resp_bytes(&self, response: Response, name: String) -> Result<Bytes> {
256        let len = response.content_length().unwrap_or(0);
257        let style = ProgressStyle::default_spinner()
258            .template("[{msg}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
259            .unwrap()
260            .progress_chars("#>-");
261
262        match &self.response_method {
263            ResponseMethod::Bytes => {
264                let bytes = response.bytes().await?;
265                Ok(bytes)
266            }
267            ResponseMethod::BytesStream => {
268                let mut stream = response.bytes_stream();
269                let mut bytes = bytes::BytesMut::new();
270
271                let pb = self.multi_pb.add(ProgressBar::new(len));
272                pb.set_style(style.clone());
273                pb.set_message(name.clone());
274                let mut downloaded: u64 = 0;
275
276                while let Some(item) = stream.next().await {
277                    let b = item?;
278                    downloaded += b.len() as u64;
279                    bytes.extend_from_slice(&b);
280                    pb.set_position(downloaded);
281                }
282                pb.finish();
283
284                Ok(bytes.freeze())
285            }
286            ResponseMethod::Chunk => {
287                let mut bytes = bytes::BytesMut::new();
288                let mut response = response;
289
290                let pb = self.multi_pb.add(ProgressBar::new(len));
291                pb.set_style(style.clone());
292                pb.set_message(name.clone());
293                let mut downloaded: u64 = 0;
294                while let Some(chunk) = response.chunk().await? {
295                    downloaded += chunk.len() as u64;
296                    bytes.extend_from_slice(&chunk);
297                    pb.set_position(downloaded);
298                }
299                pb.finish();
300                Ok(bytes.freeze())
301            }
302        }
303    }
304
305    /// Enables to fetch packages in a watching state from a config file
306    /// The config file is watched for changes and the packages are fetched
307    ///
308    /// The fetching method is `Async` and the notification method is `log`
309    pub async fn watching(&mut self) {
310        info!("Watching {}", &self.config_path.display());
311        if let Err(e) = self.watch().await {
312            error!("Error: {:?}", e)
313        }
314    }
315
316    async fn watcher() -> notify::Result<(RecommendedWatcher, Receiver<notify::Result<Event>>)> {
317        let (tx, rx) = channel(1);
318        let watcher = RecommendedWatcher::new(
319            move |res| {
320                let tx = tx.clone();
321                tokio::spawn(async move {
322                    if let Err(e) = tx.clone().send(res).await {
323                        eprintln!("Error sending event: {}", e);
324                    }
325                });
326            },
327            NConfig::default(),
328        )?;
329        Ok((watcher, rx))
330    }
331
332    async fn watch(&mut self) -> notify::Result<()> {
333        self.notify_method = NotifyMethod::Log;
334        let (mut watcher, mut rx) = Self::watcher().await?;
335        watcher.watch(&self.config_path, RecursiveMode::Recursive)?;
336
337        while let Some(res) = rx.recv().await {
338            match res {
339                Ok(event) => self.handle_event(event).await.expect("Event handle error"),
340                Err(e) => error!("Watch error: {:?}", e),
341            }
342        }
343        Ok(())
344    }
345
346    async fn handle_event(&mut self, event: Event) -> anyhow::Result<()> {
347        info!("Event: {:?}", event.kind);
348        match event.kind {
349            EventKind::Modify(_) => {
350                self.config = Config::from_file(&self.config_path, self.config_type).await?;
351                self.async_fetch().await?;
352            }
353            EventKind::Remove(_) => {
354                info!("Removed {}", &self.config_path.display());
355                info!("Clearing DB");
356                self.db.clear().unwrap();
357            }
358            _ => debug!("Other event type"),
359        }
360        Ok(())
361    }
362
363    async fn handle_entry(&self, entry: E) -> Result<()> {
364        let key = entry.key();
365        let mut value = entry.value();
366
367        if let Some(curr_val) = self.db.get(key.bytes())? {
368            let cv_bytes = curr_val.to_vec();
369            let cv = E::Value::from_ivec(curr_val);
370            if !value.is_same(&cv) {
371                if self.notify_method == NotifyMethod::Log {
372                    key.log_caching();
373                }
374                let response = self.client.get(value.url()).send().await?;
375                let bytes = self.resp_bytes(response, key.to_string()).await?;
376                value.set_response(&bytes);
377                let _ =
378                    self.db
379                        .compare_and_swap(key.bytes(), Some(cv_bytes), Some(value.bytes()))?;
380            } else if self.notify_method == NotifyMethod::Log {
381                key.log_cache();
382            }
383        } else {
384            if self.notify_method == NotifyMethod::Log {
385                key.log_caching();
386            }
387            let response = self.client.get(value.url()).send().await?;
388            let bytes = self.resp_bytes(response, key.to_string()).await?;
389            value.set_response(bytes.as_ref());
390            let _ = self.db.insert(key.bytes(), value.bytes())?;
391        }
392        Ok(())
393    }
394
395    #[cfg(feature = "unstable")]
396    fn handle_entry_sync(&self, entry: E) -> Result<()> {
397        futures::executor::block_on(self.handle_entry(entry))?;
398        Ok(())
399    }
400
401    /// Fetches and stores all results to the db
402    pub async fn async_fetch(&mut self) -> Result<()> {
403        let mut tasks = Vec::new();
404        for entry in (*self.entries).clone() {
405            let fetcher = self.clone();
406            tasks.push(tokio::spawn(async move {
407                fetcher.handle_entry(entry.clone()).await
408            }));
409        }
410
411        join_all(tasks).await.into_iter().try_for_each(|x| x?)?;
412
413        Ok(())
414    }
415
416    #[cfg(feature = "unstable")]
417    /// Fetches and stores all results to the db synchronously and in parallel
418    pub fn sync_fetch(&mut self) -> Result<()> {
419        let entries = self.entries.clone();
420
421        let results: Vec<Result<()>> = entries
422            .par_iter()
423            .map(|entry| self.handle_entry_sync(entry.clone()))
424            .collect();
425
426        results.into_iter().try_for_each(|x| x)?;
427
428        Ok(())
429    }
430
431    pub async fn fetch(&mut self, method: FetchMethod) -> Result<()> {
432        match method {
433            FetchMethod::Async => self.async_fetch().await?,
434            FetchMethod::Watch => self.watching().await,
435            #[cfg(feature = "unstable")]
436            FetchMethod::Sync => {
437                println!("Please use sync_fetch() for synchronous operations.")
438                // Honestly not sure how you would get here but here's a message I guess
439            }
440        }
441        Ok(())
442    }
443
444    /// Returns all entries in the db as a vector of key-value pairs
445    pub fn pairs<K: EntryKey, V: EntryValue>(&self) -> Result<Vec<(K, V)>> {
446        self.db
447            .iter()
448            .map(|x| {
449                let (key_iv, value_iv) = x.unwrap();
450                let key = K::from_ivec(key_iv);
451                Ok((key, V::from_ivec(value_iv)))
452            })
453            .collect()
454    }
455
456    /// Gets an entry from the db by key
457    pub fn get<K: EntryKey, V: EntryValue>(&self, key: K) -> Result<Option<V>> {
458        if let Some(value_iv) = self.db.get(key.bytes())? {
459            let value = V::from_ivec(value_iv);
460            Ok(Some(value))
461        } else {
462            Ok(None)
463        }
464    }
465
466    /// Removes an entry from the db by key
467    pub fn remove<K: EntryKey>(&self, key: K) -> Result<()> {
468        self.db.remove(key.bytes())?;
469        Ok(())
470    }
471
472    /// Updates an entry in the db by key and new value
473    pub fn update<K: EntryKey, V: EntryValue>(&self, key: K, value: V) -> Result<()> {
474        if let Some(curr_val) = self.db.get(key.bytes())? {
475            let cv_bytes = curr_val.to_vec();
476            let cv = V::from_ivec(curr_val);
477            if !value.is_same(&cv) {
478                let _ =
479                    self.db
480                        .compare_and_swap(key.bytes(), Some(cv_bytes), Some(value.bytes()))?;
481            }
482        }
483        Ok(())
484    }
485
486    /// Writes all the fetched data to the specified directory
487    pub async fn write_all(&self, dir: PathBuf) -> Result<()> {
488        let total_entries = self.entries.len();
489        let progress_bar = Arc::new(Mutex::new(ProgressBar::new(total_entries as u64)));
490        progress_bar.lock().await.set_style(
491            ProgressStyle::default_bar()
492                .template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {wide_msg}")
493                .unwrap()
494                .progress_chars("##-"),
495        );
496
497        let mut tasks = Vec::new();
498        for entry in (*self.entries).clone() {
499            let key = entry.key();
500            let value_vec = self.db.get(key.bytes())?.unwrap().to_vec();
501            let value: E::Value = E::Value::from_bytes(&value_vec);
502            let resp = value.response();
503            let file_name = Url::parse(&value.url())?
504                .path_segments()
505                .unwrap()
506                .last()
507                .unwrap()
508                .to_string();
509            let path = dir.join(&file_name);
510            if !dir.exists() {
511                create_dir(&dir).await?;
512            }
513            let bytes = resp.to_vec();
514            let pb_clone = Arc::clone(&progress_bar);
515            tasks.push(tokio::spawn(async move {
516                pb_clone
517                    .lock()
518                    .await
519                    .set_message(format!("Writing: {}", file_name));
520                let result = tokio::fs::write(&path, bytes).await;
521                pb_clone.lock().await.inc(1);
522                result
523            }));
524        }
525
526        let results = join_all(tasks).await;
527        progress_bar
528            .lock()
529            .await
530            .finish_with_message("All files written");
531
532        results.into_iter().try_for_each(|x| x?)?;
533        Ok(())
534    }
535}