Skip to main content

PrivacyFilterInference

Struct PrivacyFilterInference 

Source
pub struct PrivacyFilterInference<B: Backend> {
    pub model: PrivacyFilterModel<B>,
    pub tokenizer: Tokenizer,
    pub viterbi_config: ViterbiConfig,
    pub device: B::Device,
}
Expand description

The main inference engine.

Fields§

§model: PrivacyFilterModel<B>§tokenizer: Tokenizer§viterbi_config: ViterbiConfig§device: B::Device

Implementations§

Source§

impl<B: Backend> PrivacyFilterInference<B>

Source

pub fn load(model_dir: &Path, device: B::Device) -> Result<Self>

Load the model, tokenizer, and Viterbi config from a model directory.

The directory should contain:

  • config.json
  • model.safetensors
  • tokenizer.json
  • viterbi_calibration.json (optional)
Examples found in repository?
examples/infer.rs (lines 29-32)
20fn main() -> anyhow::Result<()> {
21    let args = Args::parse();
22
23    let n = privacy_filter_rs::init_threads(Some(args.threads));
24    eprintln!("Using {n} threads");
25
26    let device = <Device as Default>::default();
27
28    eprintln!("Loading model...");
29    let engine = privacy_filter_rs::PrivacyFilterInference::<B>::load(
30        &args.model_dir,
31        device,
32    )?;
33
34    let samples = [
35        "My name is Alice Smith and I live at 123 Main Street, Springfield.",
36        "You can reach me at alice.smith@example.com or call 555-0123.",
37        "My account number is 4532-1234-5678-9012 and my password is hunter2.",
38        "Born on January 15, 1990, Alice visited https://secret-site.com/login.",
39        "The weather is nice today and the stock market went up.",
40    ];
41
42    for text in &samples {
43        println!("\n--- Input: {text}");
44        let spans = engine.predict(text)?;
45        if spans.is_empty() {
46            println!("  No PII detected.");
47        } else {
48            for span in &spans {
49                println!(
50                    "  [{:>15}] {:<30} (score: {:.4}, chars {}..{})",
51                    span.entity_group, span.word, span.score, span.start, span.end
52                );
53            }
54        }
55    }
56
57    Ok(())
58}
More examples
Hide additional examples
examples/bench.rs (line 39)
30fn main() -> anyhow::Result<()> {
31    let args = Args::parse();
32    let n = privacy_filter_rs::init_threads(Some(args.threads));
33    eprintln!("Using {n} threads\n");
34
35    let device = <Device as Default>::default();
36
37    eprintln!("Loading model...");
38    let t0 = Instant::now();
39    let engine = PrivacyFilterInference::<B>::load(&args.model_dir, device)?;
40    let load_ms = t0.elapsed().as_secs_f64() * 1000.0;
41    eprintln!("Model loaded in {load_ms:.0} ms\n");
42
43    let samples = [
44        "My name is Alice Smith",
45        "You can reach me at alice.smith@example.com or call 555-0123.",
46        "My account number is 4532-1234-5678-9012 and my password is hunter2.",
47        "Born on January 15, 1990, Alice visited https://secret-site.com/login.",
48        "The weather is nice today and the stock market went up.",
49        "My name is Harry Potter and my email is harry.potter@hogwarts.edu.",
50    ];
51
52    println!(
53        "{:<65} {:>6} {:>8} {:>8} {:>8}",
54        "Text", "Tokens", "Entities", "Avg ms", "Min ms"
55    );
56    println!("{}", "-".repeat(100));
57
58    let mut total_tokens = 0usize;
59    let mut total_avg = 0.0f64;
60    let mut total_min = 0.0f64;
61
62    for text in &samples {
63        // Warmup
64        for _ in 0..args.warmup {
65            let _ = engine.predict(text)?;
66        }
67
68        // Timed runs
69        let mut times = Vec::with_capacity(args.iters);
70        let mut last_spans = Vec::new();
71        let mut last_tokens = 0;
72        for _ in 0..args.iters {
73            let t0 = Instant::now();
74            let spans = engine.predict(text)?;
75            let elapsed = t0.elapsed().as_secs_f64() * 1000.0;
76            times.push(elapsed);
77
78            // Get token count from logits call (only need once)
79            if last_tokens == 0 {
80                let (ids, _) = engine.predict_logits(text)?;
81                last_tokens = ids.len();
82            }
83            last_spans = spans;
84        }
85
86        let avg_ms: f64 = times.iter().sum::<f64>() / times.len() as f64;
87        let min_ms: f64 = times.iter().cloned().fold(f64::INFINITY, f64::min);
88        let n_entities = last_spans.len();
89
90        let display_text = if text.len() > 60 {
91            format!("{}...", &text[..60])
92        } else {
93            text.to_string()
94        };
95
96        println!(
97            "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
98            display_text, last_tokens, n_entities, avg_ms, min_ms
99        );
100
101        total_tokens += last_tokens;
102        total_avg += avg_ms;
103        total_min += min_ms;
104    }
105
106    println!("{}", "-".repeat(100));
107    println!(
108        "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
109        "TOTAL", total_tokens, "", total_avg, total_min
110    );
111    println!(
112        "\nThroughput (avg): {:.0} tokens/sec",
113        total_tokens as f64 / (total_avg / 1000.0)
114    );
115    println!(
116        "Throughput (min): {:.0} tokens/sec",
117        total_tokens as f64 / (total_min / 1000.0)
118    );
119
120    Ok(())
121}
Source

