1use anyhow::{anyhow, Result};
2use clap::{Args, Subcommand, ValueEnum};
3use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
4use std::path::PathBuf;
5use std::sync::Arc;
6use tokio::sync::Semaphore;
7
8use crate::client::RommClient;
9use crate::core::download::{download_directory, extract_zip_archive, unique_zip_path};
10use crate::core::interrupt::{cancelled_error, is_cancelled_error, InterruptContext};
11use crate::core::utils;
12use crate::endpoints::roms::GetRoms;
13use crate::services::{PlatformService, RomService};
14use crate::types::Platform;
15
16const DEFAULT_CONCURRENCY: usize = 4;
18
19#[derive(Args, Debug)]
21pub struct DownloadCommand {
22 pub rom_id: Option<u64>,
24
25 #[command(subcommand)]
26 pub action: Option<DownloadAction>,
27
28 #[arg(short, long, global = true)]
30 pub output: Option<PathBuf>,
31
32 #[arg(long, global = true)]
34 pub platform: Option<String>,
35
36 #[arg(long, global = true)]
38 pub search_term: Option<String>,
39
40 #[arg(long, default_value_t = DEFAULT_CONCURRENCY, global = true)]
42 pub jobs: usize,
43
44 #[arg(long, global = true)]
46 pub extract: bool,
47
48 #[arg(long, value_enum, default_value_t = ExtractLayout::Platform, global = true)]
50 pub extract_layout: ExtractLayout,
51
52 #[arg(long, global = true)]
54 pub delete_zip_after_extract: bool,
55}
56
57#[derive(Subcommand, Debug)]
58pub enum DownloadAction {
59 #[command(visible_alias = "all")]
61 Batch,
62}
63
64#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
65pub enum ExtractLayout {
66 Platform,
68 Flat,
70 Rom,
72}
73
74fn make_progress_style() -> ProgressStyle {
75 ProgressStyle::with_template(
76 "[{elapsed_precise}] {bar:40.cyan/blue} {bytes}/{total_bytes} ({eta}) {msg}",
77 )
78 .unwrap()
79 .progress_chars("#>-")
80}
81
82async fn download_one(
83 client: &RommClient,
84 rom_id: u64,
85 name: &str,
86 save_path: &std::path::Path,
87 pb: ProgressBar,
88) -> Result<()> {
89 pb.set_message(name.to_string());
90
91 client
92 .download_rom(rom_id, save_path, {
93 let pb = pb.clone();
94 move |received, total| {
95 if pb.length() != Some(total) {
96 pb.set_length(total);
97 }
98 pb.set_position(received);
99 }
100 })
101 .await?;
102
103 pb.finish_with_message(format!("✓ {name}"));
104 Ok(())
105}
106
107pub async fn handle(
108 cmd: DownloadCommand,
109 client: &RommClient,
110 interrupt: Option<InterruptContext>,
111) -> Result<()> {
112 let interrupt = interrupt.unwrap_or_default();
113 let output_dir = cmd.output.unwrap_or_else(download_directory);
114
115 tokio::fs::create_dir_all(&output_dir)
117 .await
118 .map_err(|e| anyhow!("create download dir {:?}: {e}", output_dir))?;
119
120 let is_batch = matches!(cmd.action, Some(DownloadAction::Batch));
122
123 if is_batch {
124 if cmd.platform.is_none() && cmd.search_term.is_none() {
126 return Err(anyhow!(
127 "Batch download requires at least --platform or --search-term to scope the download"
128 ));
129 }
130 let resolved_platform_id = resolve_platform_id(client, cmd.platform.as_deref()).await?;
131
132 let ep = GetRoms {
133 search_term: cmd.search_term.clone(),
134 platform_id: resolved_platform_id,
135 collection_id: None,
136 smart_collection_id: None,
137 virtual_collection_id: None,
138 limit: Some(9999),
139 offset: None,
140 ..Default::default()
141 };
142
143 let service = RomService::new(client);
144 let results = service.search_roms(&ep).await?;
145
146 if results.items.is_empty() {
147 println!("No ROMs found matching the given filters.");
148 return Ok(());
149 }
150
151 println!(
152 "Found {} ROM(s). Starting download with {} concurrent connections...",
153 results.items.len(),
154 cmd.jobs
155 );
156
157 let mp = MultiProgress::new();
158 let semaphore = Arc::new(Semaphore::new(cmd.jobs));
159 let mut handles = Vec::new();
160
161 'enqueue: for rom in results.items {
162 if interrupt.is_cancelled() {
163 break 'enqueue;
164 }
165 let permit = semaphore.clone().acquire_owned().await.unwrap();
166 let client = client.clone();
167 let dir = output_dir.clone();
168 let interrupt = interrupt.clone();
169 let pb = mp.add(ProgressBar::new(0));
170 pb.set_style(make_progress_style());
171
172 let name = rom.name.clone();
173 let rom_id = rom.id;
174 let platform_slug = rom
175 .platform_fs_slug
176 .clone()
177 .or_else(|| rom.platform_slug.clone())
178 .unwrap_or_else(|| format!("platform-{}", rom.platform_id));
179 let base = utils::sanitize_filename(&rom.fs_name);
180 let stem = base
181 .rsplit_once('.')
182 .map(|(s, _)| s.to_string())
183 .unwrap_or(base.clone());
184 let save_path = unique_zip_path(&dir, &stem);
185 let extract = cmd.extract;
186 let extract_layout = cmd.extract_layout;
187 let delete_zip_after_extract = cmd.delete_zip_after_extract;
188
189 handles.push(tokio::spawn(async move {
190 let mut progress = {
191 let pb = pb.clone();
192 move |received, total| {
193 if pb.length() != Some(total) {
194 pb.set_length(total);
195 }
196 pb.set_position(received);
197 }
198 };
199 let mut result = client
200 .download_rom_with_cancel(
201 rom_id,
202 &save_path,
203 |_, _| interrupt.is_cancelled(),
204 &mut progress,
205 )
206 .await
207 .map(|_| {
208 pb.finish_with_message(format!("✓ {name}"));
209 });
210
211 if result.is_ok() && extract {
212 let extract_dir =
213 extraction_target_dir(&dir, &platform_slug, &stem, extract_layout);
214 if let Err(err) = tokio::fs::create_dir_all(&extract_dir).await {
215 result = Err(anyhow!(
216 "failed to create extraction directory {:?}: {}",
217 extract_dir,
218 err
219 ));
220 } else if let Err(err) = extract_zip_archive(&save_path, &extract_dir) {
221 result = Err(anyhow!(
222 "failed to extract {:?} to {:?}: {}",
223 save_path,
224 extract_dir,
225 err
226 ));
227 } else if delete_zip_after_extract {
228 tokio::fs::remove_file(&save_path).await.map_err(|err| {
229 anyhow!(
230 "failed to delete zip {:?} after extraction: {}",
231 save_path,
232 err
233 )
234 })?;
235 }
236 }
237
238 drop(permit);
239 if let Err(e) = &result {
240 if !is_cancelled_error(e) {
241 eprintln!("error downloading {name} (id={rom_id}): {e}");
242 }
243 }
244 result
245 }));
246 }
247
248 let mut successes = 0u32;
249 let mut failures = 0u32;
250 let mut cancelled = 0u32;
251 for handle in handles {
252 let task_result = tokio::select! {
253 res = handle => res,
254 _ = interrupt.cancelled() => {
255 cancelled += 1;
256 continue;
257 }
258 };
259 match task_result {
260 Ok(Ok(())) => successes += 1,
261 Ok(Err(e)) if is_cancelled_error(&e) => cancelled += 1,
262 _ => failures += 1,
263 }
264 }
265
266 if interrupt.is_cancelled() {
267 println!("\nInterrupted by user.");
268 }
269 println!(
270 "\nBatch complete: {successes} succeeded, {failures} failed, {cancelled} cancelled."
271 );
272 } else {
273 let rom_id = cmd.rom_id.ok_or_else(|| {
275 anyhow!(
276 "ROM ID is required (e.g. 'download 123' or 'download batch --search-term ...')"
277 )
278 })?;
279
280 let save_path = output_dir.join(format!("rom_{rom_id}.zip"));
281
282 let mp = MultiProgress::new();
283 let pb = mp.add(ProgressBar::new(0));
284 pb.set_style(make_progress_style());
285
286 if interrupt.is_cancelled() {
287 return Err(cancelled_error());
288 }
289 download_one(client, rom_id, &format!("ROM {rom_id}"), &save_path, pb).await?;
290
291 println!("Saved to {:?}", save_path);
292 }
293
294 Ok(())
295}
296
297async fn resolve_platform_id(
298 client: &RommClient,
299 platform_query: Option<&str>,
300) -> Result<Option<u64>> {
301 let Some(query) = platform_query.map(str::trim).filter(|q| !q.is_empty()) else {
302 return Ok(None);
303 };
304 let service = PlatformService::new(client);
305 let platforms = service.list_platforms().await?;
306 resolve_platform_query(query, &platforms).map(Some)
307}
308
309fn resolve_platform_query(query: &str, platforms: &[Platform]) -> Result<u64> {
310 let normalized = query.trim().to_ascii_lowercase();
311
312 if let Some(platform) = platforms.iter().find(|p| {
313 p.slug.eq_ignore_ascii_case(&normalized) || p.fs_slug.eq_ignore_ascii_case(&normalized)
314 }) {
315 return Ok(platform.id);
316 }
317
318 let exact_name_matches: Vec<&Platform> = platforms
319 .iter()
320 .filter(|p| {
321 p.name.eq_ignore_ascii_case(&normalized)
322 || p.display_name
323 .as_deref()
324 .is_some_and(|name| name.eq_ignore_ascii_case(&normalized))
325 || p.custom_name
326 .as_deref()
327 .is_some_and(|name| name.eq_ignore_ascii_case(&normalized))
328 })
329 .collect();
330
331 match exact_name_matches.len() {
332 1 => Ok(exact_name_matches[0].id),
333 0 => Err(anyhow!(
334 "No platform found for '{}'. Use 'romm-cli platforms list' to inspect available values.",
335 query
336 )),
337 _ => {
338 let names = exact_name_matches
339 .iter()
340 .map(|p| format!("{} ({})", p.name, p.id))
341 .collect::<Vec<_>>()
342 .join(", ");
343 Err(anyhow!(
344 "Platform '{}' is ambiguous. Matches: {}. Please use a more specific --platform value.",
345 query,
346 names
347 ))
348 }
349 }
350}
351
352fn extraction_target_dir(
353 output_dir: &std::path::Path,
354 platform_slug: &str,
355 rom_stem: &str,
356 layout: ExtractLayout,
357) -> PathBuf {
358 let platform = utils::sanitize_filename(platform_slug);
359 let rom = utils::sanitize_filename(rom_stem);
360 match layout {
361 ExtractLayout::Platform => output_dir.join(platform),
362 ExtractLayout::Flat => output_dir.to_path_buf(),
363 ExtractLayout::Rom => output_dir.join(platform).join(rom),
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use clap::Parser;
371
372 use crate::commands::{Cli, Commands};
373 use crate::types::Firmware;
374
375 #[test]
376 fn parse_download_batch_with_extract_flags() {
377 let cli = Cli::parse_from([
378 "romm-cli",
379 "download",
380 "batch",
381 "--search-term",
382 "Super Mario",
383 "--extract",
384 "--extract-layout",
385 "platform",
386 "--delete-zip-after-extract",
387 "--jobs",
388 "8",
389 ]);
390
391 let Commands::Download(cmd) = cli.command else {
392 panic!("expected download command");
393 };
394
395 assert!(matches!(cmd.action, Some(DownloadAction::Batch)));
396 assert_eq!(cmd.search_term.as_deref(), Some("Super Mario"));
397 assert!(cmd.extract);
398 assert_eq!(cmd.extract_layout, ExtractLayout::Platform);
399 assert!(cmd.delete_zip_after_extract);
400 assert_eq!(cmd.jobs, 8);
401 }
402
403 #[test]
404 fn parse_download_batch_extract_defaults() {
405 let cli = Cli::parse_from(["romm-cli", "download", "batch", "--search-term", "Metroid"]);
406
407 let Commands::Download(cmd) = cli.command else {
408 panic!("expected download command");
409 };
410
411 assert!(matches!(cmd.action, Some(DownloadAction::Batch)));
412 assert!(!cmd.extract);
413 assert_eq!(cmd.extract_layout, ExtractLayout::Platform);
414 assert!(!cmd.delete_zip_after_extract);
415 }
416
417 #[test]
418 fn parse_download_batch_with_platform_alias() {
419 let cli = Cli::parse_from([
420 "romm-cli",
421 "download",
422 "batch",
423 "--platform",
424 "3ds",
425 "--search-term",
426 "Mario",
427 ]);
428
429 let Commands::Download(cmd) = cli.command else {
430 panic!("expected download command");
431 };
432
433 assert_eq!(cmd.platform.as_deref(), Some("3ds"));
434 }
435
436 #[test]
437 fn parse_download_batch_rejects_platform_id_flag() {
438 let parsed = Cli::try_parse_from([
439 "romm-cli",
440 "download",
441 "batch",
442 "--platform",
443 "3ds",
444 "--platform-id",
445 "3",
446 ]);
447 assert!(parsed.is_err(), "expected clap parse failure");
448 }
449
450 #[test]
451 fn extraction_target_dir_platform_layout() {
452 let dir = PathBuf::from("/tmp/out");
453 let target = extraction_target_dir(
454 &dir,
455 "Nintendo Switch",
456 "Mario (USA)",
457 ExtractLayout::Platform,
458 );
459 assert_eq!(target, PathBuf::from("/tmp/out/Nintendo Switch"));
460 }
461
462 #[test]
463 fn extraction_target_dir_rom_layout() {
464 let dir = PathBuf::from("/tmp/out");
465 let target = extraction_target_dir(&dir, "SNES", "Super Mario World", ExtractLayout::Rom);
466 assert_eq!(target, PathBuf::from("/tmp/out/SNES/Super Mario World"));
467 }
468
469 #[test]
470 fn resolve_platform_query_matches_slug_first() {
471 let platforms = vec![platform_fixture(
472 3,
473 "3ds",
474 "3ds",
475 "Nintendo 3DS",
476 None,
477 None,
478 )];
479 let id = resolve_platform_query("3ds", &platforms).expect("slug should resolve");
480 assert_eq!(id, 3);
481 }
482
483 #[test]
484 fn resolve_platform_query_matches_name_case_insensitive() {
485 let platforms = vec![platform_fixture(
486 4,
487 "nintendo-3ds",
488 "3ds",
489 "Nintendo 3DS",
490 None,
491 None,
492 )];
493 let id = resolve_platform_query("nintendo 3ds", &platforms).expect("name should resolve");
494 assert_eq!(id, 4);
495 }
496
497 #[test]
498 fn resolve_platform_query_errors_when_ambiguous() {
499 let platforms = vec![
500 platform_fixture(7, "foo-a", "foo-a", "Arcade", None, None),
501 platform_fixture(8, "foo-b", "foo-b", "Arcade", None, None),
502 ];
503 let err = resolve_platform_query("Arcade", &platforms).expect_err("should be ambiguous");
504 assert!(
505 err.to_string().contains("ambiguous"),
506 "unexpected error: {err:#}"
507 );
508 }
509
510 #[test]
511 fn resolve_platform_query_errors_when_missing() {
512 let platforms = vec![platform_fixture(
513 2,
514 "gba",
515 "gba",
516 "Game Boy Advance",
517 None,
518 None,
519 )];
520 let err = resolve_platform_query("3ds", &platforms).expect_err("should not match");
521 assert!(
522 err.to_string().contains("No platform found"),
523 "unexpected error: {err:#}"
524 );
525 }
526
527 fn platform_fixture(
528 id: u64,
529 slug: &str,
530 fs_slug: &str,
531 name: &str,
532 display_name: Option<&str>,
533 custom_name: Option<&str>,
534 ) -> Platform {
535 Platform {
536 id,
537 slug: slug.to_string(),
538 fs_slug: fs_slug.to_string(),
539 rom_count: 0,
540 name: name.to_string(),
541 igdb_slug: None,
542 moby_slug: None,
543 hltb_slug: None,
544 custom_name: custom_name.map(ToString::to_string),
545 igdb_id: None,
546 sgdb_id: None,
547 moby_id: None,
548 launchbox_id: None,
549 ss_id: None,
550 ra_id: None,
551 hasheous_id: None,
552 tgdb_id: None,
553 flashpoint_id: None,
554 category: None,
555 generation: None,
556 family_name: None,
557 family_slug: None,
558 url: None,
559 url_logo: None,
560 firmware: Vec::<Firmware>::new(),
561 aspect_ratio: None,
562 created_at: "2026-01-01T00:00:00Z".to_string(),
563 updated_at: "2026-01-01T00:00:00Z".to_string(),
564 fs_size_bytes: 0,
565 is_unidentified: false,
566 is_identified: true,
567 missing_from_fs: false,
568 display_name: display_name.map(ToString::to_string),
569 }
570 }
571}