use super::*;
#[derive(Debug, Clone, Copy)]
pub struct BackendWeight {
pub overall: f64,
pub per_type: Option<TypeWeights>,
}
impl Default for BackendWeight {
fn default() -> Self {
Self {
overall: 0.5,
per_type: None,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TypeWeights {
pub person: f64,
pub organization: f64,
pub location: f64,
pub date: f64,
pub money: f64,
pub other: f64,
}
impl TypeWeights {
pub(super) fn get(&self, entity_type: &EntityType) -> f64 {
match entity_type {
EntityType::Person => self.person,
EntityType::Organization => self.organization,
EntityType::Location => self.location,
EntityType::Date => self.date,
EntityType::Money => self.money,
_ => self.other,
}
}
}
pub(super) fn default_backend_weights() -> HashMap<&'static str, BackendWeight> {
let mut weights = HashMap::new();
weights.insert(
"regex",
BackendWeight {
overall: 0.98,
per_type: Some(TypeWeights {
date: 0.99,
money: 0.99,
person: 0.50, organization: 0.50,
location: 0.50,
other: 0.95, }),
},
);
weights.insert(
"gliner",
BackendWeight {
overall: 0.85,
per_type: Some(TypeWeights {
person: 0.90,
organization: 0.85,
location: 0.80,
date: 0.75,
money: 0.70,
other: 0.75,
}),
},
);
weights.insert(
"GLiNER-ONNX",
BackendWeight {
overall: 0.85,
per_type: Some(TypeWeights {
person: 0.90,
organization: 0.85,
location: 0.80,
date: 0.75,
money: 0.70,
other: 0.75,
}),
},
);
weights.insert(
"gliner-candle",
BackendWeight {
overall: 0.85,
per_type: None,
},
);
weights.insert(
"GLiNER-Candle",
BackendWeight {
overall: 0.85,
per_type: None,
},
);
weights.insert(
"bert-ner-onnx",
BackendWeight {
overall: 0.80,
per_type: None,
},
);
weights.insert(
"bert-onnx",
BackendWeight {
overall: 0.80,
per_type: None,
},
);
weights.insert(
"heuristic",
BackendWeight {
overall: 0.60,
per_type: Some(TypeWeights {
person: 0.65, organization: 0.70, location: 0.55, date: 0.40, money: 0.40,
other: 0.50,
}),
},
);
weights
}
#[derive(Debug, Clone)]
pub(super) struct Candidate {
pub(super) entity: Entity,
pub(super) source: String,
pub(super) backend_weight: f64,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(super) struct SpanKey {
pub(super) start: usize,
pub(super) end: usize,
}
impl SpanKey {
pub(super) fn from_entity(e: &Entity) -> Self {
Self {
start: e.start(),
end: e.end(),
}
}
pub(super) fn overlaps(&self, other: &SpanKey) -> bool {
let overlap_start = self.start.max(other.start);
let overlap_end = self.end.min(other.end);
if overlap_start >= overlap_end {
return false;
}
let overlap = overlap_end - overlap_start;
let smaller_span = (self.end - self.start).min(other.end - other.start);
(overlap as f64 / smaller_span as f64) > 0.5
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn backend_weight_default_is_neutral() {
let w = BackendWeight::default();
assert!(
(w.overall - 0.5).abs() < f64::EPSILON,
"default overall should be 0.5"
);
assert!(w.per_type.is_none(), "default per_type should be None");
}
#[test]
fn type_weights_get_returns_matching_field() {
let tw = TypeWeights {
person: 0.1,
organization: 0.2,
location: 0.3,
date: 0.4,
money: 0.5,
other: 0.6,
};
assert!((tw.get(&EntityType::Person) - 0.1).abs() < f64::EPSILON);
assert!((tw.get(&EntityType::Organization) - 0.2).abs() < f64::EPSILON);
assert!((tw.get(&EntityType::Location) - 0.3).abs() < f64::EPSILON);
assert!((tw.get(&EntityType::Date) - 0.4).abs() < f64::EPSILON);
assert!((tw.get(&EntityType::Money) - 0.5).abs() < f64::EPSILON);
}
#[test]
fn type_weights_get_falls_back_to_other() {
let tw = TypeWeights {
other: 0.77,
..TypeWeights::default()
};
assert!(
(tw.get(&EntityType::Email) - 0.77).abs() < f64::EPSILON,
"Email should fall back to `other`"
);
assert!(
(tw.get(&EntityType::Percent) - 0.77).abs() < f64::EPSILON,
"Percent should fall back to `other`"
);
}
#[test]
fn default_weights_contain_all_known_backends() {
let w = default_backend_weights();
let expected = [
"regex",
"gliner",
"GLiNER-ONNX",
"gliner-candle",
"bert-ner-onnx",
"heuristic",
];
for name in expected {
assert!(
w.contains_key(name),
"missing default weight for backend '{}'",
name
);
}
}
#[test]
fn default_weights_are_in_unit_range() {
let weights = default_backend_weights();
for (name, bw) in &weights {
assert!(
(0.0..=1.0).contains(&bw.overall),
"overall weight for '{}' out of range: {}",
name,
bw.overall
);
if let Some(ref tw) = bw.per_type {
for (label, val) in [
("person", tw.person),
("organization", tw.organization),
("location", tw.location),
("date", tw.date),
("money", tw.money),
("other", tw.other),
] {
assert!(
(0.0..=1.0).contains(&val),
"type weight '{}' for '{}' out of range: {}",
label,
name,
val
);
}
}
}
}
#[test]
fn span_key_from_entity_round_trips() {
let e = Entity::new("hello", EntityType::Person, 3, 8, 0.9);
let sk = SpanKey::from_entity(&e);
assert_eq!(sk.start, 3);
assert_eq!(sk.end, 8);
}
#[test]
fn span_key_no_overlap_when_disjoint() {
let a = SpanKey { start: 0, end: 5 };
let b = SpanKey { start: 10, end: 15 };
assert!(!a.overlaps(&b));
assert!(!b.overlaps(&a));
}
#[test]
fn span_key_overlap_threshold_boundary() {
let a = SpanKey { start: 0, end: 10 };
let b = SpanKey { start: 5, end: 15 };
assert!(
!a.overlaps(&b),
"exactly 50% overlap should be below the >0.5 threshold"
);
}
#[test]
fn span_key_overlap_just_above_threshold() {
let a = SpanKey { start: 0, end: 10 };
let b = SpanKey { start: 4, end: 14 };
assert!(
a.overlaps(&b),
"60% overlap should be above the >0.5 threshold"
);
assert!(b.overlaps(&a), "overlap should be symmetric");
}
#[test]
fn candidate_holds_source_and_weight() {
let e = Entity::new("ACME", EntityType::Organization, 0, 4, 0.85);
let c = Candidate {
entity: e.clone(),
source: "test-backend".to_string(),
backend_weight: 0.75,
};
assert_eq!(c.source, "test-backend");
assert!((c.backend_weight - 0.75).abs() < f64::EPSILON);
assert_eq!(c.entity.text, "ACME");
}
}