loftr 0.1.1

Native Rust/tch implementation of LoFTR feature matching
Documentation
use tch::{
    Device, Kind, Tensor,
    nn::{self},
};

use crate::{error::LoftrError, loftr_config::LoftrConfig};

#[derive(Debug)]
pub struct FinePreprocessData {
    pub hw0_f: (i64, i64),
    pub hw0_c: (i64, i64),
    pub b_ids: Tensor,
    pub i_ids: Tensor,
    pub j_ids: Tensor,
}

impl FinePreprocessData {
    pub fn stride(&self) -> Result<i64, LoftrError> {
        if self.hw0_f.0 <= 0 || self.hw0_f.1 <= 0 || self.hw0_c.0 <= 0 || self.hw0_c.1 <= 0 {
            return Err(LoftrError::InvalidConfig(format!(
                "FinePreprocessData requires positive fine/coarse shapes; got hw0_f={:?}, hw0_c={:?}",
                self.hw0_f, self.hw0_c
            )));
        }
        let stride_h = self.hw0_f.0 / self.hw0_c.0;
        if stride_h <= 0 {
            return Err(LoftrError::InvalidConfig(format!(
                "FinePreprocessData requires positive coarse stride; got hw0_f={:?}, hw0_c={:?}",
                self.hw0_f, self.hw0_c
            )));
        }
        Ok(stride_h)
    }

    pub fn match_count(&self) -> Result<i64, LoftrError> {
        let batch_count = first_dim(&self.b_ids, "b_ids")?;
        let i_count = first_dim(&self.i_ids, "i_ids")?;
        let j_count = first_dim(&self.j_ids, "j_ids")?;
        if batch_count != i_count || batch_count != j_count {
            return Err(LoftrError::InvalidConfig(format!(
                "FinePreprocessData index lengths must match; got b_ids={batch_count}, i_ids={i_count}, j_ids={j_count}"
            )));
        }
        Ok(batch_count)
    }
}

#[derive(Debug)]
pub struct FinePreprocess {
    cat_coarse_feat: bool,
    window_size: i64,
    d_model_f: i64,
    down_proj: Option<nn::Linear>,
    merge_feat: Option<nn::Linear>,
}

impl FinePreprocess {
    pub fn new(vs: &nn::Path<'_>, config: &LoftrConfig) -> Result<Self, LoftrError> {
        if config.fine_window_size <= 0 || config.fine_window_size % 2 == 0 {
            return Err(LoftrError::InvalidConfig(format!(
                "FinePreprocess requires a positive odd fine_window_size; got {}",
                config.fine_window_size
            )));
        }

        let d_model_c = config.coarse.d_model;
        let d_model_f = config.fine.d_model;
        let linear_config = nn::LinearConfig {
            ws_init: nn::init::DEFAULT_KAIMING_NORMAL,
            ..Default::default()
        };

        let (down_proj, merge_feat) = if config.fine_concat_coarse_feat {
            let down_proj = nn::linear(vs / "down_proj", d_model_c, d_model_f, linear_config);
            let merge_feat = nn::linear(vs / "merge_feat", 2 * d_model_f, d_model_f, linear_config);
            (Some(down_proj), Some(merge_feat))
        } else {
            (None, None)
        };

        Ok(Self {
            cat_coarse_feat: config.fine_concat_coarse_feat,
            window_size: config.fine_window_size,
            d_model_f,
            down_proj,
            merge_feat,
        })
    }

    pub fn forward(
        &self,
        fine_map0: &Tensor,
        fine_map1: &Tensor,
        coarse_tokens0: &Tensor,
        coarse_tokens1: &Tensor,
        data: &FinePreprocessData,
    ) -> Result<(Tensor, Tensor), LoftrError> {
        validate_fine_map(fine_map0, "fine_map0", self.d_model_f)?;
        validate_fine_map(fine_map1, "fine_map1", self.d_model_f)?;
        validate_coarse_sequence(coarse_tokens0, "coarse_tokens0")?;
        validate_coarse_sequence(coarse_tokens1, "coarse_tokens1")?;

        let match_count = data.match_count()?;
        let stride = data.stride()?;
        if match_count == 0 {
            let empty = Tensor::empty(
                [0, self.window_size * self.window_size, self.d_model_f],
                (Kind::Float, fine_map0.device()),
            );
            return Ok((empty.shallow_clone(), empty));
        }

        let fine_windows0 = unfold_local_windows(fine_map0, self.window_size, stride)?;
        let fine_windows1 = unfold_local_windows(fine_map1, self.window_size, stride)?;

        let fine_windows0 = select_unfold_windows(&fine_windows0, &data.b_ids, &data.i_ids)?;
        let fine_windows1 = select_unfold_windows(&fine_windows1, &data.b_ids, &data.j_ids)?;

        if !self.cat_coarse_feat {
            return Ok((fine_windows0, fine_windows1));
        }

        let down_proj = self.down_proj.as_ref().ok_or_else(|| {
            LoftrError::InvalidConfig(String::from(
                "FinePreprocess missing down_proj while fine_concat_coarse_feat is enabled",
            ))
        })?;
        let merge_feat = self.merge_feat.as_ref().ok_or_else(|| {
            LoftrError::InvalidConfig(String::from(
                "FinePreprocess missing merge_feat while fine_concat_coarse_feat is enabled",
            ))
        })?;

        let coarse_context0 = select_sequence_tokens(coarse_tokens0, &data.b_ids, &data.i_ids)?;
        let coarse_context1 = select_sequence_tokens(coarse_tokens1, &data.b_ids, &data.j_ids)?;
        let coarse_context = Tensor::cat(&[coarse_context0, coarse_context1], 0).apply(down_proj);
        let fine_windows = Tensor::cat(
            &[fine_windows0.shallow_clone(), fine_windows1.shallow_clone()],
            0,
        );
        let coarse_context =
            coarse_context
                .unsqueeze(1)
                .repeat([1, self.window_size * self.window_size, 1]);
        let merged = Tensor::cat(&[fine_windows, coarse_context], -1).apply(merge_feat);
        let chunks = merged.chunk(2, 0);
        Ok((chunks[0].shallow_clone(), chunks[1].shallow_clone()))
    }
}

