use super::registry::{LabelCategory, LabelDefinition, SemanticRegistry};
use anno_core::SpanCandidate;
#[derive(Debug, Clone, Copy)]
pub struct HandshakingCell {
pub i: u32,
pub j: u32,
pub label_idx: u16,
pub score: f32,
}
pub struct HandshakingMatrix {
pub cells: Vec<HandshakingCell>,
pub seq_len: usize,
pub num_labels: usize,
}
impl HandshakingMatrix {
pub fn from_dense(scores: &[f32], seq_len: usize, num_labels: usize, threshold: f32) -> Self {
let estimated_capacity = (seq_len * seq_len / 10).min(1000); let mut cells = Vec::with_capacity(estimated_capacity);
for i in 0..seq_len {
for j in i..seq_len {
for l in 0..num_labels {
let idx = i * seq_len * num_labels + j * num_labels + l;
if idx < scores.len() {
let score = scores[idx];
if score >= threshold {
cells.push(HandshakingCell {
i: i as u32,
j: j as u32,
label_idx: l as u16,
score,
});
}
}
}
}
}
Self {
cells,
seq_len,
num_labels,
}
}
pub fn decode_entities<'a>(
&self,
registry: &'a SemanticRegistry,
) -> Vec<(SpanCandidate, &'a LabelDefinition, f32)> {
let mut entities = Vec::new();
for cell in &self.cells {
if let Some(label) = registry.labels.get(cell.label_idx as usize) {
if label.category == LabelCategory::Entity {
entities.push((SpanCandidate::new(0, cell.j, cell.i + 1), label, cell.score));
}
}
}
entities.sort_unstable_by(|a, b| {
a.0.start
.cmp(&b.0.start)
.then_with(|| a.0.end.cmp(&b.0.end))
.then_with(|| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal))
});
let mut kept = Vec::with_capacity(entities.len().min(32));
for (span, label, score) in entities {
let overlaps = kept.iter().any(|(s, _, _): &(SpanCandidate, _, _)| {
!(span.end <= s.start || s.end <= span.start)
});
if !overlaps {
kept.push((span, label, score));
}
}
kept
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::inference::registry::ModalityHint;
use crate::Confidence;
use std::collections::HashMap;
#[test]
fn handshaking_from_dense_thresholding() {
let scores = vec![
0.1, 0.9, 0.0, 0.5, ];
let matrix = HandshakingMatrix::from_dense(&scores, 2, 1, 0.5);
assert_eq!(matrix.cells.len(), 2);
assert!((matrix.cells[0].score - 0.9).abs() < 1e-6);
assert_eq!(matrix.cells[0].i, 0);
assert_eq!(matrix.cells[0].j, 1);
assert!((matrix.cells[1].score - 0.5).abs() < 1e-6);
assert_eq!(matrix.cells[1].i, 1);
assert_eq!(matrix.cells[1].j, 1);
}
#[test]
fn handshaking_empty_when_all_below_threshold() {
let scores = vec![0.1, 0.2, 0.0, 0.3];
let matrix = HandshakingMatrix::from_dense(&scores, 2, 1, 0.5);
assert!(matrix.cells.is_empty());
}
#[test]
fn handshaking_decode_nms_removes_overlapping() {
let registry = SemanticRegistry {
embeddings: vec![0.0; 4],
hidden_dim: 4,
labels: vec![LabelDefinition {
slug: "PER".to_string(),
description: "Person".to_string(),
category: LabelCategory::Entity,
modality: ModalityHint::TextOnly,
threshold: Confidence::ZERO,
}],
label_index: {
let mut m = HashMap::new();
m.insert("PER".to_string(), 0);
m
},
};
let matrix = HandshakingMatrix {
cells: vec![
HandshakingCell {
i: 2,
j: 0,
label_idx: 0,
score: 0.9,
},
HandshakingCell {
i: 3,
j: 1,
label_idx: 0,
score: 0.8,
},
],
seq_len: 5,
num_labels: 1,
};
let entities = matrix.decode_entities(®istry);
assert_eq!(entities.len(), 1, "NMS should suppress overlapping span");
assert_eq!(entities[0].0.start, 0);
assert_eq!(entities[0].0.end, 3);
assert!((entities[0].2 - 0.9).abs() < 1e-6);
}
}