use anyhow::{Result, bail};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Entity {
pub text: String,
pub label: String,
pub confidence: f32,
pub start: usize, pub end: usize, }
#[allow(clippy::too_many_arguments)] pub fn find_spans(
scores_flat: &[f32],
num_entities: usize,
l: usize,
max_width: usize,
threshold: f32,
labels: &[&str],
text: &str,
start_offsets: &[usize],
end_offsets: &[usize],
) -> Result<Vec<Entity>> {
let expected = num_entities * l * max_width;
if scores_flat.len() != expected {
bail!(
"find_spans: expected {} scores ({} * {} * {}), got {}",
expected,
num_entities,
l,
max_width,
scores_flat.len()
);
}
let mut out = Vec::new();
for (p, &label) in labels.iter().enumerate().take(num_entities) {
let mut per_label = Vec::new();
let base = p * (l * max_width);
for i in 0..l {
for j in 0..max_width {
let conf = scores_flat[base + i * max_width + j];
if conf >= threshold {
let end_token_idx = i + j;
if end_token_idx < l {
let char_start = start_offsets[i];
let char_end = end_offsets[end_token_idx];
let text_val = text[char_start..char_end].to_string();
per_label.push(Entity {
text: text_val,
label: label.to_string(),
confidence: conf,
start: char_start,
end: char_end,
});
}
}
}
}
out.extend(greedy_select(per_label));
}
Ok(out)
}
#[cfg(feature = "candle")]
pub fn find_spans_tensor(
scores: &candle_core::Tensor,
threshold: f32,
labels: &[&str],
text: &str,
start_offsets: &[usize],
end_offsets: &[usize],
) -> anyhow::Result<Vec<Entity>> {
let (num_entities, l, max_width) = scores.dims3().map_err(|e| anyhow::anyhow!("{e}"))?;
let scores_v = scores
.flatten_all()
.map_err(|e| anyhow::anyhow!("{e}"))?
.to_vec1::<f32>()
.map_err(|e| anyhow::anyhow!("{e}"))?;
find_spans(
&scores_v,
num_entities,
l,
max_width,
threshold,
labels,
text,
start_offsets,
end_offsets,
)
}
#[cfg(feature = "tch")]
pub fn find_spans_tch_tensor(
scores: &tch::Tensor,
threshold: f32,
labels: &[&str],
text: &str,
start_offsets: &[usize],
end_offsets: &[usize],
) -> Result<Vec<Entity>> {
let sz = scores.size();
if sz.len() != 3 {
bail!(
"find_spans_tch_tensor: expected 3D scores, got {} dims",
sz.len()
);
}
let num_entities = sz[0] as usize;
let l = sz[1] as usize;
let max_width = sz[2] as usize;
let n = num_entities * l * max_width;
let flat_t = scores.flatten(0, 2);
let mut scores_v = vec![0f32; n];
flat_t.copy_data(&mut scores_v, n);
find_spans(
&scores_v,
num_entities,
l,
max_width,
threshold,
labels,
text,
start_offsets,
end_offsets,
)
}
pub fn greedy_select(mut entities: Vec<Entity>) -> Vec<Entity> {
entities.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
let mut selected: Vec<Entity> = Vec::new();
for entity in entities {
let mut overlap = false;
for s in &selected {
if !(entity.end <= s.start || entity.start >= s.end) {
overlap = true;
break;
}
}
if !overlap {
selected.push(entity);
}
}
selected
}
#[cfg(test)]
mod tests {
use super::{Entity, greedy_select};
#[test]
fn per_label_greedy_keeps_overlapping_spans_for_different_labels() {
let a = Entity {
text: "foo".into(),
label: "A".into(),
confidence: 0.9,
start: 0,
end: 5,
};
let b = Entity {
text: "bar".into(),
label: "B".into(),
confidence: 0.5,
start: 2,
end: 7,
};
let global = greedy_select(vec![a.clone(), b.clone()]);
assert_eq!(global.len(), 1, "global NMS drops lower-confidence overlap");
let mut per_label = greedy_select(vec![a]);
per_label.extend(greedy_select(vec![b]));
assert_eq!(per_label.len(), 2, "per-label NMS matches GliNER2 engine");
}
}