1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3use std::time::Duration;
4
5use futures::stream::StreamExt;
6use reqwest::header::HeaderMap;
7use reqwest::{header, Client};
8use tokio::fs;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::Semaphore;
11
12use crate::models::{Photo, Video};
13use crate::PexelsError;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ImageQuality {
18 Original,
19 Large2x,
20 Large,
21 Medium,
22 Small,
23 Portrait,
24 Landscape,
25 Tiny,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum VideoQuality {
31 HD,
32 SD,
33 Tiny,
34}
35
36pub type ProgressCallback = fn(current: u64, total: u64);
38
39type Result<T> = std::result::Result<T, PexelsError>;
41
42pub struct DownloadManager {
43 client: Client,
44 max_concurrent: usize,
45}
46
47impl DownloadManager {
48 pub fn new(max_concurrent: usize) -> Self {
54 let client = Client::builder()
55 .timeout(Duration::from_secs(60))
56 .pool_max_idle_per_host(20)
57 .build()
58 .unwrap_or_default();
59
60 Self {
61 client,
62 max_concurrent,
63 }
64 }
65
66 pub fn with_client(client: Client, max_concurrent: usize) -> Self {
68 Self {
69 client,
70 max_concurrent,
71 }
72 }
73
74 pub async fn download_photo<P: AsRef<Path>>(
85 &self,
86 photo: &Photo,
87 output_dir: P,
88 quality: ImageQuality,
89 ) -> Result<PathBuf> {
90 let url = self.get_photo_url(photo, quality);
91 let file_name = format!("photo_{}.jpg", photo.id);
92 self.download_file(&url, output_dir, &file_name).await
93 }
94
95 pub async fn download_video<P: AsRef<Path>>(
106 &self,
107 video: &Video,
108 output_dir: P,
109 quality: VideoQuality,
110 ) -> Result<PathBuf> {
111 let url = self.get_video_url(video, quality);
112 let file_name = format!("video_{}.mp4", video.id);
113 self.download_file(&url, output_dir, &file_name).await
114 }
115
116 pub async fn batch_download_photos<P: AsRef<Path>>(
127 &self,
128 photos: &[Photo],
129 output_dir: P,
130 quality: ImageQuality,
131 progress_callback: Option<ProgressCallback>,
132 ) -> Result<Vec<PathBuf>> {
133 let output_dir = output_dir.as_ref().to_path_buf();
134 let semaphore = Arc::new(Semaphore::new(self.max_concurrent));
135
136 let mut handles = Vec::with_capacity(photos.len());
137
138 for photo in photos {
139 let permit = Arc::clone(&semaphore).acquire_owned();
140 let photo = photo.clone();
141 let dir = output_dir.clone();
142 let client = self.client.clone();
143 let callback = progress_callback;
144
145 let handle = tokio::spawn(async move {
146 let _permit = permit.await.map_err(|_| PexelsError::AsyncError)?;
147
148 let url = match quality {
149 ImageQuality::Original => &photo.src.original,
150 ImageQuality::Large2x => &photo.src.large2x,
151 ImageQuality::Large => &photo.src.large,
152 ImageQuality::Medium => &photo.src.medium,
153 ImageQuality::Small => &photo.src.small,
154 ImageQuality::Portrait => &photo.src.portrait,
155 ImageQuality::Landscape => &photo.src.landscape,
156 ImageQuality::Tiny => &photo.src.tiny,
157 };
158
159 let file_name = format!("photo_{}.jpg", photo.id);
160 let path = dir.join(&file_name);
161
162 if !dir.exists() {
164 fs::create_dir_all(&dir).await?;
165 }
166
167 let mut headers = HeaderMap::new();
169 let mut range_start = 0;
170
171 if path.exists() {
172 if let Ok(metadata) = fs::metadata(&path).await {
173 range_start = metadata.len();
174 headers.insert(
175 header::RANGE,
176 format!("bytes={range_start}-").parse().unwrap(),
177 );
178 }
179 }
180
181 let response = client.get(url).headers(headers).send().await?;
183
184 if !response.status().is_success() {
185 return Err(PexelsError::DownloadError(format!(
186 "Failed to download file: {}",
187 response.status()
188 )));
189 }
190
191 let total_size = response.content_length().unwrap_or(0) + range_start;
193
194 let mut file = if range_start > 0 {
195 fs::OpenOptions::new().append(true).open(&path).await?
196 } else {
197 fs::File::create(&path).await?
198 };
199
200 let mut stream = response.bytes_stream();
201 let mut downloaded = range_start;
202
203 while let Some(chunk) = stream.next().await {
204 let chunk = chunk?;
205 file.write_all(&chunk).await?;
206
207 downloaded += chunk.len() as u64;
208
209 if let Some(cb) = callback {
211 cb(downloaded, total_size);
212 }
213 }
214
215 Ok::<PathBuf, PexelsError>(path)
216 });
217
218 handles.push(handle);
219 }
220
221 let results = futures::future::join_all(handles).await;
223
224 let mut successful_downloads = Vec::new();
226 for result in results {
227 match result {
228 Ok(Ok(path)) => successful_downloads.push(path),
229 Ok(Err(e)) => eprintln!("Download error: {e}"),
230 Err(e) => eprintln!("Task join error: {e}"),
231 }
232 }
233
234 Ok(successful_downloads)
235 }
236
237 pub async fn batch_download_videos<P: AsRef<Path>>(
248 &self,
249 videos: &[Video],
250 output_dir: P,
251 quality: VideoQuality,
252 progress_callback: Option<ProgressCallback>,
253 ) -> Result<Vec<PathBuf>> {
254 let output_dir = output_dir.as_ref().to_path_buf();
255 let semaphore = Arc::new(Semaphore::new(self.max_concurrent));
256
257 let mut handles = Vec::with_capacity(videos.len());
258
259 for video in videos {
260 let permit = Arc::clone(&semaphore).acquire_owned();
261 let video = video.clone();
262 let dir = output_dir.clone();
263 let client = self.client.clone();
264 let callback = progress_callback;
265
266 let handle = tokio::spawn(async move {
267 let _permit = permit.await.map_err(|_| PexelsError::AsyncError)?;
268
269 let video_file = video
271 .video_files
272 .iter()
273 .find(|file| match quality {
274 VideoQuality::HD => file.quality == "hd" || file.quality == "HD",
275 VideoQuality::SD => file.quality == "sd",
276 VideoQuality::Tiny => {
277 file.file_type == "video/mp4"
278 && (file.width.unwrap_or(0) <= 640
279 || file.height.unwrap_or(0) <= 360)
280 }
281 })
282 .ok_or_else(|| {
283 PexelsError::DownloadError("No suitable video file found".to_string())
284 })?;
285
286 let url = &video_file.link;
287 let file_name = format!("video_{}.mp4", video.id);
288 let path = dir.join(&file_name);
289
290 if !dir.exists() {
292 fs::create_dir_all(&dir).await?;
293 }
294
295 let mut headers = HeaderMap::new();
297 let mut range_start = 0;
298
299 if path.exists() {
300 if let Ok(metadata) = fs::metadata(&path).await {
301 range_start = metadata.len();
302 headers.insert(
303 header::RANGE,
304 format!("bytes={range_start}-").parse().unwrap(),
305 );
306 }
307 }
308
309 let response = client.get(url).headers(headers).send().await?;
311
312 if !response.status().is_success() {
313 return Err(PexelsError::DownloadError(format!(
314 "Failed to download file: {}",
315 response.status()
316 )));
317 }
318
319 let total_size = response.content_length().unwrap_or(0) + range_start;
321
322 let mut file = if range_start > 0 {
323 fs::OpenOptions::new().append(true).open(&path).await?
324 } else {
325 fs::File::create(&path).await?
326 };
327
328 let mut stream = response.bytes_stream();
329 let mut downloaded = range_start;
330
331 while let Some(chunk) = stream.next().await {
332 let chunk = chunk?;
333 file.write_all(&chunk).await?;
334
335 downloaded += chunk.len() as u64;
336
337 if let Some(cb) = callback {
339 cb(downloaded, total_size);
340 }
341 }
342
343 Ok::<PathBuf, PexelsError>(path)
344 });
345
346 handles.push(handle);
347 }
348
349 let results = futures::future::join_all(handles).await;
351
352 let mut successful_downloads = Vec::new();
354 for result in results {
355 match result {
356 Ok(Ok(path)) => successful_downloads.push(path),
357 Ok(Err(e)) => eprintln!("Download error: {e}"),
358 Err(e) => eprintln!("Task join error: {e}"),
359 }
360 }
361
362 Ok(successful_downloads)
363 }
364
365 async fn download_file<P: AsRef<Path>>(
375 &self,
376 url: &str,
377 output_dir: P,
378 file_name: &str,
379 ) -> Result<PathBuf> {
380 let output_dir = output_dir.as_ref().to_path_buf();
381 let path = output_dir.join(file_name);
382
383 if !output_dir.exists() {
385 fs::create_dir_all(&output_dir).await?;
386 }
387
388 let mut headers = HeaderMap::new();
390 let mut range_start = 0;
391
392 if path.exists() {
393 if let Ok(metadata) = fs::metadata(&path).await {
394 range_start = metadata.len();
395 headers.insert(
396 header::RANGE,
397 format!("bytes={range_start}-").parse().unwrap(),
398 );
399 }
400 }
401
402 let response = self.client.get(url).headers(headers).send().await?;
404
405 if !response.status().is_success() {
406 return Err(PexelsError::DownloadError(format!(
407 "Failed to download file: {}",
408 response.status()
409 )));
410 }
411
412 let _total_size = response.content_length().unwrap_or(0) + range_start;
414
415 let mut file = if range_start > 0 {
416 fs::OpenOptions::new().append(true).open(&path).await?
417 } else {
418 fs::File::create(&path).await?
419 };
420
421 let mut stream = response.bytes_stream();
422
423 while let Some(chunk) = stream.next().await {
424 let chunk = chunk?;
425 file.write_all(&chunk).await?;
426 }
427
428 Ok(path)
429 }
430
431 fn get_photo_url(&self, photo: &Photo, quality: ImageQuality) -> String {
433 match quality {
434 ImageQuality::Original => photo.src.original.clone(),
435 ImageQuality::Large2x => photo.src.large2x.clone(),
436 ImageQuality::Large => photo.src.large.clone(),
437 ImageQuality::Medium => photo.src.medium.clone(),
438 ImageQuality::Small => photo.src.small.clone(),
439 ImageQuality::Portrait => photo.src.portrait.clone(),
440 ImageQuality::Landscape => photo.src.landscape.clone(),
441 ImageQuality::Tiny => photo.src.tiny.clone(),
442 }
443 }
444
445 fn get_video_url(&self, video: &Video, quality: VideoQuality) -> String {
447 let video_file = video
448 .video_files
449 .iter()
450 .find(|file| match quality {
451 VideoQuality::HD => file.quality == "hd" || file.quality == "HD",
452 VideoQuality::SD => file.quality == "sd",
453 VideoQuality::Tiny => {
454 file.file_type == "video/mp4"
455 && (file.width.unwrap_or(0) <= 640 || file.height.unwrap_or(0) <= 360)
456 }
457 })
458 .unwrap_or_else(|| {
459 video.video_files.first().unwrap_or_else(|| {
461 panic!("No video files available for video ID: {}", video.id)
462 })
463 });
464
465 video_file.link.clone()
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::models::PhotoSources;
473 use tokio::test;
474
475 fn mock_photo() -> Photo {
477 Photo {
478 id: 1,
479 width: 800,
480 height: 600,
481 url: "https://www.pexels.com/photo/1".to_string(),
482 photographer: "Test Photographer".to_string(),
483 photographer_url: Some("https://www.pexels.com/photographer".to_string()),
484 photographer_id: Some(1),
485 avg_color: Some("#FFFFFF".to_string()),
486 src: PhotoSources {
487 original: "https://images.pexels.com/photos/1/original.jpg".to_string(),
488 large2x: "https://images.pexels.com/photos/1/large2x.jpg".to_string(),
489 large: "https://images.pexels.com/photos/1/large.jpg".to_string(),
490 medium: "https://images.pexels.com/photos/1/medium.jpg".to_string(),
491 small: "https://images.pexels.com/photos/1/small.jpg".to_string(),
492 portrait: "https://images.pexels.com/photos/1/portrait.jpg".to_string(),
493 landscape: "https://images.pexels.com/photos/1/landscape.jpg".to_string(),
494 tiny: "https://images.pexels.com/photos/1/tiny.jpg".to_string(),
495 },
496 alt: Some("Test Photo".to_string()),
497 }
498 }
499
500 #[test]
501 async fn test_get_photo_url() {
502 let manager = DownloadManager::new(5);
503 let photo = mock_photo();
504
505 assert_eq!(
506 manager.get_photo_url(&photo, ImageQuality::Original),
507 "https://images.pexels.com/photos/1/original.jpg"
508 );
509 assert_eq!(
510 manager.get_photo_url(&photo, ImageQuality::Large2x),
511 "https://images.pexels.com/photos/1/large2x.jpg"
512 );
513 }
514}