fn first_dim(tensor: &Tensor, label: &str) -> Result<i64, LoftrError> {
    let dims = tensor.size();
    if dims.len() != 1 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocessData `{label}` must be rank-1; got {dims:?}"
        )));
    }
    Ok(dims[0])
}

fn validate_fine_map(
    tensor: &Tensor,
    label: &str,
    expected_channels: i64,
) -> Result<(), LoftrError> {
    let dims = tensor.size();
    if dims.len() != 4 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess `{label}` expects [N,C,H,W]; got {dims:?}"
        )));
    }
    if dims[1] != expected_channels {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess `{label}` expected {} channels; got {}",
            expected_channels, dims[1]
        )));
    }
    Ok(())
}

fn validate_coarse_sequence(tensor: &Tensor, label: &str) -> Result<(), LoftrError> {
    let dims = tensor.size();
    if dims.len() != 3 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess `{label}` expects [N,L,C]; got {dims:?}"
        )));
    }
    Ok(())
}

fn unfold_local_windows(
    feat: &Tensor,
    window_size: i64,
    stride: i64,
) -> Result<Tensor, LoftrError> {
    let dims = feat.size();
    let unfolded = feat.im2col(
        [window_size, window_size],
        [1, 1],
        [window_size / 2, window_size / 2],
        [stride, stride],
    );
    let unfolded_dims = unfolded.size();
    if unfolded_dims.len() != 3 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess im2col expected [N,C*W*W,L]; got {unfolded_dims:?}"
        )));
    }
    let window_area = window_size * window_size;
    if unfolded_dims[1] % window_area != 0 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess im2col channel area mismatch: {unfolded_dims:?} with window_area={window_area}"
        )));
    }
    let channels = unfolded_dims[1] / window_area;
    if channels != dims[1] {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess im2col changed channel count unexpectedly: expected {}, got {}",
            dims[1], channels
        )));
    }
    Ok(unfolded
        .reshape([dims[0], channels, window_area, unfolded_dims[2]])
        .permute([0, 3, 2, 1]))
}

fn select_unfold_windows(
    windows: &Tensor,
    b_ids: &Tensor,
    token_ids: &Tensor,
) -> Result<Tensor, LoftrError> {
    let dims = windows.size();
    if dims.len() != 4 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess windows expect [N,L,WW,C]; got {dims:?}"
        )));
    }
    let batch_offsets = normalize_index_tensor(b_ids, windows.device())? * dims[1];
    let token_ids = normalize_index_tensor(token_ids, windows.device())?;
    let linear_ids = batch_offsets + token_ids;
    Ok(windows
        .reshape([dims[0] * dims[1], dims[2], dims[3]])
        .index_select(0, &linear_ids))
}

fn select_sequence_tokens(
    sequence: &Tensor,
    b_ids: &Tensor,
    token_ids: &Tensor,
) -> Result<Tensor, LoftrError> {
    let dims = sequence.size();
    if dims.len() != 3 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess coarse sequence expects [N,L,C]; got {dims:?}"
        )));
    }
    let batch_offsets = normalize_index_tensor(b_ids, sequence.device())? * dims[1];
    let token_ids = normalize_index_tensor(token_ids, sequence.device())?;
    let linear_ids = batch_offsets + token_ids;
    Ok(sequence
        .reshape([dims[0] * dims[1], dims[2]])
        .index_select(0, &linear_ids))
}

fn normalize_index_tensor(indexes: &Tensor, device: Device) -> Result<Tensor, LoftrError> {
    let dims = indexes.size();
    if dims.len() != 1 {
        return Err(LoftrError::InvalidConfig(format!(
            "FinePreprocess indexes must be rank-1; got {dims:?}"
        )));
    }
    Ok(indexes.f_to_device(device)?.f_to_kind(Kind::Int64)?)
}

#[cfg(test)]
mod tests;