use crate::backends::inference::{HandshakingCell, HandshakingMatrix};
use crate::{EntityCategory, EntityType};
pub type DiscontinuousDecodeRow = (String, Vec<(usize, usize)>, f64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum W2NERRelation {
NNW,
THW,
None,
}
impl W2NERRelation {
#[must_use]
pub fn from_index(idx: usize) -> Self {
match idx {
0 => Self::None,
1 => Self::NNW,
2 => Self::THW,
_ => Self::None,
}
}
#[must_use]
pub fn to_index(self) -> usize {
match self {
Self::None => 0,
Self::NNW => 1,
Self::THW => 2,
}
}
}
#[must_use]
pub fn decode_from_matrix(
matrix: &HandshakingMatrix,
tokens: &[&str],
entity_type_idx: usize,
threshold: f32,
allow_nested: bool,
) -> Vec<(usize, usize, f64)> {
let mut entities = Vec::with_capacity(16);
for cell in &matrix.cells {
let relation = W2NERRelation::from_index(cell.label_idx as usize);
if relation == W2NERRelation::THW && cell.score >= threshold {
let tail = cell.i as usize;
let head = cell.j as usize;
if head <= tail && head < tokens.len() && tail < tokens.len() {
entities.push((head, tail + 1, cell.score as f64));
}
}
}
entities.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| (b.1 - b.0).cmp(&(a.1 - a.0))));
if !allow_nested {
entities = remove_nested(&entities);
}
let _ = entity_type_idx;
entities
}
#[must_use]
pub fn decode_discontinuous_from_matrix(
matrix: &HandshakingMatrix,
tokens: &[&str],
threshold: f32,
first_label: &str,
) -> Vec<DiscontinuousDecodeRow> {
let n = tokens.len();
let mut entity_boundaries: Vec<(usize, usize, f64)> = Vec::new();
for cell in &matrix.cells {
if W2NERRelation::from_index(cell.label_idx as usize) == W2NERRelation::THW
&& cell.score >= threshold
{
let tail = cell.i as usize;
let head = cell.j as usize;
if head <= tail && tail < n {
entity_boundaries.push((head, tail, cell.score as f64));
}
}
}
let mut nnw: std::collections::HashSet<(usize, usize)> = std::collections::HashSet::new();
for cell in &matrix.cells {
if W2NERRelation::from_index(cell.label_idx as usize) == W2NERRelation::NNW
&& cell.score >= threshold
{
let a = cell.i as usize;
let b = cell.j as usize;
nnw.insert((a, b));
nnw.insert((b, a));
}
}
let mut results: Vec<DiscontinuousDecodeRow> = Vec::new();
let type_label = if first_label.is_empty() {
"ENTITY".to_string()
} else {
first_label.to_string()
};
for (head, tail, score) in entity_boundaries {
let mut segments: Vec<(usize, usize)> = Vec::new();
let mut seg_start = head;
for i in head..tail {
let j = i + 1;
if !nnw.contains(&(i, j)) {
segments.push((seg_start, i + 1));
seg_start = j;
}
}
segments.push((seg_start, tail + 1));
results.push((type_label.clone(), segments, score));
}
results.sort_unstable_by(|a, b| {
let a_start = a.1.first().map(|s| s.0).unwrap_or(usize::MAX);
let b_start = b.1.first().map(|s| s.0).unwrap_or(usize::MAX);
let a_len: usize = a.1.iter().map(|(s, e)| e - s).sum();
let b_len: usize = b.1.iter().map(|(s, e)| e - s).sum();
a_start.cmp(&b_start).then_with(|| b_len.cmp(&a_len))
});
results
}
#[must_use]
pub fn grid_to_matrix(
grid: &[f32],
seq_len: usize,
num_relations: usize,
threshold: f32,
) -> HandshakingMatrix {
let mut cells = Vec::new();
for i in 0..seq_len {
for j in 0..seq_len {
for rel in 0..num_relations {
let idx = i * seq_len * num_relations + j * num_relations + rel;
if let Some(&score) = grid.get(idx) {
if score >= threshold && rel > 0 {
cells.push(HandshakingCell {
i: i as u32,
j: j as u32,
label_idx: rel as u16,
score,
});
}
}
}
}
}
HandshakingMatrix {
cells,
seq_len,
num_labels: num_relations,
}
}
pub(crate) fn remove_nested(entities: &[(usize, usize, f64)]) -> Vec<(usize, usize, f64)> {
let mut result = Vec::new();
let mut last_end = 0;
for &(start, end, score) in entities {
if start >= last_end {
result.push((start, end, score));
last_end = end;
}
}
result
}
#[must_use]
pub fn map_label_to_entity_type(label: &str) -> EntityType {
match label.to_uppercase().as_str() {
"PER" | "PERSON" => EntityType::Person,
"ORG" | "ORGANIZATION" => EntityType::Organization,
"LOC" | "LOCATION" | "GPE" => EntityType::Location,
"DATE" => EntityType::Date,
"TIME" => EntityType::Time,
"MONEY" => EntityType::Money,
"PERCENT" => EntityType::Percent,
"MISC" => EntityType::custom("MISC", EntityCategory::Misc),
_ => EntityType::custom(label, EntityCategory::Misc),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backends::inference::{HandshakingCell, HandshakingMatrix};
fn cell(i: u32, j: u32, rel: W2NERRelation, score: f32) -> HandshakingCell {
HandshakingCell {
i,
j,
label_idx: rel.to_index() as u16,
score,
}
}
fn mat(cells: Vec<HandshakingCell>, seq_len: usize) -> HandshakingMatrix {
HandshakingMatrix {
cells,
seq_len,
num_labels: 3,
}
}
#[test]
fn decode_single_contiguous_entity() {
let tokens = ["New", "York", "City"];
let m = mat(vec![cell(2, 0, W2NERRelation::THW, 0.9)], 3);
let result = decode_from_matrix(&m, &tokens, 0, 0.5, true);
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, 0); assert_eq!(result[0].1, 3); }
#[test]
fn decode_removes_nested_when_disabled() {
let tokens = ["The", "University", "of", "California"];
let m = mat(
vec![
cell(3, 0, W2NERRelation::THW, 0.8),
cell(3, 1, W2NERRelation::THW, 0.9),
],
4,
);
let nested = decode_from_matrix(&m, &tokens, 0, 0.5, true);
assert_eq!(nested.len(), 2, "should keep both when nested=true");
let flat = decode_from_matrix(&m, &tokens, 0, 0.5, false);
assert_eq!(flat.len(), 1, "should keep only outer when nested=false");
}
#[test]
fn decode_discontinuous_splits_on_nnw_gap() {
let tokens = ["severe", "pain", "in", "abdomen"];
let m = mat(
vec![
cell(3, 0, W2NERRelation::THW, 0.8),
cell(0, 1, W2NERRelation::NNW, 0.8),
cell(2, 3, W2NERRelation::NNW, 0.8),
],
4,
);
let result = decode_discontinuous_from_matrix(&m, &tokens, 0.5, "SYMPTOM");
assert_eq!(result.len(), 1);
let (label, spans, _score) = &result[0];
assert_eq!(label, "SYMPTOM");
assert_eq!(
spans.len(),
2,
"expected 2 disjoint segments; got {}",
spans.len()
);
assert_eq!(spans[0], (0, 2)); assert_eq!(spans[1], (2, 4)); }
#[test]
fn grid_to_matrix_filters_none_and_below_threshold() {
let mut grid = vec![0.0f32; 2 * 2 * 3];
grid[5] = 0.9; grid[4] = 0.2; let m = grid_to_matrix(&grid, 2, 3, 0.5);
assert_eq!(m.cells.len(), 1);
assert_eq!(m.cells[0].label_idx, 2);
}
#[test]
fn map_label_person_org_loc() {
use crate::EntityType;
assert_eq!(map_label_to_entity_type("PER"), EntityType::Person);
assert_eq!(map_label_to_entity_type("ORG"), EntityType::Organization);
assert_eq!(map_label_to_entity_type("GPE"), EntityType::Location);
assert!(matches!(
map_label_to_entity_type("CUSTOM"),
EntityType::Custom { .. }
));
}
}