Skip to main content

privacy_filter_rs/
config.rs

1/// Model configuration parsed from HuggingFace `config.json`.
2
3use std::collections::HashMap;
4use std::path::Path;
5
6// ── ModelConfig ──────────────────────────────────────────────────────────────
7
8#[derive(Debug, Clone, serde::Deserialize)]
9pub struct ModelConfig {
10    pub vocab_size: usize,
11    pub hidden_size: usize,
12    pub intermediate_size: usize,
13    pub num_hidden_layers: usize,
14    pub num_attention_heads: usize,
15    pub num_key_value_heads: usize,
16    pub head_dim: usize,
17    pub sliding_window: usize,
18    pub num_local_experts: usize,
19    pub num_experts_per_tok: usize,
20    pub rms_norm_eps: f64,
21    pub max_position_embeddings: usize,
22    #[serde(default = "default_attention_bias")]
23    pub attention_bias: bool,
24    #[serde(default)]
25    pub pad_token_id: Option<usize>,
26    pub rope_parameters: RopeParameters,
27    #[serde(default)]
28    pub id2label: Option<HashMap<String, String>>,
29    #[serde(default)]
30    pub label2id: Option<HashMap<String, usize>>,
31}
32
33fn default_attention_bias() -> bool {
34    true
35}
36
37#[derive(Debug, Clone, serde::Deserialize)]
38pub struct RopeParameters {
39    #[serde(default = "default_rope_type")]
40    pub rope_type: String,
41    #[serde(default = "default_rope_theta")]
42    pub rope_theta: f64,
43    #[serde(default = "default_factor")]
44    pub factor: f64,
45    #[serde(default = "default_beta_fast")]
46    pub beta_fast: f64,
47    #[serde(default = "default_beta_slow")]
48    pub beta_slow: f64,
49    #[serde(default = "default_original_max_pos")]
50    pub original_max_position_embeddings: usize,
51    #[serde(default)]
52    pub truncate: bool,
53}
54
55fn default_rope_type() -> String { "yarn".to_string() }
56fn default_rope_theta() -> f64 { 150_000.0 }
57fn default_factor() -> f64 { 32.0 }
58fn default_beta_fast() -> f64 { 32.0 }
59fn default_beta_slow() -> f64 { 1.0 }
60fn default_original_max_pos() -> usize { 4096 }
61
62impl ModelConfig {
63    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
64        let text = std::fs::read_to_string(path)?;
65        let cfg: Self = serde_json::from_str(&text)?;
66        Ok(cfg)
67    }
68
69    /// Number of query heads per KV head group.
70    pub fn num_key_value_groups(&self) -> usize {
71        self.num_attention_heads / self.num_key_value_heads
72    }
73
74    /// Total number of output labels (33 for BIOES over 8 categories + O).
75    pub fn num_labels(&self) -> usize {
76        if let Some(ref id2l) = self.id2label {
77            id2l.len()
78        } else {
79            33
80        }
81    }
82}
83
84// ── ViterbiConfig ────────────────────────────────────────────────────────────
85
86#[derive(Debug, Clone)]
87pub struct ViterbiConfig {
88    pub transition_bias_background_stay: f64,
89    pub transition_bias_background_to_start: f64,
90    pub transition_bias_inside_to_continue: f64,
91    pub transition_bias_inside_to_end: f64,
92    pub transition_bias_end_to_background: f64,
93    pub transition_bias_end_to_start: f64,
94}
95
96impl Default for ViterbiConfig {
97    fn default() -> Self {
98        Self {
99            transition_bias_background_stay: 0.0,
100            transition_bias_background_to_start: 0.0,
101            transition_bias_inside_to_continue: 0.0,
102            transition_bias_inside_to_end: 0.0,
103            transition_bias_end_to_background: 0.0,
104            transition_bias_end_to_start: 0.0,
105        }
106    }
107}
108
109#[derive(Debug, serde::Deserialize)]
110struct ViterbiCalibrationFile {
111    operating_points: HashMap<String, ViterbiOperatingPoint>,
112}
113
114#[derive(Debug, serde::Deserialize)]
115struct ViterbiOperatingPoint {
116    biases: ViterbiBiases,
117}
118
119#[derive(Debug, serde::Deserialize)]
120struct ViterbiBiases {
121    transition_bias_background_stay: f64,
122    transition_bias_background_to_start: f64,
123    transition_bias_inside_to_continue: f64,
124    transition_bias_inside_to_end: f64,
125    transition_bias_end_to_background: f64,
126    transition_bias_end_to_start: f64,
127}
128
129impl ViterbiConfig {
130    pub fn from_file(path: &Path, operating_point: &str) -> anyhow::Result<Self> {
131        let text = std::fs::read_to_string(path)?;
132        let cal: ViterbiCalibrationFile = serde_json::from_str(&text)?;
133        let op = cal.operating_points.get(operating_point)
134            .ok_or_else(|| anyhow::anyhow!("operating point '{}' not found", operating_point))?;
135        Ok(Self {
136            transition_bias_background_stay: op.biases.transition_bias_background_stay,
137            transition_bias_background_to_start: op.biases.transition_bias_background_to_start,
138            transition_bias_inside_to_continue: op.biases.transition_bias_inside_to_continue,
139            transition_bias_inside_to_end: op.biases.transition_bias_inside_to_end,
140            transition_bias_end_to_background: op.biases.transition_bias_end_to_background,
141            transition_bias_end_to_start: op.biases.transition_bias_end_to_start,
142        })
143    }
144}
145
146// ── Label helpers ────────────────────────────────────────────────────────────
147
148/// The 8 span categories in order.
149pub const SPAN_LABELS: &[&str] = &[
150    "account_number",
151    "private_address",
152    "private_date",
153    "private_email",
154    "private_person",
155    "private_phone",
156    "private_url",
157    "secret",
158];
159
160/// BIOES tag prefixes.
161pub const BIOES_PREFIXES: &[&str] = &["B", "I", "E", "S"];
162
163/// Build the full 33-label list: O, B-account_number, I-account_number, E-account_number, S-account_number, ...
164pub fn build_label_list() -> Vec<String> {
165    let mut labels = vec!["O".to_string()];
166    for &cat in SPAN_LABELS {
167        for &prefix in BIOES_PREFIXES {
168            labels.push(format!("{}-{}", prefix, cat));
169        }
170    }
171    labels
172}
173
174/// Get the span category from a label index (0 = O, 1-4 = account_number, etc.)
175pub fn label_to_category(label_idx: usize) -> Option<&'static str> {
176    if label_idx == 0 {
177        return None; // O (background)
178    }
179    let cat_idx = (label_idx - 1) / 4;
180    SPAN_LABELS.get(cat_idx).copied()
181}
182
183/// Get the BIOES prefix for a label index.
184pub fn label_to_prefix(label_idx: usize) -> Option<&'static str> {
185    if label_idx == 0 {
186        return Some("O");
187    }
188    let prefix_idx = (label_idx - 1) % 4;
189    BIOES_PREFIXES.get(prefix_idx).copied()
190}