pub fn predict(&self, text: &str) -> Result<Vec<PrivacySpan>>

Run inference on a text string.

Returns detected privacy spans with entity type, confidence, and text.

Examples found in repository?
examples/infer.rs (line 44)
20fn main() -> anyhow::Result<()> {
21    let args = Args::parse();
22
23    let n = privacy_filter_rs::init_threads(Some(args.threads));
24    eprintln!("Using {n} threads");
25
26    let device = <Device as Default>::default();
27
28    eprintln!("Loading model...");
29    let engine = privacy_filter_rs::PrivacyFilterInference::<B>::load(
30        &args.model_dir,
31        device,
32    )?;
33
34    let samples = [
35        "My name is Alice Smith and I live at 123 Main Street, Springfield.",
36        "You can reach me at alice.smith@example.com or call 555-0123.",
37        "My account number is 4532-1234-5678-9012 and my password is hunter2.",
38        "Born on January 15, 1990, Alice visited https://secret-site.com/login.",
39        "The weather is nice today and the stock market went up.",
40    ];
41
42    for text in &samples {
43        println!("\n--- Input: {text}");
44        let spans = engine.predict(text)?;
45        if spans.is_empty() {
46            println!("  No PII detected.");
47        } else {
48            for span in &spans {
49                println!(
50                    "  [{:>15}] {:<30} (score: {:.4}, chars {}..{})",
51                    span.entity_group, span.word, span.score, span.start, span.end
52                );
53            }
54        }
55    }
56
57    Ok(())
58}
More examples
Hide additional examples
examples/bench.rs (line 65)
30fn main() -> anyhow::Result<()> {
31    let args = Args::parse();
32    let n = privacy_filter_rs::init_threads(Some(args.threads));
33    eprintln!("Using {n} threads\n");
34
35    let device = <Device as Default>::default();
36
37    eprintln!("Loading model...");
38    let t0 = Instant::now();
39    let engine = PrivacyFilterInference::<B>::load(&args.model_dir, device)?;
40    let load_ms = t0.elapsed().as_secs_f64() * 1000.0;
41    eprintln!("Model loaded in {load_ms:.0} ms\n");
42
43    let samples = [
44        "My name is Alice Smith",
45        "You can reach me at alice.smith@example.com or call 555-0123.",
46        "My account number is 4532-1234-5678-9012 and my password is hunter2.",
47        "Born on January 15, 1990, Alice visited https://secret-site.com/login.",
48        "The weather is nice today and the stock market went up.",
49        "My name is Harry Potter and my email is harry.potter@hogwarts.edu.",
50    ];
51
52    println!(
53        "{:<65} {:>6} {:>8} {:>8} {:>8}",
54        "Text", "Tokens", "Entities", "Avg ms", "Min ms"
55    );
56    println!("{}", "-".repeat(100));
57
58    let mut total_tokens = 0usize;
59    let mut total_avg = 0.0f64;
60    let mut total_min = 0.0f64;
61
62    for text in &samples {
63        // Warmup
64        for _ in 0..args.warmup {
65            let _ = engine.predict(text)?;
66        }
67
68        // Timed runs
69        let mut times = Vec::with_capacity(args.iters);
70        let mut last_spans = Vec::new();
71        let mut last_tokens = 0;
72        for _ in 0..args.iters {
73            let t0 = Instant::now();
74            let spans = engine.predict(text)?;
75            let elapsed = t0.elapsed().as_secs_f64() * 1000.0;
76            times.push(elapsed);
77
78            // Get token count from logits call (only need once)
79            if last_tokens == 0 {
80                let (ids, _) = engine.predict_logits(text)?;
81                last_tokens = ids.len();
82            }
83            last_spans = spans;
84        }
85
86        let avg_ms: f64 = times.iter().sum::<f64>() / times.len() as f64;
87        let min_ms: f64 = times.iter().cloned().fold(f64::INFINITY, f64::min);
88        let n_entities = last_spans.len();
89
90        let display_text = if text.len() > 60 {
91            format!("{}...", &text[..60])
92        } else {
93            text.to_string()
94        };
95
96        println!(
97            "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
98            display_text, last_tokens, n_entities, avg_ms, min_ms
99        );
100
101        total_tokens += last_tokens;
102        total_avg += avg_ms;
103        total_min += min_ms;
104    }
105
106    println!("{}", "-".repeat(100));
107    println!(
108        "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
109        "TOTAL", total_tokens, "", total_avg, total_min
110    );
111    println!(
112        "\nThroughput (avg): {:.0} tokens/sec",
113        total_tokens as f64 / (total_avg / 1000.0)
114    );
115    println!(
116        "Throughput (min): {:.0} tokens/sec",
117        total_tokens as f64 / (total_min / 1000.0)
118    );
119
120    Ok(())
121}
Source

