use clap::Parser;
use std::fs;
use std::io::{self, Read};
use super::super::output::{color, print_signals};
use super::super::parser::OutputFormat;
use super::super::utils::{link_tracks_to_kb, log_success, resolve_coreference};
#[cfg(feature = "graph")]
use anno::graph::{GraphDocument, GraphExportFormat};
use anno::{GroundedDocument, SignalId};
#[derive(Parser, Debug)]
pub struct EnhanceArgs {
#[arg(value_name = "FILE")]
pub input: String,
#[arg(long)]
pub coref: bool,
#[arg(long)]
pub link_kb: bool,
#[arg(short, long, value_name = "PATH")]
pub export: Option<String>,
#[arg(long, default_value = "full", value_name = "FORMAT")]
pub export_format: String,
#[arg(long, default_value = "human")]
pub format: OutputFormat,
#[arg(short, long)]
pub quiet: bool,
#[arg(long, value_name = "FORMAT")]
pub export_graph: Option<String>,
}
pub fn run(args: EnhanceArgs) -> Result<(), String> {
let json_content = if args.input == "-" {
let mut buf = String::new();
io::stdin()
.read_to_string(&mut buf)
.map_err(|e| format!("Failed to read stdin: {}", e))?;
buf
} else {
fs::read_to_string(&args.input)
.map_err(|e| format!("Failed to read {}: {}", args.input, e))?
};
let mut doc: GroundedDocument = serde_json::from_str(&json_content)
.map_err(|e| format!("Failed to parse GroundedDocument JSON: {}", e))?;
let signal_ids: Vec<SignalId> = doc.signals().iter().map(|s| s.id).collect();
if args.coref {
let text = doc.text().to_owned();
resolve_coreference(&mut doc, &text, &signal_ids);
log_success("Applied coreference resolution", args.quiet);
}
if args.link_kb {
link_tracks_to_kb(&mut doc);
log_success("Applied KB linking", args.quiet);
}
if let Some(export_path) = args.export {
let export_data = match args.export_format.as_str() {
"full" => serde_json::to_value(&doc)
.map_err(|e| format!("Failed to serialize GroundedDocument: {}", e))?,
"signals" => {
let signals: Vec<_> = doc.signals().to_vec();
serde_json::json!({
"id": doc.id(),
"text": doc.text(),
"signals": signals
})
}
"minimal" => {
let signals: Vec<_> = doc
.signals()
.iter()
.map(|s| {
let (start, end) = s.text_offsets().unwrap_or((0, 0));
serde_json::json!({
"surface": s.surface(),
"label": s.label(),
"start": start,
"end": end,
"confidence": s.confidence
})
})
.collect();
serde_json::json!({
"id": doc.id(),
"text": doc.text(),
"signals": signals
})
}
_ => {
return Err(format!(
"Invalid export format '{}'. Use: full, signals, or minimal",
args.export_format
));
}
};
let json = serde_json::to_string_pretty(&export_data)
.map_err(|e| format!("Failed to serialize export data: {}", e))?;
if let Some(parent) = std::path::Path::new(&export_path).parent() {
if !parent.exists() {
fs::create_dir_all(parent).map_err(|e| {
format!(
"Failed to create directory for export file '{}': {}",
export_path, e
)
})?;
}
}
fs::write(&export_path, json)
.map_err(|e| format!("Failed to write export file '{}': {}", export_path, e))?;
if !args.quiet {
eprintln!(
"{} Exported {} format to {}",
color("32", "✓"),
args.export_format,
export_path
);
}
}
match args.format {
OutputFormat::Grounded | OutputFormat::Json => {
println!("{}", serde_json::to_string_pretty(&doc).unwrap_or_default());
}
OutputFormat::Human => {
if !args.quiet {
let stats = doc.stats();
println!();
println!("{}", color("1;36", "Enhanced Document"));
println!(" Signals: {}", stats.signal_count);
println!(" Tracks: {}", stats.track_count);
println!(" Identities: {}", stats.identity_count);
println!();
}
print_signals(&doc, doc.text(), 0);
}
_ => {
return Err(format!(
"Format {:?} not supported for enhance command",
args.format
));
}
}
if let Some(graph_format_str) = args.export_graph {
#[cfg(not(feature = "graph"))]
{
let _ = graph_format_str;
return Err("Graph export requires the 'graph' feature to be enabled.".to_string());
}
#[cfg(feature = "graph")]
{
let graph_format = match graph_format_str.to_lowercase().as_str() {
"neo4j" | "cypher" => GraphExportFormat::Cypher,
"networkx" | "nx" => GraphExportFormat::NetworkXJson,
"jsonld" | "json-ld" => GraphExportFormat::JsonLd,
_ => {
return Err(format!(
"Invalid graph format '{}'. Use: neo4j, networkx, or jsonld",
graph_format_str
));
}
};
let graph = anno::graph::grounded_to_graph_document(&doc);
let graph_output = graph.export(graph_format);
if !args.quiet {
eprintln!(
"{} Exported graph ({} nodes, {} edges) in {} format",
color("32", "✓"),
graph.node_count(),
graph.edge_count(),
graph_format_str
);
}
println!("{}", graph_output);
}
}
Ok(())
}