loftr 0.1.1

Native Rust/tch implementation of LoFTR feature matching
Documentation
use std::{
    collections::HashMap,
    error::Error,
    fs,
    path::{Path, PathBuf},
};

use loftr::{LoftrConfig, LoftrModel};
use tch::{Device, Tensor};

const FIXTURE_SETUP_HINT: &str =
    "run `./scripts/prepare_test_fixtures.sh` from the workspace root and retry";

pub(crate) struct ReferenceFixture {
    tensors: HashMap<String, Tensor>,
}

impl ReferenceFixture {
    pub(crate) fn load(stem: &str) -> Result<Self, Box<dyn Error>> {
        let metadata_path = fixture_path(&format!("{stem}.json"))?;
        let metadata = fs::read_to_string(&metadata_path)?;
        let _: serde_json::Value = serde_json::from_str(&metadata)?;

        let tensor_path = fixture_path(&format!("{stem}.safetensors"))?;
        let tensors = Tensor::read_safetensors(&tensor_path)?
            .into_iter()
            .collect::<HashMap<_, _>>();
        Ok(Self { tensors })
    }

    pub(crate) fn tensor(&self, name: &str) -> Result<Tensor, Box<dyn Error>> {
        self.tensors
            .get(name)
            .map(Tensor::shallow_clone)
            .ok_or_else(|| format!("missing reference tensor `{name}`").into())
    }
}

pub(crate) fn fixture_path(name: &str) -> Result<PathBuf, Box<dyn Error>> {
    let path = Path::new(env!("CARGO_MANIFEST_DIR"))
        .join("tests/data")
        .join(name);
    if path.is_file() {
        Ok(path)
    } else {
        Err(format!(
            "missing test fixture `{name}` at `{}`; {FIXTURE_SETUP_HINT}",
            path.display()
        )
        .into())
    }
}

pub(crate) fn load_model(
    weights_name: &str,
    config: LoftrConfig,
) -> Result<LoftrModel, Box<dyn Error>> {
    let weights = fixture_path(weights_name)?;
    let mut model = LoftrModel::new(Device::Cpu, config)?;
    model.load_weights(&weights)?;
    Ok(model)
}