use crate::decode::EntitySpan;
#[derive(Debug)]
pub struct WindowEntities {
pub entities: Vec<EntitySpan>,
pub emit_start: usize,
pub emit_end: usize,
}
pub fn stitch(windows: Vec<WindowEntities>) -> Vec<EntitySpan> {
let mut kept: Vec<EntitySpan> = Vec::new();
for window in windows {
for entity in window.entities {
let midpoint = entity.start + (entity.end - entity.start) / 2;
if midpoint >= window.emit_start && midpoint < window.emit_end {
kept.push(entity);
}
}
}
deduplicate(kept)
}
fn deduplicate(mut entities: Vec<EntitySpan>) -> Vec<EntitySpan> {
if entities.len() <= 1 {
return entities;
}
entities.sort_by(|a, b| {
a.label
.cmp(&b.label)
.then(a.start.cmp(&b.start))
.then(a.end.cmp(&b.end))
});
let mut result: Vec<EntitySpan> = Vec::new();
let mut i = 0;
while i < entities.len() {
let label = entities[i].label.clone();
let mut cluster_end = entities[i].end;
let mut cluster_last = i + 1;
while cluster_last < entities.len()
&& entities[cluster_last].label == label
&& entities[cluster_last].start < cluster_end
{
cluster_end = cluster_end.max(entities[cluster_last].end);
cluster_last += 1;
}
let winner = entities[i..cluster_last]
.iter()
.max_by(|a, b| {
let a_len = a.end.saturating_sub(a.start);
let b_len = b.end.saturating_sub(b.start);
a_len
.cmp(&b_len)
.then_with(|| b.start.cmp(&a.start))
.then_with(|| b.end.cmp(&a.end))
})
.expect("cluster always contains at least one entity");
result.push(winner.clone());
i = cluster_last;
}
result.sort_by(|a, b| {
a.start
.cmp(&b.start)
.then(a.end.cmp(&b.end))
.then(a.label.cmp(&b.label))
});
result
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
use super::*;
use crate::decode::EntitySpan;
fn span(label: &str, start: usize, end: usize, text: &str) -> EntitySpan {
EntitySpan {
label: Cow::Owned(label.to_owned()),
start,
end,
text: text.to_owned(),
score: 1.0,
}
}
fn window(entities: Vec<EntitySpan>, emit_start: usize, emit_end: usize) -> WindowEntities {
WindowEntities {
entities,
emit_start,
emit_end,
}
}
#[test]
fn single_window_passthrough() {
let entities = vec![span("PER", 0, 8, "John Doe"), span("LOC", 13, 18, "Paris")];
let windows = vec![window(entities.clone(), 0, 100)];
let result = stitch(windows);
assert_eq!(result, entities);
}
#[test]
fn entity_assigned_to_correct_window_by_midpoint() {
let entity = span("PER", 40, 50, "some name");
let w0 = window(vec![entity.clone()], 0, 42);
let w1 = window(vec![entity.clone()], 42, 100);
let result = stitch(vec![w0, w1]);
assert_eq!(result, vec![entity]);
}
#[test]
fn duplicate_entity_from_two_windows_deduplicated() {
let entity = span("ORG", 50, 66, "Microsoft Spain");
let w0 = window(vec![entity.clone()], 0, 60);
let w1 = window(vec![entity.clone()], 55, 120);
let result = stitch(vec![w0, w1]);
assert_eq!(result, vec![entity]);
}
#[test]
fn overlapping_same_label_keeps_longest() {
let short = span("PER", 10, 18, "John Doe");
let long = span("PER", 10, 22, "John Doe III");
let w0 = window(vec![short], 0, 50);
let w1 = window(vec![long.clone()], 0, 50);
let result = stitch(vec![w0, w1]);
assert_eq!(result, vec![long]);
}
#[test]
fn overlapping_same_label_keeps_longest_even_if_it_starts_later() {
let earlier_short = span("PER", 10, 16, "John D");
let later_long = span("PER", 12, 24, "hn Doe Senior");
let w0 = window(vec![earlier_short], 0, 50);
let w1 = window(vec![later_long.clone()], 0, 50);
let result = stitch(vec![w0, w1]);
assert_eq!(result, vec![later_long]);
}
#[test]
fn different_labels_can_overlap() {
let org = span("ORG", 0, 10, "Apple Inc.");
let loc = span("LOC", 6, 15, "Inc. Park");
let w = window(vec![org.clone(), loc.clone()], 0, 100);
let result = stitch(vec![w]);
assert_eq!(result, vec![org, loc]);
}
#[test]
fn output_sorted_by_byte_offset() {
let first = span("PER", 10, 18, "John Doe");
let second = span("LOC", 50, 55, "Paris");
let w0 = window(vec![second.clone()], 40, 100);
let w1 = window(vec![first.clone()], 0, 40);
let result = stitch(vec![w0, w1]);
assert_eq!(result, vec![first, second]);
}
#[test]
fn empty_windows_produce_no_entities() {
let w0 = window(vec![], 0, 50);
let w1 = window(vec![], 50, 100);
let result = stitch(vec![w0, w1]);
assert!(result.is_empty());
}
#[test]
fn entity_outside_emit_region_is_dropped() {
let entity = span("PER", 0, 10, "John Doe");
let w = window(vec![entity], 10, 50);
let result = stitch(vec![w]);
assert!(result.is_empty());
}
}