use tch::{Kind, Tensor};
use crate::{
error::LoftrError,
loftr_config::{MatchCoarseConfig, MatchType},
numeric::{i64_pair_ratio_to_f64, i64_to_f64},
};
const INF: f64 = 1e9;
#[derive(Debug)]
pub struct CoarseMatchingData {
pub hw0_i: (i64, i64),
pub hw1_i: (i64, i64),
pub hw0_c: (i64, i64),
pub hw1_c: (i64, i64),
pub scale0: Option<Tensor>,
pub scale1: Option<Tensor>,
}
#[derive(Debug)]
pub struct CoarseMatchingOutput {
pub conf_matrix: Tensor,
pub b_ids: Tensor,
pub i_ids: Tensor,
pub j_ids: Tensor,
pub m_bids: Tensor,
pub mkpts0_c: Tensor,
pub mkpts1_c: Tensor,
pub mconf: Tensor,
}
#[derive(Debug, Clone)]
pub struct CoarseMatching {
config: MatchCoarseConfig,
}
impl CoarseMatching {
pub fn new(config: &MatchCoarseConfig) -> Self {
match config.match_type {
MatchType::DualSoftmax => Self {
config: config.clone(),
},
}
}
pub fn forward(
&self,
feat_c0: &Tensor,
feat_c1: &Tensor,
data: &CoarseMatchingData,
mask_c0: Option<&Tensor>,
mask_c1: Option<&Tensor>,
) -> Result<CoarseMatchingOutput, LoftrError> {
validate_coarse_matching_inputs(feat_c0, feat_c1, data, mask_c0, mask_c1)?;
let feat_c0 = feat_c0 / i64_to_f64(feat_c0.size()[2], "feat_c0 channel dim")?.sqrt();
let feat_c1 = feat_c1 / i64_to_f64(feat_c1.size()[2], "feat_c1 channel dim")?.sqrt();
let mut sim_matrix = match self.config.match_type {
MatchType::DualSoftmax => {
Tensor::einsum("nlc,nsc->nls", &[&feat_c0, &feat_c1], None::<&[i64]>)
/ self.config.dsmax_temperature
}
};
if let (Some(mask_c0), Some(mask_c1)) = (mask_c0, mask_c1) {
let valid = mask_c0
.f_to_device(sim_matrix.device())?
.f_to_kind(Kind::Bool)?
.unsqueeze(-1)
.logical_and(
&mask_c1
.f_to_device(sim_matrix.device())?
.f_to_kind(Kind::Bool)?
.unsqueeze(1),
);
sim_matrix = sim_matrix.f_masked_fill(&valid.logical_not(), -INF)?;
}
let conf_matrix = sim_matrix.softmax(1, Kind::Float) * sim_matrix.softmax(2, Kind::Float);
let threshold_mask =
confidence_threshold_mask(&conf_matrix, data, self.config.thr, self.config.border_rm);
let mutual_mask = conf_matrix
.f_eq_tensor(&conf_matrix.max_dim(2, true).0)?
.logical_and(&conf_matrix.f_eq_tensor(&conf_matrix.max_dim(1, true).0)?);
let matches = threshold_mask.logical_and(&mutual_mask);
let match_indices = matches.nonzero();
let (b_ids, i_ids, j_ids, mconf) = if match_indices.size()[0] == 0 {
empty_match_outputs(&conf_matrix)
} else {
let b_ids = match_indices.select(1, 0);
let i_ids = match_indices.select(1, 1);
let j_ids = match_indices.select(1, 2);
let linear_ids = &b_ids * (conf_matrix.size()[1] * conf_matrix.size()[2])
+ &i_ids * conf_matrix.size()[2]
+ &j_ids;
let mconf = conf_matrix.reshape([-1]).index_select(0, &linear_ids);
(b_ids, i_ids, j_ids, mconf)
};
let pred_mask = mconf.ne(0.0);
let m_bids = b_ids.masked_select(&pred_mask);
let mkpts0_c = coarse_points(
&i_ids,
data.hw0_c.1,
scaled_factors(data, &b_ids, true, conf_matrix.device())?,
);
let mkpts1_c = coarse_points(
&j_ids,
data.hw1_c.1,
scaled_factors(data, &b_ids, false, conf_matrix.device())?,
);
Ok(CoarseMatchingOutput {
conf_matrix,
b_ids,
i_ids,
j_ids,
m_bids,
mkpts0_c: mkpts0_c
.masked_select(&pred_mask.unsqueeze(1))
.reshape([-1, 2])
.to_kind(Kind::Float),
mkpts1_c: mkpts1_c
.masked_select(&pred_mask.unsqueeze(1))
.reshape([-1, 2])
.to_kind(Kind::Float),
mconf: mconf.masked_select(&pred_mask),
})
}
}
fn validate_coarse_matching_inputs(
feat_c0: &Tensor,
feat_c1: &Tensor,
data: &CoarseMatchingData,
mask_c0: Option<&Tensor>,
mask_c1: Option<&Tensor>,
) -> Result<(), LoftrError> {
let left_dims = feat_c0.size();
let right_dims = feat_c1.size();
if left_dims.len() != 3 || right_dims.len() != 3 {
return Err(LoftrError::InvalidConfig(format!(
"CoarseMatching expects [N,L,C] tensors; got feat_c0={left_dims:?}, feat_c1={right_dims:?}"
)));
}
if left_dims[0] != right_dims[0] || left_dims[2] != right_dims[2] {
return Err(LoftrError::InvalidConfig(format!(
"CoarseMatching batch/channel mismatch: feat_c0={left_dims:?}, feat_c1={right_dims:?}"
)));
}
if left_dims[1] != data.hw0_c.0 * data.hw0_c.1 || right_dims[1] != data.hw1_c.0 * data.hw1_c.1 {
return Err(LoftrError::InvalidConfig(format!(
"CoarseMatching hw*_c mismatch: feat_c0={left_dims:?}, feat_c1={right_dims:?}, hw0_c={:?}, hw1_c={:?}",
data.hw0_c, data.hw1_c
)));
}
if let Some(mask) = mask_c0 {
if mask.size() != vec![left_dims[0], left_dims[1]] {
return Err(LoftrError::InvalidConfig(format!(
"CoarseMatching mask_c0 expects [{}, {}]; got {:?}",
left_dims[0],
left_dims[1],
mask.size()
)));
}
}
if let Some(mask) = mask_c1 {
if mask.size() != vec![right_dims[0], right_dims[1]] {
return Err(LoftrError::InvalidConfig(format!(
"CoarseMatching mask_c1 expects [{}, {}]; got {:?}",
right_dims[0],
right_dims[1],
mask.size()
)));
}
}
Ok(())
}
fn confidence_threshold_mask(
conf_matrix: &Tensor,
data: &CoarseMatchingData,
threshold: f64,
border_rm: i64,
) -> Tensor {
let n = conf_matrix.size()[0];
let h0 = data.hw0_c.0;
let w0 = data.hw0_c.1;
let h1 = data.hw1_c.0;
let w1 = data.hw1_c.1;
let threshold_mask = conf_matrix.gt(threshold);
let border_mask = border_validity_mask(h0, w0, h1, w1, border_rm, conf_matrix.device());
threshold_mask
.reshape([n, h0, w0, h1, w1])
.logical_and(&border_mask)
.reshape([n, h0 * w0, h1 * w1])
}
fn border_validity_mask(
h0: i64,
w0: i64,
h1: i64,
w1: i64,
border_rm: i64,
device: tch::Device,
) -> Tensor {
if border_rm <= 0 {
return Tensor::ones([1, h0, w0, h1, w1], (Kind::Bool, device));
}
let y0 = border_axis_mask(h0, border_rm, device).view([1, h0, 1, 1, 1]);
let x0 = border_axis_mask(w0, border_rm, device).view([1, 1, w0, 1, 1]);
let y1 = border_axis_mask(h1, border_rm, device).view([1, 1, 1, h1, 1]);
let x1 = border_axis_mask(w1, border_rm, device).view([1, 1, 1, 1, w1]);
y0.logical_and(&x0).logical_and(&y1).logical_and(&x1)
}
fn border_axis_mask(length: i64, border_rm: i64, device: tch::Device) -> Tensor {
let indices = Tensor::arange(length, (Kind::Int64, device));
indices
.ge(border_rm)
.logical_and(&indices.lt(length - border_rm))
}
fn empty_match_outputs(conf_matrix: &Tensor) -> (Tensor, Tensor, Tensor, Tensor) {
let device = conf_matrix.device();
(
Tensor::zeros([0], (Kind::Int64, device)),
Tensor::zeros([0], (Kind::Int64, device)),
Tensor::zeros([0], (Kind::Int64, device)),
Tensor::zeros([0], (Kind::Float, device)),
)
}
fn scaled_factors(
data: &CoarseMatchingData,
b_ids: &Tensor,
use_left: bool,
device: tch::Device,
) -> Result<Tensor, LoftrError> {
let scale = if use_left {
i64_pair_ratio_to_f64(data.hw0_i.0, data.hw0_c.0, "left scale")?
} else {
i64_pair_ratio_to_f64(data.hw1_i.0, data.hw1_c.0, "right scale")?
};
let per_batch = if use_left { &data.scale0 } else { &data.scale1 };
match per_batch {
Some(per_batch) => Ok(per_batch
.f_to_device(device)?
.f_to_kind(Kind::Float)?
.index_select(0, &b_ids.f_to_device(device)?.f_to_kind(Kind::Int64)?)
.unsqueeze(1)
* scale),
None => Ok(Tensor::from(scale).to_device(device).reshape([1, 1])),
}
}
fn coarse_points(indices: &Tensor, width: i64, scale: Tensor) -> Tensor {
let indices = indices.to_kind(Kind::Int64);
let x = indices.remainder(width).to_kind(Kind::Float);
let y = indices.floor_divide_scalar(width).to_kind(Kind::Float);
Tensor::stack(&[x, y], 1) * scale
}
#[cfg(test)]
mod tests;