use std::collections::HashMap;
use std::fs;
use std::io::{self, IsTerminal, Read};
use anno::{GroundedDocument, Identity, IdentityId, SignalId, TrackId};
#[cfg(not(any(feature = "extractor-readability", feature = "extractor-html2text")))]
compile_error!(
"anno-cli requires at least one HTML extractor feature: \
`extractor-readability` (default, best quality) or \
`extractor-html2text` (no MPL-2.0 transitives, lower quality)"
);
#[cfg(feature = "extractor-readability")]
pub(crate) fn extract_html(html: &str, url: Option<&str>) -> deformat::Extracted {
deformat::extract_readable(html, url)
}
#[cfg(all(
not(feature = "extractor-readability"),
feature = "extractor-html2text"
))]
pub(crate) fn extract_html(html: &str, _url: Option<&str>) -> deformat::Extracted {
deformat::extract_html2text(html, 80)
}
pub(crate) fn extract_html_to_text(html: &str, url: Option<&str>) -> String {
extract_html(html, url).text
}
pub fn get_input_text(
text: &Option<String>,
file: Option<&str>,
positional: &[String],
) -> Result<String, String> {
if let Some(t) = text {
return Ok(sanitize_input(t));
}
if let Some(f) = file {
return read_input_file(f);
}
if !positional.is_empty() {
let joined = positional.join(" ");
return Ok(sanitize_input(&joined));
}
if !io::stdin().is_terminal() {
let mut buf = String::new();
io::stdin()
.read_to_string(&mut buf)
.map_err(|e| format_error("read stdin", &e.to_string()))?;
if !buf.is_empty() {
if deformat::detect::is_html(&buf) {
eprintln!("note: detected HTML content on stdin, converting to text");
return Ok(extract_html_to_text(&buf, None));
}
return Ok(sanitize_input(&buf));
}
}
Err("No input text provided. Use -t 'text' or -f file or pipe via stdin".to_string())
}
fn sanitize_input(input: &str) -> String {
let mut result = input.to_string();
let path_patterns = [
r"/Users/[^/]+/Documents/",
r"/home/[^/]+/",
r"/tmp/[^/]+/",
r"cargo run",
r"cargo test",
r"target/debug/",
r"target/release/",
];
for pattern in &path_patterns {
if let Some(pos) = result.find(pattern) {
if pos == 0 || result[..pos].trim_end().ends_with(' ') {
let end_pos = pos + pattern.len();
if end_pos >= result.len() || result[end_pos..].starts_with(' ') {
result.replace_range(pos..end_pos, "");
}
}
}
}
let flag_patterns = [
(r" -vv ", " "),
(r" -vvv ", " "),
(r" --verbose ", " "),
(r" -v ", " "), ];
for (pattern, replacement) in &flag_patterns {
result = result.replace(pattern, replacement);
}
while result.contains(" ") {
result = result.replace(" ", " ");
}
result.trim().to_string()
}
pub fn read_input_file(path: &str) -> Result<String, String> {
let is_pdf_ext = path.ends_with(".pdf");
let is_pdf_magic = if !is_pdf_ext {
fs::read(path)
.ok()
.map(|bytes| bytes.starts_with(b"%PDF"))
.unwrap_or(false)
} else {
false
};
if is_pdf_ext || is_pdf_magic {
#[cfg(feature = "pdf")]
{
eprintln!("note: extracting text from PDF '{}'", path);
let result = deformat::pdf::extract_file(std::path::Path::new(path))
.map_err(|e| format_error("extract PDF text", &format!("{}: {}", path, e)))?;
return Ok(result.text);
}
#[cfg(not(feature = "pdf"))]
{
return Err(format!(
"File '{}' appears to be a PDF, but the `pdf` feature is not enabled. \
Rebuild with `--features pdf` to enable PDF extraction.",
path
));
}
}
let content = fs::read_to_string(path)
.map_err(|e| format_error("read file", &format!("{}: {}", path, e)))?;
let is_html_ext = path.ends_with(".html") || path.ends_with(".htm") || path.ends_with(".xhtml");
if is_html_ext || deformat::detect::is_html(&content) {
eprintln!(
"note: detected HTML content in '{}', converting to text",
path
);
Ok(extract_html_to_text(&content, None))
} else {
Ok(content)
}
}
pub fn parse_grounded_document(json: &str) -> Result<GroundedDocument, String> {
serde_json::from_str(json)
.map_err(|e| format_error("parse GroundedDocument JSON", &e.to_string()))
}
pub fn format_error(operation: &str, details: &str) -> String {
format!("Failed to {}: {}", operation, details)
}
pub fn log_success(msg: &str, quiet: bool) {
if !quiet {
use super::output::color;
eprintln!("{} {}", color("32", "✓"), msg);
}
}
#[derive(Debug, Clone)]
pub struct GoldSpec {
pub text: String,
pub label: String,
pub start: usize,
pub end: usize,
}
pub fn parse_gold_spec(s: &str) -> Option<GoldSpec> {
let parts: Vec<&str> = s.rsplitn(4, ':').collect();
if parts.len() < 4 {
return None;
}
let end: usize = parts[0].parse().ok()?;
let start: usize = parts[1].parse().ok()?;
let label = parts[2].to_string();
let text = parts[3].to_string();
Some(GoldSpec {
text,
label,
start,
end,
})
}
pub fn load_gold_from_file(path: &str) -> Result<Vec<GoldSpec>, String> {
let content =
fs::read_to_string(path).map_err(|e| format!("Failed to read {}: {}", path, e))?;
let mut gold = Vec::new();
let mut warnings = Vec::new();
for (line_num, line) in content.lines().enumerate() {
if line.trim().is_empty() {
continue;
}
let entry: serde_json::Value = serde_json::from_str(line)
.map_err(|e| format!("Invalid JSON in gold file at line {}: {}", line_num + 1, e))?;
if let Some(entities) = entry["entities"].as_array() {
for (i, ent) in entities.iter().enumerate() {
let start = match ent["start"].as_u64() {
Some(v) => v as usize,
None => {
warnings.push(format!(
"{}:{}: entity[{}] missing 'start' field, defaulting to 0",
path,
line_num + 1,
i
));
0
}
};
let end = match ent["end"].as_u64() {
Some(v) => v as usize,
None => {
warnings.push(format!(
"{}:{}: entity[{}] missing 'end' field, defaulting to 0",
path,
line_num + 1,
i
));
0
}
};
gold.push(GoldSpec {
text: ent["text"].as_str().unwrap_or("").to_string(),
label: ent["type"]
.as_str()
.or(ent["label"].as_str())
.unwrap_or("UNK")
.to_string(),
start,
end,
});
}
}
}
for warning in &warnings {
use super::output::color;
eprintln!("{} {}", color("33", "warning:"), warning);
}
Ok(gold)
}
pub fn resolve_coreference(doc: &mut GroundedDocument, text: &str, _signal_ids: &[SignalId]) {
#[cfg(feature = "onnx")]
{
use anno::backends::coref::fcoref::FCoref;
let fcoref_paths = [
std::env::var("FCOREF_MODEL_PATH").ok(),
std::env::var("HOME")
.ok()
.map(|h| format!("{h}/.cache/anno/models/fcoref")),
dirs::cache_dir().map(|d| d.join("anno/models/fcoref").to_string_lossy().into_owned()),
];
for path in fcoref_paths.iter().flatten() {
if std::path::Path::new(path).join("encoder.onnx").exists() {
match FCoref::from_path(path) {
Ok(model) => match model.resolve(text) {
Ok(clusters) => {
fcoref_clusters_to_tracks(doc, &clusters, text);
return;
}
Err(e) => {
log::debug!("f-coref inference failed: {e}");
}
},
Err(e) => {
log::debug!("f-coref load failed from {path}: {e}");
}
}
}
}
match FCoref::from_pretrained("biu-nlp/f-coref") {
Ok(model) => match model.resolve(text) {
Ok(clusters) => {
fcoref_clusters_to_tracks(doc, &clusters, text);
return;
}
Err(e) => {
log::debug!("f-coref inference failed: {e}");
}
},
Err(e) => {
log::debug!("f-coref not available: {e}");
}
}
}
let coref = anno::backends::coref::mention_ranking::MentionRankingCoref::new();
if let Err(e) = coref.resolve_into_document(text, doc) {
use super::output::color;
eprintln!("{} coref failed: {}", color("33", "warning:"), e);
}
}
#[cfg(feature = "onnx")]
fn fcoref_clusters_to_tracks(
doc: &mut GroundedDocument,
clusters: &[anno::backends::coref::resolve::CorefCluster],
_text: &str,
) {
for cluster in clusters {
if cluster.mentions.len() < 2 || cluster.spans.len() != cluster.mentions.len() {
continue;
}
let signals = cluster.mentions.iter().zip(cluster.spans.iter()).map(
|(mention, &(char_start, char_end))| {
anno::Signal::new(
0u64, anno::Location::text(char_start, char_end),
mention.clone(),
anno::TypeLabel::from("COREF_mention"),
1.0,
)
},
);
let new_ids = doc.add_signals(signals);
doc.create_track_from_signals(cluster.canonical.clone(), &new_ids);
}
}
pub fn link_tracks_to_kb(doc: &mut GroundedDocument) {
let known_entities: HashMap<&str, (&str, &str)> = [
(
"barack obama",
("demo:barack_obama", "44th President of the United States"),
),
(
"angela merkel",
("demo:angela_merkel", "Chancellor of Germany 2005-2021"),
),
("berlin", ("demo:berlin", "Capital of Germany")),
("nato", ("demo:nato", "North Atlantic Treaty Organization")),
(
"donald trump",
("demo:donald_trump", "45th President of the United States"),
),
(
"joe biden",
("demo:joe_biden", "46th President of the United States"),
),
(
"vladimir putin",
("demo:vladimir_putin", "President of Russia"),
),
(
"emmanuel macron",
("demo:emmanuel_macron", "President of France"),
),
("lynn conway", ("demo:lynn_conway", "VLSI pioneer")),
("sophie wilson", ("demo:sophie_wilson", "Co-designed ARM")),
(
"albert einstein",
("demo:albert_einstein", "Theoretical physicist"),
),
("new york", ("demo:new_york", "City in New York State")),
("london", ("demo:london", "Capital of the United Kingdom")),
("paris", ("demo:paris", "Capital of France")),
("google", ("demo:google", "American technology company")),
("apple", ("demo:apple", "American technology company")),
(
"microsoft",
("demo:microsoft", "American technology company"),
),
(
"united nations",
("demo:united_nations", "International organization"),
),
(
"european union",
("demo:european_union", "Political and economic union"),
),
]
.into_iter()
.collect();
let track_ids: Vec<TrackId> = doc.tracks().map(|t| t.id).collect();
for track_id in track_ids {
let (canonical, entity_type) = {
let track = match doc.get_track(track_id) {
Some(t) => t,
None => continue,
};
(track.canonical_surface.clone(), track.entity_type.clone())
};
let canonical_lower = canonical.to_lowercase();
if let Some(&(kb_id, description)) = known_entities.get(canonical_lower.as_str()) {
let mut identity = Identity::from_kb(IdentityId::ZERO, &canonical, "demo", kb_id);
identity.aliases.push(description.to_string());
if let Some(etype) = &entity_type {
identity.entity_type = Some(etype.clone());
}
let identity_id = doc.add_identity(identity);
doc.link_track_to_identity(track_id, identity_id);
} else {
let kb_id = format!("demo:{}", canonical_lower.replace(' ', "_"));
let identity = Identity::from_kb(IdentityId::ZERO, &canonical, "demo", &kb_id);
let identity_id = doc.add_identity(identity);
doc.link_track_to_identity(track_id, identity_id);
}
}
}
pub fn find_similar_models(query: &str, candidates: &[&str]) -> Vec<String> {
let query_lower = query.to_lowercase();
let mut matches: Vec<(f64, &str)> = candidates
.iter()
.filter_map(|&candidate| {
let candidate_lower = candidate.to_lowercase();
if candidate_lower.starts_with(&query_lower)
|| query_lower.starts_with(&candidate_lower)
{
Some((0.9, candidate))
} else if candidate_lower.contains(&query_lower)
|| query_lower.contains(&candidate_lower)
{
Some((0.7, candidate))
} else {
if candidate_lower.chars().next() == query_lower.chars().next() {
Some((0.5, candidate))
} else {
None
}
}
})
.collect();
matches.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
matches
.into_iter()
.take(3)
.map(|(_, name)| name.to_string())
.collect()
}
pub fn get_cache_dir() -> Result<std::path::PathBuf, String> {
#[cfg(feature = "eval")]
{
use dirs::cache_dir;
if let Some(mut cache) = cache_dir() {
cache.push("anno");
fs::create_dir_all(&cache)
.map_err(|e| format!("Failed to create cache directory: {}", e))?;
Ok(cache)
} else {
Ok(std::path::PathBuf::from(".anno-cache"))
}
}
#[cfg(not(feature = "eval"))]
{
Ok(std::path::PathBuf::from(".anno-cache"))
}
}
pub fn get_config_dir() -> Result<std::path::PathBuf, String> {
if let Ok(dir) = std::env::var("ANNO_CONFIG_DIR") {
let path = std::path::PathBuf::from(dir);
fs::create_dir_all(&path)
.map_err(|e| format!("Failed to create config directory: {}", e))?;
return Ok(path);
}
#[cfg(feature = "eval")]
{
use dirs::config_dir;
if let Some(mut config) = config_dir() {
config.push("anno");
fs::create_dir_all(&config)
.map_err(|e| format!("Failed to create config directory: {}", e))?;
Ok(config)
} else {
Ok(std::path::PathBuf::from(".anno-config"))
}
}
#[cfg(not(feature = "eval"))]
{
Ok(std::path::PathBuf::from(".anno-config"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "eval")]
use anno::{Entity, EntityType};
#[test]
fn test_parse_gold_spec_basic() {
let spec = parse_gold_spec("John:PER:0:4").unwrap();
assert_eq!(spec.text, "John");
assert_eq!(spec.label, "PER");
assert_eq!(spec.start, 0);
assert_eq!(spec.end, 4);
}
#[test]
fn test_parse_gold_spec_with_colon_in_text() {
let spec = parse_gold_spec("http://example.com:URL:5:22").unwrap();
assert_eq!(spec.text, "http://example.com");
assert_eq!(spec.label, "URL");
assert_eq!(spec.start, 5);
assert_eq!(spec.end, 22);
}
#[test]
fn test_parse_gold_spec_invalid() {
assert!(parse_gold_spec("invalid").is_none());
assert!(parse_gold_spec("only:two").is_none());
assert!(parse_gold_spec("text:label:notanumber:4").is_none());
}
#[test]
fn test_find_similar_models() {
let candidates = &["gliner", "gliner-candle", "heuristic", "pattern", "stacked"];
let matches = find_similar_models("gli", candidates);
assert!(matches.contains(&"gliner".to_string()));
assert!(matches.contains(&"gliner-candle".to_string()));
let matches = find_similar_models("pattern", candidates);
assert!(matches.contains(&"pattern".to_string()));
let matches = find_similar_models("xyz", candidates);
assert!(matches.is_empty() || matches.len() <= 3);
}
#[cfg(feature = "eval")]
use anno_eval::eval::relation::create_entity_pair_relations;
#[cfg(feature = "eval")]
#[test]
fn test_create_entity_pair_relations_founded() {
let entities = vec![
Entity::new("Steve Jobs", EntityType::Person, 0, 10, 0.9),
Entity::new("Apple", EntityType::Organization, 30, 35, 0.9),
];
let text = "Steve Jobs, who founded Apple in 1976, changed the world.";
let relation_types = &["FOUNDED", "WORKS_FOR"];
let relations = create_entity_pair_relations(&entities, text, relation_types);
assert_eq!(relations.len(), 1);
assert_eq!(relations[0].relation_type, "FOUNDED");
}
#[cfg(feature = "eval")]
#[test]
fn test_create_entity_pair_relations_unknown_fallback() {
let entities = vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Bob", EntityType::Person, 15, 18, 0.9),
];
let text = "Alice met Bob yesterday.";
let relation_types = &["FOUNDED", "WORKS_FOR"];
let relations = create_entity_pair_relations(&entities, text, relation_types);
assert_eq!(relations.len(), 1);
assert_eq!(
relations[0].relation_type, "UNKNOWN",
"Unknown relations should use UNKNOWN, not first gold type"
);
}
#[cfg(feature = "eval")]
#[test]
fn test_create_entity_pair_relations_max_distance() {
let entities = vec![
Entity::new("Alice", EntityType::Person, 0, 5, 0.9),
Entity::new("Bob", EntityType::Person, 300, 303, 0.9),
];
let text = &"x".repeat(400); let relation_types = &["RELATED"];
let relations = create_entity_pair_relations(&entities, text, relation_types);
assert!(relations.is_empty());
}
#[cfg(feature = "eval")]
#[test]
fn test_create_entity_pair_relations_overlapping_entities() {
let entities = vec![
Entity::new("New York", EntityType::Location, 0, 8, 0.9),
Entity::new("New York City", EntityType::Location, 0, 13, 0.9), ];
let text = "New York City is great.";
let relation_types = &["LOCATED_IN"];
let relations = create_entity_pair_relations(&entities, text, relation_types);
assert!(relations.is_empty());
}
#[test]
fn read_input_file_html_strips_nav_footer() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.html");
fs::write(
&path,
r#"<!DOCTYPE html>
<html><head><title>Test</title></head>
<body>
<nav>Menu</nav>
<p>Angela Merkel met Emmanuel Macron in Berlin.</p>
<footer>Copyright</footer>
</body></html>"#,
)
.unwrap();
let text = read_input_file(path.to_str().unwrap()).unwrap();
assert!(
text.contains("Angela Merkel"),
"should extract person names, got: {}",
text
);
assert!(
text.contains("Berlin"),
"should extract location, got: {}",
text
);
assert!(
!text.contains("Menu"),
"nav content should be stripped, got: {}",
text
);
assert!(
!text.contains("Copyright"),
"footer content should be stripped, got: {}",
text
);
}
#[test]
fn read_input_file_plain_text_passthrough() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.txt");
fs::write(&path, "Tim Cook met Sundar Pichai in Seattle.").unwrap();
let text = read_input_file(path.to_str().unwrap()).unwrap();
assert_eq!(text, "Tim Cook met Sundar Pichai in Seattle.");
}
#[test]
fn read_input_file_html_detected_by_content() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("page.dat");
fs::write(&path, "<html><body><p>Hello World</p></body></html>").unwrap();
let text = read_input_file(path.to_str().unwrap()).unwrap();
assert!(text.contains("Hello World"));
assert!(!text.contains("<p>"));
}
#[test]
fn read_input_file_nonexistent_errors() {
let result = read_input_file("/tmp/anno-does-not-exist-12345.txt");
assert!(result.is_err());
}
#[cfg(feature = "pdf")]
#[test]
fn read_input_file_pdf_extension_detected() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("fake.pdf");
fs::write(&path, "not a real pdf").unwrap();
let result = read_input_file(path.to_str().unwrap());
assert!(result.is_err());
}
#[cfg(feature = "extractor-readability")]
#[test]
fn html_content_detected_for_stripping() {
let html = r#"<!DOCTYPE html>
<html><head><title>Test</title></head>
<body>
<nav>Skip this</nav>
<article><p>Marie Curie discovered radium at the University of Paris.</p></article>
<footer>Copyright</footer>
</body></html>"#;
assert!(
deformat::detect::is_html(html),
"should detect HTML content"
);
let text = super::extract_html(html, None).text;
assert!(text.contains("Marie Curie"), "should extract article text");
assert!(
text.contains("University of Paris"),
"should extract org name"
);
assert!(!text.contains("<nav>"), "should not contain HTML tags");
assert!(
!text.contains("Skip this"),
"nav content should be stripped"
);
}
#[test]
fn plain_text_not_detected_as_html() {
let text = "Angela Merkel met Emmanuel Macron in Berlin on Tuesday.";
assert!(
!deformat::detect::is_html(text),
"plain text should not trigger HTML detection"
);
}
#[cfg(not(feature = "pdf"))]
#[test]
fn read_input_file_pdf_without_feature_errors() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("test.pdf");
fs::write(&path, "%PDF-1.4 fake").unwrap();
let result = read_input_file(path.to_str().unwrap());
assert!(result.is_err());
assert!(
result.unwrap_err().contains("pdf"),
"should mention pdf feature"
);
}
}