use clap::Parser;
use std::fs;
use super::super::output::{color, print_signals};
use super::super::parser::{ModelBackend, OutputFormat};
use super::super::utils::{link_tracks_to_kb, resolve_coreference};
#[cfg(feature = "eval")]
use anno::{Entity, EntityType};
use anno::{GroundedDocument, Signal, SignalId};
#[derive(Parser, Debug)]
pub struct PipelineArgs {
pub text: Vec<String>,
#[arg(short, long, value_name = "PATH", visible_alias = "input")]
pub files: Vec<String>,
#[arg(short, long, value_name = "DIR", visible_alias = "input-dir")]
pub dir: Option<String>,
#[arg(short, long, default_value = "stacked")]
pub model: ModelBackend,
#[arg(long)]
pub coref: bool,
#[arg(long)]
pub link_kb: bool,
#[arg(long)]
pub cross_doc: bool,
#[arg(long, default_value = "0.6")]
pub threshold: f64,
#[arg(long, default_value = "human")]
pub format: OutputFormat,
#[arg(short, long, value_name = "PATH")]
pub output: Option<String>,
#[arg(long)]
pub progress: bool,
#[arg(short, long)]
pub quiet: bool,
}
pub fn run(args: PipelineArgs) -> Result<(), String> {
let mut texts: Vec<(String, String)> = Vec::new();
if !args.text.is_empty() {
for (idx, text) in args.text.iter().enumerate() {
texts.push((format!("text{}", idx + 1), text.clone()));
}
}
for file_path in &args.files {
let text = fs::read_to_string(file_path)
.map_err(|e| format!("Failed to read {}: {}", file_path, e))?;
let doc_id = std::path::Path::new(file_path)
.file_stem()
.and_then(|s| s.to_str())
.map(|s| s.to_string())
.unwrap_or_else(|| file_path.clone());
texts.push((doc_id, text));
}
if let Some(dir) = &args.dir {
let dir_path = std::path::Path::new(dir);
let entries = fs::read_dir(dir_path)
.map_err(|e| format!("Failed to read directory {}: {}", dir, e))?;
for entry in entries {
let entry = entry.map_err(|e| format!("Failed to read entry: {}", e))?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension() {
if ext == "txt" || ext == "md" {
let text = fs::read_to_string(&path)
.map_err(|e| format!("Failed to read {}: {}", path.display(), e))?;
let doc_id = path
.file_stem()
.and_then(|s| s.to_str())
.map(|s| s.to_string())
.unwrap_or_else(|| format!("doc{}", texts.len()));
texts.push((doc_id, text));
}
}
}
}
}
if texts.is_empty() {
return Err(
"No input provided. Use positional TEXT, --files/--input, or --dir/--input-dir."
.to_string(),
);
}
let model = args.model.create_model()?;
let mut documents: Vec<GroundedDocument> = Vec::new();
let pb = if args.progress && !args.quiet {
use indicatif::{ProgressBar, ProgressStyle};
let pb = ProgressBar::new(texts.len() as u64);
let style = ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} {msg}")
.expect("Progress bar template should be valid");
pb.set_style(style.progress_chars("#>-"));
Some(pb)
} else {
None
};
for (doc_id, text) in &texts {
if let Some(ref pb) = pb {
pb.set_message(format!("Processing {}", doc_id));
}
let entities = model
.extract_entities(text, None)
.map_err(|e| format!("Extraction failed for {}: {}", doc_id, e))?;
let mut doc = GroundedDocument::new(doc_id, text);
let mut signal_ids: Vec<SignalId> = Vec::new();
for e in &entities {
let id = doc.add_signal(Signal::from(e));
signal_ids.push(id);
}
if args.coref {
resolve_coreference(&mut doc, text, &signal_ids);
}
if args.link_kb {
link_tracks_to_kb(&mut doc);
}
documents.push(doc);
if let Some(ref pb) = pb {
pb.inc(1);
}
}
if let Some(ref pb) = pb {
pb.finish_with_message("Processing complete");
}
if args.cross_doc {
#[cfg(feature = "eval")]
{
use anno_eval::cdcr::{CDCRConfig, CDCRResolver, Document};
let cdcr_docs: Vec<Document> = documents
.iter()
.map(|doc| {
let entities: Vec<_> = doc
.signals()
.iter()
.map(|s| {
let (start, end) = s.text_offsets().unwrap_or((0, 0));
anno::Entity::new(
s.surface(),
anno::EntityType::from_label(s.label()),
start,
end,
f64::from(s.confidence),
)
})
.collect();
Document::new(doc.id(), doc.text()).with_entities(entities)
})
.collect();
let config = CDCRConfig {
min_similarity: args.threshold,
require_type_match: false,
..Default::default()
};
let resolver = CDCRResolver::with_config(config);
let clusters = resolver.resolve(&cdcr_docs);
match args.format {
OutputFormat::Json | OutputFormat::Grounded => {
let output = serde_json::to_string_pretty(&clusters)
.map_err(|e| format!("Failed to serialize clusters: {}", e))?;
if let Some(output_path) = &args.output {
fs::write(output_path, output)
.map_err(|e| format!("Failed to write output: {}", e))?;
} else {
println!("{}", output);
}
}
OutputFormat::Tree => {
let doc_index: std::collections::HashMap<_, _> =
cdcr_docs.iter().map(|doc| (doc.id.clone(), doc)).collect();
for cluster in &clusters {
println!("Cluster {}: {}", cluster.id, cluster.canonical_name);
for (doc_id, entity_idx) in &cluster.mentions {
let mention_text = doc_index
.get(doc_id.as_str())
.and_then(|doc| doc.entities.get(*entity_idx))
.map(|e| e.text.clone())
.unwrap_or_else(|| format!("entity_{}", entity_idx));
println!(" - {} (doc: {})", mention_text, doc_id);
}
println!();
}
}
_ => {
println!();
println!(
"{} Cross-document clusters: {}",
color("1;36", "Found"),
clusters.len()
);
for cluster in &clusters {
println!(
" {}: {} mentions across {} documents",
cluster.canonical_name,
cluster.mentions.len(),
cluster.doc_count()
);
}
}
}
}
#[cfg(not(feature = "eval"))]
{
return Err("Cross-document clustering requires 'eval' feature".to_string());
}
} else {
match args.format {
OutputFormat::Json | OutputFormat::Grounded => {
let output = serde_json::to_string_pretty(&documents)
.map_err(|e| format!("Failed to serialize documents: {}", e))?;
if let Some(output_path) = &args.output {
fs::write(output_path, output)
.map_err(|e| format!("Failed to write output: {}", e))?;
} else {
println!("{}", output);
}
}
_ => {
for doc in &documents {
println!();
println!("{}", color("1;36", &format!("Document: {}", doc.id())));
print_signals(doc, doc.text(), 0);
}
}
}
}
Ok(())
}