pub fn predict_logits(&self, text: &str) -> Result<(Vec<u32>, Vec<f32>)>

Run inference and return raw logits (no Viterbi decoding).

Returns logits as Vec of shape [seq_len, num_labels].

Examples found in repository?
examples/bench.rs (line 80)
30fn main() -> anyhow::Result<()> {
31    let args = Args::parse();
32    let n = privacy_filter_rs::init_threads(Some(args.threads));
33    eprintln!("Using {n} threads\n");
34
35    let device = <Device as Default>::default();
36
37    eprintln!("Loading model...");
38    let t0 = Instant::now();
39    let engine = PrivacyFilterInference::<B>::load(&args.model_dir, device)?;
40    let load_ms = t0.elapsed().as_secs_f64() * 1000.0;
41    eprintln!("Model loaded in {load_ms:.0} ms\n");
42
43    let samples = [
44        "My name is Alice Smith",
45        "You can reach me at alice.smith@example.com or call 555-0123.",
46        "My account number is 4532-1234-5678-9012 and my password is hunter2.",
47        "Born on January 15, 1990, Alice visited https://secret-site.com/login.",
48        "The weather is nice today and the stock market went up.",
49        "My name is Harry Potter and my email is harry.potter@hogwarts.edu.",
50    ];
51
52    println!(
53        "{:<65} {:>6} {:>8} {:>8} {:>8}",
54        "Text", "Tokens", "Entities", "Avg ms", "Min ms"
55    );
56    println!("{}", "-".repeat(100));
57
58    let mut total_tokens = 0usize;
59    let mut total_avg = 0.0f64;
60    let mut total_min = 0.0f64;
61
62    for text in &samples {
63        // Warmup
64        for _ in 0..args.warmup {
65            let _ = engine.predict(text)?;
66        }
67
68        // Timed runs
69        let mut times = Vec::with_capacity(args.iters);
70        let mut last_spans = Vec::new();
71        let mut last_tokens = 0;
72        for _ in 0..args.iters {
73            let t0 = Instant::now();
74            let spans = engine.predict(text)?;
75            let elapsed = t0.elapsed().as_secs_f64() * 1000.0;
76            times.push(elapsed);
77
78            // Get token count from logits call (only need once)
79            if last_tokens == 0 {
80                let (ids, _) = engine.predict_logits(text)?;
81                last_tokens = ids.len();
82            }
83            last_spans = spans;
84        }
85
86        let avg_ms: f64 = times.iter().sum::<f64>() / times.len() as f64;
87        let min_ms: f64 = times.iter().cloned().fold(f64::INFINITY, f64::min);
88        let n_entities = last_spans.len();
89
90        let display_text = if text.len() > 60 {
91            format!("{}...", &text[..60])
92        } else {
93            text.to_string()
94        };
95
96        println!(
97            "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
98            display_text, last_tokens, n_entities, avg_ms, min_ms
99        );
100
101        total_tokens += last_tokens;
102        total_avg += avg_ms;
103        total_min += min_ms;
104    }
105
106    println!("{}", "-".repeat(100));
107    println!(
108        "{:<65} {:>6} {:>8} {:>8.1} {:>8.1}",
109        "TOTAL", total_tokens, "", total_avg, total_min
110    );
111    println!(
112        "\nThroughput (avg): {:.0} tokens/sec",
113        total_tokens as f64 / (total_avg / 1000.0)
114    );
115    println!(
116        "Throughput (min): {:.0} tokens/sec",
117        total_tokens as f64 / (total_min / 1000.0)
118    );
119
120    Ok(())
121}
Source

pub fn predict_argmax(&self, text: &str) -> Result<Vec<String>>

Run inference and return per-token argmax labels (no Viterbi).

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V