1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3
4use candle_core::{Device, Tensor};
5use candle_nn::VarBuilder;
6use candle_transformers::models::debertav2::{
7 Config as DebertaV2Config, DTYPE, DebertaV2NERModel, Id2Label,
8};
9use hf_hub::{Repo, RepoType, api::sync::Api};
10use serde::Serialize;
11use thiserror::Error;
12use tokenizers::Tokenizer;
13
14const MODEL_REPO_ID: &str = "hydroxai/pii_model_weight";
15const MODEL_WEIGHTS_FILE: &str = "model.safetensors";
16const CONFIG_JSON: &str = include_str!("../assets/deberta3base_1024/config.json");
17const TOKENIZER_JSON: &[u8] = include_bytes!("../assets/deberta3base_1024/tokenizer.json");
18const WEIGHTS_ENV_VAR: &str = "PII_MASKER_MODEL_WEIGHTS";
19const MODEL_DIR_WEIGHTS_CANDIDATE: &str = "model/model.safetensors";
20
21pub type Result<T> = std::result::Result<T, MaskerError>;
22
23#[derive(Debug, Error)]
24pub enum MaskerError {
25 #[error("failed to parse model config: {0}")]
26 Config(#[from] serde_json::Error),
27 #[error("failed to read model weights: {0}")]
28 Io(#[from] std::io::Error),
29 #[error("tokenizer error: {0}")]
30 Tokenizer(String),
31 #[error("model error: {0}")]
32 Model(String),
33 #[error("missing id2label in model config")]
34 MissingId2Label,
35}
36
37#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
38pub struct PiiEntity {
39 pub label: String,
40 pub start: usize,
41 pub end: usize,
42 pub text: String,
43}
44
45#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
46pub struct MaskResult {
47 pub masked_text: String,
48 pub pii: BTreeMap<String, Vec<String>>,
49}
50
51#[derive(Debug, Clone, Default)]
52pub struct PiiMaskerBuilder {
53 weights_path: Option<PathBuf>,
54}
55
56impl PiiMaskerBuilder {
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 pub fn weights_path(mut self, path: impl Into<PathBuf>) -> Self {
62 self.weights_path = Some(path.into());
63 self
64 }
65
66 pub fn build(self) -> Result<PiiMasker> {
67 let weights_path = match self.weights_path {
68 Some(path) => path,
69 None => default_weights_path()?,
70 };
71
72 PiiMasker::from_weights_path(weights_path)
73 }
74}
75
76pub struct PiiMasker {
77 tokenizer: Tokenizer,
78 model: DebertaV2NERModel,
79 id2label: Id2Label,
80 device: Device,
81 weights_path: PathBuf,
82}
83
84impl PiiMasker {
85 pub fn builder() -> PiiMaskerBuilder {
86 PiiMaskerBuilder::new()
87 }
88
89 pub fn new() -> Result<Self> {
90 Self::builder().build()
91 }
92
93 pub fn from_weights_path(path: impl Into<PathBuf>) -> Result<Self> {
94 let weights_path = path.into();
95 let config: DebertaV2Config = serde_json::from_str(CONFIG_JSON)?;
96 let id2label = config
97 .id2label
98 .clone()
99 .ok_or(MaskerError::MissingId2Label)?;
100 let tokenizer = Tokenizer::from_bytes(TOKENIZER_JSON)
101 .map_err(|err| MaskerError::Tokenizer(err.to_string()))?;
102 let device = Device::Cpu;
103 let vb = unsafe {
104 VarBuilder::from_mmaped_safetensors(&[&weights_path], DTYPE, &device)
105 .map_err(|err| MaskerError::Model(err.to_string()))?
106 };
107 let vb = vb.set_prefix("deberta");
108 let model = DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))
109 .map_err(|err| MaskerError::Model(err.to_string()))?;
110
111 Ok(Self {
112 tokenizer,
113 model,
114 id2label,
115 device,
116 weights_path,
117 })
118 }
119
120 pub fn weights_path(&self) -> &Path {
121 &self.weights_path
122 }
123
124 pub fn detect_pii(&self, input: &str) -> Result<Vec<PiiEntity>> {
125 let encoding = self
126 .tokenizer
127 .encode(input, true)
128 .map_err(|err| MaskerError::Tokenizer(err.to_string()))?;
129
130 let input_ids = Tensor::stack(
131 &[Tensor::new(encoding.get_ids(), &self.device)
132 .map_err(|err| MaskerError::Model(err.to_string()))?],
133 0,
134 )
135 .map_err(|err| MaskerError::Model(err.to_string()))?;
136 let attention_mask = Tensor::stack(
137 &[Tensor::new(encoding.get_attention_mask(), &self.device)
138 .map_err(|err| MaskerError::Model(err.to_string()))?],
139 0,
140 )
141 .map_err(|err| MaskerError::Model(err.to_string()))?;
142 let token_type_ids = Tensor::stack(
143 &[Tensor::new(encoding.get_type_ids(), &self.device)
144 .map_err(|err| MaskerError::Model(err.to_string()))?],
145 0,
146 )
147 .map_err(|err| MaskerError::Model(err.to_string()))?;
148
149 let logits = self
150 .model
151 .forward(&input_ids, Some(token_type_ids), Some(attention_mask))
152 .map_err(|err| MaskerError::Model(err.to_string()))?;
153 let predictions = logits
154 .argmax(2)
155 .map_err(|err| MaskerError::Model(err.to_string()))?
156 .to_vec2::<u32>()
157 .map_err(|err| MaskerError::Model(err.to_string()))?;
158
159 let labels = &predictions[0];
160 let special_mask = encoding.get_special_tokens_mask();
161 let offsets = encoding.get_offsets();
162
163 let mut entities = Vec::new();
164 let mut current: Option<(String, usize, usize)> = None;
165
166 for (index, label_id) in labels.iter().enumerate() {
167 if special_mask.get(index).copied().unwrap_or_default() == 1 {
168 continue;
169 }
170
171 let Some(&(start, end)) = offsets.get(index) else {
172 continue;
173 };
174 if start == end {
175 continue;
176 }
177
178 let raw_label = self
179 .id2label
180 .get(label_id)
181 .map(String::as_str)
182 .unwrap_or("O");
183 if raw_label == "O" {
184 flush_entity(&mut entities, &mut current, input);
185 continue;
186 }
187
188 let normalized_label = normalize_label(raw_label);
189 let can_extend = current.as_ref().is_some_and(|(label, _, current_end)| {
190 label == &normalized_label && start <= *current_end + 1
191 });
192
193 if can_extend {
194 if let Some((_, _, current_end)) = current.as_mut() {
195 *current_end = end.max(*current_end);
196 }
197 continue;
198 }
199
200 flush_entity(&mut entities, &mut current, input);
201 current = Some((normalized_label, start, end));
202 }
203
204 flush_entity(&mut entities, &mut current, input);
205 Ok(entities)
206 }
207
208 pub fn mask(&self, input: &str) -> Result<MaskResult> {
209 let (masked_text, pii) = self.mask_pii(input)?;
210 Ok(MaskResult { masked_text, pii })
211 }
212
213 pub fn mask_pii(&self, input: &str) -> Result<(String, BTreeMap<String, Vec<String>>)> {
214 let entities = self.detect_pii(input)?;
215 let mut masked_text = String::with_capacity(input.len());
216 let mut pii = BTreeMap::<String, Vec<String>>::new();
217 let mut cursor = 0;
218
219 for entity in &entities {
220 masked_text.push_str(&input[cursor..entity.start]);
221 masked_text.push('[');
222 masked_text.push_str(&entity.label);
223 masked_text.push(']');
224 cursor = entity.end;
225
226 let values = pii.entry(entity.label.clone()).or_default();
227 if !values.iter().any(|value| value == &entity.text) {
228 values.push(entity.text.clone());
229 }
230 }
231
232 masked_text.push_str(&input[cursor..]);
233 Ok((masked_text, pii))
234 }
235}
236
237fn default_weights_path() -> Result<PathBuf> {
238 if let Ok(path) = std::env::var(WEIGHTS_ENV_VAR) {
239 return Ok(PathBuf::from(path));
240 }
241
242 let local_candidate =
243 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(MODEL_DIR_WEIGHTS_CANDIDATE);
244 if local_candidate.exists() {
245 return Ok(local_candidate);
246 }
247
248 download_weights_from_hub()
249}
250
251fn download_weights_from_hub() -> Result<PathBuf> {
252 let api = Api::new().map_err(|err| MaskerError::Model(err.to_string()))?;
253 let repo = Repo::new(MODEL_REPO_ID.to_owned(), RepoType::Model);
254 let api = api.repo(repo);
255 api.get(MODEL_WEIGHTS_FILE)
256 .map_err(|err| MaskerError::Model(err.to_string()))
257}
258
259fn normalize_label(label: &str) -> String {
260 let cleaned = label
261 .strip_prefix("B-")
262 .or_else(|| label.strip_prefix("I-"))
263 .unwrap_or(label);
264
265 match cleaned {
266 "ID_NUM" => "ID".to_string(),
267 "NAME_STUDENT" => "NAME".to_string(),
268 "PHONE_NUM" => "PHONE".to_string(),
269 "STREET_ADDRESS" => "ADDRESS".to_string(),
270 "URL_PERSONAL" => "URL".to_string(),
271 other => other.to_string(),
272 }
273}
274
275fn flush_entity(
276 entities: &mut Vec<PiiEntity>,
277 current: &mut Option<(String, usize, usize)>,
278 input: &str,
279) {
280 let Some((label, start, end)) = current.take() else {
281 return;
282 };
283
284 let (start, end) = trim_span(input, start, end);
285 if start >= end {
286 return;
287 }
288
289 entities.push(PiiEntity {
290 label,
291 start,
292 end,
293 text: input[start..end].to_string(),
294 });
295}
296
297fn trim_span(input: &str, start: usize, end: usize) -> (usize, usize) {
298 let segment = &input[start..end];
299 let leading = segment.len() - segment.trim_start_matches(char::is_whitespace).len();
300 let trailing = segment.len() - segment.trim_end_matches(char::is_whitespace).len();
301 (start + leading, end - trailing)
302}
303
304#[cfg(test)]
305mod tests {
306 use super::{MODEL_DIR_WEIGHTS_CANDIDATE, PiiMaskerBuilder, normalize_label, trim_span};
307 use std::path::PathBuf;
308
309 const TEST_WEIGHTS_ENV_VAR: &str = "PII_MASKER_TEST_MODEL_WEIGHTS";
310
311 #[test]
312 fn normalizes_model_labels() {
313 assert_eq!(normalize_label("B-NAME_STUDENT"), "NAME");
314 assert_eq!(normalize_label("I-STREET_ADDRESS"), "ADDRESS");
315 assert_eq!(normalize_label("B-EMAIL"), "EMAIL");
316 }
317
318 #[test]
319 fn trims_surrounding_whitespace() {
320 let input = " hello ";
321 assert_eq!(trim_span(input, 0, input.len()), (1, 6));
322 }
323
324 #[test]
325 fn masks_with_local_model_weights() {
326 let Some(weights) = optional_test_weights() else {
327 eprintln!("Skipping model-backed test because no test weights were configured.");
328 return;
329 };
330
331 let masker = PiiMaskerBuilder::new()
332 .weights_path(weights)
333 .build()
334 .expect("load local model");
335
336 let result = masker
337 .mask("John Doe lives at 1234 Elm St.")
338 .expect("mask text");
339 assert_eq!(result.masked_text, "John Doe lives at [ADDRESS].");
340 assert_eq!(
341 result.pii.get("ADDRESS").expect("address label"),
342 &vec!["1234 Elm St".to_string()]
343 );
344 }
345
346 fn optional_test_weights() -> Option<PathBuf> {
347 if let Ok(path) = std::env::var(TEST_WEIGHTS_ENV_VAR) {
348 let path = PathBuf::from(path);
349 if path.exists() {
350 return Some(path);
351 }
352 }
353
354 let repo_local =
355 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(MODEL_DIR_WEIGHTS_CANDIDATE);
356 if repo_local.exists() {
357 return Some(repo_local);
358 }
359
360 None
361 }
362}