loftr 0.1.1

Native Rust/tch implementation of LoFTR feature matching
Documentation
use tch::{Device, Kind, Tensor};

use crate::{error::LoftrError, numeric::i64_to_f64};

#[derive(Debug)]
pub struct PositionEncodingSine {
    d_model: i64,
    temp_bug_fix: bool,
    pe: Tensor,
}

impl PositionEncodingSine {
    pub fn new(
        d_model: i64,
        max_shape: (i64, i64),
        temp_bug_fix: bool,
        device: Device,
    ) -> Result<Self, LoftrError> {
        let pe = create_position_encoding(d_model, max_shape, temp_bug_fix, device)?;
        Ok(Self {
            d_model,
            temp_bug_fix,
            pe,
        })
    }

    pub fn update_position_encoding_size(
        &mut self,
        max_shape: (i64, i64),
    ) -> Result<(), LoftrError> {
        self.pe =
            create_position_encoding(self.d_model, max_shape, self.temp_bug_fix, self.pe.device())?;
        Ok(())
    }

    pub fn forward(&mut self, x: &Tensor) -> Result<Tensor, LoftrError> {
        let dims = x.size();
        if dims.len() != 4 {
            return Err(LoftrError::InvalidInput(format!(
                "PositionEncodingSine expects [N,C,H,W]; got {dims:?}"
            )));
        }
        if dims[1] != self.d_model {
            return Err(LoftrError::InvalidInput(format!(
                "PositionEncodingSine d_model mismatch: expected {}, got {}",
                self.d_model, dims[1]
            )));
        }
        if dims[2] > self.pe.size()[2] || dims[3] > self.pe.size()[3] {
            let max_shape = (
                dims[2].max(self.pe.size()[2]),
                dims[3].max(self.pe.size()[3]),
            );
            self.update_position_encoding_size(max_shape)?;
        }
        Ok(x + self.pe.slice(2, 0, dims[2], 1).slice(3, 0, dims[3], 1))
    }
}

fn create_position_encoding(
    d_model: i64,
    max_shape: (i64, i64),
    temp_bug_fix: bool,
    device: Device,
) -> Result<Tensor, LoftrError> {
    if d_model <= 0 || d_model % 4 != 0 {
        return Err(LoftrError::InvalidConfig(format!(
            "PositionEncodingSine requires d_model > 0 and divisible by 4; got {d_model}"
        )));
    }
    if max_shape.0 <= 0 || max_shape.1 <= 0 {
        return Err(LoftrError::InvalidConfig(format!(
            "PositionEncodingSine requires positive max_shape; got {max_shape:?}"
        )));
    }

    let options = (Kind::Float, device);
    let pe = Tensor::zeros([d_model, max_shape.0, max_shape.1], options);
    let y_position = Tensor::ones([max_shape.0, max_shape.1], options)
        .cumsum(0, Kind::Float)
        .unsqueeze(0);
    let x_position = Tensor::ones([max_shape.0, max_shape.1], options)
        .cumsum(1, Kind::Float)
        .unsqueeze(0);
    let div_base = if temp_bug_fix {
        -(10000.0_f64.ln()) / i64_to_f64(d_model / 2, "position encoding half d_model")?
    } else {
        (-(10000.0_f64.ln()) / i64_to_f64(d_model, "position encoding d_model")? / 2.0).floor()
    };
    let div_term = (Tensor::arange_start_step(0, d_model / 2, 2, options) * div_base)
        .exp()
        .unsqueeze(1)
        .unsqueeze(2);

    pe.slice(0, 0, d_model, 4)
        .copy_(&(x_position.shallow_clone() * &div_term).sin());
    pe.slice(0, 1, d_model, 4)
        .copy_(&(x_position.shallow_clone() * &div_term).cos());
    pe.slice(0, 2, d_model, 4)
        .copy_(&(y_position.shallow_clone() * &div_term).sin());
    pe.slice(0, 3, d_model, 4)
        .copy_(&(y_position.shallow_clone() * &div_term).cos());
    Ok(pe.unsqueeze(0))
}

#[cfg(test)]
mod tests;