loftr 0.1.1

Native Rust/tch implementation of LoFTR feature matching
Documentation
use tch::{
    Tensor,
    nn::{self},
};

use crate::{
    error::LoftrError,
    linear_attention::{FullAttention, LinearAttention},
    loftr_config::{AttentionType, TransformerConfig, TransformerLayer},
};

#[derive(Debug)]
enum AttentionKind {
    Linear(LinearAttention),
    Full(FullAttention),
}

impl AttentionKind {
    fn forward(
        &self,
        queries: &Tensor,
        keys: &Tensor,
        values: &Tensor,
        q_mask: Option<&Tensor>,
        kv_mask: Option<&Tensor>,
    ) -> Result<Tensor, LoftrError> {
        match self {
            Self::Linear(attention) => attention.forward(queries, keys, values, q_mask, kv_mask),
            Self::Full(attention) => attention.forward(queries, keys, values, q_mask, kv_mask),
        }
    }
}

#[derive(Debug)]
pub struct LoFTREncoderLayer {
    dim: i64,
    nhead: i64,
    q_proj: nn::Linear,
    k_proj: nn::Linear,
    v_proj: nn::Linear,
    attention: AttentionKind,
    merge: nn::Linear,
    mlp: nn::Sequential,
    norm1: nn::LayerNorm,
    norm2: nn::LayerNorm,
}

impl LoFTREncoderLayer {
    pub fn new(
        vs: &nn::Path<'_>,
        d_model: i64,
        nhead: i64,
        attention: AttentionType,
    ) -> Result<Self, LoftrError> {
        if d_model <= 0 || nhead <= 0 || d_model % nhead != 0 {
            return Err(LoftrError::InvalidConfig(format!(
                "LoFTREncoderLayer requires positive d_model and nhead, with d_model divisible by nhead; got d_model={d_model}, nhead={nhead}"
            )));
        }

        let attention = match attention {
            AttentionType::Linear => AttentionKind::Linear(LinearAttention::default()),
            AttentionType::Full => AttentionKind::Full(FullAttention::default()),
        };
        let linear_config = nn::LinearConfig {
            bias: false,
            ..Default::default()
        };
        let mlp = nn::seq()
            .add(nn::linear(
                vs / "mlp" / "0",
                d_model * 2,
                d_model * 2,
                linear_config,
            ))
            .add_fn(Tensor::relu)
            .add(nn::linear(
                vs / "mlp" / "2",
                d_model * 2,
                d_model,
                linear_config,
            ));
        let layer_norm_config = nn::LayerNormConfig::default();

        Ok(Self {
            dim: d_model / nhead,
            nhead,
            q_proj: nn::linear(vs / "q_proj", d_model, d_model, linear_config),
            k_proj: nn::linear(vs / "k_proj", d_model, d_model, linear_config),
            v_proj: nn::linear(vs / "v_proj", d_model, d_model, linear_config),
            attention,
            merge: nn::linear(vs / "merge", d_model, d_model, linear_config),
            mlp,
            norm1: nn::layer_norm(vs / "norm1", vec![d_model], layer_norm_config),
            norm2: nn::layer_norm(vs / "norm2", vec![d_model], layer_norm_config),
        })
    }

    pub fn forward(
        &self,
        x: &Tensor,
        source: &Tensor,
        x_mask: Option<&Tensor>,
        source_mask: Option<&Tensor>,
    ) -> Result<Tensor, LoftrError> {
        validate_sequence_tensor(x, "x")?;
        validate_sequence_tensor(source, "source")?;
        if x.size()[0] != source.size()[0] || x.size()[2] != source.size()[2] {
            return Err(LoftrError::InvalidConfig(format!(
                "LoFTREncoderLayer source mismatch: x={:?}, source={:?}",
                x.size(),
                source.size()
            )));
        }

        let batch_size = x.size()[0];
        let query = x
            .apply(&self.q_proj)
            .view([batch_size, -1, self.nhead, self.dim]);
        let key = source
            .apply(&self.k_proj)
            .view([batch_size, -1, self.nhead, self.dim]);
        let value = source
            .apply(&self.v_proj)
            .view([batch_size, -1, self.nhead, self.dim]);
        let message = self
            .attention
            .forward(&query, &key, &value, x_mask, source_mask)?
            .view([batch_size, -1, self.nhead * self.dim])
            .apply(&self.merge)
            .apply(&self.norm1);
        let message = Tensor::cat(&[x, &message], 2)
            .apply(&self.mlp)
            .apply(&self.norm2);
        Ok(x + message)
    }
}

#[derive(Debug)]
pub struct LocalFeatureTransformer {
    d_model: i64,
    layer_types: Vec<TransformerLayer>,
    layers: Vec<LoFTREncoderLayer>,
}

impl LocalFeatureTransformer {
    pub fn new(vs: &nn::Path<'_>, config: &TransformerConfig) -> Result<Self, LoftrError> {
        let mut encoder_layers = Vec::with_capacity(config.layers.len());
        for (index, _) in config.layers.iter().enumerate() {
            encoder_layers.push(LoFTREncoderLayer::new(
                &(vs / "layers" / index.to_string()),
                config.d_model,
                config.nhead,
                config.attention,
            )?);
        }

        Ok(Self {
            d_model: config.d_model,
            layer_types: config.layers.clone(),
            layers: encoder_layers,
        })
    }

    pub fn forward(
        &self,
        feat0: &Tensor,
        feat1: &Tensor,
        mask0: Option<&Tensor>,
        mask1: Option<&Tensor>,
    ) -> Result<(Tensor, Tensor), LoftrError> {
        validate_sequence_tensor(feat0, "feat0")?;
        validate_sequence_tensor(feat1, "feat1")?;
        if feat0.size()[2] != self.d_model || feat1.size()[2] != self.d_model {
            return Err(LoftrError::InvalidConfig(format!(
                "LocalFeatureTransformer expected feature dim {}, got feat0={:?}, feat1={:?}",
                self.d_model,
                feat0.size(),
                feat1.size()
            )));
        }

        let mut feat0 = feat0.shallow_clone();
        let mut feat1 = feat1.shallow_clone();
        for (layer, layer_kind) in self.layers.iter().zip(self.layer_types.iter()) {
            match layer_kind {
                TransformerLayer::SelfAttention => {
                    feat0 = layer.forward(&feat0, &feat0, mask0, mask0)?;
                    feat1 = layer.forward(&feat1, &feat1, mask1, mask1)?;
                }
                TransformerLayer::CrossAttention => {
                    let next0 = layer.forward(&feat0, &feat1, mask0, mask1)?;
                    let next1 = layer.forward(&feat1, &next0, mask1, mask0)?;
                    feat0 = next0;
                    feat1 = next1;
                }
            }
        }

        Ok((feat0, feat1))
    }
}

fn validate_sequence_tensor(tensor: &Tensor, label: &str) -> Result<(), LoftrError> {
    let dims = tensor.size();
    if dims.len() != 3 {
        return Err(LoftrError::InvalidInput(format!(
            "LocalFeatureTransformer `{label}` expects [N,L,C]; got {dims:?}"
        )));
    }
    Ok(())
}

#[cfg(test)]
mod tests;