1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
//! A micro-library for downloading from a URL and streaming it directly to the disk
//!
//! # Getting Started
//!
//! ```rust
//! use std::path::Path;
//! use tokio_dl_stream_to_disk::AsyncDownload;
//!
//! #[tokio::main]
//! async fn main() {
//!     if AsyncDownload::new("https://bit.ly/3yWXSOW", &Path::new("/tmp"), "5mb_test.bin").download(&None).await.is_ok() {
//!         println!("File downloaded successfully!");
//!     }
//! }
//! ```

pub mod error;

use std::convert::TryInto;
use std::error::Error;
use std::io::{Error as IOError, ErrorKind};
use std::path::{Path, PathBuf};

use bytes::Bytes;
use futures_util::stream::Stream;
use futures_util::StreamExt;

#[cfg(feature="sha256sum")]
use sha2::{Sha256, Digest};
use tokio_util::io::StreamReader;

use crate::error::{Error as TDSTDError, ErrorKind as TDSTDErrorKind};

type S = dyn Stream<Item = Result<Bytes, IOError>> + Unpin;

/// The AsyncDownload struct allows you to stream the contents of a download to the disk.
pub struct AsyncDownload {
    url: String,
    dst_path: PathBuf,
    fname: String,
    length: Option<u64>,
    response_stream: Option<Box<S>>
}

impl AsyncDownload {
    /// Returns an AsyncDownload struct with the url, destination on disk and filename specified.
    ///
    /// # Arguments
    ///
    /// * `url` - A string type containing the URL you want to download the contents of
    /// * `dst_path` - A PathBuf type containing the destination path
    /// * `fname` - A string type containing the filename of the download
    pub fn new(url: &str, dst_path: &Path, fname: &str) -> Self {
        Self {
            url: String::from(url),
            dst_path: PathBuf::from(dst_path),
            fname: String::from(fname),
            length: None,
            response_stream: None
        }
    }

    /// Returns the length of the download in bytes.  This should be called after calling [`get`]
    /// or [`download`].
    pub fn length(&self) -> Option<u64> {
       self.length 
    }

    /// Get the download URL, but do not download it.  If successful, returns an `AsyncDownload`
    /// object with a response stream, which you can then call [`download`] on.  After this, the
    /// length of the download should also be known and you can call [`length`] on it.
    pub async fn get(mut self) -> Result<AsyncDownload, Box<dyn Error>> {
        self.get_non_consumable().await?;
        Ok(self)
    }

    async fn get_non_consumable(&mut self) -> Result<(), Box<dyn Error>> {
        let response = reqwest::get(self.url.clone())
            .await?;
        let content_length = response.headers().get("content-length").map_or(None, 
            |l| {
                match l.to_str() {
                    Err(_) => None,
                    Ok(l_str) => {
                        l_str.parse::<u64>().ok()
                    }
                }
            });
        self.response_stream = Some(Box::new(response
            .error_for_status()?
            .bytes_stream()
            .map(|result| result.map_err(|e| IOError::new(ErrorKind::Other, e)))));
        self.length = content_length;
        Ok(())
    }

    /// Initiate the download and return a result.  Specify an optional callback.
    ///
    /// Arguments:
    /// * `cb` - An optional callback for reporting information about the download asynchronously.
    /// The callback takes the position of the current download, in bytes.
    pub async fn download(&mut self, cb: &Option<Box<dyn Fn(u64) -> ()>>) -> Result<(), TDSTDError> {
        if self.response_stream.is_none() {
            self.get_non_consumable().await.map_err(|_| TDSTDError::new(TDSTDErrorKind::InvalidResponse))?;
        }
        use tokio::io::{AsyncReadExt, AsyncWriteExt};

        let fname = self.dst_path.join(self.fname.clone());
        if fname.is_file() {
            return Err(TDSTDError::new(TDSTDErrorKind::FileExists));
        }

        if self.dst_path.is_dir() {
            let mut http_async_reader = StreamReader::new(self.response_stream.take().unwrap());

            let mut dest = tokio::fs::File::create(fname).await?;
            let mut buf = [0; 8 * 1024];
            let mut num_bytes_total = 0;
            loop {
                let num_bytes = http_async_reader.read(&mut buf).await?;
                if let Some(ref cb) = cb {
                    num_bytes_total += num_bytes;
                    cb(num_bytes_total.try_into().unwrap());
                }
                if num_bytes > 0 {
                    dest.write(&mut buf[0..num_bytes]).await?;
                } else {
                    break;
                }
            }
            Ok(())
        } else {
            Err(TDSTDError::new(TDSTDErrorKind::DirectoryMissing))
        }
    }

    #[cfg(feature="sha256sum")]
    /// Initiate the download and return a result with the sha256sum of the download contents.
    /// Specify an optional callback.
    ///
    /// Arguments:
    /// * `cb` - An optional callback for reporting information about the download asynchronously.
    /// The callback takes the position of the current download, in bytes.
    pub async fn download_and_return_sha256sum(&mut self, cb: &Option<Box<dyn Fn(u64) -> ()>>) -> Result<Vec<u8>, TDSTDError> {
        use tokio::io::{AsyncReadExt, AsyncWriteExt};

        let fname = self.dst_path.join(self.fname.clone());
        if fname.is_file() {
            return Err(TDSTDError::new(TDSTDErrorKind::FileExists));
        }

        if self.dst_path.is_dir() {
            let mut http_async_reader = StreamReader::new(self.response_stream.take().unwrap());

            let mut dest = tokio::fs::File::create(fname).await?;
            let mut buf = [0; 8 * 1024];
            let mut num_bytes_total = 0;
            let mut hasher = Sha256::new();
            loop {
                let num_bytes = http_async_reader.read(&mut buf).await?;
                if let Some(ref cb) = cb {
                    num_bytes_total += num_bytes;
                    cb(num_bytes_total.try_into().unwrap());
                }
                if num_bytes > 0 {
                    dest.write(&mut buf[0..num_bytes]).await?;
                    hasher.update(&buf[0..num_bytes]);
                } else {
                    break;
                }
            }
            Ok(hasher.finalize().to_vec())
        } else {
            Err(TDSTDError::new(TDSTDErrorKind::DirectoryMissing))
        }
    }
}