use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct ModelConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub sliding_window: usize,
pub num_local_experts: usize,
pub num_experts_per_tok: usize,
pub rms_norm_eps: f64,
pub max_position_embeddings: usize,
#[serde(default = "default_attention_bias")]
pub attention_bias: bool,
#[serde(default)]
pub pad_token_id: Option<usize>,
pub rope_parameters: RopeParameters,
#[serde(default)]
pub id2label: Option<HashMap<String, String>>,
#[serde(default)]
pub label2id: Option<HashMap<String, usize>>,
}
fn default_attention_bias() -> bool {
true
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct RopeParameters {
#[serde(default = "default_rope_type")]
pub rope_type: String,
#[serde(default = "default_rope_theta")]
pub rope_theta: f64,
#[serde(default = "default_factor")]
pub factor: f64,
#[serde(default = "default_beta_fast")]
pub beta_fast: f64,
#[serde(default = "default_beta_slow")]
pub beta_slow: f64,
#[serde(default = "default_original_max_pos")]
pub original_max_position_embeddings: usize,
#[serde(default)]
pub truncate: bool,
}
fn default_rope_type() -> String { "yarn".to_string() }
fn default_rope_theta() -> f64 { 150_000.0 }
fn default_factor() -> f64 { 32.0 }
fn default_beta_fast() -> f64 { 32.0 }
fn default_beta_slow() -> f64 { 1.0 }
fn default_original_max_pos() -> usize { 4096 }
impl ModelConfig {
pub fn from_file(path: &Path) -> anyhow::Result<Self> {
let text = std::fs::read_to_string(path)?;
let cfg: Self = serde_json::from_str(&text)?;
Ok(cfg)
}
pub fn num_key_value_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads
}
pub fn num_labels(&self) -> usize {
if let Some(ref id2l) = self.id2label {
id2l.len()
} else {
33
}
}
}
#[derive(Debug, Clone)]
pub struct ViterbiConfig {
pub transition_bias_background_stay: f64,
pub transition_bias_background_to_start: f64,
pub transition_bias_inside_to_continue: f64,
pub transition_bias_inside_to_end: f64,
pub transition_bias_end_to_background: f64,
pub transition_bias_end_to_start: f64,
}
impl Default for ViterbiConfig {
fn default() -> Self {
Self {
transition_bias_background_stay: 0.0,
transition_bias_background_to_start: 0.0,
transition_bias_inside_to_continue: 0.0,
transition_bias_inside_to_end: 0.0,
transition_bias_end_to_background: 0.0,
transition_bias_end_to_start: 0.0,
}
}
}
#[derive(Debug, serde::Deserialize)]
struct ViterbiCalibrationFile {
operating_points: HashMap<String, ViterbiOperatingPoint>,
}
#[derive(Debug, serde::Deserialize)]
struct ViterbiOperatingPoint {
biases: ViterbiBiases,
}
#[derive(Debug, serde::Deserialize)]
struct ViterbiBiases {
transition_bias_background_stay: f64,
transition_bias_background_to_start: f64,
transition_bias_inside_to_continue: f64,
transition_bias_inside_to_end: f64,
transition_bias_end_to_background: f64,
transition_bias_end_to_start: f64,
}
impl ViterbiConfig {
pub fn from_file(path: &Path, operating_point: &str) -> anyhow::Result<Self> {
let text = std::fs::read_to_string(path)?;
let cal: ViterbiCalibrationFile = serde_json::from_str(&text)?;
let op = cal.operating_points.get(operating_point)
.ok_or_else(|| anyhow::anyhow!("operating point '{}' not found", operating_point))?;
Ok(Self {
transition_bias_background_stay: op.biases.transition_bias_background_stay,
transition_bias_background_to_start: op.biases.transition_bias_background_to_start,
transition_bias_inside_to_continue: op.biases.transition_bias_inside_to_continue,
transition_bias_inside_to_end: op.biases.transition_bias_inside_to_end,
transition_bias_end_to_background: op.biases.transition_bias_end_to_background,
transition_bias_end_to_start: op.biases.transition_bias_end_to_start,
})
}
}
pub const SPAN_LABELS: &[&str] = &[
"account_number",
"private_address",
"private_date",
"private_email",
"private_person",
"private_phone",
"private_url",
"secret",
];
pub const BIOES_PREFIXES: &[&str] = &["B", "I", "E", "S"];
pub fn build_label_list() -> Vec<String> {
let mut labels = vec!["O".to_string()];
for &cat in SPAN_LABELS {
for &prefix in BIOES_PREFIXES {
labels.push(format!("{}-{}", prefix, cat));
}
}
labels
}
pub fn label_to_category(label_idx: usize) -> Option<&'static str> {
if label_idx == 0 {
return None; }
let cat_idx = (label_idx - 1) / 4;
SPAN_LABELS.get(cat_idx).copied()
}
pub fn label_to_prefix(label_idx: usize) -> Option<&'static str> {
if label_idx == 0 {
return Some("O");
}
let prefix_idx = (label_idx - 1) % 4;
BIOES_PREFIXES.get(prefix_idx).copied()
}