loftr 0.1.1

Native Rust/tch implementation of LoFTR feature matching
Documentation
use tch::{
    Tensor,
    nn::{self},
};

use crate::{
    error::LoftrError,
    loftr_config::{LoftrConfig, ResNetFpnConfig},
};

fn conv1x1(vs: &nn::Path<'_>, in_planes: i64, out_planes: i64, stride: i64) -> nn::Conv2D {
    nn::conv2d(
        vs,
        in_planes,
        out_planes,
        1,
        nn::ConvConfig {
            stride,
            padding: 0,
            bias: false,
            ..Default::default()
        },
    )
}

fn conv3x3(vs: &nn::Path<'_>, in_planes: i64, out_planes: i64, stride: i64) -> nn::Conv2D {
    nn::conv2d(
        vs,
        in_planes,
        out_planes,
        3,
        nn::ConvConfig {
            stride,
            padding: 1,
            bias: false,
            ..Default::default()
        },
    )
}

#[derive(Debug)]
struct BasicBlock {
    conv1: nn::Conv2D,
    conv2: nn::Conv2D,
    bn1: nn::BatchNorm,
    bn2: nn::BatchNorm,
    downsample: Option<(nn::Conv2D, nn::BatchNorm)>,
}

impl BasicBlock {
    fn new(vs: &nn::Path<'_>, in_planes: i64, planes: i64, stride: i64) -> Self {
        let downsample = if stride == 1 {
            None
        } else {
            Some((
                conv1x1(&(vs / "downsample" / "0"), in_planes, planes, stride),
                nn::batch_norm2d(
                    &(vs / "downsample" / "1"),
                    planes,
                    nn::BatchNormConfig::default(),
                ),
            ))
        };
        Self {
            conv1: conv3x3(&(vs / "conv1"), in_planes, planes, stride),
            conv2: conv3x3(&(vs / "conv2"), planes, planes, 1),
            bn1: nn::batch_norm2d(&(vs / "bn1"), planes, nn::BatchNormConfig::default()),
            bn2: nn::batch_norm2d(&(vs / "bn2"), planes, nn::BatchNormConfig::default()),
            downsample,
        }
    }

    fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
        let y = x
            .apply(&self.conv1)
            .apply_t(&self.bn1, train)
            .relu()
            .apply(&self.conv2)
            .apply_t(&self.bn2, train);
        let residual = match &self.downsample {
            Some((conv, bn)) => x.apply(conv).apply_t(bn, train),
            None => x.shallow_clone(),
        };
        (residual + y).relu()
    }
}

#[derive(Debug)]
struct ResNetLayer {
    block1: BasicBlock,
    block2: BasicBlock,
}

impl ResNetLayer {
    fn new(vs: &nn::Path<'_>, in_planes: i64, dim: i64, stride: i64) -> (Self, i64) {
        let block1 = BasicBlock::new(&(vs / "0"), in_planes, dim, stride);
        let block2 = BasicBlock::new(&(vs / "1"), dim, dim, 1);
        (Self { block1, block2 }, dim)
    }

    fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
        let x = self.block1.forward_t(x, train);
        self.block2.forward_t(&x, train)
    }
}

#[derive(Debug)]
struct FpnHead {
    conv1: nn::Conv2D,
    bn: nn::BatchNorm,
    conv2: nn::Conv2D,
}

impl FpnHead {
    fn new(vs: &nn::Path<'_>, in_planes: i64, hidden_planes: i64, out_planes: i64) -> Self {
        Self {
            conv1: conv3x3(&(vs / "0"), in_planes, hidden_planes, 1),
            bn: nn::batch_norm2d(&(vs / "1"), hidden_planes, nn::BatchNormConfig::default()),
            conv2: conv3x3(&(vs / "3"), hidden_planes, out_planes, 1),
        }
    }

    fn forward_t(&self, x: &Tensor, train: bool) -> Tensor {
        x.apply(&self.conv1)
            .apply_t(&self.bn, train)
            .leaky_relu()
            .apply(&self.conv2)
    }
}

#[derive(Debug)]
pub struct ResNetFpn8_2 {
    conv1: nn::Conv2D,
    bn1: nn::BatchNorm,
    layer1: ResNetLayer,
    layer2: ResNetLayer,
    layer3: ResNetLayer,
    layer3_outconv: nn::Conv2D,
    layer2_outconv: nn::Conv2D,
    layer2_outconv2: FpnHead,
    layer1_outconv: nn::Conv2D,
    layer1_outconv2: FpnHead,
}

