use std::collections::HashMap;
use std::io::{self, IsTerminal};
use std::sync::atomic::{AtomicU8, Ordering};
use anno::{Entity, GroundedDocument, Location, Signal};
#[cfg(feature = "eval")]
use anno::core::grounded::{EvalComparison, EvalMatch};
static COLOR_MODE: AtomicU8 = AtomicU8::new(0);
pub fn set_color_mode(mode: crate::cli::parser::ColorMode) {
use crate::cli::parser::ColorMode;
COLOR_MODE.store(
match mode {
ColorMode::Auto => 0,
ColorMode::Always => 1,
ColorMode::Never => 2,
},
Ordering::Relaxed,
);
}
pub fn log_info(msg: &str, quiet: bool) {
if !quiet {
eprintln!("{}", msg);
}
}
pub fn format_size(bytes: u64) -> String {
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
let mut size = bytes as f64;
let mut unit_idx = 0;
while size >= 1024.0 && unit_idx < UNITS.len() - 1 {
size /= 1024.0;
unit_idx += 1;
}
if unit_idx == 0 {
format!("{} {}", bytes, UNITS[unit_idx])
} else {
format!("{:.2} {}", size, UNITS[unit_idx])
}
}
pub fn color(code: &str, text: &str) -> String {
let use_color = match COLOR_MODE.load(Ordering::Relaxed) {
1 => true, 2 => false, _ => io::stdout().is_terminal(), };
if use_color {
format!("\x1b[{}m{}\x1b[0m", code, text)
} else {
text.to_string()
}
}
pub fn type_color(typ: &str) -> &'static str {
match typ.to_lowercase().as_str() {
"person" | "per" => "1;34",
"organization" | "org" => "1;32",
"location" | "loc" | "gpe" => "1;33",
"date" | "time" => "1;35",
"money" | "percent" => "1;36",
"email" | "url" | "phone" => "36",
_ => "1;37",
}
}
pub fn metric_colored(value: f64) -> String {
let code = if value >= 90.0 {
"1;32"
} else if value >= 70.0 {
"1;33"
} else if value >= 50.0 {
"33"
} else {
"1;31"
};
color(code, &format!("{:5.1}", value))
}
pub fn print_signals(doc: &GroundedDocument, text: &str, verbose_level: u8) {
let mut by_type: HashMap<String, Vec<&Signal<Location>>> = HashMap::new();
for s in doc.signals() {
by_type.entry(s.label().to_string()).or_default().push(s);
}
let text_len = text.chars().count();
if by_type.is_empty() {
if verbose_level == 0 {
println!(
"(no entities found - try -v for debugging or --model gliner for zero-shot NER)"
);
} else {
println!("(no entities found)");
}
return;
}
#[derive(Debug)]
struct TypeGroup<'a> {
label: String,
signals: Vec<&'a Signal<Location>>,
min_start: usize,
}
let mut groups: Vec<TypeGroup<'_>> = by_type
.into_iter()
.map(|(label, mut signals)| {
signals.sort_by(|a, b| {
let (a_start, a_end) = a.text_offsets().unwrap_or((usize::MAX, usize::MAX));
let (b_start, b_end) = b.text_offsets().unwrap_or((usize::MAX, usize::MAX));
a_start
.cmp(&b_start)
.then_with(|| a_end.cmp(&b_end))
.then_with(|| a.surface().cmp(b.surface()))
});
let min_start = signals
.iter()
.filter_map(|s| s.text_offsets().map(|(start, _)| start))
.min()
.unwrap_or(usize::MAX);
TypeGroup {
label,
signals,
min_start,
}
})
.collect();
groups.sort_by(|a, b| {
a.min_start
.cmp(&b.min_start)
.then_with(|| a.label.cmp(&b.label))
});
if verbose_level == 0 {
for g in &groups {
let col = type_color(&g.label);
let entities: Vec<String> = g
.signals
.iter()
.map(|s| format!("\"{}\"", s.surface()))
.collect();
println!(
"{}:{} {}",
color(col, &g.label),
g.signals.len(),
entities.join(" ")
);
}
return;
}
for g in &groups {
let col = type_color(&g.label);
println!("{}:{}", color(col, &g.label), g.signals.len());
for s in &g.signals {
let (start, end) = s.text_offsets().unwrap_or((0, 0));
let conf_str = format!("({:.2})", s.confidence);
let neg = if s.negated {
color("31", " [NEG]")
} else {
String::new()
};
let quant = s
.quantifier
.map(|q| color("35", &format!(" [{:?}]", q)))
.unwrap_or_default();
print!(" \"{}\" {}", s.surface(), color("90", &conf_str));
if !neg.is_empty() || !quant.is_empty() {
print!("{}{}", neg, quant);
}
println!();
let ctx_start = start.saturating_sub(30);
let ctx_end = (end + 30).min(text_len);
let before: String = text
.chars()
.skip(ctx_start)
.take(start.saturating_sub(ctx_start))
.collect();
let entity: String = text.chars().skip(start).take(end - start).collect();
let after: String = text.chars().skip(end).take(ctx_end - end).collect();
println!(
" {}{}{}{}{}",
if ctx_start > 0 {
color("90", "...")
} else {
String::new()
},
color("90", &before),
color("1;33", &entity),
color("90", &after),
if ctx_end < text_len {
color("90", "...")
} else {
String::new()
}
);
}
}
let tracks: Vec<_> = doc.tracks().collect();
if verbose_level >= 2 && !tracks.is_empty() {
println!();
println!("{}:", color("1;36", "Coreference"));
for track in &tracks {
let track_type = track
.entity_type
.as_ref()
.map(|t| t.as_str())
.unwrap_or("-");
let mentions: Vec<String> = track
.signals
.iter()
.filter_map(|s| doc.get_signal(s.signal_id))
.map(|sig| format!("\"{}\"", sig.surface()))
.collect();
let identity_link = track
.identity_id
.map(|id| format!(" -> I{}", id))
.unwrap_or_default();
let _cluster_conf = if verbose_level >= 3 {
format!(" (conf:{:.2})", track.cluster_confidence)
} else {
String::new()
};
if mentions.len() > 1 {
println!(
" \"{}\" [{}] → {}",
track.canonical_surface,
track_type,
mentions.join(" ")
);
}
if !identity_link.is_empty() {
println!(" {}", identity_link);
}
}
}
if verbose_level >= 2 {
let stats = doc.stats();
println!();
println!(
"{}: {} entities, {} tracks, {} identities, avg confidence {:.2}",
color("90", "stats"),
stats.signal_count,
stats.track_count,
stats.identity_count,
stats.avg_confidence
);
}
if verbose_level >= 3 {
let identities: Vec<_> = doc.identities().collect();
if !identities.is_empty() {
println!();
println!("{}:", color("1;35", "Identities"));
for identity in &identities {
let kb_info =
if let (Some(kb_name), Some(kb_id)) = (&identity.kb_name, &identity.kb_id) {
format!(" [{}/{}]", kb_name, kb_id)
} else {
String::new()
};
let aliases = if !identity.aliases.is_empty() {
format!(" aliases: {}", identity.aliases.join(", "))
} else {
String::new()
};
let desc = identity
.description
.as_deref()
.map(|d| format!(" desc: \"{}\"", d))
.unwrap_or_default();
println!(
" I{}: \"{}\"{}{}{}",
identity.id, identity.canonical_name, kb_info, aliases, desc
);
}
}
println!();
println!("{}:", color("1;37", "Annotated text"));
print_annotated_signals(text, doc.signals());
}
}
pub fn print_annotated_entities(text: &str, entities: &[Entity]) {
let mut sorted: Vec<&Entity> = entities.iter().collect();
sorted.sort_by_key(|e| e.start());
let chars: Vec<char> = text.chars().collect();
let char_len = chars.len();
let mut result = String::new();
let mut last_end = 0;
for e in sorted {
if e.start() >= char_len || e.end() > char_len || e.start() >= e.end() {
continue;
}
if e.start() < last_end {
continue;
}
if e.start() > last_end {
let before: String = chars[last_end..e.start()].iter().collect();
result.push_str(&before);
}
let col = type_color(e.entity_type.as_label());
let entity_text: String = chars[e.start()..e.end()].iter().collect();
result.push_str(&color(
col,
&format!("[{}: {}]", e.entity_type.as_label(), entity_text),
));
last_end = e.end();
}
if last_end < char_len {
let after: String = chars[last_end..].iter().collect();
result.push_str(&after);
}
println!();
for line in result.lines() {
println!(" {}", line);
}
}
pub fn print_annotated_signals(text: &str, signals: &[Signal<Location>]) {
let mut sorted: Vec<&Signal<Location>> = signals.iter().collect();
sorted.sort_by_key(|s| s.text_offsets().map(|(start, _)| start).unwrap_or(0));
let chars: Vec<char> = text.chars().collect();
let char_len = chars.len();
let mut result = String::new();
let mut last_end = 0;
for s in sorted {
let (start, end) = match s.text_offsets() {
Some((start, end)) => (start, end),
None => continue,
};
if start >= char_len || end > char_len || start >= end {
continue;
}
if start < last_end {
continue;
}
if start > last_end {
let before: String = chars[last_end..start].iter().collect();
result.push_str(&before);
}
let col = type_color(s.label());
let entity_text: String = chars[start..end].iter().collect();
result.push_str(&color(col, &format!("[{}: {}]", s.label(), entity_text)));
last_end = end;
}
if last_end < char_len {
let after: String = chars[last_end..].iter().collect();
result.push_str(&after);
}
println!();
for line in result.lines() {
println!(" {}", line);
}
}
#[cfg(feature = "eval")]
pub fn print_matches(cmp: &EvalComparison, _verbose: bool) {
for m in &cmp.matches {
match m {
EvalMatch::Correct { gold_id, .. } => {
let g = cmp.gold.iter().find(|s| s.id == *gold_id);
println!(
" {} {}: [{}] \"{}\"",
color("32", "+"),
color("32", "correct"),
g.map(|s| s.label.as_str()).unwrap_or("?"),
g.map(|s| s.surface()).unwrap_or("?")
);
}
EvalMatch::TypeMismatch {
gold_id,
gold_label,
pred_label,
..
} => {
let g = cmp.gold.iter().find(|s| s.id == *gold_id);
println!(
" {} {}: \"{}\" ({} -> {})",
color("33", "!"),
color("33", "type mismatch"),
g.map(|s| s.surface()).unwrap_or("?"),
gold_label,
pred_label
);
}
EvalMatch::BoundaryError {
gold_id,
pred_id,
iou,
} => {
let g = cmp.gold.iter().find(|s| s.id == *gold_id);
let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
println!(
" {} {}: gold=\"{}\" pred=\"{}\" (IoU={:.2})",
color("33", "!"),
color("33", "boundary"),
g.map(|s| s.surface()).unwrap_or("?"),
p.map(|s| s.surface()).unwrap_or("?"),
iou
);
}
EvalMatch::Spurious { pred_id } => {
let p = cmp.predicted.iter().find(|s| s.id == *pred_id);
println!(
" {} {}: [{}] \"{}\"",
color("31", "x"),
color("31", "false positive"),
p.map(|s| s.label.as_str()).unwrap_or("?"),
p.map(|s| s.surface()).unwrap_or("?")
);
}
EvalMatch::Missed { gold_id } => {
let g = cmp.gold.iter().find(|s| s.id == *gold_id);
println!(
" {} {}: [{}] \"{}\"",
color("31", "x"),
color("31", "false negative"),
g.map(|s| s.label.as_str()).unwrap_or("?"),
g.map(|s| s.surface()).unwrap_or("?")
);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_type_color() {
assert_eq!(type_color("PER"), "1;34");
assert_eq!(type_color("person"), "1;34");
assert_eq!(type_color("ORG"), "1;32");
assert_eq!(type_color("LOC"), "1;33");
assert_eq!(type_color("UNKNOWN"), "1;37");
}
#[test]
fn test_metric_colored() {
let result = metric_colored(95.0);
assert!(result.contains("95.0"));
let result = metric_colored(75.0);
assert!(result.contains("75.0"));
let result = metric_colored(30.0);
assert!(result.contains("30.0"));
}
#[test]
fn test_color_function() {
let result = color("32", "test");
assert!(result.contains("test"));
}
}