Skip to main content

gdown_core/
download.rs

1//! Core download logic for Google Drive files
2
3use crate::error::{GdownError, Result};
4use crate::url::{parse_url, build_download_url, FileId};
5use futures_util::stream::StreamExt;
6use reqwest::Client;
7use std::path::{Path, PathBuf};
8use std::time::Duration;
9
10/// Default user agent string
11const DEFAULT_USER_AGENT: &str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36";
12
13/// Options for download operations
14pub struct DownloadOptions {
15    /// Speed limit in bytes per second (None = unlimited)
16    pub speed_limit: Option<u64>,
17    /// Enable resume mode
18    pub resume: bool,
19    /// Export format for Google Docs/Sheets/Slides
20    pub format: Option<String>,
21    /// Progress callback (bytes_downloaded, total_bytes)
22    #[allow(clippy::type_complexity)]
23    pub progress_callback: Option<Box<dyn Fn(u64, Option<u64>) + Send + 'static>>,
24}
25
26impl Clone for DownloadOptions {
27    fn clone(&self) -> Self {
28        Self {
29            speed_limit: self.speed_limit,
30            resume: self.resume,
31            format: self.format.clone(),
32            progress_callback: None,  // Cannot clone fn pointers
33        }
34    }
35}
36
37impl std::fmt::Debug for DownloadOptions {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("DownloadOptions")
40            .field("speed_limit", &self.speed_limit)
41            .field("resume", &self.resume)
42            .field("format", &self.format)
43            .field("progress_callback", &"...")
44            .finish()
45    }
46}
47
48/// Downloader client for Google Drive
49#[derive(Clone)]
50pub struct Downloader {
51    proxy: Option<String>,
52    user_agent: String,
53    verify_ssl: bool,
54    cookies_path: PathBuf,
55}
56
57impl Downloader {
58    /// Create a new Downloader with default settings
59    pub fn new() -> Self {
60        Self {
61            proxy: None,
62            user_agent: DEFAULT_USER_AGENT.to_string(),
63            verify_ssl: true,
64            cookies_path: PathBuf::from("~/.cache/gdown/cookies.txt"),
65        }
66    }
67
68    /// Set proxy URL
69    pub fn proxy(mut self, proxy: &str) -> Self {
70        self.proxy = Some(proxy.to_string());
71        self
72    }
73
74    /// Set user agent
75    pub fn user_agent(mut self, ua: &str) -> Self {
76        self.user_agent = ua.to_string();
77        self
78    }
79
80    /// Set SSL verification
81    pub fn verify_ssl(mut self, verify: bool) -> Self {
82        self.verify_ssl = verify;
83        self
84    }
85
86    /// Set cookies path
87    pub fn cookies_path(mut self, path: &Path) -> Self {
88        self.cookies_path = path.to_path_buf();
89        self
90    }
91
92    /// Build the client with current settings
93    pub fn build_client(&self) -> Client {
94        let mut builder = Client::builder()
95            .user_agent(&self.user_agent)
96            .timeout(Duration::from_secs(60));
97
98        if !self.verify_ssl {
99            builder = builder.danger_accept_invalid_certs(true);
100        }
101
102        if let Some(proxy) = &self.proxy {
103            if proxy.starts_with("socks5://") {
104                builder = builder.proxy(reqwest::Proxy::all(proxy).unwrap());
105            } else {
106                builder = builder.proxy(reqwest::Proxy::http(proxy).unwrap());
107            }
108        }
109
110        builder.build().unwrap_or_else(|_| {
111            Client::builder()
112                .user_agent(&self.user_agent)
113                .build()
114                .expect("Failed to create HTTP client")
115        })
116    }
117
118    /// Download a file from URL
119    pub async fn download(
120        &self,
121        url: &str,
122        output: &Path,
123        options: DownloadOptions,
124    ) -> Result<u64> {
125        let client = self.build_client();
126
127        // Parse URL to get file ID
128        let (file_id, is_download_link) = parse_url(url)?;
129        let file_id = file_id.ok_or_else(|| GdownError::InvalidUrl("No file ID found".into()))?;
130
131        // Build initial request URL
132        let request_url = if is_download_link {
133            build_download_url(&file_id)
134        } else {
135            // Check if it's a Google Doc type and needs export
136            if let Some(format) = options.format.clone() {
137                if url.contains("document") {
138                    return self.download_doc_export(&file_id, &format, output, options).await;
139                } else if url.contains("spreadsheet") {
140                    return self.download_sheet_export(&file_id, &format, output, options).await;
141                } else if url.contains("presentation") {
142                    return self.download_slides_export(&file_id, &format, output, options).await;
143                }
144            }
145            build_download_url(&file_id)
146        };
147
148        // Make initial request
149        let response = client.get(&request_url).send().await.map_err(|e| GdownError::Download(e.to_string()))?;
150
151        // Check content type
152        let content_type = response
153            .headers()
154            .get("Content-Type")
155            .and_then(|v| v.to_str().ok())
156            .unwrap_or("");
157
158        if content_type.contains("text/html") {
159            // Need to handle confirmation page
160            let html = response.text().await.map_err(|e| GdownError::Download(e.to_string()))?;
161            let actual_url = self.extract_confirmation_url(&html).await?;
162
163            // Resume if requested and partial file exists
164            if options.resume && output.exists() {
165                return self.resume_download(&actual_url, output, options).await;
166            }
167
168            return self.download_file(&actual_url, output, options).await;
169        }
170
171        // Resume if requested and partial file exists
172        if options.resume && output.exists() {
173            return self.resume_download(&request_url, output, options).await;
174        }
175
176        // Download directly
177        self.download_file(&request_url, output, options).await
178    }
179
180    /// Download Google Document (Docs) with export format
181    async fn download_doc_export(
182        &self,
183        file_id: &FileId,
184        format: &str,
185        output: &Path,
186        options: DownloadOptions,
187    ) -> Result<u64> {
188        let url = format!(
189            "https://docs.google.com/document/d/{}/export?format={}",
190            file_id, format
191        );
192        self.download_file(&url, output, options).await
193    }
194
195    /// Download Google Spreadsheet with export format
196    async fn download_sheet_export(
197        &self,
198        file_id: &FileId,
199        format: &str,
200        output: &Path,
201        options: DownloadOptions,
202    ) -> Result<u64> {
203        let url = format!(
204            "https://docs.google.com/spreadsheets/d/{}/export?format={}",
205            file_id, format
206        );
207        self.download_file(&url, output, options).await
208    }
209
210    /// Download Google Slides with export format
211    async fn download_slides_export(
212        &self,
213        file_id: &FileId,
214        format: &str,
215        output: &Path,
216        options: DownloadOptions,
217    ) -> Result<u64> {
218        let url = format!(
219            "https://docs.google.com/presentation/d/{}/export?format={}",
220            file_id, format
221        );
222        self.download_file(&url, output, options).await
223    }
224
225    /// Extract actual download URL from confirmation page HTML
226    async fn extract_confirmation_url(&self, html: &str) -> Result<String> {
227        use regex::Regex;
228
229        // Try to find form action URL
230        let form_regex = Regex::new(r#"action="([^"]+)""#).unwrap();
231        if let Some(caps) = form_regex.captures(html) {
232            let action = caps.get(1).unwrap().as_str();
233
234            // Build POST request to confirmation URL
235            if action.contains("confirm") {
236                let client = self.build_client();
237
238                // Try to find confirmation token
239                let token_regex = Regex::new(r#"name="confirm".*?value="([^"]+)""#).unwrap();
240                let token = token_regex.captures(html).and_then(|c| c.get(1)).map(|m| m.as_str());
241
242                let mut request = client.post(action);
243                if let Some(t) = token {
244                    request = request.form(&[("confirm", t)]);
245                }
246
247                let response = request.send().await.map_err(|e| GdownError::Download(e.to_string()))?;
248
249                if let Some(location) = response.headers().get("Location") {
250                    return Ok(location.to_str().unwrap_or(action).to_string());
251                }
252            }
253
254            return Ok(action.to_string());
255        }
256
257        // Fallback: try to find downloadUrl in JavaScript
258        let download_url_regex = Regex::new(r#"downloadUrl\s*:\s*"([^"]+)""#).unwrap();
259        if let Some(caps) = download_url_regex.captures(html) {
260            return Ok(caps.get(1).unwrap().as_str().to_string());
261        }
262
263        Err(GdownError::FileUrlRetrieval("Could not find download URL in confirmation page".into()))
264    }
265
266    /// Download file content to output path
267    async fn download_file(
268        &self,
269        url: &str,
270        output: &Path,
271        options: DownloadOptions,
272    ) -> Result<u64> {
273        use tokio::io::AsyncWriteExt;
274
275        let client = self.build_client();
276        let response = client.get(url).send().await.map_err(|e| GdownError::Download(e.to_string()))?;
277
278        let total_size = response.content_length();
279        let mut file = tokio::fs::File::create(output).await?;
280        let mut downloaded: u64 = 0;
281
282        let mut stream = response.bytes_stream();
283        while let Some(chunk_result) = stream.next().await {
284            let chunk = chunk_result.map_err(|e| GdownError::Download(e.to_string()))?;
285            file.write_all(&chunk).await?;
286            downloaded += chunk.len() as u64;
287
288            // Call progress callback
289            if let Some(ref cb) = options.progress_callback {
290                cb(downloaded, total_size);
291            }
292
293            // Speed limiting
294            if let Some(limit) = options.speed_limit {
295                let expected_time = (downloaded as f64 / limit as f64 * 1000.0) as u64;
296                tokio::time::sleep(std::time::Duration::from_millis(expected_time)).await;
297            }
298        }
299
300        Ok(downloaded)
301    }
302
303    /// Resume a partially downloaded file
304    async fn resume_download(
305        &self,
306        url: &str,
307        output: &Path,
308        options: DownloadOptions,
309    ) -> Result<u64> {
310        use tokio::io::AsyncWriteExt;
311
312        let existing_size = tokio::fs::metadata(output).await?.len();
313        let client = self.build_client();
314
315        let response = client
316            .get(url)
317            .header("Range", format!("bytes={}-", existing_size))
318            .send()
319            .await.map_err(|e| GdownError::Download(e.to_string()))?;
320
321        let mut file = tokio::fs::OpenOptions::new()
322            .append(true)
323            .open(output)
324            .await?;
325
326        let mut downloaded = existing_size;
327        let mut stream = response.bytes_stream();
328
329        while let Some(chunk_result) = stream.next().await {
330            let chunk = chunk_result.map_err(|e| GdownError::Download(e.to_string()))?;
331            file.write_all(&chunk).await?;
332            downloaded += chunk.len() as u64;
333
334            if let Some(ref cb) = options.progress_callback {
335                cb(downloaded, None);
336            }
337        }
338
339        Ok(downloaded)
340    }
341
342    /// Get filename from Content-Disposition header
343    pub fn get_filename_from_disposition(disposition: &str) -> Option<String> {
344        // Try filename* first (UTF-8 encoded)
345        if let Some(start) = disposition.find("filename*=UTF-8''") {
346            let remainder = &disposition[start + 17..];
347            if let Some(end) = remainder.find(';') {
348                return Some(remainder[..end].to_string());
349            }
350            return Some(remainder.to_string());
351        }
352
353        // Try simple filename
354        if let Some(start) = disposition.find("filename=\"") {
355            let remainder = &disposition[start + 10..];
356            if let Some(end) = remainder.find('"') {
357                return Some(remainder[..end].to_string());
358            }
359        }
360
361        None
362    }
363}
364
365impl Default for Downloader {
366    fn default() -> Self {
367        Self::new()
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_downloader_creation() {
377        let dl = Downloader::new();
378        assert!(dl.verify_ssl);
379    }
380
381    #[test]
382    fn test_filename_from_disposition() {
383        let disp = r#"attachment; filename="test.txt"; filename*=UTF-8''test%20file.txt"#;
384        let filename = Downloader::get_filename_from_disposition(disp);
385        assert_eq!(filename, Some("test%20file.txt".to_string()));
386    }
387
388    #[test]
389    fn test_filename_simple() {
390        let disp = r#"attachment; filename="test.txt""#;
391        let filename = Downloader::get_filename_from_disposition(disp);
392        assert_eq!(filename, Some("test.txt".to_string()));
393    }
394
395    #[test]
396    fn test_filename_from_disposition_empty() {
397        let disp = r#"attachment"#;
398        let filename = Downloader::get_filename_from_disposition(disp);
399        assert_eq!(filename, None);
400    }
401
402    #[test]
403    fn test_filename_from_disposition_only_filename_star() {
404        // Only filename* (UTF-8), no regular filename
405        let disp = r#"attachment; filename*=UTF-8''test%20file.txt"#;
406        let filename = Downloader::get_filename_from_disposition(disp);
407        assert_eq!(filename, Some("test%20file.txt".to_string()));
408    }
409
410    #[test]
411    fn test_filename_from_disposition_with_spaces() {
412        let disp = r#"attachment; filename="test file with spaces.txt""#;
413        let filename = Downloader::get_filename_from_disposition(disp);
414        assert_eq!(filename, Some("test file with spaces.txt".to_string()));
415    }
416
417    #[test]
418    fn test_filename_from_disposition_no_quotes() {
419        let disp = r#"attachment; filename=test.txt"#;
420        let filename = Downloader::get_filename_from_disposition(disp);
421        assert_eq!(filename, None); // Must have quotes
422    }
423
424    #[test]
425    fn test_filename_from_disposition_rfc5987_chars() {
426        // filename* with special UTF-8 characters
427        let disp = r#"attachment; filename*=UTF-8''%E6%96%87%E4%BB%B6.txt"#;
428        let filename = Downloader::get_filename_from_disposition(disp);
429        assert_eq!(filename, Some("%E6%96%87%E4%BB%B6.txt".to_string()));
430    }
431}