Skip to main content

apex_io/
utils.rs

1//! Dataset utilities — registry, download helpers, and on-demand ensure functions.
2//!
3//! All dataset metadata (names, URLs, categories) lives in `datasets.toml`, which is
4//! embedded at compile time. No URLs are hardcoded in Rust source.
5//!
6//! # Usage in tests
7//!
8//! ```no_run
9//! use apex_io::ensure_odometry_dataset;
10//!
11//! let path = ensure_odometry_dataset("sphere2500").expect("failed to fetch dataset");
12//! // path == "data/odometry/3d/sphere2500.g2o"
13//! ```
14//!
15//! # Usage in the download binary
16//!
17//! ```no_run
18//! use apex_io::utils::DatasetRegistry;
19//!
20//! let registry = DatasetRegistry::load().unwrap();
21//! for (name, entry) in registry.odometry_by_category("3d") {
22//!     println!("{name}: {}", entry.url);
23//! }
24//! ```
25
26use std::collections::HashMap;
27use std::fs;
28use std::io::{self, Read, Write};
29use std::path::{Path, PathBuf};
30
31use serde::Deserialize;
32use tracing::info;
33
34use crate::{BUNDLE_ADJUSTMENT_DATA_DIR, ODOMETRY_DATA_DIR};
35
36// Compile-time embed of the dataset registry.
37const DATASETS_TOML: &str = include_str!("../datasets.toml");
38
39// ---------------------------------------------------------------------------
40// Registry types
41// ---------------------------------------------------------------------------
42
43/// Metadata for a single odometry (pose graph) dataset.
44#[derive(Debug, Clone, Deserialize)]
45pub struct OdometryEntry {
46    /// Direct download URL for the `.g2o` file.
47    pub url: String,
48    /// Filename on disk (saved to `data/odometry/<filename>`).
49    pub filename: String,
50    /// Pose graph dimensionality: `"2d"` or `"3d"`.
51    pub category: String,
52}
53
54/// Metadata for a bundle adjustment (BAL) dataset collection.
55#[derive(Debug, Clone, Deserialize)]
56pub struct BaEntry {
57    /// URL prefix; full URL = `{url_prefix}/problem-{cameras}-{points}-pre.txt.bz2`.
58    pub url_prefix: String,
59    /// All available (cameras, points) problem sizes in this collection.
60    pub problems: Vec<[u32; 2]>,
61}
62
63impl BaEntry {
64    /// Returns the largest problem (most cameras) in this collection.
65    pub fn largest(&self) -> Option<[u32; 2]> {
66        self.problems.last().copied()
67    }
68
69    /// Constructs the download URL for a specific problem size.
70    pub fn problem_url(&self, cameras: u32, points: u32) -> String {
71        format!(
72            "{}/problem-{}-{}-pre.txt.bz2",
73            self.url_prefix, cameras, points
74        )
75    }
76}
77
78/// The complete dataset registry, parsed from `datasets.toml`.
79#[derive(Debug, Deserialize)]
80pub struct DatasetRegistry {
81    /// Odometry datasets keyed by short name (e.g. `"sphere2500"`, `"intel"`).
82    pub odometry: HashMap<String, OdometryEntry>,
83    /// Bundle adjustment datasets keyed by collection name (e.g. `"ladybug"`).
84    pub bundle_adjustment: HashMap<String, BaEntry>,
85}
86
87impl DatasetRegistry {
88    /// Load the registry from the compile-time embedded `datasets.toml`.
89    ///
90    /// # Errors
91    /// Returns an error only if `datasets.toml` is malformed TOML — a
92    /// developer error that should never occur with the bundled file.
93    pub fn load() -> io::Result<Self> {
94        toml::from_str(DATASETS_TOML).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
95    }
96
97    /// Returns the on-disk path for an odometry dataset, including its category subdirectory.
98    ///
99    /// Returns `None` if `name` is not in the registry.
100    ///
101    /// # Example
102    /// ```
103    /// use apex_io::DatasetRegistry;
104    /// # fn main() -> std::io::Result<()> {
105    /// let reg = DatasetRegistry::load()?;
106    /// assert_eq!(
107    ///     reg.odometry_path("intel"),
108    ///     Some(std::path::Path::new("data/odometry").join("2d").join("intel.g2o"))
109    /// );
110    /// # Ok(())
111    /// # }
112    /// ```
113    pub fn odometry_path(&self, name: &str) -> Option<std::path::PathBuf> {
114        self.odometry.get(name).map(|e| {
115            std::path::PathBuf::from(crate::ODOMETRY_DATA_DIR)
116                .join(&e.category)
117                .join(&e.filename)
118        })
119    }
120
121    /// Returns all odometry entries with the given category (`"2d"` or `"3d"`),
122    /// sorted alphabetically by name for deterministic output.
123    pub fn odometry_by_category(&self, category: &str) -> Vec<(&str, &OdometryEntry)> {
124        let mut entries: Vec<_> = self
125            .odometry
126            .iter()
127            .filter(|(_, e)| e.category == category)
128            .map(|(name, entry)| (name.as_str(), entry))
129            .collect();
130        entries.sort_by_key(|(name, _)| *name);
131        entries
132    }
133
134    /// Returns the on-disk path for a specific BA problem file.
135    ///
136    /// The path follows the same layout the downloader creates:
137    /// `data/bundle_adjustment/{name}/problem-{cameras}-{points}-pre.txt`
138    ///
139    /// Returns `None` if `name` is not in the registry.
140    pub fn ba_path(&self, name: &str, cameras: u32, points: u32) -> Option<std::path::PathBuf> {
141        self.bundle_adjustment.get(name).map(|_| {
142            std::path::PathBuf::from(crate::BUNDLE_ADJUSTMENT_DATA_DIR)
143                .join(name)
144                .join(format!("problem-{cameras}-{points}-pre.txt"))
145        })
146    }
147
148    /// Returns all bundle adjustment entries sorted alphabetically by name.
149    pub fn ba_sorted(&self) -> Vec<(&str, &BaEntry)> {
150        let mut entries: Vec<_> = self
151            .bundle_adjustment
152            .iter()
153            .map(|(name, entry)| (name.as_str(), entry))
154            .collect();
155        entries.sort_by_key(|(name, _)| *name);
156        entries
157    }
158}
159
160// ---------------------------------------------------------------------------
161// Public ensure API (used by tests and binaries)
162// ---------------------------------------------------------------------------
163
164/// Ensure an odometry `.g2o` dataset is present at `data/odometry/{name}.g2o`.
165///
166/// If the file already exists it is returned immediately (no network access).
167/// Otherwise it is looked up in the dataset registry and downloaded.
168///
169/// # Errors
170/// Returns an error if the dataset name is not in the registry, the download
171/// fails, or the file cannot be written.
172pub fn ensure_odometry_dataset(name: &str) -> io::Result<PathBuf> {
173    let registry = DatasetRegistry::load()?;
174
175    let entry = registry.odometry.get(name).ok_or_else(|| {
176        io::Error::other(format!(
177            "Dataset '{name}' not found in registry. \
178             Available: {}",
179            {
180                let mut names: Vec<_> = registry.odometry.keys().map(String::as_str).collect();
181                names.sort();
182                names.join(", ")
183            }
184        ))
185    })?;
186
187    let path = PathBuf::from(ODOMETRY_DATA_DIR)
188        .join(&entry.category)
189        .join(&entry.filename);
190    if path.exists() {
191        return Ok(path);
192    }
193
194    info!("Downloading {name} ({}) ...", entry.filename);
195    download_file(&entry.url, &path)
196        .map_err(|e| io::Error::other(format!("Failed to download {name}: {e}")))?;
197    info!("Saved to {}", path.display());
198    Ok(path)
199}
200
201/// Ensure a BAL bundle-adjustment file is present at
202/// `data/bundle_adjustment/{name}/problem-{cameras}-{points}-pre.txt`.
203///
204/// If the file already exists it is returned immediately. Otherwise the
205/// `.bz2` archive is downloaded, decompressed, and the `.bz2` is cleaned up.
206///
207/// # Errors
208/// Returns an error if the download, decompression, or disk write fails.
209pub fn ensure_ba_dataset(name: &str, cameras: u32, points: u32) -> io::Result<PathBuf> {
210    let txt_path = PathBuf::from(BUNDLE_ADJUSTMENT_DATA_DIR)
211        .join(name)
212        .join(format!("problem-{cameras}-{points}-pre.txt"));
213
214    if txt_path.exists() {
215        return Ok(txt_path);
216    }
217
218    let registry = DatasetRegistry::load()?;
219    let entry = registry.bundle_adjustment.get(name).ok_or_else(|| {
220        io::Error::other(format!(
221            "BA dataset '{name}' not found in registry. \
222             Available: {}",
223            {
224                let mut names: Vec<_> = registry
225                    .bundle_adjustment
226                    .keys()
227                    .map(String::as_str)
228                    .collect();
229                names.sort();
230                names.join(", ")
231            }
232        ))
233    })?;
234
235    let url = entry.problem_url(cameras, points);
236    let bz2_path = txt_path.with_extension("txt.bz2");
237
238    info!("Downloading {name}/problem-{cameras}-{points} ...");
239    download_file(&url, &bz2_path)
240        .map_err(|e| io::Error::other(format!("Failed to download {name}: {e}")))?;
241
242    decompress_bzip2(&bz2_path, &txt_path)
243        .map_err(|e| io::Error::other(format!("Failed to decompress: {e}")))?;
244
245    let _ = fs::remove_file(&bz2_path); // clean up; ignore errors
246    info!("Saved to {}", txt_path.display());
247    Ok(txt_path)
248}
249
250// ---------------------------------------------------------------------------
251// Low-level download helpers (pub so the download_datasets binary can use them)
252// ---------------------------------------------------------------------------
253
254/// Download a URL to a local file, creating parent directories as needed.
255///
256/// # Errors
257/// Returns an error if the HTTP request fails or the file cannot be written.
258pub fn download_file(url: &str, dest: &Path) -> io::Result<()> {
259    if let Some(parent) = dest.parent() {
260        fs::create_dir_all(parent)?;
261    }
262
263    let response = ureq::get(url)
264        .call()
265        .map_err(|e| io::Error::other(format!("HTTP request failed for {url}: {e}")))?;
266
267    let mut buf = Vec::new();
268    response
269        .into_reader()
270        .read_to_end(&mut buf)
271        .map_err(|e| io::Error::other(format!("Failed to read response body: {e}")))?;
272
273    let mut file = fs::File::create(dest)?;
274    file.write_all(&buf)?;
275    Ok(())
276}
277
278/// Decompress a `.bz2` file to `dest`.
279///
280/// # Errors
281/// Returns an error if the file cannot be read or the decompressed data
282/// cannot be written.
283pub fn decompress_bzip2(src: &Path, dest: &Path) -> io::Result<()> {
284    use bzip2::read::BzDecoder;
285
286    if let Some(parent) = dest.parent() {
287        fs::create_dir_all(parent)?;
288    }
289
290    let compressed = fs::File::open(src)?;
291    let mut decoder = BzDecoder::new(compressed);
292    let mut decompressed = Vec::new();
293    decoder.read_to_end(&mut decompressed)?;
294
295    let mut out = fs::File::create(dest)?;
296    out.write_all(&decompressed)?;
297    Ok(())
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn registry_parses_without_panic() -> io::Result<()> {
306        let registry = DatasetRegistry::load()?;
307        assert!(
308            !registry.odometry.is_empty(),
309            "odometry section must not be empty"
310        );
311        assert!(
312            !registry.bundle_adjustment.is_empty(),
313            "bundle_adjustment section must not be empty"
314        );
315        Ok(())
316    }
317
318    #[test]
319    fn registry_contains_expected_odometry_datasets() -> io::Result<()> {
320        let registry = DatasetRegistry::load()?;
321        for name in &["sphere2500", "parking-garage", "intel", "M3500"] {
322            assert!(
323                registry.odometry.contains_key(*name),
324                "missing expected dataset: {name}"
325            );
326        }
327        Ok(())
328    }
329
330    #[test]
331    fn registry_contains_expected_ba_datasets() -> io::Result<()> {
332        let registry = DatasetRegistry::load()?;
333        for name in &["ladybug", "trafalgar", "dubrovnik", "venice", "final"] {
334            assert!(
335                registry.bundle_adjustment.contains_key(*name),
336                "missing expected BA dataset: {name}"
337            );
338        }
339        Ok(())
340    }
341
342    #[test]
343    fn odometry_entries_have_valid_categories() -> io::Result<()> {
344        let registry = DatasetRegistry::load()?;
345        for (name, entry) in &registry.odometry {
346            assert!(
347                entry.category == "2d" || entry.category == "3d",
348                "dataset '{name}' has invalid category: '{}'",
349                entry.category
350            );
351        }
352        Ok(())
353    }
354
355    #[test]
356    fn ba_entries_have_at_least_one_problem() -> io::Result<()> {
357        let registry = DatasetRegistry::load()?;
358        for (name, entry) in &registry.bundle_adjustment {
359            assert!(
360                !entry.problems.is_empty(),
361                "BA dataset '{name}' has no problems listed"
362            );
363        }
364        Ok(())
365    }
366
367    #[test]
368    fn ba_problem_url_format_is_correct() -> io::Result<()> {
369        let registry = DatasetRegistry::load()?;
370        let ladybug = registry
371            .bundle_adjustment
372            .get("ladybug")
373            .ok_or_else(|| io::Error::other("ladybug dataset not found"))?;
374        let url = ladybug.problem_url(49, 7776);
375        assert_eq!(
376            url,
377            "https://grail.cs.washington.edu/projects/bal/data/ladybug/problem-49-7776-pre.txt.bz2"
378        );
379        Ok(())
380    }
381
382    #[test]
383    fn odometry_by_category_returns_only_3d() -> io::Result<()> {
384        let registry = DatasetRegistry::load()?;
385        let entries = registry.odometry_by_category("3d");
386        for (_, entry) in &entries {
387            assert_eq!(entry.category, "3d");
388        }
389        assert!(!entries.is_empty());
390        Ok(())
391    }
392
393    #[test]
394    fn sphere2500_uses_github_url() -> io::Result<()> {
395        let registry = DatasetRegistry::load()?;
396        let entry = registry
397            .odometry
398            .get("sphere2500")
399            .ok_or_else(|| io::Error::other("sphere2500 must exist"))?;
400        assert!(
401            entry.url.contains("github"),
402            "sphere2500 should use the GitHub URL, got: {}",
403            entry.url
404        );
405        Ok(())
406    }
407
408    #[test]
409    fn registry_contains_new_vertigo_datasets() -> io::Result<()> {
410        let registry = DatasetRegistry::load()?;
411        for name in &["manhattanOlson3500", "ring", "ring_city", "city10000"] {
412            assert!(
413                registry.odometry.contains_key(*name),
414                "missing expected dataset: {name}"
415            );
416        }
417        Ok(())
418    }
419
420    #[test]
421    fn odometry_path_includes_category_subdir() -> io::Result<()> {
422        let registry = DatasetRegistry::load()?;
423        let path_3d = registry
424            .odometry_path("sphere2500")
425            .ok_or_else(|| io::Error::other("sphere2500 path not found"))?;
426        let path_2d = registry
427            .odometry_path("intel")
428            .ok_or_else(|| io::Error::other("intel path not found"))?;
429        assert!(
430            path_3d.components().any(|c| c.as_os_str() == "3d"),
431            "3D path should contain '3d' component, got: {}",
432            path_3d.display()
433        );
434        assert!(
435            path_2d.components().any(|c| c.as_os_str() == "2d"),
436            "2D path should contain '2d' component, got: {}",
437            path_2d.display()
438        );
439        Ok(())
440    }
441
442    #[test]
443    fn sphere_bignoise_removed_from_registry() -> io::Result<()> {
444        let registry = DatasetRegistry::load()?;
445        assert!(
446            !registry.odometry.contains_key("sphere_bignoise"),
447            "sphere_bignoise should have been removed (merged into sphere2500)"
448        );
449        Ok(())
450    }
451
452    #[test]
453    fn ba_path_returns_correct_structure() -> io::Result<()> {
454        let registry = DatasetRegistry::load()?;
455        let path = registry
456            .ba_path("ladybug", 49, 7776)
457            .ok_or_else(|| io::Error::other("ladybug ba_path not found"))?;
458        assert!(
459            path.components()
460                .any(|c| c.as_os_str() == "bundle_adjustment"),
461            "path should contain 'bundle_adjustment', got: {}",
462            path.display()
463        );
464        assert!(
465            path.components().any(|c| c.as_os_str() == "ladybug"),
466            "path should contain 'ladybug', got: {}",
467            path.display()
468        );
469        assert!(
470            path.file_name()
471                .is_some_and(|f| f == "problem-49-7776-pre.txt"),
472            "filename should be 'problem-49-7776-pre.txt', got: {}",
473            path.display()
474        );
475        Ok(())
476    }
477
478    #[test]
479    fn ba_path_returns_none_for_unknown() -> io::Result<()> {
480        let registry = DatasetRegistry::load()?;
481        assert!(
482            registry.ba_path("nonexistent_ba_xyz", 1, 1).is_none(),
483            "unknown BA name should return None"
484        );
485        Ok(())
486    }
487
488    #[test]
489    fn ba_sorted_returns_alphabetical_order() -> io::Result<()> {
490        let registry = DatasetRegistry::load()?;
491        let entries = registry.ba_sorted();
492        assert!(!entries.is_empty(), "ba_sorted should not be empty");
493        for window in entries.windows(2) {
494            assert!(
495                window[0].0 <= window[1].0,
496                "ba_sorted is not sorted: '{}' > '{}'",
497                window[0].0,
498                window[1].0
499            );
500        }
501        Ok(())
502    }
503
504    #[test]
505    fn ba_entry_largest_returns_last_problem() -> io::Result<()> {
506        let registry = DatasetRegistry::load()?;
507        let ladybug = registry
508            .bundle_adjustment
509            .get("ladybug")
510            .ok_or_else(|| io::Error::other("ladybug not found"))?;
511        let largest = ladybug.largest();
512        assert!(largest.is_some(), "ladybug should have a largest problem");
513        assert_eq!(
514            largest,
515            ladybug.problems.last().copied(),
516            "largest() should equal the last problem"
517        );
518        Ok(())
519    }
520
521    #[test]
522    fn ba_entry_largest_empty_returns_none() {
523        let entry = BaEntry {
524            url_prefix: "https://example.com".to_string(),
525            problems: vec![],
526        };
527        assert!(
528            entry.largest().is_none(),
529            "empty problems should return None"
530        );
531    }
532
533    #[test]
534    fn odometry_by_category_returns_only_2d() -> io::Result<()> {
535        let registry = DatasetRegistry::load()?;
536        let entries = registry.odometry_by_category("2d");
537        assert!(!entries.is_empty(), "should have at least one 2d dataset");
538        for (_, entry) in &entries {
539            assert_eq!(entry.category, "2d");
540        }
541        Ok(())
542    }
543
544    #[test]
545    fn odometry_by_category_is_sorted() -> io::Result<()> {
546        let registry = DatasetRegistry::load()?;
547        let entries = registry.odometry_by_category("3d");
548        for window in entries.windows(2) {
549            assert!(
550                window[0].0 <= window[1].0,
551                "odometry_by_category is not sorted: '{}' > '{}'",
552                window[0].0,
553                window[1].0
554            );
555        }
556        Ok(())
557    }
558
559    #[test]
560    fn odometry_entries_have_nonempty_url_and_filename() -> io::Result<()> {
561        let registry = DatasetRegistry::load()?;
562        for (name, entry) in &registry.odometry {
563            assert!(!entry.url.is_empty(), "dataset '{name}' has empty url");
564            assert!(
565                !entry.filename.is_empty(),
566                "dataset '{name}' has empty filename"
567            );
568        }
569        Ok(())
570    }
571
572    #[test]
573    fn decompress_bzip2_roundtrip() -> io::Result<()> {
574        use bzip2::Compression;
575        use bzip2::write::BzEncoder;
576        use std::io::Write as _;
577
578        let original = b"hello bzip2 roundtrip test data";
579
580        // Write compressed bytes to a temp file
581        let tmp_dir = tempfile::tempdir()?;
582        let bz2_path = tmp_dir.path().join("test.txt.bz2");
583        let txt_path = tmp_dir.path().join("test.txt");
584
585        {
586            let file = fs::File::create(&bz2_path)?;
587            let mut encoder = BzEncoder::new(file, Compression::fast());
588            encoder.write_all(original)?;
589            encoder.finish()?;
590        }
591
592        decompress_bzip2(&bz2_path, &txt_path)?;
593
594        let decompressed = fs::read(&txt_path)?;
595        assert_eq!(
596            decompressed, original,
597            "decompressed content must match original"
598        );
599        Ok(())
600    }
601
602    #[test]
603    fn ensure_odometry_dataset_unknown_name_errors() -> io::Result<()> {
604        let err = ensure_odometry_dataset("nonexistent_dataset_xyz_abc")
605            .err()
606            .ok_or_else(|| io::Error::other("expected Err but got Ok"))?;
607        assert!(
608            err.to_string().contains("not found in registry"),
609            "error message should mention registry, got: {err}"
610        );
611        Ok(())
612    }
613
614    #[test]
615    fn ensure_ba_dataset_unknown_name_errors() -> io::Result<()> {
616        let err = ensure_ba_dataset("nonexistent_ba_xyz_abc", 1, 1)
617            .err()
618            .ok_or_else(|| io::Error::other("expected Err but got Ok"))?;
619        assert!(
620            err.to_string().contains("not found in registry"),
621            "error message should mention registry, got: {err}"
622        );
623        Ok(())
624    }
625}