privacy_filter_rs/
inference.rs1use std::path::Path;
6use std::time::Instant;
7
8use burn::prelude::*;
9
10use crate::config::{ModelConfig, ViterbiConfig};
11use crate::model::privacy_filter::PrivacyFilterModel;
12use crate::viterbi::{self, PrivacySpan};
13use crate::weights;
14
15pub struct PrivacyFilterInference<B: Backend> {
17 pub model: PrivacyFilterModel<B>,
18 pub tokenizer: tokenizers::Tokenizer,
19 pub viterbi_config: ViterbiConfig,
20 pub device: B::Device,
21}
22
23impl<B: Backend> PrivacyFilterInference<B> {
24 pub fn load(model_dir: &Path, device: B::Device) -> anyhow::Result<Self> {
32 let config_path = model_dir.join("config.json");
33 let weights_path = model_dir.join("model.safetensors");
34 let tokenizer_path = model_dir.join("tokenizer.json");
35 let viterbi_path = model_dir.join("viterbi_calibration.json");
36
37 let config = ModelConfig::from_file(&config_path)?;
39 eprintln!("Model config: {} layers, {} hidden, {} experts (top-{})",
40 config.num_hidden_layers, config.hidden_size,
41 config.num_local_experts, config.num_experts_per_tok);
42
43 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
45 .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;
46 eprintln!("Tokenizer loaded ({} tokens)", tokenizer.get_vocab_size(false));
47
48 let viterbi_config = if viterbi_path.exists() {
50 ViterbiConfig::from_file(&viterbi_path, "default")
51 .unwrap_or_default()
52 } else {
53 ViterbiConfig::default()
54 };
55
56 let t0 = Instant::now();
58 let model = weights::load_model(
59 &config,
60 weights_path.to_str().unwrap(),
61 &device,
62 )?;
63 eprintln!("Weights loaded in {:.1}s", t0.elapsed().as_secs_f64());
64
65 Ok(Self {
66 model,
67 tokenizer,
68 viterbi_config,
69 device,
70 })
71 }
72
73 pub fn predict(&self, text: &str) -> anyhow::Result<Vec<PrivacySpan>> {
77 let t0 = Instant::now();
78
79 let encoding = self.tokenizer.encode(text, false)
81 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
82 let input_ids = encoding.get_ids();
83 let tokens: Vec<String> = encoding.get_tokens().iter().map(|s| s.to_string()).collect();
84 let offsets: Vec<(usize, usize)> = encoding.get_offsets().to_vec();
85 let seq_len = input_ids.len();
86
87 eprintln!("Tokenized: {} tokens in {:.1}ms",
88 seq_len, t0.elapsed().as_secs_f64() * 1000.0);
89
90 let t1 = Instant::now();
92 let logits = self.model.forward(input_ids, &self.device);
93 let logits_data: Vec<f32> = logits.to_data().convert::<f32>().to_vec::<f32>().unwrap();
96 eprintln!("Forward pass: {:.1}ms", t1.elapsed().as_secs_f64() * 1000.0);
97
98 let t2 = Instant::now();
100 let label_path = viterbi::viterbi_decode(&logits_data, seq_len, &self.viterbi_config);
101 eprintln!("Viterbi decode: {:.1}ms", t2.elapsed().as_secs_f64() * 1000.0);
102
103 let spans = viterbi::extract_spans(&label_path, &logits_data, &tokens, &offsets, text);
105
106 eprintln!("Total: {:.1}ms, {} spans detected",
107 t0.elapsed().as_secs_f64() * 1000.0, spans.len());
108
109 Ok(spans)
110 }
111
112 pub fn predict_logits(&self, text: &str) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
116 let encoding = self.tokenizer.encode(text, false)
117 .map_err(|e| anyhow::anyhow!("Tokenization failed: {}", e))?;
118 let input_ids: Vec<u32> = encoding.get_ids().to_vec();
119
120 let logits = self.model.forward(&input_ids, &self.device);
121 let logits_data: Vec<f32> = logits.to_data().convert::<f32>().to_vec::<f32>().unwrap();
122
123 Ok((input_ids, logits_data))
124 }
125
126 pub fn predict_argmax(&self, text: &str) -> anyhow::Result<Vec<String>> {
128 let (_, logits_data) = self.predict_logits(text)?;
129 let labels = crate::config::build_label_list();
130 let num_labels = labels.len();
131 let seq_len = logits_data.len() / num_labels;
132
133 let mut result = Vec::with_capacity(seq_len);
134 for t in 0..seq_len {
135 let offset = t * num_labels;
136 let mut best_idx = 0;
137 let mut best_val = f32::NEG_INFINITY;
138 for l in 0..num_labels {
139 if logits_data[offset + l] > best_val {
140 best_val = logits_data[offset + l];
141 best_idx = l;
142 }
143 }
144 result.push(labels[best_idx].clone());
145 }
146
147 Ok(result)
148 }
149}