1use std::collections::HashMap;
27use std::fs;
28use std::io::{self, Read, Write};
29use std::path::{Path, PathBuf};
30
31use serde::Deserialize;
32use tracing::info;
33
34use crate::{BUNDLE_ADJUSTMENT_DATA_DIR, ODOMETRY_DATA_DIR};
35
36const DATASETS_TOML: &str = include_str!("../datasets.toml");
38
39#[derive(Debug, Clone, Deserialize)]
45pub struct OdometryEntry {
46 pub url: String,
48 pub filename: String,
50 pub category: String,
52}
53
54#[derive(Debug, Clone, Deserialize)]
56pub struct BaEntry {
57 pub url_prefix: String,
59 pub problems: Vec<[u32; 2]>,
61}
62
63impl BaEntry {
64 pub fn largest(&self) -> Option<[u32; 2]> {
66 self.problems.last().copied()
67 }
68
69 pub fn problem_url(&self, cameras: u32, points: u32) -> String {
71 format!(
72 "{}/problem-{}-{}-pre.txt.bz2",
73 self.url_prefix, cameras, points
74 )
75 }
76}
77
78#[derive(Debug, Deserialize)]
80pub struct DatasetRegistry {
81 pub odometry: HashMap<String, OdometryEntry>,
83 pub bundle_adjustment: HashMap<String, BaEntry>,
85}
86
87impl DatasetRegistry {
88 pub fn load() -> io::Result<Self> {
94 toml::from_str(DATASETS_TOML).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
95 }
96
97 pub fn odometry_path(&self, name: &str) -> Option<std::path::PathBuf> {
114 self.odometry.get(name).map(|e| {
115 std::path::PathBuf::from(crate::ODOMETRY_DATA_DIR)
116 .join(&e.category)
117 .join(&e.filename)
118 })
119 }
120
121 pub fn odometry_by_category(&self, category: &str) -> Vec<(&str, &OdometryEntry)> {
124 let mut entries: Vec<_> = self
125 .odometry
126 .iter()
127 .filter(|(_, e)| e.category == category)
128 .map(|(name, entry)| (name.as_str(), entry))
129 .collect();
130 entries.sort_by_key(|(name, _)| *name);
131 entries
132 }
133
134 pub fn ba_path(&self, name: &str, cameras: u32, points: u32) -> Option<std::path::PathBuf> {
141 self.bundle_adjustment.get(name).map(|_| {
142 std::path::PathBuf::from(crate::BUNDLE_ADJUSTMENT_DATA_DIR)
143 .join(name)
144 .join(format!("problem-{cameras}-{points}-pre.txt"))
145 })
146 }
147
148 pub fn ba_sorted(&self) -> Vec<(&str, &BaEntry)> {
150 let mut entries: Vec<_> = self
151 .bundle_adjustment
152 .iter()
153 .map(|(name, entry)| (name.as_str(), entry))
154 .collect();
155 entries.sort_by_key(|(name, _)| *name);
156 entries
157 }
158}
159
160pub fn ensure_odometry_dataset(name: &str) -> io::Result<PathBuf> {
173 let registry = DatasetRegistry::load()?;
174
175 let entry = registry.odometry.get(name).ok_or_else(|| {
176 io::Error::other(format!(
177 "Dataset '{name}' not found in registry. \
178 Available: {}",
179 {
180 let mut names: Vec<_> = registry.odometry.keys().map(String::as_str).collect();
181 names.sort();
182 names.join(", ")
183 }
184 ))
185 })?;
186
187 let path = PathBuf::from(ODOMETRY_DATA_DIR)
188 .join(&entry.category)
189 .join(&entry.filename);
190 if path.exists() {
191 return Ok(path);
192 }
193
194 info!("Downloading {name} ({}) ...", entry.filename);
195 download_file(&entry.url, &path)
196 .map_err(|e| io::Error::other(format!("Failed to download {name}: {e}")))?;
197 info!("Saved to {}", path.display());
198 Ok(path)
199}
200
201pub fn ensure_ba_dataset(name: &str, cameras: u32, points: u32) -> io::Result<PathBuf> {
210 let txt_path = PathBuf::from(BUNDLE_ADJUSTMENT_DATA_DIR)
211 .join(name)
212 .join(format!("problem-{cameras}-{points}-pre.txt"));
213
214 if txt_path.exists() {
215 return Ok(txt_path);
216 }
217
218 let registry = DatasetRegistry::load()?;
219 let entry = registry.bundle_adjustment.get(name).ok_or_else(|| {
220 io::Error::other(format!(
221 "BA dataset '{name}' not found in registry. \
222 Available: {}",
223 {
224 let mut names: Vec<_> = registry
225 .bundle_adjustment
226 .keys()
227 .map(String::as_str)
228 .collect();
229 names.sort();
230 names.join(", ")
231 }
232 ))
233 })?;
234
235 let url = entry.problem_url(cameras, points);
236 let bz2_path = txt_path.with_extension("txt.bz2");
237
238 info!("Downloading {name}/problem-{cameras}-{points} ...");
239 download_file(&url, &bz2_path)
240 .map_err(|e| io::Error::other(format!("Failed to download {name}: {e}")))?;
241
242 decompress_bzip2(&bz2_path, &txt_path)
243 .map_err(|e| io::Error::other(format!("Failed to decompress: {e}")))?;
244
245 let _ = fs::remove_file(&bz2_path); info!("Saved to {}", txt_path.display());
247 Ok(txt_path)
248}
249
250pub fn download_file(url: &str, dest: &Path) -> io::Result<()> {
259 if let Some(parent) = dest.parent() {
260 fs::create_dir_all(parent)?;
261 }
262
263 let response = ureq::get(url)
264 .call()
265 .map_err(|e| io::Error::other(format!("HTTP request failed for {url}: {e}")))?;
266
267 let mut buf = Vec::new();
268 response
269 .into_reader()
270 .read_to_end(&mut buf)
271 .map_err(|e| io::Error::other(format!("Failed to read response body: {e}")))?;
272
273 let mut file = fs::File::create(dest)?;
274 file.write_all(&buf)?;
275 Ok(())
276}
277
278pub fn decompress_bzip2(src: &Path, dest: &Path) -> io::Result<()> {
284 use bzip2::read::BzDecoder;
285
286 if let Some(parent) = dest.parent() {
287 fs::create_dir_all(parent)?;
288 }
289
290 let compressed = fs::File::open(src)?;
291 let mut decoder = BzDecoder::new(compressed);
292 let mut decompressed = Vec::new();
293 decoder.read_to_end(&mut decompressed)?;
294
295 let mut out = fs::File::create(dest)?;
296 out.write_all(&decompressed)?;
297 Ok(())
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn registry_parses_without_panic() -> io::Result<()> {
306 let registry = DatasetRegistry::load()?;
307 assert!(
308 !registry.odometry.is_empty(),
309 "odometry section must not be empty"
310 );
311 assert!(
312 !registry.bundle_adjustment.is_empty(),
313 "bundle_adjustment section must not be empty"
314 );
315 Ok(())
316 }
317
318 #[test]
319 fn registry_contains_expected_odometry_datasets() -> io::Result<()> {
320 let registry = DatasetRegistry::load()?;
321 for name in &["sphere2500", "parking-garage", "intel", "M3500"] {
322 assert!(
323 registry.odometry.contains_key(*name),
324 "missing expected dataset: {name}"
325 );
326 }
327 Ok(())
328 }
329
330 #[test]
331 fn registry_contains_expected_ba_datasets() -> io::Result<()> {
332 let registry = DatasetRegistry::load()?;
333 for name in &["ladybug", "trafalgar", "dubrovnik", "venice", "final"] {
334 assert!(
335 registry.bundle_adjustment.contains_key(*name),
336 "missing expected BA dataset: {name}"
337 );
338 }
339 Ok(())
340 }
341
342 #[test]
343 fn odometry_entries_have_valid_categories() -> io::Result<()> {
344 let registry = DatasetRegistry::load()?;
345 for (name, entry) in ®istry.odometry {
346 assert!(
347 entry.category == "2d" || entry.category == "3d",
348 "dataset '{name}' has invalid category: '{}'",
349 entry.category
350 );
351 }
352 Ok(())
353 }
354
355 #[test]
356 fn ba_entries_have_at_least_one_problem() -> io::Result<()> {
357 let registry = DatasetRegistry::load()?;
358 for (name, entry) in ®istry.bundle_adjustment {
359 assert!(
360 !entry.problems.is_empty(),
361 "BA dataset '{name}' has no problems listed"
362 );
363 }
364 Ok(())
365 }
366
367 #[test]
368 fn ba_problem_url_format_is_correct() -> io::Result<()> {
369 let registry = DatasetRegistry::load()?;
370 let ladybug = registry
371 .bundle_adjustment
372 .get("ladybug")
373 .ok_or_else(|| io::Error::other("ladybug dataset not found"))?;
374 let url = ladybug.problem_url(49, 7776);
375 assert_eq!(
376 url,
377 "https://grail.cs.washington.edu/projects/bal/data/ladybug/problem-49-7776-pre.txt.bz2"
378 );
379 Ok(())
380 }
381
382 #[test]
383 fn odometry_by_category_returns_only_3d() -> io::Result<()> {
384 let registry = DatasetRegistry::load()?;
385 let entries = registry.odometry_by_category("3d");
386 for (_, entry) in &entries {
387 assert_eq!(entry.category, "3d");
388 }
389 assert!(!entries.is_empty());
390 Ok(())
391 }
392
393 #[test]
394 fn sphere2500_uses_github_url() -> io::Result<()> {
395 let registry = DatasetRegistry::load()?;
396 let entry = registry
397 .odometry
398 .get("sphere2500")
399 .ok_or_else(|| io::Error::other("sphere2500 must exist"))?;
400 assert!(
401 entry.url.contains("github"),
402 "sphere2500 should use the GitHub URL, got: {}",
403 entry.url
404 );
405 Ok(())
406 }
407
408 #[test]
409 fn registry_contains_new_vertigo_datasets() -> io::Result<()> {
410 let registry = DatasetRegistry::load()?;
411 for name in &["manhattanOlson3500", "ring", "ring_city", "city10000"] {
412 assert!(
413 registry.odometry.contains_key(*name),
414 "missing expected dataset: {name}"
415 );
416 }
417 Ok(())
418 }
419
420 #[test]
421 fn odometry_path_includes_category_subdir() -> io::Result<()> {
422 let registry = DatasetRegistry::load()?;
423 let path_3d = registry
424 .odometry_path("sphere2500")
425 .ok_or_else(|| io::Error::other("sphere2500 path not found"))?;
426 let path_2d = registry
427 .odometry_path("intel")
428 .ok_or_else(|| io::Error::other("intel path not found"))?;
429 assert!(
430 path_3d.components().any(|c| c.as_os_str() == "3d"),
431 "3D path should contain '3d' component, got: {}",
432 path_3d.display()
433 );
434 assert!(
435 path_2d.components().any(|c| c.as_os_str() == "2d"),
436 "2D path should contain '2d' component, got: {}",
437 path_2d.display()
438 );
439 Ok(())
440 }
441
442 #[test]
443 fn sphere_bignoise_removed_from_registry() -> io::Result<()> {
444 let registry = DatasetRegistry::load()?;
445 assert!(
446 !registry.odometry.contains_key("sphere_bignoise"),
447 "sphere_bignoise should have been removed (merged into sphere2500)"
448 );
449 Ok(())
450 }
451
452 #[test]
453 fn ba_path_returns_correct_structure() -> io::Result<()> {
454 let registry = DatasetRegistry::load()?;
455 let path = registry
456 .ba_path("ladybug", 49, 7776)
457 .ok_or_else(|| io::Error::other("ladybug ba_path not found"))?;
458 assert!(
459 path.components()
460 .any(|c| c.as_os_str() == "bundle_adjustment"),
461 "path should contain 'bundle_adjustment', got: {}",
462 path.display()
463 );
464 assert!(
465 path.components().any(|c| c.as_os_str() == "ladybug"),
466 "path should contain 'ladybug', got: {}",
467 path.display()
468 );
469 assert!(
470 path.file_name()
471 .is_some_and(|f| f == "problem-49-7776-pre.txt"),
472 "filename should be 'problem-49-7776-pre.txt', got: {}",
473 path.display()
474 );
475 Ok(())
476 }
477
478 #[test]
479 fn ba_path_returns_none_for_unknown() -> io::Result<()> {
480 let registry = DatasetRegistry::load()?;
481 assert!(
482 registry.ba_path("nonexistent_ba_xyz", 1, 1).is_none(),
483 "unknown BA name should return None"
484 );
485 Ok(())
486 }
487
488 #[test]
489 fn ba_sorted_returns_alphabetical_order() -> io::Result<()> {
490 let registry = DatasetRegistry::load()?;
491 let entries = registry.ba_sorted();
492 assert!(!entries.is_empty(), "ba_sorted should not be empty");
493 for window in entries.windows(2) {
494 assert!(
495 window[0].0 <= window[1].0,
496 "ba_sorted is not sorted: '{}' > '{}'",
497 window[0].0,
498 window[1].0
499 );
500 }
501 Ok(())
502 }
503
504 #[test]
505 fn ba_entry_largest_returns_last_problem() -> io::Result<()> {
506 let registry = DatasetRegistry::load()?;
507 let ladybug = registry
508 .bundle_adjustment
509 .get("ladybug")
510 .ok_or_else(|| io::Error::other("ladybug not found"))?;
511 let largest = ladybug.largest();
512 assert!(largest.is_some(), "ladybug should have a largest problem");
513 assert_eq!(
514 largest,
515 ladybug.problems.last().copied(),
516 "largest() should equal the last problem"
517 );
518 Ok(())
519 }
520
521 #[test]
522 fn ba_entry_largest_empty_returns_none() {
523 let entry = BaEntry {
524 url_prefix: "https://example.com".to_string(),
525 problems: vec![],
526 };
527 assert!(
528 entry.largest().is_none(),
529 "empty problems should return None"
530 );
531 }
532
533 #[test]
534 fn odometry_by_category_returns_only_2d() -> io::Result<()> {
535 let registry = DatasetRegistry::load()?;
536 let entries = registry.odometry_by_category("2d");
537 assert!(!entries.is_empty(), "should have at least one 2d dataset");
538 for (_, entry) in &entries {
539 assert_eq!(entry.category, "2d");
540 }
541 Ok(())
542 }
543
544 #[test]
545 fn odometry_by_category_is_sorted() -> io::Result<()> {
546 let registry = DatasetRegistry::load()?;
547 let entries = registry.odometry_by_category("3d");
548 for window in entries.windows(2) {
549 assert!(
550 window[0].0 <= window[1].0,
551 "odometry_by_category is not sorted: '{}' > '{}'",
552 window[0].0,
553 window[1].0
554 );
555 }
556 Ok(())
557 }
558
559 #[test]
560 fn odometry_entries_have_nonempty_url_and_filename() -> io::Result<()> {
561 let registry = DatasetRegistry::load()?;
562 for (name, entry) in ®istry.odometry {
563 assert!(!entry.url.is_empty(), "dataset '{name}' has empty url");
564 assert!(
565 !entry.filename.is_empty(),
566 "dataset '{name}' has empty filename"
567 );
568 }
569 Ok(())
570 }
571
572 #[test]
573 fn decompress_bzip2_roundtrip() -> io::Result<()> {
574 use bzip2::Compression;
575 use bzip2::write::BzEncoder;
576 use std::io::Write as _;
577
578 let original = b"hello bzip2 roundtrip test data";
579
580 let tmp_dir = tempfile::tempdir()?;
582 let bz2_path = tmp_dir.path().join("test.txt.bz2");
583 let txt_path = tmp_dir.path().join("test.txt");
584
585 {
586 let file = fs::File::create(&bz2_path)?;
587 let mut encoder = BzEncoder::new(file, Compression::fast());
588 encoder.write_all(original)?;
589 encoder.finish()?;
590 }
591
592 decompress_bzip2(&bz2_path, &txt_path)?;
593
594 let decompressed = fs::read(&txt_path)?;
595 assert_eq!(
596 decompressed, original,
597 "decompressed content must match original"
598 );
599 Ok(())
600 }
601
602 #[test]
603 fn ensure_odometry_dataset_unknown_name_errors() -> io::Result<()> {
604 let err = ensure_odometry_dataset("nonexistent_dataset_xyz_abc")
605 .err()
606 .ok_or_else(|| io::Error::other("expected Err but got Ok"))?;
607 assert!(
608 err.to_string().contains("not found in registry"),
609 "error message should mention registry, got: {err}"
610 );
611 Ok(())
612 }
613
614 #[test]
615 fn ensure_ba_dataset_unknown_name_errors() -> io::Result<()> {
616 let err = ensure_ba_dataset("nonexistent_ba_xyz_abc", 1, 1)
617 .err()
618 .ok_or_else(|| io::Error::other("expected Err but got Ok"))?;
619 assert!(
620 err.to_string().contains("not found in registry"),
621 "error message should mention registry, got: {err}"
622 );
623 Ok(())
624 }
625}