#![allow(missing_docs)]
use anno::{Error, Result};
use anno_core::EntityType;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct GoldEntity {
pub text: String,
pub entity_type: EntityType,
pub original_label: String,
pub start: usize,
pub end: usize,
}
impl GoldEntity {
#[must_use]
pub fn new(text: impl Into<String>, entity_type: EntityType, start: usize) -> Self {
let text = text.into();
let end = start + text.chars().count();
Self {
text,
entity_type,
original_label: String::new(),
start,
end,
}
}
pub fn with_span(
text: impl Into<String>,
entity_type: EntityType,
start: usize,
end: usize,
) -> Self {
Self {
text: text.into(),
entity_type,
original_label: String::new(),
start,
end,
}
}
pub fn with_label(
text: impl Into<String>,
entity_type: EntityType,
original_label: impl Into<String>,
start: usize,
) -> Self {
let text = text.into();
let end = start + text.chars().count();
Self {
text,
entity_type,
original_label: original_label.into(),
start,
end,
}
}
pub fn full(
text: impl Into<String>,
entity_type: EntityType,
original_label: impl Into<String>,
start: usize,
end: usize,
) -> Self {
Self {
text: text.into(),
entity_type,
original_label: original_label.into(),
start,
end,
}
}
pub fn overlaps(&self, other: &Self) -> bool {
self.start < other.end && other.start < self.end
}
#[must_use]
pub fn extract_text(&self, source_text: &str) -> String {
let char_count = source_text.chars().count();
if self.start >= char_count || self.end > char_count || self.start >= self.end {
return String::new();
}
source_text
.chars()
.skip(self.start)
.take(self.end - self.start)
.collect()
}
pub fn span_matches(&self, other: &Self) -> bool {
self.start == other.start && self.end == other.end
}
pub fn exact_matches(&self, other: &Self) -> bool {
self.span_matches(other) && self.entity_type == other.entity_type
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JSONNERExample {
pub text: String,
pub entities: Vec<JSONEntity>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JSONEntity {
pub text: String,
pub label: String,
pub start: usize,
pub end: usize,
#[serde(default)]
pub confidence: Option<f64>,
}
pub type JSONLNERExample = JSONNERExample;
pub fn load_json_ner_dataset<P: AsRef<Path>>(path: P) -> Result<Vec<(String, Vec<GoldEntity>)>> {
let content = std::fs::read_to_string(path.as_ref()).map_err(Error::Io)?;
let mut test_cases = Vec::new();
let is_jsonl = content.lines().count() > 1
&& content
.lines()
.all(|line| line.trim().starts_with('{') && line.trim().ends_with('}'));
if is_jsonl {
for (line_num, line) in content.lines().enumerate() {
let line = line.trim();
if line.is_empty() {
continue;
}
let example: JSONLNERExample = serde_json::from_str(line).map_err(|e| {
Error::Parse(format!(
"Failed to parse JSONL line {}: {}",
line_num + 1,
e
))
})?;
let entities: Vec<GoldEntity> = example
.entities
.into_iter()
.map(|e| {
let entity_type = map_label_to_entity_type(&e.label);
GoldEntity::full(e.text, entity_type, &e.label, e.start, e.end)
})
.collect();
let validation = crate::eval::validation::validate_ground_truth_entities(
&example.text,
&entities,
false, );
if !validation.is_valid {
return Err(Error::InvalidInput(format!(
"Invalid entities in dataset: {}",
validation.errors.join("; ")
)));
}
test_cases.push((example.text, entities));
}
} else {
let examples: Vec<JSONNERExample> = serde_json::from_str(&content)
.map_err(|e| Error::Parse(format!("Failed to parse JSON: {}", e)))?;
for example in examples {
let entities: Vec<GoldEntity> = example
.entities
.into_iter()
.map(|e| {
let entity_type = map_label_to_entity_type(&e.label);
GoldEntity::full(e.text, entity_type, &e.label, e.start, e.end)
})
.collect();
let validation = crate::eval::validation::validate_ground_truth_entities(
&example.text,
&entities,
false, );
if !validation.is_valid {
return Err(Error::InvalidInput(format!(
"Invalid entities in dataset: {}",
validation.errors.join("; ")
)));
}
test_cases.push((example.text, entities));
}
}
Ok(test_cases)
}
pub fn load_hf_ner_dataset<P: AsRef<Path>>(path: P) -> Result<Vec<(String, Vec<GoldEntity>)>> {
load_json_ner_dataset(path)
}
fn map_label_to_entity_type(label: &str) -> EntityType {
anno::schema::map_to_canonical(label, None)
}
pub fn load_ner_dataset<P: AsRef<Path>>(path: P) -> Result<Vec<(String, Vec<GoldEntity>)>> {
let path = path.as_ref();
let extension = path
.extension()
.and_then(|ext| ext.to_str())
.unwrap_or("")
.to_lowercase();
match extension.as_str() {
"conll" | "conll2003" | "txt" => {
load_conll_2003_dataset_internal(path).or_else(|_| {
load_json_ner_dataset(path)
})
}
"json" | "jsonl" => load_json_ner_dataset(path),
_ => {
load_conll_2003_dataset_internal(path).or_else(|_| load_json_ner_dataset(path))
}
}
}
fn load_conll_2003_dataset_internal<P: AsRef<Path>>(
path: P,
) -> Result<Vec<(String, Vec<GoldEntity>)>> {
crate::eval::load_conll2003(path)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetMetadata {
pub name: String,
pub format: String,
pub language: Option<String>,
pub entity_types: Vec<String>,
pub num_examples: usize,
pub source: Option<String>,
pub year: Option<u32>,
}
pub fn extract_dataset_metadata(
examples: &[(String, Vec<GoldEntity>)],
name: &str,
) -> DatasetMetadata {
let mut entity_types = std::collections::HashSet::new();
for (_, entities) in examples {
for entity in entities {
let type_str = crate::eval::entity_type_to_string(&entity.entity_type);
entity_types.insert(type_str);
}
}
DatasetMetadata {
name: name.to_string(),
format: "auto-detected".to_string(),
language: None,
entity_types: entity_types.into_iter().collect(),
num_examples: examples.len(),
source: None,
year: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::Write;
#[test]
fn test_load_json_ner_dataset() {
let json_content = r#"[
{
"text": "John Smith works at Acme Corp.",
"entities": [
{"text": "John Smith", "label": "PER", "start": 0, "end": 10},
{"text": "Acme Corp", "label": "ORG", "start": 20, "end": 29}
]
}
]"#;
let temp_dir = std::env::temp_dir();
let file_path = temp_dir.join("test_ner.json");
let mut file = File::create(&file_path).expect("should create test file");
file.write_all(json_content.as_bytes())
.expect("should write test file");
file.flush().expect("should flush test file");
let result = load_json_ner_dataset(&file_path).expect("should load test dataset");
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, "John Smith works at Acme Corp.");
assert_eq!(result[0].1.len(), 2);
std::fs::remove_file(&file_path).ok();
}
#[test]
fn test_load_jsonl_ner_dataset() {
let jsonl_content = r#"{"text": "John Smith works.", "entities": [{"text": "John Smith", "label": "PER", "start": 0, "end": 10}]}
{"text": "Acme Corp is hiring.", "entities": [{"text": "Acme Corp", "label": "ORG", "start": 0, "end": 9}]}
"#;
let temp_dir = std::env::temp_dir();
let file_path = temp_dir.join("test_ner.jsonl");
let mut file = File::create(&file_path).unwrap();
file.write_all(jsonl_content.as_bytes()).unwrap();
file.flush().unwrap();
let result = load_json_ner_dataset(&file_path).expect("should load test dataset");
assert_eq!(result.len(), 2);
std::fs::remove_file(&file_path).ok();
}
#[test]
fn test_map_label_to_entity_type() {
assert!(matches!(
map_label_to_entity_type("PER"),
EntityType::Person
));
assert!(matches!(
map_label_to_entity_type("ORG"),
EntityType::Organization
));
assert!(matches!(
map_label_to_entity_type("LOC"),
EntityType::Location
));
assert!(matches!(
map_label_to_entity_type("MISC"),
EntityType::Custom { .. }
));
assert!(matches!(
map_label_to_entity_type("ANIM"),
EntityType::Custom { .. }
));
}
#[test]
fn test_load_ner_dataset_auto_detect() {
let json_content = r#"[{"text": "Test", "entities": []}]"#;
let temp_dir = std::env::temp_dir();
let file_path = temp_dir.join("test_auto.json");
let mut file = File::create(&file_path).expect("should create test file");
file.write_all(json_content.as_bytes())
.expect("should write test file");
file.flush().expect("should flush test file");
let result = load_ner_dataset(&file_path);
assert!(result.is_ok());
std::fs::remove_file(&file_path).ok();
}
}