use numr::dtype::DType;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ArchFamily {
#[default]
Bert,
XlmRoberta,
NomicBert,
GemmaEmbedding,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum FfnVariant {
#[default]
Standard,
GatedSilu,
GatedGelu,
}
pub const DEFAULT_MAX_TOKENS_PER_FORWARD: usize = 16384;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EncoderConfig {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
#[serde(default = "default_eps")]
pub layer_norm_eps: f64,
#[serde(default)]
pub hidden_act: HiddenAct,
#[serde(default)]
pub type_vocab_size: usize,
#[serde(default)]
pub arch_family: ArchFamily,
#[serde(default)]
pub padding_token_id: i64,
#[serde(skip, default = "default_compute_dtype")]
pub compute_dtype: DType,
#[serde(default = "default_rope_freq_base")]
pub rope_freq_base: f32,
#[serde(default)]
pub causal: bool,
#[serde(default)]
pub ffn_variant: FfnVariant,
#[serde(default)]
pub token_type_embed_size: usize,
#[serde(default)]
pub num_kv_heads: usize,
#[serde(default)]
pub head_dim_explicit: Option<usize>,
#[serde(default = "default_rms_eps")]
pub rms_eps: f64,
#[serde(default)]
pub sliding_window: Option<usize>,
#[serde(default)]
pub embed_scale: bool,
#[serde(default)]
pub max_tokens_per_forward: Option<usize>,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum HiddenAct {
#[default]
Gelu,
Relu,
}
fn default_eps() -> f64 {
1e-12
}
fn default_compute_dtype() -> DType {
DType::F32
}
fn default_rope_freq_base() -> f32 {
10000.0
}
fn default_rms_eps() -> f64 {
1e-6
}
impl EncoderConfig {
pub fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
}
pub fn resolved_head_dim(&self) -> usize {
self.head_dim_explicit
.unwrap_or_else(|| self.hidden_size / self.num_attention_heads)
}
pub fn resolved_num_kv_heads(&self) -> usize {
if self.num_kv_heads == 0 {
self.num_attention_heads
} else {
self.num_kv_heads
}
}
}
#[path = "config_gguf.rs"]
mod gguf;