Skip to main content

bench/
bench.rs

1//! Benchmark: time Rust inference across sample texts.
2//!
3//! ```bash
4//! cargo run --example bench --release -- -m data
5//! ```
6
7use std::path::PathBuf;
8use std::time::Instant;
9use clap::Parser;
10use privacy_filter_rs::backend::{B, Device};
11use privacy_filter_rs::PrivacyFilterInference;
12
13#[derive(Parser)]
14struct Args {
15    #[arg(short = 'm', long)]
16    model_dir: PathBuf,
17
18    #[arg(short = 't', long, default_value = "0")]
19    threads: usize,
20
21    /// Number of warmup iterations per sample
22    #[arg(long, default_value = "1")]
23    warmup: usize,
24
25    /// Number of timed iterations per sample
26    #[arg(long, default_value = "5")]
27    iters: usize,
28}
29
30fn main() -> anyhow::Result<()> {
31    let args = Args::parse();
32    let n = privacy_filter_rs::init_threads(Some(args.threads));
33    eprintln!("Using {n} threads\n");
34
35    let device = <Device as Default>::default();
36
37    eprintln!("Loading model...");
38    let t0 = Instant::now();
39    let engine = PrivacyFilterInference::<B>::load(&args.model_dir, device)?;
40    let load_ms = t0.elapsed().as_secs_f64() * 1000.0;
41    eprintln!("Model loaded in {load_ms:.0} ms\n");
42
43    let samples = [
44        "My name is Alice Smith",
45        "You can reach me at alice.smith@example.com or call 555-0123.",
46        "My account number is 4532-1234-5678-9012 and my password is hunter2.",
47        "Born on January 15, 1990, Alice visited https://secret-site.com/login.",
48        "The weather is nice today and the stock market went up.",
49        "My name is Harry Potter and my email is harry.potter@hogwarts.edu.",
50    ];
51
52    println!(
53        "{:<65} {:>6} {:>8} {:>8} {:>8}",
54        "Text", "Tokens", "Entities", "Avg ms", "Min ms"
55    );
56    println!("{}", "-".repeat(100));
57
58    let mut total_tokens = 0usize;
59    let mut total_avg = 0.0f64;
60    let mut total_min = 0.0f64;
61
62    for text in &samples {
63        // Warmup
64        for _ in 0..args.warmup {
65            let _ = engine.predict(text)?;
66        }
67
68        // Timed runs
69        let mut times = Vec::with_capacity(args.iters);
70        let mut last_spans = Vec::new();
71        let mut last_tokens = 0;
72        for _ in 0..args.iters {
73            let t0 = Instant::now();
74            let spans = engine.predict(text)?;
75            let elapsed = t0.elapsed().as_secs_f64() * 1000.0;
76            times.push(elapsed);
77
78            // Get token count from logits call (only need once)
79            if last_tokens == 0 {
80                let (ids, _) = engine.predict_logits(text)?;
81                last_tokens = ids.len();
82            }
83            last_spans = spans;
84        }
85
86        let avg_ms: f64 = times.iter().sum::<f64>() / times.len() as f64;
87        let min_ms: f64 = times.iter().cloned().fold(f64::INFINITY, f64::min);
88        let n_entities = last_spans.len();
89
90        let display_text = if text.len() > 60 {
91            format!("{}...", &text[..60])
92        } else {
93            text.to_string()
94        };
95
96        println!(
97            "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
98            display_text, last_tokens, n_entities, avg_ms, min_ms
99        );
100
101        total_tokens += last_tokens;
102        total_avg += avg_ms;
103        total_min += min_ms;
104    }
105
106    println!("{}", "-".repeat(100));
107    println!(
108        "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
109        "TOTAL", total_tokens, "", total_avg, total_min
110    );
111    println!(
112        "\nThroughput (avg): {:.0} tokens/sec",
113        total_tokens as f64 / (total_avg / 1000.0)
114    );
115    println!(
116        "Throughput (min): {:.0} tokens/sec",
117        total_tokens as f64 / (total_min / 1000.0)
118    );
119
120    Ok(())
121}