Skip to main content

pii_masker/
lib.rs

1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3
4use candle_core::{Device, Tensor};
5use candle_nn::VarBuilder;
6use candle_transformers::models::debertav2::{
7    Config as DebertaV2Config, DTYPE, DebertaV2NERModel, Id2Label,
8};
9use hf_hub::{Repo, RepoType, api::sync::Api};
10use serde::Serialize;
11use thiserror::Error;
12use tokenizers::Tokenizer;
13
14const MODEL_REPO_ID: &str = "hydroxai/pii_model_weight";
15const MODEL_WEIGHTS_FILE: &str = "model.safetensors";
16const CONFIG_JSON: &str = include_str!("../assets/deberta3base_1024/config.json");
17const TOKENIZER_JSON: &[u8] = include_bytes!("../assets/deberta3base_1024/tokenizer.json");
18const WEIGHTS_ENV_VAR: &str = "PII_MASKER_MODEL_WEIGHTS";
19const MODEL_DIR_WEIGHTS_CANDIDATE: &str = "model/model.safetensors";
20
21pub type Result<T> = std::result::Result<T, MaskerError>;
22
23#[derive(Debug, Error)]
24pub enum MaskerError {
25    #[error("failed to parse model config: {0}")]
26    Config(#[from] serde_json::Error),
27    #[error("failed to read model weights: {0}")]
28    Io(#[from] std::io::Error),
29    #[error("tokenizer error: {0}")]
30    Tokenizer(String),
31    #[error("model error: {0}")]
32    Model(String),
33    #[error("missing id2label in model config")]
34    MissingId2Label,
35}
36
37#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
38pub struct PiiEntity {
39    pub label: String,
40    pub start: usize,
41    pub end: usize,
42    pub text: String,
43}
44
45#[derive(Debug, Clone, Serialize, PartialEq, Eq)]
46pub struct MaskResult {
47    pub masked_text: String,
48    pub pii: BTreeMap<String, Vec<String>>,
49}
50
51#[derive(Debug, Clone, Default)]
52pub struct PiiMaskerBuilder {
53    weights_path: Option<PathBuf>,
54}
55
56impl PiiMaskerBuilder {
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    pub fn weights_path(mut self, path: impl Into<PathBuf>) -> Self {
62        self.weights_path = Some(path.into());
63        self
64    }
65
66    pub fn build(self) -> Result<PiiMasker> {
67        let weights_path = match self.weights_path {
68            Some(path) => path,
69            None => default_weights_path()?,
70        };
71
72        PiiMasker::from_weights_path(weights_path)
73    }
74}
75
76pub struct PiiMasker {
77    tokenizer: Tokenizer,
78    model: DebertaV2NERModel,
79    id2label: Id2Label,
80    device: Device,
81    weights_path: PathBuf,
82}
83
84impl PiiMasker {
85    pub fn builder() -> PiiMaskerBuilder {
86        PiiMaskerBuilder::new()
87    }
88
89    pub fn new() -> Result<Self> {
90        Self::builder().build()
91    }
92
93    pub fn from_weights_path(path: impl Into<PathBuf>) -> Result<Self> {
94        let weights_path = path.into();
95        let config: DebertaV2Config = serde_json::from_str(CONFIG_JSON)?;
96        let id2label = config
97            .id2label
98            .clone()
99            .ok_or(MaskerError::MissingId2Label)?;
100        let tokenizer = Tokenizer::from_bytes(TOKENIZER_JSON)
101            .map_err(|err| MaskerError::Tokenizer(err.to_string()))?;
102        let device = Device::Cpu;
103        let vb = unsafe {
104            VarBuilder::from_mmaped_safetensors(&[&weights_path], DTYPE, &device)
105                .map_err(|err| MaskerError::Model(err.to_string()))?
106        };
107        let vb = vb.set_prefix("deberta");
108        let model = DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))
109            .map_err(|err| MaskerError::Model(err.to_string()))?;
110
111        Ok(Self {
112            tokenizer,
113            model,
114            id2label,
115            device,
116            weights_path,
117        })
118    }
119
120    pub fn weights_path(&self) -> &Path {
121        &self.weights_path
122    }
123
124    pub fn detect_pii(&self, input: &str) -> Result<Vec<PiiEntity>> {
125        let encoding = self
126            .tokenizer
127            .encode(input, true)
128            .map_err(|err| MaskerError::Tokenizer(err.to_string()))?;
129
130        let input_ids = Tensor::stack(
131            &[Tensor::new(encoding.get_ids(), &self.device)
132                .map_err(|err| MaskerError::Model(err.to_string()))?],
133            0,
134        )
135        .map_err(|err| MaskerError::Model(err.to_string()))?;
136        let attention_mask = Tensor::stack(
137            &[Tensor::new(encoding.get_attention_mask(), &self.device)
138                .map_err(|err| MaskerError::Model(err.to_string()))?],
139            0,
140        )
141        .map_err(|err| MaskerError::Model(err.to_string()))?;
142        let token_type_ids = Tensor::stack(
143            &[Tensor::new(encoding.get_type_ids(), &self.device)
144                .map_err(|err| MaskerError::Model(err.to_string()))?],
145            0,
146        )
147        .map_err(|err| MaskerError::Model(err.to_string()))?;
148
149        let logits = self
150            .model
151            .forward(&input_ids, Some(token_type_ids), Some(attention_mask))
152            .map_err(|err| MaskerError::Model(err.to_string()))?;
153        let predictions = logits
154            .argmax(2)
155            .map_err(|err| MaskerError::Model(err.to_string()))?
156            .to_vec2::<u32>()
157            .map_err(|err| MaskerError::Model(err.to_string()))?;
158
159        let labels = &predictions[0];
160        let special_mask = encoding.get_special_tokens_mask();
161        let offsets = encoding.get_offsets();
162
163        let mut entities = Vec::new();
164        let mut current: Option<(String, usize, usize)> = None;
165
166        for (index, label_id) in labels.iter().enumerate() {
167            if special_mask.get(index).copied().unwrap_or_default() == 1 {
168                continue;
169            }
170
171            let Some(&(start, end)) = offsets.get(index) else {
172                continue;
173            };
174            if start == end {
175                continue;
176            }
177
178            let raw_label = self
179                .id2label
180                .get(label_id)
181                .map(String::as_str)
182                .unwrap_or("O");
183            if raw_label == "O" {
184                flush_entity(&mut entities, &mut current, input);
185                continue;
186            }
187
188            let normalized_label = normalize_label(raw_label);
189            let can_extend = current.as_ref().is_some_and(|(label, _, current_end)| {
190                label == &normalized_label && start <= *current_end + 1
191            });
192
193            if can_extend {
194                if let Some((_, _, current_end)) = current.as_mut() {
195                    *current_end = end.max(*current_end);
196                }
197                continue;
198            }
199
200            flush_entity(&mut entities, &mut current, input);
201            current = Some((normalized_label, start, end));
202        }
203
204        flush_entity(&mut entities, &mut current, input);
205        Ok(entities)
206    }
207
208    pub fn mask(&self, input: &str) -> Result<MaskResult> {
209        let (masked_text, pii) = self.mask_pii(input)?;
210        Ok(MaskResult { masked_text, pii })
211    }
212
213    pub fn mask_pii(&self, input: &str) -> Result<(String, BTreeMap<String, Vec<String>>)> {
214        let entities = self.detect_pii(input)?;
215        let mut masked_text = String::with_capacity(input.len());
216        let mut pii = BTreeMap::<String, Vec<String>>::new();
217        let mut cursor = 0;
218
219        for entity in &entities {
220            masked_text.push_str(&input[cursor..entity.start]);
221            masked_text.push('[');
222            masked_text.push_str(&entity.label);
223            masked_text.push(']');
224            cursor = entity.end;
225
226            let values = pii.entry(entity.label.clone()).or_default();
227            if !values.iter().any(|value| value == &entity.text) {
228                values.push(entity.text.clone());
229            }
230        }
231
232        masked_text.push_str(&input[cursor..]);
233        Ok((masked_text, pii))
234    }
235}
236
237fn default_weights_path() -> Result<PathBuf> {
238    if let Ok(path) = std::env::var(WEIGHTS_ENV_VAR) {
239        return Ok(PathBuf::from(path));
240    }
241
242    let local_candidate =
243        PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(MODEL_DIR_WEIGHTS_CANDIDATE);
244    if local_candidate.exists() {
245        return Ok(local_candidate);
246    }
247
248    download_weights_from_hub()
249}
250
251fn download_weights_from_hub() -> Result<PathBuf> {
252    let api = Api::new().map_err(|err| MaskerError::Model(err.to_string()))?;
253    let repo = Repo::new(MODEL_REPO_ID.to_owned(), RepoType::Model);
254    let api = api.repo(repo);
255    api.get(MODEL_WEIGHTS_FILE)
256        .map_err(|err| MaskerError::Model(err.to_string()))
257}
258
259fn normalize_label(label: &str) -> String {
260    let cleaned = label
261        .strip_prefix("B-")
262        .or_else(|| label.strip_prefix("I-"))
263        .unwrap_or(label);
264
265    match cleaned {
266        "ID_NUM" => "ID".to_string(),
267        "NAME_STUDENT" => "NAME".to_string(),
268        "PHONE_NUM" => "PHONE".to_string(),
269        "STREET_ADDRESS" => "ADDRESS".to_string(),
270        "URL_PERSONAL" => "URL".to_string(),
271        other => other.to_string(),
272    }
273}
274
275fn flush_entity(
276    entities: &mut Vec<PiiEntity>,
277    current: &mut Option<(String, usize, usize)>,
278    input: &str,
279) {
280    let Some((label, start, end)) = current.take() else {
281        return;
282    };
283
284    let (start, end) = trim_span(input, start, end);
285    if start >= end {
286        return;
287    }
288
289    entities.push(PiiEntity {
290        label,
291        start,
292        end,
293        text: input[start..end].to_string(),
294    });
295}
296
297fn trim_span(input: &str, start: usize, end: usize) -> (usize, usize) {
298    let segment = &input[start..end];
299    let leading = segment.len() - segment.trim_start_matches(char::is_whitespace).len();
300    let trailing = segment.len() - segment.trim_end_matches(char::is_whitespace).len();
301    (start + leading, end - trailing)
302}
303
304#[cfg(test)]
305mod tests {
306    use super::{MODEL_DIR_WEIGHTS_CANDIDATE, PiiMaskerBuilder, normalize_label, trim_span};
307    use std::path::PathBuf;
308
309    const TEST_WEIGHTS_ENV_VAR: &str = "PII_MASKER_TEST_MODEL_WEIGHTS";
310
311    #[test]
312    fn normalizes_model_labels() {
313        assert_eq!(normalize_label("B-NAME_STUDENT"), "NAME");
314        assert_eq!(normalize_label("I-STREET_ADDRESS"), "ADDRESS");
315        assert_eq!(normalize_label("B-EMAIL"), "EMAIL");
316    }
317
318    #[test]
319    fn trims_surrounding_whitespace() {
320        let input = " hello ";
321        assert_eq!(trim_span(input, 0, input.len()), (1, 6));
322    }
323
324    #[test]
325    fn masks_with_local_model_weights() {
326        let Some(weights) = optional_test_weights() else {
327            eprintln!("Skipping model-backed test because no test weights were configured.");
328            return;
329        };
330
331        let masker = PiiMaskerBuilder::new()
332            .weights_path(weights)
333            .build()
334            .expect("load local model");
335
336        let result = masker
337            .mask("John Doe lives at 1234 Elm St.")
338            .expect("mask text");
339        assert_eq!(result.masked_text, "John Doe lives at [ADDRESS].");
340        assert_eq!(
341            result.pii.get("ADDRESS").expect("address label"),
342            &vec!["1234 Elm St".to_string()]
343        );
344    }
345
346    fn optional_test_weights() -> Option<PathBuf> {
347        if let Ok(path) = std::env::var(TEST_WEIGHTS_ENV_VAR) {
348            let path = PathBuf::from(path);
349            if path.exists() {
350                return Some(path);
351            }
352        }
353
354        let repo_local =
355            PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(MODEL_DIR_WEIGHTS_CANDIDATE);
356        if repo_local.exists() {
357            return Some(repo_local);
358        }
359
360        None
361    }
362}