pexels_sdk/
download.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures::stream::StreamExt;
6use reqwest::header::HeaderMap;
7use reqwest::{header, Client};
8use tokio::fs;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::Semaphore;
11
12use crate::models::{Photo, Video};
13use crate::PexelsError;
14
15/// Picture quality enumeration
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ImageQuality {
18    Original,
19    Large2x,
20    Large,
21    Medium,
22    Small,
23    Portrait,
24    Landscape,
25    Tiny,
26}
27
28/// Video quality enumeration
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum VideoQuality {
31    HD,
32    SD,
33    Tiny,
34}
35
36/// The type of progress callback function
37pub type ProgressCallback = fn(current: u64, total: u64);
38
39/// Result type alias
40type Result<T> = std::result::Result<T, PexelsError>;
41
42pub struct DownloadManager {
43    client: Client,
44    max_concurrent: usize,
45}
46
47impl DownloadManager {
48    /// Create a new 'DownloadManager' and specify the maximum number of concurrent downloads
49    /// The default timeout is set to 60 seconds
50    ///
51    /// # Arguments
52    /// * `max_concurrent` - Maximum number of concurrent downloads
53    pub fn new(max_concurrent: usize) -> Self {
54        let client = Client::builder()
55            .timeout(Duration::from_secs(60))
56            .pool_max_idle_per_host(20)
57            .build()
58            .unwrap_or_default();
59
60        Self {
61            client,
62            max_concurrent,
63        }
64    }
65
66    /// Create a 'DownloadManager' with a custom 'Client'
67    pub fn with_client(client: Client, max_concurrent: usize) -> Self {
68        Self {
69            client,
70            max_concurrent,
71        }
72    }
73
74    /// Download the photos from the given URL and save to the specified output directory
75    /// Resumable upload is supported
76    ///
77    /// # Arguments
78    /// * `photo` - Photos to download
79    /// * `output_dir` - Output directory
80    /// * `quality` - Download quality
81    ///
82    /// # Returns
83    /// The path to download the file
84    pub async fn download_photo<P: AsRef<Path>>(
85        &self,
86        photo: &Photo,
87        output_dir: P,
88        quality: ImageQuality,
89    ) -> Result<PathBuf> {
90        let url = self.get_photo_url(photo, quality);
91        let file_name = format!("photo_{}.jpg", photo.id);
92        self.download_file(&url, output_dir, &file_name).await
93    }
94
95    /// Download the video from the given URL and save to the specified output directory
96    /// Resumable upload is supported
97    ///
98    /// # Arguments
99    /// * `video` - Video to download
100    /// * `output_dir` - Output directory
101    /// * `quality` - Download quality
102    ///
103    /// # Returns
104    /// The path to download the file
105    pub async fn download_video<P: AsRef<Path>>(
106        &self,
107        video: &Video,
108        output_dir: P,
109        quality: VideoQuality,
110    ) -> Result<PathBuf> {
111        let url = self.get_video_url(video, quality);
112        let file_name = format!("video_{}.mp4", video.id);
113        self.download_file(&url, output_dir, &file_name).await
114    }
115
116    /// Download photos in batches
117    ///
118    /// # Arguments
119    /// * `photos` - A list of photos to download
120    /// * `output_dir` - Output directory
121    /// * `quality` - Download quality
122    /// * `progress_callback` - Optional progress callback function
123    ///
124    /// # Returns
125    /// A list of files that have been successfully downloaded
126    pub async fn batch_download_photos<P: AsRef<Path>>(
127        &self,
128        photos: &[Photo],
129        output_dir: P,
130        quality: ImageQuality,
131        progress_callback: Option<ProgressCallback>,
132    ) -> Result<Vec<PathBuf>> {
133        let output_dir = output_dir.as_ref().to_path_buf();
134        let semaphore = Arc::new(Semaphore::new(self.max_concurrent));
135
136        let mut handles = Vec::with_capacity(photos.len());
137
138        for photo in photos {
139            let permit = Arc::clone(&semaphore).acquire_owned();
140            let photo = photo.clone();
141            let dir = output_dir.clone();
142            let client = self.client.clone();
143            let callback = progress_callback;
144
145            let handle = tokio::spawn(async move {
146                let _permit = permit.await.map_err(|_| PexelsError::AsyncError)?;
147
148                let url = match quality {
149                    ImageQuality::Original => &photo.src.original,
150                    ImageQuality::Large2x => &photo.src.large2x,
151                    ImageQuality::Large => &photo.src.large,
152                    ImageQuality::Medium => &photo.src.medium,
153                    ImageQuality::Small => &photo.src.small,
154                    ImageQuality::Portrait => &photo.src.portrait,
155                    ImageQuality::Landscape => &photo.src.landscape,
156                    ImageQuality::Tiny => &photo.src.tiny,
157                };
158
159                let file_name = format!("photo_{}.jpg", photo.id);
160                let path = dir.join(&file_name);
161
162                // Make sure the directory exists
163                if !dir.exists() {
164                    fs::create_dir_all(&dir).await?;
165                }
166
167                // Resumable upload logic
168                let mut headers = HeaderMap::new();
169                let mut range_start = 0;
170
171                if path.exists() {
172                    if let Ok(metadata) = fs::metadata(&path).await {
173                        range_start = metadata.len();
174                        headers.insert(
175                            header::RANGE,
176                            format!("bytes={range_start}-").parse().unwrap(),
177                        );
178                    }
179                }
180
181                // Download the file
182                let response = client.get(url).headers(headers).send().await?;
183
184                if !response.status().is_success() {
185                    return Err(PexelsError::DownloadError(format!(
186                        "Failed to download file: {}",
187                        response.status()
188                    )));
189                }
190
191                // Get the file size
192                let total_size = response.content_length().unwrap_or(0) + range_start;
193
194                let mut file = if range_start > 0 {
195                    fs::OpenOptions::new().append(true).open(&path).await?
196                } else {
197                    fs::File::create(&path).await?
198                };
199
200                let mut stream = response.bytes_stream();
201                let mut downloaded = range_start;
202
203                while let Some(chunk) = stream.next().await {
204                    let chunk = chunk?;
205                    file.write_all(&chunk).await?;
206
207                    downloaded += chunk.len() as u64;
208
209                    // Call progress callback (if provided)
210                    if let Some(cb) = callback {
211                        cb(downloaded, total_size);
212                    }
213                }
214
215                Ok::<PathBuf, PexelsError>(path)
216            });
217
218            handles.push(handle);
219        }
220
221        // wait for all downloads to complete
222        let results = futures::future::join_all(handles).await;
223
224        // Process the results
225        let mut successful_downloads = Vec::new();
226        for result in results {
227            match result {
228                Ok(Ok(path)) => successful_downloads.push(path),
229                Ok(Err(e)) => eprintln!("Download error: {e}"),
230                Err(e) => eprintln!("Task join error: {e}"),
231            }
232        }
233
234        Ok(successful_downloads)
235    }
236
237    /// Download videos in batches
238    ///
239    /// # Arguments
240    /// * `videos` - A list of videos to download
241    /// * `output_dir` - Output directory
242    /// * `quality` - Download quality
243    /// * `progress_callback` - Optional progress callback function
244    ///
245    /// # Returns
246    /// A list of files that have been successfully downloaded
247    pub async fn batch_download_videos<P: AsRef<Path>>(
248        &self,
249        videos: &[Video],
250        output_dir: P,
251        quality: VideoQuality,
252        progress_callback: Option<ProgressCallback>,
253    ) -> Result<Vec<PathBuf>> {
254        let output_dir = output_dir.as_ref().to_path_buf();
255        let semaphore = Arc::new(Semaphore::new(self.max_concurrent));
256
257        let mut handles = Vec::with_capacity(videos.len());
258
259        for video in videos {
260            let permit = Arc::clone(&semaphore).acquire_owned();
261            let video = video.clone();
262            let dir = output_dir.clone();
263            let client = self.client.clone();
264            let callback = progress_callback;
265
266            let handle = tokio::spawn(async move {
267                let _permit = permit.await.map_err(|_| PexelsError::AsyncError)?;
268
269                // 获取对应质量的视频 URL
270                let video_file = video
271                    .video_files
272                    .iter()
273                    .find(|file| match quality {
274                        VideoQuality::HD => file.quality == "hd" || file.quality == "HD",
275                        VideoQuality::SD => file.quality == "sd",
276                        VideoQuality::Tiny => {
277                            file.file_type == "video/mp4"
278                                && (file.width.unwrap_or(0) <= 640
279                                    || file.height.unwrap_or(0) <= 360)
280                        }
281                    })
282                    .ok_or_else(|| {
283                        PexelsError::DownloadError("No suitable video file found".to_string())
284                    })?;
285
286                let url = &video_file.link;
287                let file_name = format!("video_{}.mp4", video.id);
288                let path = dir.join(&file_name);
289
290                // Make sure the directory exists
291                if !dir.exists() {
292                    fs::create_dir_all(&dir).await?;
293                }
294
295                // Resumable upload logic
296                let mut headers = HeaderMap::new();
297                let mut range_start = 0;
298
299                if path.exists() {
300                    if let Ok(metadata) = fs::metadata(&path).await {
301                        range_start = metadata.len();
302                        headers.insert(
303                            header::RANGE,
304                            format!("bytes={range_start}-").parse().unwrap(),
305                        );
306                    }
307                }
308
309                // Download the file
310                let response = client.get(url).headers(headers).send().await?;
311
312                if !response.status().is_success() {
313                    return Err(PexelsError::DownloadError(format!(
314                        "Failed to download file: {}",
315                        response.status()
316                    )));
317                }
318
319                // Get the file size
320                let total_size = response.content_length().unwrap_or(0) + range_start;
321
322                let mut file = if range_start > 0 {
323                    fs::OpenOptions::new().append(true).open(&path).await?
324                } else {
325                    fs::File::create(&path).await?
326                };
327
328                let mut stream = response.bytes_stream();
329                let mut downloaded = range_start;
330
331                while let Some(chunk) = stream.next().await {
332                    let chunk = chunk?;
333                    file.write_all(&chunk).await?;
334
335                    downloaded += chunk.len() as u64;
336
337                    // Call progress callback (if provided)
338                    if let Some(cb) = callback {
339                        cb(downloaded, total_size);
340                    }
341                }
342
343                Ok::<PathBuf, PexelsError>(path)
344            });
345
346            handles.push(handle);
347        }
348
349        // Wait for all downloads to complete
350        let results = futures::future::join_all(handles).await;
351
352        // Process the results
353        let mut successful_downloads = Vec::new();
354        for result in results {
355            match result {
356                Ok(Ok(path)) => successful_downloads.push(path),
357                Ok(Err(e)) => eprintln!("Download error: {e}"),
358                Err(e) => eprintln!("Task join error: {e}"),
359            }
360        }
361
362        Ok(successful_downloads)
363    }
364
365    /// Download a single file
366    ///
367    /// # Arguments
368    /// * `url` - File URL
369    /// * `output_dir` - Output directory
370    /// * `file_name` - Filename
371    ///
372    /// # Returns
373    /// The path to download the file
374    async fn download_file<P: AsRef<Path>>(
375        &self,
376        url: &str,
377        output_dir: P,
378        file_name: &str,
379    ) -> Result<PathBuf> {
380        let output_dir = output_dir.as_ref().to_path_buf();
381        let path = output_dir.join(file_name);
382
383        // Make sure the directory exists
384        if !output_dir.exists() {
385            fs::create_dir_all(&output_dir).await?;
386        }
387
388        // Resumable upload logic
389        let mut headers = HeaderMap::new();
390        let mut range_start = 0;
391
392        if path.exists() {
393            if let Ok(metadata) = fs::metadata(&path).await {
394                range_start = metadata.len();
395                headers.insert(
396                    header::RANGE,
397                    format!("bytes={range_start}-").parse().unwrap(),
398                );
399            }
400        }
401
402        // Send a request
403        let response = self.client.get(url).headers(headers).send().await?;
404
405        if !response.status().is_success() {
406            return Err(PexelsError::DownloadError(format!(
407                "Failed to download file: {}",
408                response.status()
409            )));
410        }
411
412        // Get the file size
413        let _total_size = response.content_length().unwrap_or(0) + range_start;
414
415        let mut file = if range_start > 0 {
416            fs::OpenOptions::new().append(true).open(&path).await?
417        } else {
418            fs::File::create(&path).await?
419        };
420
421        let mut stream = response.bytes_stream();
422
423        while let Some(chunk) = stream.next().await {
424            let chunk = chunk?;
425            file.write_all(&chunk).await?;
426        }
427
428        Ok(path)
429    }
430
431    /// Get the photo URL
432    fn get_photo_url(&self, photo: &Photo, quality: ImageQuality) -> String {
433        match quality {
434            ImageQuality::Original => photo.src.original.clone(),
435            ImageQuality::Large2x => photo.src.large2x.clone(),
436            ImageQuality::Large => photo.src.large.clone(),
437            ImageQuality::Medium => photo.src.medium.clone(),
438            ImageQuality::Small => photo.src.small.clone(),
439            ImageQuality::Portrait => photo.src.portrait.clone(),
440            ImageQuality::Landscape => photo.src.landscape.clone(),
441            ImageQuality::Tiny => photo.src.tiny.clone(),
442        }
443    }
444
445    /// Get the video URL
446    fn get_video_url(&self, video: &Video, quality: VideoQuality) -> String {
447        let video_file = video
448            .video_files
449            .iter()
450            .find(|file| match quality {
451                VideoQuality::HD => file.quality == "hd" || file.quality == "HD",
452                VideoQuality::SD => file.quality == "sd",
453                VideoQuality::Tiny => {
454                    file.file_type == "video/mp4"
455                        && (file.width.unwrap_or(0) <= 640 || file.height.unwrap_or(0) <= 360)
456                }
457            })
458            .unwrap_or_else(|| {
459                // If you can't find the specified quality, return the first video file
460                video.video_files.first().unwrap_or_else(|| {
461                    panic!("No video files available for video ID: {}", video.id)
462                })
463            });
464
465        video_file.link.clone()
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::models::PhotoSources;
473    use tokio::test;
474
475    // Simulate the Photo data structure
476    fn mock_photo() -> Photo {
477        Photo {
478            id: 1,
479            width: 800,
480            height: 600,
481            url: "https://www.pexels.com/photo/1".to_string(),
482            photographer: "Test Photographer".to_string(),
483            photographer_url: Some("https://www.pexels.com/photographer".to_string()),
484            photographer_id: Some(1),
485            avg_color: Some("#FFFFFF".to_string()),
486            src: PhotoSources {
487                original: "https://images.pexels.com/photos/1/original.jpg".to_string(),
488                large2x: "https://images.pexels.com/photos/1/large2x.jpg".to_string(),
489                large: "https://images.pexels.com/photos/1/large.jpg".to_string(),
490                medium: "https://images.pexels.com/photos/1/medium.jpg".to_string(),
491                small: "https://images.pexels.com/photos/1/small.jpg".to_string(),
492                portrait: "https://images.pexels.com/photos/1/portrait.jpg".to_string(),
493                landscape: "https://images.pexels.com/photos/1/landscape.jpg".to_string(),
494                tiny: "https://images.pexels.com/photos/1/tiny.jpg".to_string(),
495            },
496            alt: Some("Test Photo".to_string()),
497        }
498    }
499
500    #[test]
501    async fn test_get_photo_url() {
502        let manager = DownloadManager::new(5);
503        let photo = mock_photo();
504
505        assert_eq!(
506            manager.get_photo_url(&photo, ImageQuality::Original),
507            "https://images.pexels.com/photos/1/original.jpg"
508        );
509        assert_eq!(
510            manager.get_photo_url(&photo, ImageQuality::Large2x),
511            "https://images.pexels.com/photos/1/large2x.jpg"
512        );
513    }
514}