use super::{ArchFamily, EncoderConfig, FfnVariant, HiddenAct};
use numr::dtype::DType;
impl EncoderConfig {
pub fn from_gguf_metadata(
metadata: &crate::format::GgufMetadata,
) -> crate::error::Result<Self> {
use crate::error::Error;
if metadata.get_string("general.architecture") == Some("gemma-embedding") {
return Self::from_gguf_metadata_gemma(metadata);
}
if metadata.get_string("general.architecture") == Some("nomic-bert") {
return Self::from_gguf_metadata_nomic(metadata);
}
let hidden_size =
metadata
.get_u32("bert.embedding_length")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing bert.embedding_length".into(),
})? as usize;
let intermediate_size = metadata
.get_u32("bert.feed_forward_length")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing bert.feed_forward_length".into(),
})? as usize;
let num_attention_heads =
metadata
.get_u32("bert.attention.head_count")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing bert.attention.head_count".into(),
})? as usize;
let num_hidden_layers =
metadata
.get_u32("bert.block_count")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing bert.block_count".into(),
})? as usize;
let max_position_embeddings =
metadata.get_u32("bert.context_length").unwrap_or(512) as usize;
let vocab_size = metadata
.get_array("tokenizer.ggml.tokens")
.map(|a| a.len())
.unwrap_or(30522);
let tokenizer_model = metadata
.get_string("tokenizer.ggml.model")
.unwrap_or("bert");
let (arch_family, padding_token_id) = if tokenizer_model == "t5" {
(ArchFamily::XlmRoberta, 1i64)
} else {
(ArchFamily::Bert, 0i64)
};
Ok(Self {
vocab_size,
hidden_size,
num_hidden_layers,
num_attention_heads,
intermediate_size,
max_position_embeddings,
layer_norm_eps: 1e-12,
hidden_act: HiddenAct::Gelu,
type_vocab_size: 0,
arch_family,
padding_token_id,
compute_dtype: DType::F32,
rope_freq_base: 10000.0,
causal: false,
ffn_variant: FfnVariant::Standard,
token_type_embed_size: 0,
num_kv_heads: 0,
head_dim_explicit: None,
rms_eps: 1e-6,
sliding_window: None,
embed_scale: false,
max_tokens_per_forward: None,
})
}
fn from_gguf_metadata_nomic(
metadata: &crate::format::GgufMetadata,
) -> crate::error::Result<Self> {
use crate::error::Error;
use crate::format::GgufValue;
let hidden_size = metadata
.get_u32("nomic-bert.embedding_length")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing nomic-bert.embedding_length".into(),
})? as usize;
let intermediate_size = metadata
.get_u32("nomic-bert.feed_forward_length")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing nomic-bert.feed_forward_length".into(),
})? as usize;
let num_attention_heads = metadata
.get_u32("nomic-bert.attention.head_count")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing nomic-bert.attention.head_count".into(),
})? as usize;
let num_hidden_layers =
metadata
.get_u32("nomic-bert.block_count")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing nomic-bert.block_count".into(),
})? as usize;
let max_position_embeddings = metadata
.get_u32("nomic-bert.context_length")
.unwrap_or(2048) as usize;
let layer_norm_eps = metadata
.get_f32("nomic-bert.attention.layer_norm_epsilon")
.map(|v| v as f64)
.unwrap_or(1e-12);
let rope_freq_base = metadata
.get_f32("nomic-bert.rope.freq_base")
.unwrap_or(10000.0);
let causal = metadata
.get("nomic-bert.attention.causal")
.and_then(|v| match v {
GgufValue::Bool(b) => Some(*b),
_ => None,
})
.unwrap_or(false);
if let Some(pt) = metadata.get_u32("nomic-bert.pooling_type")
&& pt != 1
{
return Err(Error::ModelError {
reason: format!(
"nomic-bert.pooling_type = {pt}; only mean pooling (1) is supported"
),
});
}
let vocab_size = metadata
.get_array("tokenizer.ggml.tokens")
.map(|a| a.len())
.unwrap_or(30522);
Ok(Self {
vocab_size,
hidden_size,
num_hidden_layers,
num_attention_heads,
intermediate_size,
max_position_embeddings,
layer_norm_eps,
hidden_act: HiddenAct::Gelu,
type_vocab_size: 2,
arch_family: ArchFamily::NomicBert,
padding_token_id: 0,
compute_dtype: DType::F32,
rope_freq_base,
causal,
ffn_variant: FfnVariant::GatedSilu,
token_type_embed_size: 2,
num_kv_heads: 0,
head_dim_explicit: None,
rms_eps: 1e-6,
sliding_window: None,
embed_scale: false,
max_tokens_per_forward: None,
})
}
fn from_gguf_metadata_gemma(
metadata: &crate::format::GgufMetadata,
) -> crate::error::Result<Self> {
use crate::error::Error;
let hidden_size = metadata
.get_u32("gemma-embedding.embedding_length")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing gemma-embedding.embedding_length".into(),
})? as usize;
let intermediate_size = metadata
.get_u32("gemma-embedding.feed_forward_length")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing gemma-embedding.feed_forward_length".into(),
})? as usize;
let num_attention_heads = metadata
.get_u32("gemma-embedding.attention.head_count")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing gemma-embedding.attention.head_count".into(),
})? as usize;
let num_kv_heads = metadata
.get_u32("gemma-embedding.attention.head_count_kv")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing gemma-embedding.attention.head_count_kv".into(),
})? as usize;
let head_dim_explicit = metadata
.get_u32("gemma-embedding.attention.key_length")
.map(|v| v as usize);
let num_hidden_layers =
metadata
.get_u32("gemma-embedding.block_count")
.ok_or_else(|| Error::ModelError {
reason: "GGUF missing gemma-embedding.block_count".into(),
})? as usize;
let max_position_embeddings = metadata
.get_u32("gemma-embedding.context_length")
.unwrap_or(8192) as usize;
let rms_eps = metadata
.get_f32("gemma-embedding.attention.layer_norm_rms_epsilon")
.map(|v| v as f64)
.unwrap_or(1e-6);
let sliding_window = metadata
.get_u32("gemma-embedding.attention.sliding_window")
.map(|v| v as usize);
let rope_freq_base = metadata
.get_f32("gemma-embedding.rope.freq_base")
.unwrap_or(10000.0);
if let Some(pt) = metadata.get_u32("gemma-embedding.pooling_type")
&& pt != 1
{
return Err(Error::ModelError {
reason: format!(
"gemma-embedding.pooling_type = {pt}; only mean pooling (1) is supported"
),
});
}
let vocab_size = metadata
.get_array("tokenizer.ggml.tokens")
.map(|a| a.len())
.unwrap_or(256000);
Ok(Self {
vocab_size,
hidden_size,
num_hidden_layers,
num_attention_heads,
intermediate_size,
max_position_embeddings,
layer_norm_eps: rms_eps,
hidden_act: HiddenAct::Gelu,
type_vocab_size: 0,
arch_family: ArchFamily::GemmaEmbedding,
padding_token_id: 0,
compute_dtype: DType::F32,
rope_freq_base,
causal: false,
ffn_variant: FfnVariant::GatedGelu,
token_type_embed_size: 0,
num_kv_heads,
head_dim_explicit,
rms_eps,
sliding_window,
embed_scale: true,
max_tokens_per_forward: None,
})
}
}