impl ResNetFpn8_2 {
    pub fn new(vs: &nn::Path<'_>, config: &ResNetFpnConfig) -> Result<Self, LoftrError> {
        let initial_dim = config.initial_dim;
        let block_dims = config.block_dims;
        if initial_dim <= 0 || block_dims.iter().any(|dim| *dim <= 0) {
            return Err(LoftrError::InvalidConfig(format!(
                "ResNetFpn8_2 requires positive dimensions; got initial_dim={initial_dim}, block_dims={block_dims:?}"
            )));
        }

        let conv1 = nn::conv2d(
            &(vs / "conv1"),
            1,
            initial_dim,
            7,
            nn::ConvConfig {
                stride: 2,
                padding: 3,
                bias: false,
                ..Default::default()
            },
        );
        let bn1 = nn::batch_norm2d(&(vs / "bn1"), initial_dim, nn::BatchNormConfig::default());
        let (layer1, next_planes) =
            ResNetLayer::new(&(vs / "layer1"), initial_dim, block_dims[0], 1);
        let (layer2, next_planes) =
            ResNetLayer::new(&(vs / "layer2"), next_planes, block_dims[1], 2);
        let (layer3, _) = ResNetLayer::new(&(vs / "layer3"), next_planes, block_dims[2], 2);

        Ok(Self {
            conv1,
            bn1,
            layer1,
            layer2,
            layer3,
            layer3_outconv: conv1x1(&(vs / "layer3_outconv"), block_dims[2], block_dims[2], 1),
            layer2_outconv: conv1x1(&(vs / "layer2_outconv"), block_dims[1], block_dims[2], 1),
            layer2_outconv2: FpnHead::new(
                &(vs / "layer2_outconv2"),
                block_dims[2],
                block_dims[2],
                block_dims[1],
            ),
            layer1_outconv: conv1x1(&(vs / "layer1_outconv"), block_dims[0], block_dims[1], 1),
            layer1_outconv2: FpnHead::new(
                &(vs / "layer1_outconv2"),
                block_dims[1],
                block_dims[1],
                block_dims[0],
            ),
        })
    }

    pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<(Tensor, Tensor), LoftrError> {
        let dims = x.size();
        if dims.len() != 4 || dims[1] != 1 {
            return Err(LoftrError::InvalidInput(format!(
                "ResNetFpn8_2 expects grayscale [N,1,H,W] input; got {dims:?}"
            )));
        }

        let x0 = x.apply(&self.conv1).apply_t(&self.bn1, train).relu();
        let x1 = self.layer1.forward_t(&x0, train);
        let x2 = self.layer2.forward_t(&x1, train);
        let x3 = self.layer3.forward_t(&x2, train);

        let x3_out = x3.apply(&self.layer3_outconv);
        let x2_out = x2.apply(&self.layer2_outconv);
        let x3_out_2x =
            x3_out.upsample_bilinear2d([x2_out.size()[2], x2_out.size()[3]], true, None, None);
        let x2_out = self.layer2_outconv2.forward_t(&(x2_out + x3_out_2x), train);

        let x1_out = x1.apply(&self.layer1_outconv);
        let x2_out_2x =
            x2_out.upsample_bilinear2d([x1_out.size()[2], x1_out.size()[3]], true, None, None);
        let x1_out = self.layer1_outconv2.forward_t(&(x1_out + x2_out_2x), train);

        Ok((x3_out, x1_out))
    }
}

#[derive(Debug)]
pub enum Backbone {
    ResNetFpn8_2(ResNetFpn8_2),
}

impl Backbone {
    pub fn forward_t(&self, x: &Tensor, train: bool) -> Result<(Tensor, Tensor), LoftrError> {
        match self {
            Self::ResNetFpn8_2(backbone) => backbone.forward_t(x, train),
        }
    }
}

pub fn build_backbone(vs: &nn::Path<'_>, config: &LoftrConfig) -> Result<Backbone, LoftrError> {
    match config.backbone_type {
        crate::loftr_config::BackboneType::ResNetFpn => match config.resolution {
            (8, 2) => Ok(Backbone::ResNetFpn8_2(ResNetFpn8_2::new(
                &(vs / "backbone"),
                &config.resnetfpn,
            )?)),
            other => Err(LoftrError::InvalidConfig(format!(
                "Unsupported LoFTR backbone resolution {other:?}; only (8, 2) is ported so far"
            ))),
        },
    }
}

#[cfg(test)]
mod tests;