use std::path::PathBuf;
use clap::Parser;
use privacy_filter_rs::backend::{B, Device};
#[derive(Parser, Debug)]
#[command(name = "privacy-filter")]
#[command(about = "OpenAI Privacy Filter — PII detection in Rust")]
struct Args {
#[arg(short = 'm', long)]
model_dir: PathBuf,
#[arg()]
text: Option<String>,
#[arg(short = 't', long, default_value = "0")]
threads: usize,
#[arg(short = 'f', long, default_value = "spans")]
format: String,
#[arg(long, default_value = "default")]
operating_point: String,
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
let n = privacy_filter_rs::init_threads(Some(args.threads));
eprintln!("Using {n} threads");
let device = <Device as Default>::default();
let engine = privacy_filter_rs::PrivacyFilterInference::<B>::load(
&args.model_dir,
device,
)?;
let text = if let Some(t) = args.text {
t
} else {
let mut buf = String::new();
std::io::Read::read_to_string(&mut std::io::stdin(), &mut buf)?;
buf.trim_end().to_string()
};
if text.is_empty() {
eprintln!("No input text provided.");
return Ok(());
}
match args.format.as_str() {
"spans" => {
let spans = engine.predict(&text)?;
print!("[");
for (i, span) in spans.iter().enumerate() {
if i > 0 {
print!(", ");
}
print!(
"{{\"entity_group\": \"{}\", \"score\": {:.6}, \"word\": {}, \"start\": {}, \"end\": {}}}",
span.entity_group,
span.score,
serde_json::to_string(&span.word)?,
span.start,
span.end,
);
}
println!("]");
}
"labels" => {
let labels = engine.predict_argmax(&text)?;
for label in &labels {
println!("{label}");
}
}
"logits" => {
let (ids, logits) = engine.predict_logits(&text)?;
let num_labels = 33;
println!("token_id\tlogits");
for (t, &id) in ids.iter().enumerate() {
let offset = t * num_labels;
let logit_str: Vec<String> = logits[offset..offset + num_labels]
.iter()
.map(|v| format!("{v:.4}"))
.collect();
println!("{id}\t{}", logit_str.join("\t"));
}
}
other => {
anyhow::bail!("Unknown format: {other}. Use 'spans', 'labels', or 'logits'.");
}
}
Ok(())
}