Skip to main content

privacy_filter_rs/
inference.rs

1/// Top-level inference API.
2///
3/// Pipeline: text → tokenize → model forward → Viterbi decode → span extraction
4
5use std::path::Path;
6use std::time::Instant;
7
8use burn::prelude::*;
9
10use crate::config::{ModelConfig, ViterbiConfig};
11use crate::model::privacy_filter::PrivacyFilterModel;
12use crate::viterbi::{self, PrivacySpan};
13use crate::weights;
14
15/// The main inference engine.
16pub struct PrivacyFilterInference<B: Backend> {
17    pub model: PrivacyFilterModel<B>,
18    pub tokenizer: tokenizers::Tokenizer,
19    pub viterbi_config: ViterbiConfig,
20    pub device: B::Device,
21}
22
23impl<B: Backend> PrivacyFilterInference<B> {
24    /// Load the model, tokenizer, and Viterbi config from a model directory.
25    ///
26    /// The directory should contain:
27    ///   - config.json
28    ///   - model.safetensors
29    ///   - tokenizer.json
30    ///   - viterbi_calibration.json (optional)
31    pub fn load(model_dir: &Path, device: B::Device) -> anyhow::Result<Self> {
32        let config_path = model_dir.join("config.json");
33        let weights_path = model_dir.join("model.safetensors");
34        let tokenizer_path = model_dir.join("tokenizer.json");
35        let viterbi_path = model_dir.join("viterbi_calibration.json");
36
37        // Load config
38        let config = ModelConfig::from_file(&config_path)?;
39        eprintln!("Model config: {} layers, {} hidden, {} experts (top-{})",
40            config.num_hidden_layers, config.hidden_size,
41            config.num_local_experts, config.num_experts_per_tok);
42
43        // Load tokenizer
44        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
45            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
46        eprintln!("Tokenizer loaded ({} tokens)", tokenizer.get_vocab_size(false));
47
48        // Load Viterbi config (optional, defaults to all-zero biases)
49        let viterbi_config = if viterbi_path.exists() {
50            ViterbiConfig::from_file(&viterbi_path, "default")
51                .unwrap_or_default()
52        } else {
53            ViterbiConfig::default()
54        };
55
56        // Load model weights
57        let t0 = Instant::now();
58        let model = weights::load_model(
59            &config,
60            weights_path.to_str().unwrap(),
61            &device,
62        )?;
63        eprintln!("Weights loaded in {:.1}s", t0.elapsed().as_secs_f64());
64
65        Ok(Self {
66            model,
67            tokenizer,
68            viterbi_config,
69            device,
70        })
71    }
72
73    /// Run inference on a text string.
74    ///
75    /// Returns detected privacy spans with entity type, confidence, and text.
76    pub fn predict(&self, text: &str) -> anyhow::Result<Vec<PrivacySpan>> {
77        let t0 = Instant::now();
78
79        // 1. Tokenize
80        let encoding = self.tokenizer.encode(text, false)
81            .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
82        let input_ids = encoding.get_ids();
83        let tokens: Vec<String> = encoding.get_tokens().iter().map(|s| s.to_string()).collect();
84        let offsets: Vec<(usize, usize)> = encoding.get_offsets().to_vec();
85        let seq_len = input_ids.len();
86
87        eprintln!("Tokenized: {} tokens in {:.1}ms",
88            seq_len, t0.elapsed().as_secs_f64() * 1000.0);
89
90        // 2. Model forward pass
91        let t1 = Instant::now();
92        let logits = self.model.forward(input_ids, &self.device);
93        // [1, seq_len, num_labels]
94
95        let logits_data: Vec<f32> = logits.to_data().convert::<f32>().to_vec::<f32>().unwrap();
96        eprintln!("Forward pass: {:.1}ms", t1.elapsed().as_secs_f64() * 1000.0);
97
98        // 3. Viterbi decode
99        let t2 = Instant::now();
100        let label_path = viterbi::viterbi_decode(&logits_data, seq_len, &self.viterbi_config);
101        eprintln!("Viterbi decode: {:.1}ms", t2.elapsed().as_secs_f64() * 1000.0);
102
103        // 4. Extract spans
104        let spans = viterbi::extract_spans(&label_path, &logits_data, &tokens, &offsets, text);
105
106        eprintln!("Total: {:.1}ms, {} spans detected",
107            t0.elapsed().as_secs_f64() * 1000.0, spans.len());
108
109        Ok(spans)
110    }
111
112    /// Run inference and return raw logits (no Viterbi decoding).
113    ///
114    /// Returns logits as Vec<f32> of shape [seq_len, num_labels].
115    pub fn predict_logits(&self, text: &str) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
116        let encoding = self.tokenizer.encode(text, false)
117            .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
118        let input_ids: Vec<u32> = encoding.get_ids().to_vec();
119
120        let logits = self.model.forward(&input_ids, &self.device);
121        let logits_data: Vec<f32> = logits.to_data().convert::<f32>().to_vec::<f32>().unwrap();
122
123        Ok((input_ids, logits_data))
124    }
125
126    /// Run inference and return per-token argmax labels (no Viterbi).
127    pub fn predict_argmax(&self, text: &str) -> anyhow::Result<Vec<String>> {
128        let (_, logits_data) = self.predict_logits(text)?;
129        let labels = crate::config::build_label_list();
130        let num_labels = labels.len();
131        let seq_len = logits_data.len() / num_labels;
132
133        let mut result = Vec::with_capacity(seq_len);
134        for t in 0..seq_len {
135            let offset = t * num_labels;
136            let mut best_idx = 0;
137            let mut best_val = f32::NEG_INFINITY;
138            for l in 0..num_labels {
139                if logits_data[offset + l] > best_val {
140                    best_val = logits_data[offset + l];
141                    best_idx = l;
142                }
143            }
144            result.push(labels[best_idx].clone());
145        }
146
147        Ok(result)
148    }
149}