privacy_filter_rs/
config.rs1use std::collections::HashMap;
4use std::path::Path;
5
6#[derive(Debug, Clone, serde::Deserialize)]
9pub struct ModelConfig {
10 pub vocab_size: usize,
11 pub hidden_size: usize,
12 pub intermediate_size: usize,
13 pub num_hidden_layers: usize,
14 pub num_attention_heads: usize,
15 pub num_key_value_heads: usize,
16 pub head_dim: usize,
17 pub sliding_window: usize,
18 pub num_local_experts: usize,
19 pub num_experts_per_tok: usize,
20 pub rms_norm_eps: f64,
21 pub max_position_embeddings: usize,
22 #[serde(default = "default_attention_bias")]
23 pub attention_bias: bool,
24 #[serde(default)]
25 pub pad_token_id: Option<usize>,
26 pub rope_parameters: RopeParameters,
27 #[serde(default)]
28 pub id2label: Option<HashMap<String, String>>,
29 #[serde(default)]
30 pub label2id: Option<HashMap<String, usize>>,
31}
32
33fn default_attention_bias() -> bool {
34 true
35}
36
37#[derive(Debug, Clone, serde::Deserialize)]
38pub struct RopeParameters {
39 #[serde(default = "default_rope_type")]
40 pub rope_type: String,
41 #[serde(default = "default_rope_theta")]
42 pub rope_theta: f64,
43 #[serde(default = "default_factor")]
44 pub factor: f64,
45 #[serde(default = "default_beta_fast")]
46 pub beta_fast: f64,
47 #[serde(default = "default_beta_slow")]
48 pub beta_slow: f64,
49 #[serde(default = "default_original_max_pos")]
50 pub original_max_position_embeddings: usize,
51 #[serde(default)]
52 pub truncate: bool,
53}
54
55fn default_rope_type() -> String { "yarn".to_string() }
56fn default_rope_theta() -> f64 { 150_000.0 }
57fn default_factor() -> f64 { 32.0 }
58fn default_beta_fast() -> f64 { 32.0 }
59fn default_beta_slow() -> f64 { 1.0 }
60fn default_original_max_pos() -> usize { 4096 }
61
62impl ModelConfig {
63 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
64 let text = std::fs::read_to_string(path)?;
65 let cfg: Self = serde_json::from_str(&text)?;
66 Ok(cfg)
67 }
68
69 pub fn num_key_value_groups(&self) -> usize {
71 self.num_attention_heads / self.num_key_value_heads
72 }
73
74 pub fn num_labels(&self) -> usize {
76 if let Some(ref id2l) = self.id2label {
77 id2l.len()
78 } else {
79 33
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
87pub struct ViterbiConfig {
88 pub transition_bias_background_stay: f64,
89 pub transition_bias_background_to_start: f64,
90 pub transition_bias_inside_to_continue: f64,
91 pub transition_bias_inside_to_end: f64,
92 pub transition_bias_end_to_background: f64,
93 pub transition_bias_end_to_start: f64,
94}
95
96impl Default for ViterbiConfig {
97 fn default() -> Self {
98 Self {
99 transition_bias_background_stay: 0.0,
100 transition_bias_background_to_start: 0.0,
101 transition_bias_inside_to_continue: 0.0,
102 transition_bias_inside_to_end: 0.0,
103 transition_bias_end_to_background: 0.0,
104 transition_bias_end_to_start: 0.0,
105 }
106 }
107}
108
109#[derive(Debug, serde::Deserialize)]
110struct ViterbiCalibrationFile {
111 operating_points: HashMap<String, ViterbiOperatingPoint>,
112}
113
114#[derive(Debug, serde::Deserialize)]
115struct ViterbiOperatingPoint {
116 biases: ViterbiBiases,
117}
118
119#[derive(Debug, serde::Deserialize)]
120struct ViterbiBiases {
121 transition_bias_background_stay: f64,
122 transition_bias_background_to_start: f64,
123 transition_bias_inside_to_continue: f64,
124 transition_bias_inside_to_end: f64,
125 transition_bias_end_to_background: f64,
126 transition_bias_end_to_start: f64,
127}
128
129impl ViterbiConfig {
130 pub fn from_file(path: &Path, operating_point: &str) -> anyhow::Result<Self> {
131 let text = std::fs::read_to_string(path)?;
132 let cal: ViterbiCalibrationFile = serde_json::from_str(&text)?;
133 let op = cal.operating_points.get(operating_point)
134 .ok_or_else(|| anyhow::anyhow!("operating point '{}' not found", operating_point))?;
135 Ok(Self {
136 transition_bias_background_stay: op.biases.transition_bias_background_stay,
137 transition_bias_background_to_start: op.biases.transition_bias_background_to_start,
138 transition_bias_inside_to_continue: op.biases.transition_bias_inside_to_continue,
139 transition_bias_inside_to_end: op.biases.transition_bias_inside_to_end,
140 transition_bias_end_to_background: op.biases.transition_bias_end_to_background,
141 transition_bias_end_to_start: op.biases.transition_bias_end_to_start,
142 })
143 }
144}
145
146pub const SPAN_LABELS: &[&str] = &[
150 "account_number",
151 "private_address",
152 "private_date",
153 "private_email",
154 "private_person",
155 "private_phone",
156 "private_url",
157 "secret",
158];
159
160pub const BIOES_PREFIXES: &[&str] = &["B", "I", "E", "S"];
162
163pub fn build_label_list() -> Vec<String> {
165 let mut labels = vec!["O".to_string()];
166 for &cat in SPAN_LABELS {
167 for &prefix in BIOES_PREFIXES {
168 labels.push(format!("{}-{}", prefix, cat));
169 }
170 }
171 labels
172}
173
174pub fn label_to_category(label_idx: usize) -> Option<&'static str> {
176 if label_idx == 0 {
177 return None; }
179 let cat_idx = (label_idx - 1) / 4;
180 SPAN_LABELS.get(cat_idx).copied()
181}
182
183pub fn label_to_prefix(label_idx: usize) -> Option<&'static str> {
185 if label_idx == 0 {
186 return Some("O");
187 }
188 let prefix_idx = (label_idx - 1) % 4;
189 BIOES_PREFIXES.get(prefix_idx).copied()
190}