use super::confidence::Confidence;
use super::types::MentionType;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum EntityCategory {
Agent,
Organization,
Place,
Creative,
Temporal,
Numeric,
Contact,
Relation,
Misc,
}
impl EntityCategory {
#[must_use]
pub const fn requires_ml(&self) -> bool {
matches!(
self,
EntityCategory::Agent
| EntityCategory::Organization
| EntityCategory::Place
| EntityCategory::Creative
| EntityCategory::Relation
)
}
#[must_use]
pub const fn pattern_detectable(&self) -> bool {
matches!(
self,
EntityCategory::Temporal | EntityCategory::Numeric | EntityCategory::Contact
)
}
#[must_use]
pub const fn is_relation(&self) -> bool {
matches!(self, EntityCategory::Relation)
}
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
EntityCategory::Agent => "agent",
EntityCategory::Organization => "organization",
EntityCategory::Place => "place",
EntityCategory::Creative => "creative",
EntityCategory::Temporal => "temporal",
EntityCategory::Numeric => "numeric",
EntityCategory::Contact => "contact",
EntityCategory::Relation => "relation",
EntityCategory::Misc => "misc",
}
}
}
impl std::fmt::Display for EntityCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum EntityType {
Person,
Organization,
Location,
Date,
Time,
Money,
Percent,
Quantity,
Cardinal,
Ordinal,
Email,
Url,
Phone,
Custom {
name: String,
category: EntityCategory,
},
}
impl EntityType {
#[must_use]
pub fn category(&self) -> EntityCategory {
match self {
EntityType::Person => EntityCategory::Agent,
EntityType::Organization => EntityCategory::Organization,
EntityType::Location => EntityCategory::Place,
EntityType::Date | EntityType::Time => EntityCategory::Temporal,
EntityType::Money
| EntityType::Percent
| EntityType::Quantity
| EntityType::Cardinal
| EntityType::Ordinal => EntityCategory::Numeric,
EntityType::Email | EntityType::Url | EntityType::Phone => EntityCategory::Contact,
EntityType::Custom { category, .. } => *category,
}
}
#[must_use]
pub fn requires_ml(&self) -> bool {
self.category().requires_ml()
}
#[must_use]
pub fn pattern_detectable(&self) -> bool {
self.category().pattern_detectable()
}
#[must_use]
pub fn as_label(&self) -> &str {
match self {
EntityType::Person => "PER",
EntityType::Organization => "ORG",
EntityType::Location => "LOC",
EntityType::Date => "DATE",
EntityType::Time => "TIME",
EntityType::Money => "MONEY",
EntityType::Percent => "PERCENT",
EntityType::Quantity => "QUANTITY",
EntityType::Cardinal => "CARDINAL",
EntityType::Ordinal => "ORDINAL",
EntityType::Email => "EMAIL",
EntityType::Url => "URL",
EntityType::Phone => "PHONE",
EntityType::Custom { name, .. } => name.as_str(),
}
}
#[must_use]
pub fn from_label(label: &str) -> Self {
let label = label
.strip_prefix("B-")
.or_else(|| label.strip_prefix("I-"))
.or_else(|| label.strip_prefix("E-"))
.or_else(|| label.strip_prefix("S-"))
.unwrap_or(label);
match label.to_uppercase().as_str() {
"PER" | "PERSON" => EntityType::Person,
"ORG" | "ORGANIZATION" | "COMPANY" | "CORPORATION" => EntityType::Organization,
"LOC" | "LOCATION" | "GPE" | "GEO-LOC" => EntityType::Location,
"FACILITY" | "FAC" | "BUILDING" => {
EntityType::custom("BUILDING", EntityCategory::Place)
}
"PRODUCT" | "PROD" => EntityType::custom("PRODUCT", EntityCategory::Misc),
"EVENT" => EntityType::custom("EVENT", EntityCategory::Creative),
"CREATIVE-WORK" | "WORK_OF_ART" | "ART" => {
EntityType::custom("CREATIVE_WORK", EntityCategory::Creative)
}
"GROUP" | "NORP" => EntityType::custom("GROUP", EntityCategory::Agent),
"DATE" => EntityType::Date,
"TIME" => EntityType::Time,
"MONEY" | "CURRENCY" => EntityType::Money,
"PERCENT" | "PERCENTAGE" => EntityType::Percent,
"QUANTITY" => EntityType::Quantity,
"CARDINAL" => EntityType::Cardinal,
"ORDINAL" => EntityType::Ordinal,
"EMAIL" => EntityType::Email,
"URL" | "URI" => EntityType::Url,
"PHONE" | "TELEPHONE" => EntityType::Phone,
"MISC" | "MISCELLANEOUS" | "OTHER" => EntityType::custom("MISC", EntityCategory::Misc),
"DISEASE" | "DISORDER" => EntityType::custom("DISEASE", EntityCategory::Misc),
"CHEMICAL" | "DRUG" => EntityType::custom("CHEMICAL", EntityCategory::Misc),
"GENE" => EntityType::custom("GENE", EntityCategory::Misc),
"PROTEIN" => EntityType::custom("PROTEIN", EntityCategory::Misc),
other => EntityType::custom(other, EntityCategory::Misc),
}
}
#[must_use]
pub fn custom(name: impl Into<String>, category: EntityCategory) -> Self {
EntityType::Custom {
name: name.into(),
category,
}
}
}
impl std::fmt::Display for EntityType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_label())
}
}
impl std::str::FromStr for EntityType {
type Err = std::convert::Infallible;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Ok(Self::from_label(s))
}
}
impl Serialize for EntityType {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(self.as_label())
}
}
impl<'de> Deserialize<'de> for EntityType {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct EntityTypeVisitor;
impl<'de> serde::de::Visitor<'de> for EntityTypeVisitor {
type Value = EntityType;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a string label or a tagged enum object")
}
fn visit_str<E: serde::de::Error>(self, v: &str) -> Result<EntityType, E> {
Ok(EntityType::from_label(v))
}
fn visit_map<A: serde::de::MapAccess<'de>>(
self,
mut map: A,
) -> Result<EntityType, A::Error> {
let key: String = map
.next_key()?
.ok_or_else(|| serde::de::Error::custom("empty object"))?;
match key.as_str() {
"Custom" => {
#[derive(Deserialize)]
struct CustomFields {
name: String,
category: EntityCategory,
}
let fields: CustomFields = map.next_value()?;
Ok(EntityType::Custom {
name: fields.name,
category: fields.category,
})
}
"Other" => {
let val: String = map.next_value()?;
Ok(EntityType::custom(val, EntityCategory::Misc))
}
variant => {
let _: serde::de::IgnoredAny = map.next_value()?;
Ok(EntityType::from_label(variant))
}
}
}
}
deserializer.deserialize_any(EntityTypeVisitor)
}
}
#[derive(Debug, Clone, Default)]
pub struct TypeMapper {
mappings: std::collections::HashMap<String, EntityType>,
}
impl TypeMapper {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn mit_movie() -> Self {
let mut mapper = Self::new();
mapper.add("ACTOR", EntityType::Person);
mapper.add("DIRECTOR", EntityType::Person);
mapper.add("CHARACTER", EntityType::Person);
mapper.add(
"TITLE",
EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
);
mapper.add("GENRE", EntityType::custom("GENRE", EntityCategory::Misc));
mapper.add("YEAR", EntityType::Date);
mapper.add("RATING", EntityType::custom("RATING", EntityCategory::Misc));
mapper.add("PLOT", EntityType::custom("PLOT", EntityCategory::Misc));
mapper
}
#[must_use]
pub fn mit_restaurant() -> Self {
let mut mapper = Self::new();
mapper.add("RESTAURANT_NAME", EntityType::Organization);
mapper.add("LOCATION", EntityType::Location);
mapper.add(
"CUISINE",
EntityType::custom("CUISINE", EntityCategory::Misc),
);
mapper.add("DISH", EntityType::custom("DISH", EntityCategory::Misc));
mapper.add("PRICE", EntityType::Money);
mapper.add(
"AMENITY",
EntityType::custom("AMENITY", EntityCategory::Misc),
);
mapper.add("HOURS", EntityType::Time);
mapper
}
#[must_use]
pub fn biomedical() -> Self {
let mut mapper = Self::new();
mapper.add(
"DISEASE",
EntityType::custom("DISEASE", EntityCategory::Agent),
);
mapper.add(
"CHEMICAL",
EntityType::custom("CHEMICAL", EntityCategory::Misc),
);
mapper.add("DRUG", EntityType::custom("DRUG", EntityCategory::Misc));
mapper.add("GENE", EntityType::custom("GENE", EntityCategory::Misc));
mapper.add(
"PROTEIN",
EntityType::custom("PROTEIN", EntityCategory::Misc),
);
mapper.add("DNA", EntityType::custom("DNA", EntityCategory::Misc));
mapper.add("RNA", EntityType::custom("RNA", EntityCategory::Misc));
mapper.add(
"cell_line",
EntityType::custom("CELL_LINE", EntityCategory::Misc),
);
mapper.add(
"cell_type",
EntityType::custom("CELL_TYPE", EntityCategory::Misc),
);
mapper
}
#[must_use]
pub fn social_media() -> Self {
let mut mapper = Self::new();
mapper.add("person", EntityType::Person);
mapper.add("corporation", EntityType::Organization);
mapper.add("location", EntityType::Location);
mapper.add("group", EntityType::Organization);
mapper.add(
"product",
EntityType::custom("PRODUCT", EntityCategory::Misc),
);
mapper.add(
"creative_work",
EntityType::custom("WORK_OF_ART", EntityCategory::Creative),
);
mapper.add("event", EntityType::custom("EVENT", EntityCategory::Misc));
mapper
}
#[must_use]
pub fn manufacturing() -> Self {
let mut mapper = Self::new();
mapper.add("MATE", EntityType::custom("MATERIAL", EntityCategory::Misc));
mapper.add("MANP", EntityType::custom("PROCESS", EntityCategory::Misc));
mapper.add("MACEQ", EntityType::custom("MACHINE", EntityCategory::Misc));
mapper.add(
"APPL",
EntityType::custom("APPLICATION", EntityCategory::Misc),
);
mapper.add("FEAT", EntityType::custom("FEATURE", EntityCategory::Misc));
mapper.add(
"PARA",
EntityType::custom("PARAMETER", EntityCategory::Misc),
);
mapper.add("PRO", EntityType::custom("PROPERTY", EntityCategory::Misc));
mapper.add(
"CHAR",
EntityType::custom("CHARACTERISTIC", EntityCategory::Misc),
);
mapper.add(
"ENAT",
EntityType::custom("ENABLING_TECHNOLOGY", EntityCategory::Misc),
);
mapper.add(
"CONPRI",
EntityType::custom("CONCEPT_PRINCIPLE", EntityCategory::Misc),
);
mapper.add(
"BIOP",
EntityType::custom("BIO_PROCESS", EntityCategory::Misc),
);
mapper.add(
"MANS",
EntityType::custom("MAN_STANDARD", EntityCategory::Misc),
);
mapper
}
pub fn add(&mut self, source: impl Into<String>, target: EntityType) {
self.mappings.insert(source.into().to_uppercase(), target);
}
#[must_use]
pub fn map(&self, label: &str) -> Option<&EntityType> {
self.mappings.get(&label.to_uppercase())
}
#[must_use]
pub fn normalize(&self, label: &str) -> EntityType {
self.map(label)
.cloned()
.unwrap_or_else(|| EntityType::from_label(label))
}
#[must_use]
pub fn contains(&self, label: &str) -> bool {
self.mappings.contains_key(&label.to_uppercase())
}
pub fn labels(&self) -> impl Iterator<Item = &String> {
self.mappings.keys()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[non_exhaustive]
pub enum ExtractionMethod {
Pattern,
#[default]
Neural,
Consensus,
Heuristic,
Unknown,
}
impl ExtractionMethod {
#[must_use]
pub const fn is_calibrated(&self) -> bool {
match self {
ExtractionMethod::Neural => true,
ExtractionMethod::Pattern => false,
ExtractionMethod::Consensus => false,
ExtractionMethod::Heuristic => false,
ExtractionMethod::Unknown => false,
}
}
#[must_use]
pub const fn confidence_interpretation(&self) -> &'static str {
match self {
ExtractionMethod::Neural => "probability",
ExtractionMethod::Pattern => "binary",
ExtractionMethod::Heuristic => "heuristic_score",
ExtractionMethod::Consensus => "agreement_ratio",
ExtractionMethod::Unknown => "unknown",
}
}
}
impl std::fmt::Display for ExtractionMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExtractionMethod::Pattern => write!(f, "pattern"),
ExtractionMethod::Neural => write!(f, "neural"),
ExtractionMethod::Consensus => write!(f, "consensus"),
ExtractionMethod::Heuristic => write!(f, "heuristic"),
ExtractionMethod::Unknown => write!(f, "unknown"),
}
}
}
pub trait Lexicon: Send + Sync {
fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)>;
fn contains(&self, text: &str) -> bool {
self.lookup(text).is_some()
}
fn source(&self) -> &str;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone)]
pub struct HashMapLexicon {
entries: std::collections::HashMap<String, (EntityType, Confidence)>,
source: String,
}
impl HashMapLexicon {
#[must_use]
pub fn new(source: impl Into<String>) -> Self {
Self {
entries: std::collections::HashMap::new(),
source: source.into(),
}
}
pub fn insert(
&mut self,
text: impl Into<String>,
entity_type: EntityType,
confidence: impl Into<Confidence>,
) {
self.entries
.insert(text.into(), (entity_type, confidence.into()));
}
pub fn from_iter<I, S, C>(source: impl Into<String>, entries: I) -> Self
where
I: IntoIterator<Item = (S, EntityType, C)>,
S: Into<String>,
C: Into<Confidence>,
{
let mut lexicon = Self::new(source);
for (text, entity_type, conf) in entries {
lexicon.insert(text, entity_type, conf);
}
lexicon
}
pub fn entries(&self) -> impl Iterator<Item = (&str, &EntityType, Confidence)> {
self.entries.iter().map(|(k, (t, c))| (k.as_str(), t, *c))
}
}
impl Lexicon for HashMapLexicon {
fn lookup(&self, text: &str) -> Option<(EntityType, Confidence)> {
self.entries.get(text).cloned()
}
fn source(&self) -> &str {
&self.source
}
fn len(&self) -> usize {
self.entries.len()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct Provenance {
pub source: Cow<'static, str>,
pub method: ExtractionMethod,
pub pattern: Option<Cow<'static, str>>,
pub raw_confidence: Option<Confidence>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_version: Option<Cow<'static, str>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timestamp: Option<String>,
}
impl Provenance {
#[must_use]
pub fn pattern(pattern_name: &'static str) -> Self {
Self {
source: Cow::Borrowed("pattern"),
method: ExtractionMethod::Pattern,
pattern: Some(Cow::Borrowed(pattern_name)),
raw_confidence: Some(Confidence::ONE), model_version: None,
timestamp: None,
}
}
#[must_use]
pub fn ml(model_name: impl Into<Cow<'static, str>>, confidence: impl Into<Confidence>) -> Self {
Self {
source: model_name.into(),
method: ExtractionMethod::Neural,
pattern: None,
raw_confidence: Some(confidence.into()),
model_version: None,
timestamp: None,
}
}
#[must_use]
pub fn ensemble(sources: &'static str) -> Self {
Self {
source: Cow::Borrowed(sources),
method: ExtractionMethod::Consensus,
pattern: None,
raw_confidence: None,
model_version: None,
timestamp: None,
}
}
#[must_use]
pub fn with_version(mut self, version: &'static str) -> Self {
self.model_version = Some(Cow::Borrowed(version));
self
}
#[must_use]
pub fn with_timestamp(mut self, timestamp: impl Into<String>) -> Self {
self.timestamp = Some(timestamp.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Span {
Text {
start: usize,
end: usize,
},
BoundingBox {
x: f32,
y: f32,
width: f32,
height: f32,
page: Option<u32>,
},
Hybrid {
start: usize,
end: usize,
bbox: Box<Span>,
},
}
impl Span {
#[must_use]
pub const fn text(start: usize, end: usize) -> Self {
Self::Text { start, end }
}
#[must_use]
pub fn bbox(x: f32, y: f32, width: f32, height: f32) -> Self {
Self::BoundingBox {
x,
y,
width,
height,
page: None,
}
}
#[must_use]
pub fn bbox_on_page(x: f32, y: f32, width: f32, height: f32, page: u32) -> Self {
Self::BoundingBox {
x,
y,
width,
height,
page: Some(page),
}
}
#[must_use]
pub const fn is_text(&self) -> bool {
matches!(self, Self::Text { .. } | Self::Hybrid { .. })
}
#[must_use]
pub const fn is_visual(&self) -> bool {
matches!(self, Self::BoundingBox { .. } | Self::Hybrid { .. })
}
#[must_use]
pub const fn text_offsets(&self) -> Option<(usize, usize)> {
match self {
Self::Text { start, end } => Some((*start, *end)),
Self::Hybrid { start, end, .. } => Some((*start, *end)),
Self::BoundingBox { .. } => None,
}
}
#[must_use]
pub fn len(&self) -> usize {
match self {
Self::Text { start, end } => end.saturating_sub(*start),
Self::Hybrid { start, end, .. } => end.saturating_sub(*start),
Self::BoundingBox { .. } => 0,
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct DiscontinuousSpan {
segments: Vec<std::ops::Range<usize>>,
}
impl<'de> serde::Deserialize<'de> for DiscontinuousSpan {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct Raw {
segments: Vec<std::ops::Range<usize>>,
}
let raw = Raw::deserialize(deserializer)?;
Ok(Self::new(raw.segments))
}
}
impl DiscontinuousSpan {
#[must_use]
pub fn new(mut segments: Vec<std::ops::Range<usize>>) -> Self {
segments.retain(|r| r.start < r.end);
segments.sort_by_key(|r| r.start);
let mut merged: Vec<std::ops::Range<usize>> = Vec::with_capacity(segments.len());
for seg in segments {
if let Some(last) = merged.last_mut() {
if seg.start <= last.end {
last.end = last.end.max(seg.end);
continue;
}
}
merged.push(seg);
}
Self { segments: merged }
}
#[must_use]
#[allow(clippy::single_range_in_vec_init)] pub fn contiguous(start: usize, end: usize) -> Self {
let (lo, hi) = if start <= end {
(start, end)
} else {
(end, start)
};
if lo == hi {
Self {
segments: Vec::new(),
}
} else {
Self {
segments: vec![lo..hi],
}
}
}
#[must_use]
pub fn num_segments(&self) -> usize {
self.segments.len()
}
#[must_use]
pub fn is_discontinuous(&self) -> bool {
self.segments.len() > 1
}
#[must_use]
pub fn is_contiguous(&self) -> bool {
self.segments.len() <= 1
}
#[must_use]
pub fn segments(&self) -> &[std::ops::Range<usize>] {
&self.segments
}
#[must_use]
pub fn bounding_range(&self) -> Option<std::ops::Range<usize>> {
if self.segments.is_empty() {
return None;
}
let start = self.segments.first()?.start;
let end = self.segments.last()?.end;
Some(start..end)
}
#[must_use]
pub fn total_len(&self) -> usize {
self.segments.iter().map(|r| r.end - r.start).sum()
}
#[must_use]
pub fn extract_text(&self, text: &str, separator: &str) -> String {
self.segments
.iter()
.map(|r| {
let start = r.start;
let len = r.end.saturating_sub(r.start);
text.chars().skip(start).take(len).collect::<String>()
})
.collect::<Vec<_>>()
.join(separator)
}
#[must_use]
pub fn contains(&self, pos: usize) -> bool {
self.segments.iter().any(|r| r.contains(&pos))
}
#[must_use]
pub fn to_span(&self) -> Option<Span> {
self.bounding_range().map(|r| Span::Text {
start: r.start,
end: r.end,
})
}
}
impl From<std::ops::Range<usize>> for DiscontinuousSpan {
fn from(range: std::ops::Range<usize>) -> Self {
Self::contiguous(range.start, range.end)
}
}
impl Default for Span {
fn default() -> Self {
Self::Text { start: 0, end: 0 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct HierarchicalConfidence {
pub linkage: Confidence,
pub type_score: Confidence,
pub boundary: Confidence,
}
impl HierarchicalConfidence {
#[must_use]
pub fn new(
linkage: impl Into<Confidence>,
type_score: impl Into<Confidence>,
boundary: impl Into<Confidence>,
) -> Self {
Self {
linkage: linkage.into(),
type_score: type_score.into(),
boundary: boundary.into(),
}
}
#[must_use]
pub fn from_single(confidence: impl Into<Confidence>) -> Self {
let c = confidence.into();
Self {
linkage: c,
type_score: c,
boundary: c,
}
}
#[must_use]
pub fn combined(&self) -> Confidence {
let product = self.linkage.value() * self.type_score.value() * self.boundary.value();
Confidence::new(product.powf(1.0 / 3.0))
}
#[must_use]
pub fn as_f64(&self) -> f64 {
self.combined().value()
}
#[must_use]
pub fn passes_threshold(&self, linkage_min: f64, type_min: f64, boundary_min: f64) -> bool {
self.linkage >= linkage_min && self.type_score >= type_min && self.boundary >= boundary_min
}
}
impl Default for HierarchicalConfidence {
fn default() -> Self {
Self {
linkage: Confidence::ONE,
type_score: Confidence::ONE,
boundary: Confidence::ONE,
}
}
}
impl From<f64> for HierarchicalConfidence {
fn from(confidence: f64) -> Self {
Self::from_single(confidence)
}
}
impl From<f32> for HierarchicalConfidence {
fn from(confidence: f32) -> Self {
Self::from_single(confidence)
}
}
impl From<Confidence> for HierarchicalConfidence {
fn from(confidence: Confidence) -> Self {
Self::from_single(confidence)
}
}
#[derive(Debug, Clone)]
pub struct RaggedBatch {
pub token_ids: Vec<u32>,
pub cumulative_offsets: Vec<u32>,
pub max_seq_len: usize,
}
impl RaggedBatch {
pub fn from_sequences(sequences: &[Vec<u32>]) -> Self {
let total_tokens: usize = sequences.iter().map(|s| s.len()).sum();
let mut token_ids = Vec::with_capacity(total_tokens);
let mut cumulative_offsets = Vec::with_capacity(sequences.len() + 1);
let mut max_seq_len = 0;
cumulative_offsets.push(0);
for seq in sequences {
token_ids.extend_from_slice(seq);
let len = token_ids.len();
if len > u32::MAX as usize {
log::warn!(
"Token count {} exceeds u32::MAX, truncating to {}",
len,
u32::MAX
);
cumulative_offsets.push(u32::MAX);
} else {
cumulative_offsets.push(len as u32);
}
max_seq_len = max_seq_len.max(seq.len());
}
Self {
token_ids,
cumulative_offsets,
max_seq_len,
}
}
#[must_use]
pub fn batch_size(&self) -> usize {
self.cumulative_offsets.len().saturating_sub(1)
}
#[must_use]
pub fn total_tokens(&self) -> usize {
self.token_ids.len()
}
#[must_use]
pub fn doc_range(&self, doc_idx: usize) -> Option<std::ops::Range<usize>> {
if doc_idx + 1 < self.cumulative_offsets.len() {
let start = self.cumulative_offsets[doc_idx] as usize;
let end = self.cumulative_offsets[doc_idx + 1] as usize;
Some(start..end)
} else {
None
}
}
#[must_use]
pub fn doc_tokens(&self, doc_idx: usize) -> Option<&[u32]> {
self.doc_range(doc_idx).map(|r| &self.token_ids[r])
}
#[must_use]
pub fn padding_savings(&self) -> f64 {
let padded_size = self.batch_size() * self.max_seq_len;
if padded_size == 0 {
return 0.0;
}
1.0 - (self.total_tokens() as f64 / padded_size as f64)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SpanCandidate {
pub doc_idx: u32,
pub start: u32,
pub end: u32,
}
impl SpanCandidate {
#[must_use]
pub const fn new(doc_idx: u32, start: u32, end: u32) -> Self {
Self {
doc_idx,
start,
end,
}
}
#[must_use]
pub const fn width(&self) -> u32 {
self.end.saturating_sub(self.start)
}
}
pub fn generate_span_candidates(batch: &RaggedBatch, max_width: usize) -> Vec<SpanCandidate> {
let mut candidates = Vec::new();
for doc_idx in 0..batch.batch_size() {
if let Some(range) = batch.doc_range(doc_idx) {
let doc_len = range.len();
for start in 0..doc_len {
let max_end = (start + max_width).min(doc_len);
for end in (start + 1)..=max_end {
candidates.push(SpanCandidate::new(doc_idx as u32, start as u32, end as u32));
}
}
}
}
candidates
}
pub fn generate_filtered_candidates(
batch: &RaggedBatch,
max_width: usize,
linkage_mask: &[f32],
threshold: f32,
) -> Vec<SpanCandidate> {
let mut candidates = Vec::new();
let mut mask_idx = 0;
for doc_idx in 0..batch.batch_size() {
if let Some(range) = batch.doc_range(doc_idx) {
let doc_len = range.len();
for start in 0..doc_len {
let max_end = (start + max_width).min(doc_len);
for end in (start + 1)..=max_end {
if mask_idx < linkage_mask.len() && linkage_mask[mask_idx] >= threshold {
candidates.push(SpanCandidate::new(
doc_idx as u32,
start as u32,
end as u32,
));
}
mask_idx += 1;
}
}
}
}
candidates
}
#[derive(Debug, Clone, Serialize)]
pub struct Entity {
pub text: String,
pub entity_type: EntityType,
start: usize,
end: usize,
pub confidence: Confidence,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub normalized: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub provenance: Option<Provenance>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub kb_id: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub canonical_id: Option<super::types::CanonicalId>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub hierarchical_confidence: Option<HierarchicalConfidence>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub visual_span: Option<Span>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub discontinuous_span: Option<DiscontinuousSpan>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub mention_type: Option<MentionType>,
}
impl<'de> Deserialize<'de> for Entity {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
struct EntityHelper {
text: String,
entity_type: EntityType,
start: usize,
end: usize,
confidence: Confidence,
#[serde(default)]
normalized: Option<String>,
#[serde(default)]
provenance: Option<Provenance>,
#[serde(default)]
kb_id: Option<String>,
#[serde(default)]
canonical_id: Option<super::types::CanonicalId>,
#[serde(default)]
hierarchical_confidence: Option<HierarchicalConfidence>,
#[serde(default)]
visual_span: Option<Span>,
#[serde(default)]
discontinuous_span: Option<DiscontinuousSpan>,
#[serde(default)]
mention_type: Option<MentionType>,
}
let h = EntityHelper::deserialize(deserializer)?;
let mut entity = Entity::new(h.text, h.entity_type, h.start, h.end, h.confidence);
entity.normalized = h.normalized;
entity.provenance = h.provenance;
entity.kb_id = h.kb_id;
entity.canonical_id = h.canonical_id;
entity.hierarchical_confidence = h.hierarchical_confidence;
entity.visual_span = h.visual_span;
entity.discontinuous_span = h.discontinuous_span;
entity.mention_type = h.mention_type;
Ok(entity)
}
}
impl Entity {
#[must_use]
pub fn new(
text: impl Into<String>,
entity_type: EntityType,
start: usize,
end: usize,
confidence: impl Into<Confidence>,
) -> Self {
let (start, end) = if start > end {
(end, start)
} else {
(start, end)
};
Self {
text: text.into(),
entity_type,
start,
end,
confidence: confidence.into(),
normalized: None,
provenance: None,
kb_id: None,
canonical_id: None,
hierarchical_confidence: None,
visual_span: None,
discontinuous_span: None,
mention_type: None,
}
}
#[inline]
#[must_use]
pub fn start(&self) -> usize {
self.start
}
#[inline]
#[must_use]
pub fn end(&self) -> usize {
self.end
}
#[inline]
pub fn set_start(&mut self, start: usize) {
debug_assert!(
start <= self.end,
"set_start({start}) would invert span (end={})",
self.end
);
self.start = start;
}
#[inline]
pub fn set_end(&mut self, end: usize) {
debug_assert!(
end >= self.start,
"set_end({end}) would invert span (start={})",
self.start
);
self.end = end;
}
#[must_use]
pub fn with_provenance(
text: impl Into<String>,
entity_type: EntityType,
start: usize,
end: usize,
confidence: impl Into<Confidence>,
provenance: Provenance,
) -> Self {
let (start, end) = if start > end {
(end, start)
} else {
(start, end)
};
Self {
text: text.into(),
entity_type,
start,
end,
confidence: confidence.into(),
normalized: None,
provenance: Some(provenance),
kb_id: None,
canonical_id: None,
hierarchical_confidence: None,
visual_span: None,
discontinuous_span: None,
mention_type: None,
}
}
#[must_use]
pub fn with_hierarchical_confidence(
text: impl Into<String>,
entity_type: EntityType,
start: usize,
end: usize,
confidence: HierarchicalConfidence,
) -> Self {
let (start, end) = if start > end {
(end, start)
} else {
(start, end)
};
Self {
text: text.into(),
entity_type,
start,
end,
confidence: Confidence::new(confidence.as_f64()),
normalized: None,
provenance: None,
kb_id: None,
canonical_id: None,
hierarchical_confidence: Some(confidence),
visual_span: None,
discontinuous_span: None,
mention_type: None,
}
}
#[must_use]
pub fn from_visual(
text: impl Into<String>,
entity_type: EntityType,
bbox: Span,
confidence: impl Into<Confidence>,
) -> Self {
Self {
text: text.into(),
entity_type,
start: 0,
end: 0,
confidence: confidence.into(),
normalized: None,
provenance: None,
kb_id: None,
canonical_id: None,
hierarchical_confidence: None,
visual_span: Some(bbox),
discontinuous_span: None,
mention_type: None,
}
}
#[must_use]
pub fn with_type(
text: impl Into<String>,
entity_type: EntityType,
start: usize,
end: usize,
) -> Self {
Self::new(text, entity_type, start, end, 1.0)
}
pub fn link_to_kb(&mut self, kb_id: impl Into<String>) {
self.kb_id = Some(kb_id.into());
}
pub fn set_canonical(&mut self, canonical_id: impl Into<super::types::CanonicalId>) {
self.canonical_id = Some(canonical_id.into());
}
#[must_use]
pub fn with_canonical_id(mut self, canonical_id: impl Into<super::types::CanonicalId>) -> Self {
self.canonical_id = Some(canonical_id.into());
self
}
#[must_use]
pub fn is_linked(&self) -> bool {
self.kb_id.is_some()
}
#[must_use]
pub fn has_coreference(&self) -> bool {
self.canonical_id.is_some()
}
#[must_use]
pub fn is_discontinuous(&self) -> bool {
self.discontinuous_span
.as_ref()
.map(|s| s.is_discontinuous())
.unwrap_or(false)
}
#[must_use]
pub fn discontinuous_segments(&self) -> Option<Vec<std::ops::Range<usize>>> {
self.discontinuous_span
.as_ref()
.filter(|s| s.is_discontinuous())
.map(|s| s.segments().to_vec())
}
pub fn set_discontinuous_span(&mut self, span: DiscontinuousSpan) {
if let Some(bounding) = span.bounding_range() {
self.start = bounding.start;
self.end = bounding.end;
}
self.discontinuous_span = Some(span);
}
#[must_use]
pub fn total_len(&self) -> usize {
if let Some(ref span) = self.discontinuous_span {
span.segments().iter().map(|r| r.end - r.start).sum()
} else {
self.end.saturating_sub(self.start)
}
}
pub fn set_normalized(&mut self, normalized: impl Into<String>) {
self.normalized = Some(normalized.into());
}
#[must_use]
pub fn normalized_or_text(&self) -> &str {
self.normalized.as_deref().unwrap_or(&self.text)
}
#[must_use]
pub fn method(&self) -> ExtractionMethod {
self.provenance
.as_ref()
.map_or(ExtractionMethod::Unknown, |p| p.method)
}
#[must_use]
pub fn source(&self) -> Option<&str> {
self.provenance.as_ref().map(|p| p.source.as_ref())
}
#[must_use]
pub fn category(&self) -> EntityCategory {
self.entity_type.category()
}
#[must_use]
pub fn is_structured(&self) -> bool {
self.entity_type.pattern_detectable()
}
#[must_use]
pub fn is_named(&self) -> bool {
self.entity_type.requires_ml()
}
#[must_use]
pub fn overlaps(&self, other: &Entity) -> bool {
!(self.end <= other.start || other.end <= self.start)
}
#[must_use]
pub fn overlap_ratio(&self, other: &Entity) -> f64 {
let intersection_start = self.start.max(other.start);
let intersection_end = self.end.min(other.end);
if intersection_start >= intersection_end {
return 0.0;
}
let intersection = (intersection_end - intersection_start) as f64;
let union = ((self.end - self.start) + (other.end - other.start)
- (intersection_end - intersection_start)) as f64;
if union == 0.0 {
return 1.0;
}
intersection / union
}
pub fn set_hierarchical_confidence(&mut self, confidence: HierarchicalConfidence) {
self.confidence = Confidence::new(confidence.as_f64());
self.hierarchical_confidence = Some(confidence);
}
#[must_use]
pub fn linkage_confidence(&self) -> Confidence {
self.hierarchical_confidence
.map_or(self.confidence, |h| h.linkage)
}
#[must_use]
pub fn type_confidence(&self) -> Confidence {
self.hierarchical_confidence
.map_or(self.confidence, |h| h.type_score)
}
#[must_use]
pub fn boundary_confidence(&self) -> Confidence {
self.hierarchical_confidence
.map_or(self.confidence, |h| h.boundary)
}
#[must_use]
pub fn is_visual(&self) -> bool {
self.visual_span.is_some()
}
#[must_use]
pub const fn text_span(&self) -> (usize, usize) {
(self.start, self.end)
}
#[must_use]
pub const fn span_len(&self) -> usize {
self.end.saturating_sub(self.start)
}
pub fn set_visual_span(&mut self, span: Span) {
self.visual_span = Some(span);
}
#[must_use]
pub fn extract_text(&self, source_text: &str) -> String {
let char_count = source_text.chars().count();
self.extract_text_with_len(source_text, char_count)
}
#[must_use]
pub fn extract_text_with_len(&self, source_text: &str, text_char_count: usize) -> String {
if self.start >= text_char_count || self.end > text_char_count || self.start >= self.end {
return String::new();
}
source_text
.chars()
.skip(self.start)
.take(self.end - self.start)
.collect()
}
#[must_use]
pub fn builder(text: impl Into<String>, entity_type: EntityType) -> EntityBuilder {
EntityBuilder::new(text, entity_type)
}
#[must_use]
pub fn validate(&self, source_text: &str) -> Vec<ValidationIssue> {
let char_count = source_text.chars().count();
self.validate_with_len(source_text, char_count)
}
#[must_use]
pub fn validate_with_len(
&self,
source_text: &str,
text_char_count: usize,
) -> Vec<ValidationIssue> {
let mut issues = Vec::new();
if self.start >= self.end {
issues.push(ValidationIssue::InvalidSpan {
start: self.start,
end: self.end,
reason: "start must be less than end".to_string(),
});
}
if self.end > text_char_count {
issues.push(ValidationIssue::SpanOutOfBounds {
end: self.end,
text_len: text_char_count,
});
}
if self.start < self.end && self.end <= text_char_count {
let actual = self.extract_text_with_len(source_text, text_char_count);
if actual != self.text {
issues.push(ValidationIssue::TextMismatch {
expected: self.text.clone(),
actual,
start: self.start,
end: self.end,
});
}
}
if let EntityType::Custom { ref name, .. } = self.entity_type {
if name.is_empty() {
issues.push(ValidationIssue::InvalidType {
reason: "Custom entity type has empty name".to_string(),
});
}
}
if let Some(ref disc_span) = self.discontinuous_span {
for (i, seg) in disc_span.segments().iter().enumerate() {
if seg.start >= seg.end {
issues.push(ValidationIssue::InvalidSpan {
start: seg.start,
end: seg.end,
reason: format!("discontinuous segment {} is invalid", i),
});
}
if seg.end > text_char_count {
issues.push(ValidationIssue::SpanOutOfBounds {
end: seg.end,
text_len: text_char_count,
});
}
}
}
issues
}
#[must_use]
pub fn is_valid(&self, source_text: &str) -> bool {
self.validate(source_text).is_empty()
}
#[must_use]
pub fn validate_batch(
entities: &[Entity],
source_text: &str,
) -> std::collections::HashMap<usize, Vec<ValidationIssue>> {
entities
.iter()
.enumerate()
.filter_map(|(idx, entity)| {
let issues = entity.validate(source_text);
if issues.is_empty() {
None
} else {
Some((idx, issues))
}
})
.collect()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationIssue {
InvalidSpan {
start: usize,
end: usize,
reason: String,
},
SpanOutOfBounds {
end: usize,
text_len: usize,
},
TextMismatch {
expected: String,
actual: String,
start: usize,
end: usize,
},
InvalidConfidence {
value: f64,
},
InvalidType {
reason: String,
},
}
impl std::fmt::Display for ValidationIssue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ValidationIssue::InvalidSpan { start, end, reason } => {
write!(f, "Invalid span [{}, {}): {}", start, end, reason)
}
ValidationIssue::SpanOutOfBounds { end, text_len } => {
write!(f, "Span end {} exceeds text length {}", end, text_len)
}
ValidationIssue::TextMismatch {
expected,
actual,
start,
end,
} => {
write!(
f,
"Text mismatch at [{}, {}): expected '{}', got '{}'",
start, end, expected, actual
)
}
ValidationIssue::InvalidConfidence { value } => {
write!(f, "Confidence {} outside [0.0, 1.0]", value)
}
ValidationIssue::InvalidType { reason } => {
write!(f, "Invalid entity type: {}", reason)
}
}
}
}
#[derive(Debug, Clone)]
pub struct EntityBuilder {
text: String,
entity_type: EntityType,
start: usize,
end: usize,
confidence: Confidence,
normalized: Option<String>,
provenance: Option<Provenance>,
kb_id: Option<String>,
canonical_id: Option<super::types::CanonicalId>,
hierarchical_confidence: Option<HierarchicalConfidence>,
visual_span: Option<Span>,
discontinuous_span: Option<DiscontinuousSpan>,
mention_type: Option<MentionType>,
}
impl EntityBuilder {
#[must_use]
pub fn new(text: impl Into<String>, entity_type: EntityType) -> Self {
let text = text.into();
let end = text.chars().count();
Self {
text,
entity_type,
start: 0,
end,
confidence: Confidence::ONE,
normalized: None,
provenance: None,
kb_id: None,
canonical_id: None,
hierarchical_confidence: None,
visual_span: None,
discontinuous_span: None,
mention_type: None,
}
}
#[must_use]
pub const fn span(mut self, start: usize, end: usize) -> Self {
self.start = start;
self.end = end;
self
}
#[must_use]
pub fn confidence(mut self, confidence: impl Into<Confidence>) -> Self {
self.confidence = confidence.into();
self
}
#[must_use]
pub fn hierarchical_confidence(mut self, confidence: HierarchicalConfidence) -> Self {
self.confidence = Confidence::new(confidence.as_f64());
self.hierarchical_confidence = Some(confidence);
self
}
#[must_use]
pub fn normalized(mut self, normalized: impl Into<String>) -> Self {
self.normalized = Some(normalized.into());
self
}
#[must_use]
pub fn provenance(mut self, provenance: Provenance) -> Self {
self.provenance = Some(provenance);
self
}
#[must_use]
pub fn kb_id(mut self, kb_id: impl Into<String>) -> Self {
self.kb_id = Some(kb_id.into());
self
}
#[must_use]
pub const fn canonical_id(mut self, canonical_id: u64) -> Self {
self.canonical_id = Some(super::types::CanonicalId::new(canonical_id));
self
}
#[must_use]
pub fn visual_span(mut self, span: Span) -> Self {
self.visual_span = Some(span);
self
}
#[must_use]
pub fn discontinuous_span(mut self, span: DiscontinuousSpan) -> Self {
if let Some(bounding) = span.bounding_range() {
self.start = bounding.start;
self.end = bounding.end;
}
self.discontinuous_span = Some(span);
self
}
#[must_use]
pub fn mention_type(mut self, mention_type: MentionType) -> Self {
self.mention_type = Some(mention_type);
self
}
#[must_use]
pub fn build(self) -> Entity {
let (start, end) = if self.start <= self.end {
(self.start, self.end)
} else {
(self.end, self.start)
};
Entity {
text: self.text,
entity_type: self.entity_type,
start,
end,
confidence: self.confidence,
normalized: self.normalized,
provenance: self.provenance,
kb_id: self.kb_id,
canonical_id: self.canonical_id,
hierarchical_confidence: self.hierarchical_confidence,
visual_span: self.visual_span,
discontinuous_span: self.discontinuous_span,
mention_type: self.mention_type,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Relation {
pub head: Entity,
pub tail: Entity,
pub relation_type: String,
pub trigger_span: Option<(usize, usize)>,
pub confidence: Confidence,
}
impl Relation {
#[must_use]
pub fn new(
head: Entity,
tail: Entity,
relation_type: impl Into<String>,
confidence: impl Into<Confidence>,
) -> Self {
Self {
head,
tail,
relation_type: relation_type.into(),
trigger_span: None,
confidence: confidence.into(),
}
}
#[must_use]
pub fn with_trigger(
head: Entity,
tail: Entity,
relation_type: impl Into<String>,
trigger_start: usize,
trigger_end: usize,
confidence: impl Into<Confidence>,
) -> Self {
Self {
head,
tail,
relation_type: relation_type.into(),
trigger_span: Some((trigger_start, trigger_end)),
confidence: confidence.into(),
}
}
#[must_use]
pub fn as_triple(&self) -> String {
format!(
"({}, {}, {})",
self.head.text, self.relation_type, self.tail.text
)
}
#[must_use]
pub fn span_distance(&self) -> usize {
if self.head.end <= self.tail.start {
self.tail.start.saturating_sub(self.head.end)
} else if self.tail.end <= self.head.start {
self.head.start.saturating_sub(self.tail.end)
} else {
0 }
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)] use super::*;
#[test]
fn entity_new_swaps_inverted_span() {
let e = Entity::new("test", EntityType::Person, 10, 5, 0.9);
assert_eq!(e.start(), 5);
assert_eq!(e.end(), 10);
}
#[test]
fn entity_deserialize_swaps_inverted_span() {
let json = r#"{"text":"test","entity_type":"PER","start":10,"end":5,"confidence":0.9}"#;
let e: Entity = serde_json::from_str(json).unwrap();
assert_eq!(e.start(), 5);
assert_eq!(e.end(), 10);
}
#[test]
fn entity_serde_round_trip() {
let original = Entity::new("Berlin", EntityType::Location, 10, 16, 0.95);
let json = serde_json::to_string(&original).unwrap();
let restored: Entity = serde_json::from_str(&json).unwrap();
assert_eq!(restored.text, original.text);
assert_eq!(restored.entity_type, original.entity_type);
assert_eq!(restored.start(), original.start());
assert_eq!(restored.end(), original.end());
assert!((restored.confidence.value() - original.confidence.value()).abs() < f64::EPSILON);
}
#[test]
fn test_entity_type_roundtrip() {
let types = [
EntityType::Person,
EntityType::Organization,
EntityType::Location,
EntityType::Date,
EntityType::Money,
EntityType::Percent,
];
for t in types {
let label = t.as_label();
let parsed = EntityType::from_label(label);
assert_eq!(t, parsed);
}
}
#[test]
fn test_entity_overlap() {
let e1 = Entity::new("John", EntityType::Person, 0, 4, 0.9);
let e2 = Entity::new("Smith", EntityType::Person, 5, 10, 0.9);
let e3 = Entity::new("John Smith", EntityType::Person, 0, 10, 0.9);
assert!(!e1.overlaps(&e2)); assert!(e1.overlaps(&e3)); assert!(e3.overlaps(&e2)); }
#[test]
fn test_confidence_clamping() {
let e1 = Entity::new("test", EntityType::Person, 0, 4, 1.5);
assert!((e1.confidence - 1.0).abs() < f64::EPSILON);
let e2 = Entity::new("test", EntityType::Person, 0, 4, -0.5);
assert!(e2.confidence.abs() < f64::EPSILON);
}
#[test]
fn test_entity_categories() {
assert_eq!(EntityType::Person.category(), EntityCategory::Agent);
assert_eq!(
EntityType::Organization.category(),
EntityCategory::Organization
);
assert_eq!(EntityType::Location.category(), EntityCategory::Place);
assert!(EntityType::Person.requires_ml());
assert!(!EntityType::Person.pattern_detectable());
assert_eq!(EntityType::Date.category(), EntityCategory::Temporal);
assert_eq!(EntityType::Time.category(), EntityCategory::Temporal);
assert!(EntityType::Date.pattern_detectable());
assert!(!EntityType::Date.requires_ml());
assert_eq!(EntityType::Money.category(), EntityCategory::Numeric);
assert_eq!(EntityType::Percent.category(), EntityCategory::Numeric);
assert!(EntityType::Money.pattern_detectable());
assert_eq!(EntityType::Email.category(), EntityCategory::Contact);
assert_eq!(EntityType::Url.category(), EntityCategory::Contact);
assert_eq!(EntityType::Phone.category(), EntityCategory::Contact);
assert!(EntityType::Email.pattern_detectable());
}
#[test]
fn test_new_types_roundtrip() {
let types = [
EntityType::Time,
EntityType::Email,
EntityType::Url,
EntityType::Phone,
EntityType::Quantity,
EntityType::Cardinal,
EntityType::Ordinal,
];
for t in types {
let label = t.as_label();
let parsed = EntityType::from_label(label);
assert_eq!(t, parsed, "Roundtrip failed for {}", label);
}
}
#[test]
fn test_custom_entity_type() {
let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
assert_eq!(disease.as_label(), "DISEASE");
assert!(disease.requires_ml());
let product_id = EntityType::custom("PRODUCT_ID", EntityCategory::Misc);
assert_eq!(product_id.as_label(), "PRODUCT_ID");
assert!(!product_id.requires_ml());
assert!(!product_id.pattern_detectable());
}
#[test]
fn test_entity_normalization() {
let mut e = Entity::new("Jan 15", EntityType::Date, 0, 6, 0.95);
assert!(e.normalized.is_none());
assert_eq!(e.normalized_or_text(), "Jan 15");
e.set_normalized("2024-01-15");
assert_eq!(e.normalized.as_deref(), Some("2024-01-15"));
assert_eq!(e.normalized_or_text(), "2024-01-15");
}
#[test]
fn test_entity_helpers() {
let named = Entity::new("John", EntityType::Person, 0, 4, 0.9);
assert!(named.is_named());
assert!(!named.is_structured());
assert_eq!(named.category(), EntityCategory::Agent);
let structured = Entity::new("$100", EntityType::Money, 0, 4, 0.95);
assert!(!structured.is_named());
assert!(structured.is_structured());
assert_eq!(structured.category(), EntityCategory::Numeric);
}
#[test]
fn test_knowledge_linking() {
let mut entity = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
assert!(!entity.is_linked());
assert!(!entity.has_coreference());
entity.link_to_kb("Q7186"); assert!(entity.is_linked());
assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
entity.set_canonical(42);
assert!(entity.has_coreference());
assert_eq!(
entity.canonical_id,
Some(crate::core::types::CanonicalId::new(42))
);
}
#[test]
fn test_relation_creation() {
let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
let relation = Relation::new(head.clone(), tail.clone(), "WORKED_AT", 0.85);
assert_eq!(relation.relation_type, "WORKED_AT");
assert_eq!(relation.as_triple(), "(Marie Curie, WORKED_AT, Sorbonne)");
assert!(relation.trigger_span.is_none());
let relation2 = Relation::with_trigger(head, tail, "EMPLOYMENT", 13, 19, 0.85);
assert_eq!(relation2.trigger_span, Some((13, 19)));
}
#[test]
fn test_relation_span_distance() {
let head = Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95);
let tail = Entity::new("Sorbonne", EntityType::Organization, 24, 32, 0.90);
let relation = Relation::new(head, tail, "WORKED_AT", 0.85);
assert_eq!(relation.span_distance(), 13);
}
#[test]
fn test_relation_category() {
let rel_type = EntityType::custom("CEO_OF", EntityCategory::Relation);
assert_eq!(rel_type.category(), EntityCategory::Relation);
assert!(rel_type.category().is_relation());
assert!(rel_type.requires_ml()); }
#[test]
fn test_span_text() {
let span = Span::text(10, 20);
assert!(span.is_text());
assert!(!span.is_visual());
assert_eq!(span.text_offsets(), Some((10, 20)));
assert_eq!(span.len(), 10);
assert!(!span.is_empty());
}
#[test]
fn test_span_bbox() {
let span = Span::bbox(0.1, 0.2, 0.3, 0.4);
assert!(!span.is_text());
assert!(span.is_visual());
assert_eq!(span.text_offsets(), None);
assert_eq!(span.len(), 0); }
#[test]
fn test_span_bbox_with_page() {
let span = Span::bbox_on_page(0.1, 0.2, 0.3, 0.4, 5);
if let Span::BoundingBox { page, .. } = span {
assert_eq!(page, Some(5));
} else {
panic!("Expected BoundingBox");
}
}
#[test]
fn test_span_hybrid() {
let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
let hybrid = Span::Hybrid {
start: 10,
end: 20,
bbox: Box::new(bbox),
};
assert!(hybrid.is_text());
assert!(hybrid.is_visual());
assert_eq!(hybrid.text_offsets(), Some((10, 20)));
assert_eq!(hybrid.len(), 10);
}
#[test]
fn test_hierarchical_confidence_new() {
let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
assert!((hc.linkage - 0.9).abs() < f64::EPSILON);
assert!((hc.type_score - 0.8).abs() < f64::EPSILON);
assert!((hc.boundary - 0.7).abs() < f64::EPSILON);
}
#[test]
fn test_hierarchical_confidence_clamping() {
let hc = HierarchicalConfidence::new(1.5, -0.5, 0.5);
assert_eq!(hc.linkage, 1.0);
assert_eq!(hc.type_score, 0.0);
assert_eq!(hc.boundary, 0.5);
}
#[test]
fn test_hierarchical_confidence_from_single() {
let hc = HierarchicalConfidence::from_single(0.8);
assert!((hc.linkage - 0.8).abs() < f64::EPSILON);
assert!((hc.type_score - 0.8).abs() < f64::EPSILON);
assert!((hc.boundary - 0.8).abs() < f64::EPSILON);
}
#[test]
fn test_hierarchical_confidence_combined() {
let hc = HierarchicalConfidence::new(1.0, 1.0, 1.0);
assert!((hc.combined() - 1.0).abs() < f64::EPSILON);
let hc2 = HierarchicalConfidence::new(0.8, 0.8, 0.8);
assert!((hc2.combined() - 0.8).abs() < 0.001);
let hc3 = HierarchicalConfidence::new(0.5, 0.5, 0.5);
assert!((hc3.combined() - 0.5).abs() < 0.001);
}
#[test]
fn test_hierarchical_confidence_threshold() {
let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
assert!(hc.passes_threshold(0.5, 0.5, 0.5));
assert!(hc.passes_threshold(0.9, 0.8, 0.7));
assert!(!hc.passes_threshold(0.95, 0.8, 0.7)); assert!(!hc.passes_threshold(0.9, 0.85, 0.7)); }
#[test]
fn test_hierarchical_confidence_from_f64() {
let hc: HierarchicalConfidence = 0.85_f64.into();
assert!((hc.linkage - 0.85).abs() < 0.001);
}
#[test]
fn test_ragged_batch_from_sequences() {
let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
let batch = RaggedBatch::from_sequences(&seqs);
assert_eq!(batch.batch_size(), 3);
assert_eq!(batch.total_tokens(), 9);
assert_eq!(batch.max_seq_len, 4);
assert_eq!(batch.cumulative_offsets, vec![0, 3, 5, 9]);
}
#[test]
fn test_ragged_batch_doc_range() {
let seqs = vec![vec![1, 2, 3], vec![4, 5]];
let batch = RaggedBatch::from_sequences(&seqs);
assert_eq!(batch.doc_range(0), Some(0..3));
assert_eq!(batch.doc_range(1), Some(3..5));
assert_eq!(batch.doc_range(2), None);
}
#[test]
fn test_ragged_batch_doc_tokens() {
let seqs = vec![vec![1, 2, 3], vec![4, 5]];
let batch = RaggedBatch::from_sequences(&seqs);
assert_eq!(batch.doc_tokens(0), Some(&[1, 2, 3][..]));
assert_eq!(batch.doc_tokens(1), Some(&[4, 5][..]));
}
#[test]
fn test_ragged_batch_padding_savings() {
let seqs = vec![vec![1, 2, 3], vec![4, 5], vec![6, 7, 8, 9]];
let batch = RaggedBatch::from_sequences(&seqs);
let savings = batch.padding_savings();
assert!((savings - 0.25).abs() < 0.001);
}
#[test]
fn test_span_candidate() {
let sc = SpanCandidate::new(0, 5, 10);
assert_eq!(sc.doc_idx, 0);
assert_eq!(sc.start, 5);
assert_eq!(sc.end, 10);
assert_eq!(sc.width(), 5);
}
#[test]
fn test_generate_span_candidates() {
let seqs = vec![vec![1, 2, 3]]; let batch = RaggedBatch::from_sequences(&seqs);
let candidates = generate_span_candidates(&batch, 2);
assert_eq!(candidates.len(), 5);
for c in &candidates {
assert_eq!(c.doc_idx, 0);
assert!(c.end as usize <= 3);
assert!(c.width() as usize <= 2);
}
}
#[test]
fn test_generate_filtered_candidates() {
let seqs = vec![vec![1, 2, 3]];
let batch = RaggedBatch::from_sequences(&seqs);
let mask = vec![0.9, 0.9, 0.1, 0.1, 0.1];
let candidates = generate_filtered_candidates(&batch, 2, &mask, 0.5);
assert_eq!(candidates.len(), 2);
}
#[test]
fn test_entity_builder_basic() {
let entity = Entity::builder("John", EntityType::Person)
.span(0, 4)
.confidence(0.95)
.build();
assert_eq!(entity.text, "John");
assert_eq!(entity.entity_type, EntityType::Person);
assert_eq!(entity.start(), 0);
assert_eq!(entity.end(), 4);
assert!((entity.confidence - 0.95).abs() < f64::EPSILON);
}
#[test]
fn test_entity_builder_full() {
let entity = Entity::builder("Marie Curie", EntityType::Person)
.span(0, 11)
.confidence(0.95)
.kb_id("Q7186")
.canonical_id(42)
.normalized("Marie Salomea Skłodowska Curie")
.provenance(Provenance::ml("bert", 0.95))
.build();
assert_eq!(entity.text, "Marie Curie");
assert_eq!(entity.kb_id.as_deref(), Some("Q7186"));
assert_eq!(
entity.canonical_id,
Some(crate::core::types::CanonicalId::new(42))
);
assert_eq!(
entity.normalized.as_deref(),
Some("Marie Salomea Skłodowska Curie")
);
assert!(entity.provenance.is_some());
}
#[test]
fn test_entity_builder_hierarchical() {
let hc = HierarchicalConfidence::new(0.9, 0.8, 0.7);
let entity = Entity::builder("test", EntityType::Person)
.span(0, 4)
.hierarchical_confidence(hc)
.build();
assert!(entity.hierarchical_confidence.is_some());
assert!((entity.linkage_confidence() - 0.9).abs() < 0.001);
assert!((entity.type_confidence() - 0.8).abs() < 0.001);
assert!((entity.boundary_confidence() - 0.7).abs() < 0.001);
}
#[test]
fn test_entity_builder_visual() {
let bbox = Span::bbox(0.1, 0.2, 0.3, 0.4);
let entity = Entity::builder("receipt item", EntityType::Money)
.visual_span(bbox)
.confidence(0.9)
.build();
assert!(entity.is_visual());
assert!(entity.visual_span.is_some());
}
#[test]
fn test_entity_hierarchical_confidence_helpers() {
let mut entity = Entity::new("test", EntityType::Person, 0, 4, 0.8);
assert!((entity.linkage_confidence() - 0.8).abs() < 0.001);
assert!((entity.type_confidence() - 0.8).abs() < 0.001);
assert!((entity.boundary_confidence() - 0.8).abs() < 0.001);
entity.set_hierarchical_confidence(HierarchicalConfidence::new(0.95, 0.85, 0.75));
assert!((entity.linkage_confidence() - 0.95).abs() < 0.001);
assert!((entity.type_confidence() - 0.85).abs() < 0.001);
assert!((entity.boundary_confidence() - 0.75).abs() < 0.001);
}
#[test]
fn test_entity_from_visual() {
let entity = Entity::from_visual(
"receipt total",
EntityType::Money,
Span::bbox(0.5, 0.8, 0.2, 0.05),
0.92,
);
assert!(entity.is_visual());
assert_eq!(entity.start(), 0);
assert_eq!(entity.end(), 0);
assert!((entity.confidence - 0.92).abs() < f64::EPSILON);
}
#[test]
fn test_entity_span_helpers() {
let entity = Entity::new("test", EntityType::Person, 10, 20, 0.9);
assert_eq!(entity.text_span(), (10, 20));
assert_eq!(entity.span_len(), 10);
}
#[test]
fn test_provenance_pattern() {
let prov = Provenance::pattern("EMAIL");
assert_eq!(prov.method, ExtractionMethod::Pattern);
assert_eq!(prov.pattern.as_deref(), Some("EMAIL"));
assert_eq!(prov.raw_confidence, Some(Confidence::new(1.0))); }
#[test]
fn test_provenance_ml() {
let prov = Provenance::ml("bert-ner", 0.87);
assert_eq!(prov.method, ExtractionMethod::Neural);
assert_eq!(prov.source.as_ref(), "bert-ner");
assert_eq!(prov.raw_confidence, Some(Confidence::new(0.87)));
}
#[test]
fn test_provenance_with_version() {
let prov = Provenance::ml("gliner", 0.92).with_version("v2.1.0");
assert_eq!(prov.model_version.as_deref(), Some("v2.1.0"));
assert_eq!(prov.source.as_ref(), "gliner");
}
#[test]
fn test_provenance_with_timestamp() {
let prov = Provenance::pattern("DATE").with_timestamp("2024-01-15T10:30:00Z");
assert_eq!(prov.timestamp.as_deref(), Some("2024-01-15T10:30:00Z"));
}
#[test]
fn test_provenance_builder_chain() {
let prov = Provenance::ml("modernbert-ner", 0.95)
.with_version("v1.0.0")
.with_timestamp("2024-11-27T12:00:00Z");
assert_eq!(prov.method, ExtractionMethod::Neural);
assert_eq!(prov.source.as_ref(), "modernbert-ner");
assert_eq!(prov.raw_confidence, Some(Confidence::new(0.95)));
assert_eq!(prov.model_version.as_deref(), Some("v1.0.0"));
assert_eq!(prov.timestamp.as_deref(), Some("2024-11-27T12:00:00Z"));
}
#[test]
fn test_provenance_serialization() {
let prov = Provenance::ml("test", 0.9)
.with_version("v1.0")
.with_timestamp("2024-01-01");
let json = serde_json::to_string(&prov).unwrap();
assert!(json.contains("model_version"));
assert!(json.contains("v1.0"));
let restored: Provenance = serde_json::from_str(&json).unwrap();
assert_eq!(restored.model_version.as_deref(), Some("v1.0"));
assert_eq!(restored.timestamp.as_deref(), Some("2024-01-01"));
}
#[test]
fn entity_serde_roundtrip_no_temporal_fields() {
let entity = Entity::new("Berlin", EntityType::Location, 0, 6, 0.95);
let json = serde_json::to_string(&entity).unwrap();
assert!(!json.contains("valid_from"));
assert!(!json.contains("valid_until"));
assert!(!json.contains("phi_features"));
let recovered: Entity = serde_json::from_str(&json).unwrap();
assert_eq!(recovered.text, "Berlin");
assert_eq!(recovered.start(), 0);
assert_eq!(recovered.end(), 6);
}
#[test]
fn entity_deserialize_ignores_unknown_fields() {
let json = r#"{"text":"Berlin","entity_type":"LOC","start":0,"end":6,"confidence":0.95,"valid_from":null,"phi_features":null}"#;
let entity: Entity = serde_json::from_str(json).unwrap();
assert_eq!(entity.text, "Berlin");
}
}
#[cfg(test)]
mod proptests {
#![allow(clippy::unwrap_used)] use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn confidence_always_clamped(conf in -10.0f64..10.0) {
let e = Entity::new("test", EntityType::Person, 0, 4, conf);
prop_assert!(e.confidence >= 0.0);
prop_assert!(e.confidence <= 1.0);
}
#[test]
fn entity_type_roundtrip(label in "[A-Z]{3,10}") {
let et = EntityType::from_label(&label);
let back = EntityType::from_label(et.as_label());
let is_custom = matches!(back, EntityType::Custom { .. });
prop_assert!(is_custom || back == et);
}
#[test]
fn overlap_is_symmetric(
s1 in 0usize..100,
len1 in 1usize..50,
s2 in 0usize..100,
len2 in 1usize..50,
) {
let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
prop_assert_eq!(e1.overlaps(&e2), e2.overlaps(&e1));
}
#[test]
fn overlap_ratio_bounded(
s1 in 0usize..100,
len1 in 1usize..50,
s2 in 0usize..100,
len2 in 1usize..50,
) {
let e1 = Entity::new("a", EntityType::Person, s1, s1 + len1, 1.0);
let e2 = Entity::new("b", EntityType::Person, s2, s2 + len2, 1.0);
let ratio = e1.overlap_ratio(&e2);
prop_assert!(ratio >= 0.0);
prop_assert!(ratio <= 1.0);
}
#[test]
fn self_overlap_ratio_is_one(s in 0usize..100, len in 1usize..50) {
let e = Entity::new("test", EntityType::Person, s, s + len, 1.0);
let ratio = e.overlap_ratio(&e);
prop_assert!((ratio - 1.0).abs() < 1e-10);
}
#[test]
fn hierarchical_confidence_always_clamped(
linkage in -2.0f32..2.0,
type_score in -2.0f32..2.0,
boundary in -2.0f32..2.0,
) {
let hc = HierarchicalConfidence::new(linkage, type_score, boundary);
prop_assert!(hc.linkage >= 0.0 && hc.linkage <= 1.0);
prop_assert!(hc.type_score >= 0.0 && hc.type_score <= 1.0);
prop_assert!(hc.boundary >= 0.0 && hc.boundary <= 1.0);
prop_assert!(hc.combined() >= 0.0 && hc.combined() <= 1.0);
}
#[test]
fn span_candidate_width_consistent(
doc in 0u32..10,
start in 0u32..100,
end in 1u32..100,
) {
let actual_end = start.max(end);
let sc = SpanCandidate::new(doc, start, actual_end);
prop_assert_eq!(sc.width(), actual_end.saturating_sub(start));
}
#[test]
fn ragged_batch_preserves_tokens(
seq_lens in proptest::collection::vec(1usize..10, 1..5),
) {
let mut counter = 0u32;
let seqs: Vec<Vec<u32>> = seq_lens.iter().map(|&len| {
let seq: Vec<u32> = (counter..counter + len as u32).collect();
counter += len as u32;
seq
}).collect();
let batch = RaggedBatch::from_sequences(&seqs);
prop_assert_eq!(batch.batch_size(), seqs.len());
prop_assert_eq!(batch.total_tokens(), seq_lens.iter().sum::<usize>());
for (i, seq) in seqs.iter().enumerate() {
let doc_tokens = batch.doc_tokens(i).unwrap();
prop_assert_eq!(doc_tokens, seq.as_slice());
}
}
#[test]
fn span_text_offsets_consistent(start in 0usize..100, len in 0usize..50) {
let end = start + len;
let span = Span::text(start, end);
let (s, e) = span.text_offsets().unwrap();
prop_assert_eq!(s, start);
prop_assert_eq!(e, end);
prop_assert_eq!(span.len(), len);
}
#[test]
fn entity_span_validity(
start in 0usize..10000,
len in 1usize..500,
conf in 0.0f64..=1.0,
) {
let end = start + len;
let text_content: String = "x".repeat(end);
let entity_text: String = text_content.chars().skip(start).take(len).collect();
let e = Entity::new(&entity_text, EntityType::Person, start, end, conf);
let issues = e.validate(&text_content);
for issue in &issues {
match issue {
ValidationIssue::InvalidSpan { .. } => {
prop_assert!(false, "start < end should never produce InvalidSpan");
}
ValidationIssue::SpanOutOfBounds { .. } => {
prop_assert!(false, "span within text should never produce SpanOutOfBounds");
}
_ => {} }
}
}
#[test]
fn entity_type_label_roundtrip_standard(
idx in 0usize..13,
) {
let standard_types = [
EntityType::Person,
EntityType::Organization,
EntityType::Location,
EntityType::Date,
EntityType::Time,
EntityType::Money,
EntityType::Percent,
EntityType::Quantity,
EntityType::Cardinal,
EntityType::Ordinal,
EntityType::Email,
EntityType::Url,
EntityType::Phone,
];
let et = &standard_types[idx];
let label = et.as_label();
let roundtripped = EntityType::from_label(label);
prop_assert_eq!(&roundtripped, et,
"from_label(as_label()) must roundtrip for {:?} (label={:?})", et, label);
}
#[test]
fn span_containment_property(
a_start in 0usize..5000,
a_len in 1usize..5000,
b_offset in 0usize..5000,
b_len in 1usize..5000,
) {
let a_end = a_start + a_len;
let b_start = a_start + (b_offset % a_len); let b_end_candidate = b_start + b_len;
if b_start >= a_start && b_end_candidate <= a_end {
prop_assert!(a_start <= b_start);
prop_assert!(a_end >= b_end_candidate);
let ea = Entity::new("a", EntityType::Person, a_start, a_end, 1.0);
let eb = Entity::new("b", EntityType::Person, b_start, b_end_candidate, 1.0);
prop_assert!(ea.overlaps(&eb),
"containing span must overlap contained span");
}
}
#[test]
fn entity_serde_roundtrip(
start in 0usize..10000,
len in 1usize..500,
conf in 0.0f64..=1.0,
type_idx in 0usize..5,
) {
let end = start + len;
let types = [
EntityType::Person,
EntityType::Organization,
EntityType::Location,
EntityType::Date,
EntityType::Email,
];
let et = types[type_idx].clone();
let text = format!("entity_{}", start);
let e = Entity::new(&text, et, start, end, conf);
let json = serde_json::to_string(&e).unwrap();
let e2: Entity = serde_json::from_str(&json).unwrap();
prop_assert_eq!(&e.text, &e2.text);
prop_assert_eq!(&e.entity_type, &e2.entity_type);
prop_assert_eq!(e.start(), e2.start());
prop_assert_eq!(e.end(), e2.end());
prop_assert!((e.confidence - e2.confidence).abs() < 1e-10,
"confidence roundtrip: {} vs {}", e.confidence, e2.confidence);
prop_assert_eq!(&e.normalized, &e2.normalized);
prop_assert_eq!(&e.kb_id, &e2.kb_id);
}
#[test]
fn discontinuous_span_total_length(
segments in proptest::collection::vec(
(0usize..5000, 1usize..500),
1..6
),
) {
let ranges: Vec<std::ops::Range<usize>> = segments.iter()
.map(|&(start, len)| start..start + len)
.collect();
let span = DiscontinuousSpan::new(ranges);
let expected_sum: usize = span.segments().iter().map(|r| r.end - r.start).sum();
prop_assert_eq!(span.total_len(), expected_sum,
"total_len must equal sum of merged segment lengths");
for w in span.segments().windows(2) {
prop_assert!(w[0].end <= w[1].start,
"segments must not overlap: {:?} vs {:?}", w[0], w[1]);
}
}
}
#[test]
fn test_entity_category_requires_ml() {
assert!(EntityCategory::Agent.requires_ml());
assert!(EntityCategory::Organization.requires_ml());
assert!(EntityCategory::Place.requires_ml());
assert!(EntityCategory::Creative.requires_ml());
assert!(EntityCategory::Relation.requires_ml());
assert!(!EntityCategory::Temporal.requires_ml());
assert!(!EntityCategory::Numeric.requires_ml());
assert!(!EntityCategory::Contact.requires_ml());
assert!(!EntityCategory::Misc.requires_ml());
}
#[test]
fn test_entity_category_pattern_detectable() {
assert!(EntityCategory::Temporal.pattern_detectable());
assert!(EntityCategory::Numeric.pattern_detectable());
assert!(EntityCategory::Contact.pattern_detectable());
assert!(!EntityCategory::Agent.pattern_detectable());
assert!(!EntityCategory::Organization.pattern_detectable());
assert!(!EntityCategory::Place.pattern_detectable());
assert!(!EntityCategory::Creative.pattern_detectable());
assert!(!EntityCategory::Relation.pattern_detectable());
assert!(!EntityCategory::Misc.pattern_detectable());
}
#[test]
fn test_entity_category_is_relation() {
assert!(EntityCategory::Relation.is_relation());
assert!(!EntityCategory::Agent.is_relation());
assert!(!EntityCategory::Organization.is_relation());
assert!(!EntityCategory::Place.is_relation());
assert!(!EntityCategory::Temporal.is_relation());
assert!(!EntityCategory::Numeric.is_relation());
assert!(!EntityCategory::Contact.is_relation());
assert!(!EntityCategory::Creative.is_relation());
assert!(!EntityCategory::Misc.is_relation());
}
#[test]
fn test_entity_category_as_str() {
assert_eq!(EntityCategory::Agent.as_str(), "agent");
assert_eq!(EntityCategory::Organization.as_str(), "organization");
assert_eq!(EntityCategory::Place.as_str(), "place");
assert_eq!(EntityCategory::Creative.as_str(), "creative");
assert_eq!(EntityCategory::Temporal.as_str(), "temporal");
assert_eq!(EntityCategory::Numeric.as_str(), "numeric");
assert_eq!(EntityCategory::Contact.as_str(), "contact");
assert_eq!(EntityCategory::Relation.as_str(), "relation");
assert_eq!(EntityCategory::Misc.as_str(), "misc");
}
#[test]
fn test_entity_category_display() {
assert_eq!(format!("{}", EntityCategory::Agent), "agent");
assert_eq!(format!("{}", EntityCategory::Temporal), "temporal");
assert_eq!(format!("{}", EntityCategory::Relation), "relation");
}
#[test]
fn test_entity_type_serializes_to_flat_string() {
assert_eq!(
serde_json::to_string(&EntityType::Person).unwrap(),
r#""PER""#
);
assert_eq!(
serde_json::to_string(&EntityType::Organization).unwrap(),
r#""ORG""#
);
assert_eq!(
serde_json::to_string(&EntityType::Location).unwrap(),
r#""LOC""#
);
assert_eq!(
serde_json::to_string(&EntityType::Date).unwrap(),
r#""DATE""#
);
assert_eq!(
serde_json::to_string(&EntityType::Money).unwrap(),
r#""MONEY""#
);
}
#[test]
fn test_custom_entity_type_serializes_flat() {
let misc = EntityType::custom("MISC", EntityCategory::Misc);
assert_eq!(serde_json::to_string(&misc).unwrap(), r#""MISC""#);
let disease = EntityType::custom("DISEASE", EntityCategory::Agent);
assert_eq!(serde_json::to_string(&disease).unwrap(), r#""DISEASE""#);
}
#[test]
fn test_entity_type_deserializes_from_flat_string() {
let per: EntityType = serde_json::from_str(r#""PER""#).unwrap();
assert_eq!(per, EntityType::Person);
let org: EntityType = serde_json::from_str(r#""ORG""#).unwrap();
assert_eq!(org, EntityType::Organization);
let misc: EntityType = serde_json::from_str(r#""MISC""#).unwrap();
assert_eq!(misc, EntityType::custom("MISC", EntityCategory::Misc));
}
#[test]
fn test_entity_type_deserializes_backward_compat_custom() {
let json = r#"{"Custom":{"name":"MISC","category":"Misc"}}"#;
let et: EntityType = serde_json::from_str(json).unwrap();
assert_eq!(et, EntityType::custom("MISC", EntityCategory::Misc));
}
#[test]
fn test_entity_type_deserializes_backward_compat_other() {
let json = r#"{"Other":"foo"}"#;
let et: EntityType = serde_json::from_str(json).unwrap();
assert_eq!(et, EntityType::custom("foo", EntityCategory::Misc));
}
#[test]
fn test_entity_type_serde_roundtrip() {
let types = vec![
EntityType::Person,
EntityType::Organization,
EntityType::Location,
EntityType::Date,
EntityType::Time,
EntityType::Money,
EntityType::Percent,
EntityType::Quantity,
EntityType::Cardinal,
EntityType::Ordinal,
EntityType::Email,
EntityType::Url,
EntityType::Phone,
EntityType::custom("MISC", EntityCategory::Misc),
EntityType::custom("DISEASE", EntityCategory::Agent),
];
for t in &types {
let json = serde_json::to_string(t).unwrap();
let back: EntityType = serde_json::from_str(&json).unwrap();
assert_eq!(
t.as_label(),
back.as_label(),
"roundtrip failed for {:?}",
t
);
}
}
}