1use std::path::PathBuf;
8use std::time::Instant;
9use clap::Parser;
10use privacy_filter_rs::backend::{B, Device};
11use privacy_filter_rs::PrivacyFilterInference;
12
13#[derive(Parser)]
14struct Args {
15 #[arg(short = 'm', long)]
16 model_dir: PathBuf,
17
18 #[arg(short = 't', long, default_value = "0")]
19 threads: usize,
20
21 #[arg(long, default_value = "1")]
23 warmup: usize,
24
25 #[arg(long, default_value = "5")]
27 iters: usize,
28}
29
30fn main() -> anyhow::Result<()> {
31 let args = Args::parse();
32 let n = privacy_filter_rs::init_threads(Some(args.threads));
33 eprintln!("Using {n} threads\n");
34
35 let device = <Device as Default>::default();
36
37 eprintln!("Loading model...");
38 let t0 = Instant::now();
39 let engine = PrivacyFilterInference::<B>::load(&args.model_dir, device)?;
40 let load_ms = t0.elapsed().as_secs_f64() * 1000.0;
41 eprintln!("Model loaded in {load_ms:.0} ms\n");
42
43 let samples = [
44 "My name is Alice Smith",
45 "You can reach me at alice.smith@example.com or call 555-0123.",
46 "My account number is 4532-1234-5678-9012 and my password is hunter2.",
47 "Born on January 15, 1990, Alice visited https://secret-site.com/login.",
48 "The weather is nice today and the stock market went up.",
49 "My name is Harry Potter and my email is harry.potter@hogwarts.edu.",
50 ];
51
52 println!(
53 "{:<65} {:>6} {:>8} {:>8} {:>8}",
54 "Text", "Tokens", "Entities", "Avg ms", "Min ms"
55 );
56 println!("{}", "-".repeat(100));
57
58 let mut total_tokens = 0usize;
59 let mut total_avg = 0.0f64;
60 let mut total_min = 0.0f64;
61
62 for text in &samples {
63 for _ in 0..args.warmup {
65 let _ = engine.predict(text)?;
66 }
67
68 let mut times = Vec::with_capacity(args.iters);
70 let mut last_spans = Vec::new();
71 let mut last_tokens = 0;
72 for _ in 0..args.iters {
73 let t0 = Instant::now();
74 let spans = engine.predict(text)?;
75 let elapsed = t0.elapsed().as_secs_f64() * 1000.0;
76 times.push(elapsed);
77
78 if last_tokens == 0 {
80 let (ids, _) = engine.predict_logits(text)?;
81 last_tokens = ids.len();
82 }
83 last_spans = spans;
84 }
85
86 let avg_ms: f64 = times.iter().sum::<f64>() / times.len() as f64;
87 let min_ms: f64 = times.iter().cloned().fold(f64::INFINITY, f64::min);
88 let n_entities = last_spans.len();
89
90 let display_text = if text.len() > 60 {
91 format!("{}...", &text[..60])
92 } else {
93 text.to_string()
94 };
95
96 println!(
97 "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
98 display_text, last_tokens, n_entities, avg_ms, min_ms
99 );
100
101 total_tokens += last_tokens;
102 total_avg += avg_ms;
103 total_min += min_ms;
104 }
105
106 println!("{}", "-".repeat(100));
107 println!(
108 "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
109 "TOTAL", total_tokens, "", total_avg, total_min
110 );
111 println!(
112 "\nThroughput (avg): {:.0} tokens/sec",
113 total_tokens as f64 / (total_avg / 1000.0)
114 );
115 println!(
116 "Throughput (min): {:.0} tokens/sec",
117 total_tokens as f64 / (total_min / 1000.0)
118 );
119
120 Ok(())
121}