use std::collections::HashMap;
use std::fs;
use std::io::{self, Read, Write};
use std::path::{Path, PathBuf};
use serde::Deserialize;
use tracing::info;
use crate::{BUNDLE_ADJUSTMENT_DATA_DIR, ODOMETRY_DATA_DIR};
const DATASETS_TOML: &str = include_str!("../datasets.toml");
#[derive(Debug, Clone, Deserialize)]
pub struct OdometryEntry {
pub url: String,
pub filename: String,
pub category: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BaEntry {
pub url_prefix: String,
pub problems: Vec<[u32; 2]>,
}
impl BaEntry {
pub fn largest(&self) -> Option<[u32; 2]> {
self.problems.last().copied()
}
pub fn problem_url(&self, cameras: u32, points: u32) -> String {
format!(
"{}/problem-{}-{}-pre.txt.bz2",
self.url_prefix, cameras, points
)
}
}
#[derive(Debug, Deserialize)]
pub struct DatasetRegistry {
pub odometry: HashMap<String, OdometryEntry>,
pub bundle_adjustment: HashMap<String, BaEntry>,
}
impl DatasetRegistry {
pub fn load() -> io::Result<Self> {
toml::from_str(DATASETS_TOML).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
}
pub fn odometry_path(&self, name: &str) -> Option<std::path::PathBuf> {
self.odometry.get(name).map(|e| {
std::path::PathBuf::from(crate::ODOMETRY_DATA_DIR)
.join(&e.category)
.join(&e.filename)
})
}
pub fn odometry_by_category(&self, category: &str) -> Vec<(&str, &OdometryEntry)> {
let mut entries: Vec<_> = self
.odometry
.iter()
.filter(|(_, e)| e.category == category)
.map(|(name, entry)| (name.as_str(), entry))
.collect();
entries.sort_by_key(|(name, _)| *name);
entries
}
pub fn ba_path(&self, name: &str, cameras: u32, points: u32) -> Option<std::path::PathBuf> {
self.bundle_adjustment.get(name).map(|_| {
std::path::PathBuf::from(crate::BUNDLE_ADJUSTMENT_DATA_DIR)
.join(name)
.join(format!("problem-{cameras}-{points}-pre.txt"))
})
}
pub fn ba_sorted(&self) -> Vec<(&str, &BaEntry)> {
let mut entries: Vec<_> = self
.bundle_adjustment
.iter()
.map(|(name, entry)| (name.as_str(), entry))
.collect();
entries.sort_by_key(|(name, _)| *name);
entries
}
}
pub fn ensure_odometry_dataset(name: &str) -> io::Result<PathBuf> {
let registry = DatasetRegistry::load()?;
let entry = registry.odometry.get(name).ok_or_else(|| {
io::Error::other(format!(
"Dataset '{name}' not found in registry. \
Available: {}",
{
let mut names: Vec<_> = registry.odometry.keys().map(String::as_str).collect();
names.sort();
names.join(", ")
}
))
})?;
let path = PathBuf::from(ODOMETRY_DATA_DIR)
.join(&entry.category)
.join(&entry.filename);
if path.exists() {
return Ok(path);
}
info!("Downloading {name} ({}) ...", entry.filename);
download_file(&entry.url, &path)
.map_err(|e| io::Error::other(format!("Failed to download {name}: {e}")))?;
info!("Saved to {}", path.display());
Ok(path)
}
pub fn ensure_ba_dataset(name: &str, cameras: u32, points: u32) -> io::Result<PathBuf> {
let txt_path = PathBuf::from(BUNDLE_ADJUSTMENT_DATA_DIR)
.join(name)
.join(format!("problem-{cameras}-{points}-pre.txt"));
if txt_path.exists() {
return Ok(txt_path);
}
let registry = DatasetRegistry::load()?;
let entry = registry.bundle_adjustment.get(name).ok_or_else(|| {
io::Error::other(format!(
"BA dataset '{name}' not found in registry. \
Available: {}",
{
let mut names: Vec<_> = registry
.bundle_adjustment
.keys()
.map(String::as_str)
.collect();
names.sort();
names.join(", ")
}
))
})?;
let url = entry.problem_url(cameras, points);
let bz2_path = txt_path.with_extension("txt.bz2");
info!("Downloading {name}/problem-{cameras}-{points} ...");
download_file(&url, &bz2_path)
.map_err(|e| io::Error::other(format!("Failed to download {name}: {e}")))?;
decompress_bzip2(&bz2_path, &txt_path)
.map_err(|e| io::Error::other(format!("Failed to decompress: {e}")))?;
let _ = fs::remove_file(&bz2_path); info!("Saved to {}", txt_path.display());
Ok(txt_path)
}
pub fn download_file(url: &str, dest: &Path) -> io::Result<()> {
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent)?;
}
let response = ureq::get(url)
.call()
.map_err(|e| io::Error::other(format!("HTTP request failed for {url}: {e}")))?;
let mut buf = Vec::new();
response
.into_reader()
.read_to_end(&mut buf)
.map_err(|e| io::Error::other(format!("Failed to read response body: {e}")))?;
let mut file = fs::File::create(dest)?;
file.write_all(&buf)?;
Ok(())
}
pub fn decompress_bzip2(src: &Path, dest: &Path) -> io::Result<()> {
use bzip2::read::BzDecoder;
if let Some(parent) = dest.parent() {
fs::create_dir_all(parent)?;
}
let compressed = fs::File::open(src)?;
let mut decoder = BzDecoder::new(compressed);
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed)?;
let mut out = fs::File::create(dest)?;
out.write_all(&decompressed)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn registry_parses_without_panic() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
assert!(
!registry.odometry.is_empty(),
"odometry section must not be empty"
);
assert!(
!registry.bundle_adjustment.is_empty(),
"bundle_adjustment section must not be empty"
);
Ok(())
}
#[test]
fn registry_contains_expected_odometry_datasets() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
for name in &["sphere2500", "parking-garage", "intel", "M3500"] {
assert!(
registry.odometry.contains_key(*name),
"missing expected dataset: {name}"
);
}
Ok(())
}
#[test]
fn registry_contains_expected_ba_datasets() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
for name in &["ladybug", "trafalgar", "dubrovnik", "venice", "final"] {
assert!(
registry.bundle_adjustment.contains_key(*name),
"missing expected BA dataset: {name}"
);
}
Ok(())
}
#[test]
fn odometry_entries_have_valid_categories() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
for (name, entry) in ®istry.odometry {
assert!(
entry.category == "2d" || entry.category == "3d",
"dataset '{name}' has invalid category: '{}'",
entry.category
);
}
Ok(())
}
#[test]
fn ba_entries_have_at_least_one_problem() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
for (name, entry) in ®istry.bundle_adjustment {
assert!(
!entry.problems.is_empty(),
"BA dataset '{name}' has no problems listed"
);
}
Ok(())
}
#[test]
fn ba_problem_url_format_is_correct() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let ladybug = registry
.bundle_adjustment
.get("ladybug")
.ok_or_else(|| io::Error::other("ladybug dataset not found"))?;
let url = ladybug.problem_url(49, 7776);
assert_eq!(
url,
"https://grail.cs.washington.edu/projects/bal/data/ladybug/problem-49-7776-pre.txt.bz2"
);
Ok(())
}
#[test]
fn odometry_by_category_returns_only_3d() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let entries = registry.odometry_by_category("3d");
for (_, entry) in &entries {
assert_eq!(entry.category, "3d");
}
assert!(!entries.is_empty());
Ok(())
}
#[test]
fn sphere2500_uses_github_url() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let entry = registry
.odometry
.get("sphere2500")
.ok_or_else(|| io::Error::other("sphere2500 must exist"))?;
assert!(
entry.url.contains("github"),
"sphere2500 should use the GitHub URL, got: {}",
entry.url
);
Ok(())
}
#[test]
fn registry_contains_new_vertigo_datasets() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
for name in &["manhattanOlson3500", "ring", "ring_city", "city10000"] {
assert!(
registry.odometry.contains_key(*name),
"missing expected dataset: {name}"
);
}
Ok(())
}
#[test]
fn odometry_path_includes_category_subdir() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let path_3d = registry
.odometry_path("sphere2500")
.ok_or_else(|| io::Error::other("sphere2500 path not found"))?;
let path_2d = registry
.odometry_path("intel")
.ok_or_else(|| io::Error::other("intel path not found"))?;
assert!(
path_3d.components().any(|c| c.as_os_str() == "3d"),
"3D path should contain '3d' component, got: {}",
path_3d.display()
);
assert!(
path_2d.components().any(|c| c.as_os_str() == "2d"),
"2D path should contain '2d' component, got: {}",
path_2d.display()
);
Ok(())
}
#[test]
fn sphere_bignoise_removed_from_registry() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
assert!(
!registry.odometry.contains_key("sphere_bignoise"),
"sphere_bignoise should have been removed (merged into sphere2500)"
);
Ok(())
}
#[test]
fn ba_path_returns_correct_structure() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let path = registry
.ba_path("ladybug", 49, 7776)
.ok_or_else(|| io::Error::other("ladybug ba_path not found"))?;
assert!(
path.components()
.any(|c| c.as_os_str() == "bundle_adjustment"),
"path should contain 'bundle_adjustment', got: {}",
path.display()
);
assert!(
path.components().any(|c| c.as_os_str() == "ladybug"),
"path should contain 'ladybug', got: {}",
path.display()
);
assert!(
path.file_name()
.is_some_and(|f| f == "problem-49-7776-pre.txt"),
"filename should be 'problem-49-7776-pre.txt', got: {}",
path.display()
);
Ok(())
}
#[test]
fn ba_path_returns_none_for_unknown() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
assert!(
registry.ba_path("nonexistent_ba_xyz", 1, 1).is_none(),
"unknown BA name should return None"
);
Ok(())
}
#[test]
fn ba_sorted_returns_alphabetical_order() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let entries = registry.ba_sorted();
assert!(!entries.is_empty(), "ba_sorted should not be empty");
for window in entries.windows(2) {
assert!(
window[0].0 <= window[1].0,
"ba_sorted is not sorted: '{}' > '{}'",
window[0].0,
window[1].0
);
}
Ok(())
}
#[test]
fn ba_entry_largest_returns_last_problem() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let ladybug = registry
.bundle_adjustment
.get("ladybug")
.ok_or_else(|| io::Error::other("ladybug not found"))?;
let largest = ladybug.largest();
assert!(largest.is_some(), "ladybug should have a largest problem");
assert_eq!(
largest,
ladybug.problems.last().copied(),
"largest() should equal the last problem"
);
Ok(())
}
#[test]
fn ba_entry_largest_empty_returns_none() {
let entry = BaEntry {
url_prefix: "https://example.com".to_string(),
problems: vec![],
};
assert!(
entry.largest().is_none(),
"empty problems should return None"
);
}
#[test]
fn odometry_by_category_returns_only_2d() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let entries = registry.odometry_by_category("2d");
assert!(!entries.is_empty(), "should have at least one 2d dataset");
for (_, entry) in &entries {
assert_eq!(entry.category, "2d");
}
Ok(())
}
#[test]
fn odometry_by_category_is_sorted() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
let entries = registry.odometry_by_category("3d");
for window in entries.windows(2) {
assert!(
window[0].0 <= window[1].0,
"odometry_by_category is not sorted: '{}' > '{}'",
window[0].0,
window[1].0
);
}
Ok(())
}
#[test]
fn odometry_entries_have_nonempty_url_and_filename() -> io::Result<()> {
let registry = DatasetRegistry::load()?;
for (name, entry) in ®istry.odometry {
assert!(!entry.url.is_empty(), "dataset '{name}' has empty url");
assert!(
!entry.filename.is_empty(),
"dataset '{name}' has empty filename"
);
}
Ok(())
}
#[test]
fn decompress_bzip2_roundtrip() -> io::Result<()> {
use bzip2::Compression;
use bzip2::write::BzEncoder;
use std::io::Write as _;
let original = b"hello bzip2 roundtrip test data";
let tmp_dir = tempfile::tempdir()?;
let bz2_path = tmp_dir.path().join("test.txt.bz2");
let txt_path = tmp_dir.path().join("test.txt");
{
let file = fs::File::create(&bz2_path)?;
let mut encoder = BzEncoder::new(file, Compression::fast());
encoder.write_all(original)?;
encoder.finish()?;
}
decompress_bzip2(&bz2_path, &txt_path)?;
let decompressed = fs::read(&txt_path)?;
assert_eq!(
decompressed, original,
"decompressed content must match original"
);
Ok(())
}
#[test]
fn ensure_odometry_dataset_unknown_name_errors() -> io::Result<()> {
let err = ensure_odometry_dataset("nonexistent_dataset_xyz_abc")
.err()
.ok_or_else(|| io::Error::other("expected Err but got Ok"))?;
assert!(
err.to_string().contains("not found in registry"),
"error message should mention registry, got: {err}"
);
Ok(())
}
#[test]
fn ensure_ba_dataset_unknown_name_errors() -> io::Result<()> {
let err = ensure_ba_dataset("nonexistent_ba_xyz_abc", 1, 1)
.err()
.ok_or_else(|| io::Error::other("expected Err but got Ok"))?;
assert!(
err.to_string().contains("not found in registry"),
"error message should mention registry, got: {err}"
);
Ok(())
}
}