impl AprV2Model {
fn resolve_rope_params(&self) -> (f32, u32) {
let arch = self.metadata.architecture.as_deref().unwrap_or("unknown");
let rope_theta = self.metadata.rope_theta.unwrap_or_else(||
crate::gguf::default_rope_theta_for_architecture(arch));
let rope_type = self.metadata.rope_type.unwrap_or_else(||
crate::gguf::infer_rope_type(arch));
(rope_theta, rope_type)
}
fn embed_tokens(
&self,
token_ids: &[u32],
hidden_dim: usize,
) -> Result<Vec<f32>> {
let embed_name = self.find_tensor_name(&[
"model.embed_tokens.weight",
"embed_tokens.weight", "transformer.wte.weight",
"embeddings.word_embeddings.weight",
"tok_embeddings.weight",
"token_embd.weight", ])?;
let embeddings = self.get_tensor_f32(&embed_name)?;
let mut hidden = Vec::with_capacity(token_ids.len() * hidden_dim);
for &token_id in token_ids {
let offset = (token_id as usize) * hidden_dim;
if offset + hidden_dim <= embeddings.len() {
hidden.extend_from_slice(&embeddings[offset..offset + hidden_dim]);
} else {
hidden.extend(std::iter::repeat_n(0.0, hidden_dim));
}
}
Ok(hidden)
}
fn project_lm_head(
last_hidden: &[f32],
lm_head: &[f32],
hidden_dim: usize,
vocab_size: usize,
is_tied: bool,
) -> Vec<f32> {
let mut logits = vec![0.0; vocab_size];
if is_tied && lm_head.len() == hidden_dim * vocab_size {
for (i, logit) in logits.iter_mut().enumerate() {
for (j, &h) in last_hidden.iter().enumerate() {
*logit += h * lm_head.get(j * vocab_size + i).copied().unwrap_or(0.0);
}
}
} else {
for (i, logit) in logits.iter_mut().enumerate() {
for (j, &h) in last_hidden.iter().enumerate() {
*logit += h * lm_head.get(i * hidden_dim + j).copied().unwrap_or(0.0);
}
}
}
logits
}
fn apply_rope_to_qk(
q: &mut [f32],
k: &mut [f32],
seq_len: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
hidden_dim: usize,
rope_theta: f32,
rope_type: u32,
) {
let q_dim_per_token = hidden_dim;
let k_dim_per_token = num_kv_heads * head_dim;
for pos in 0..seq_len {
let q_start = pos * q_dim_per_token;
let q_end = q_start + q_dim_per_token;
if q_end <= q.len() {
apply_rope_norm(&mut q[q_start..q_end], num_heads, head_dim, pos, rope_theta, rope_type);
}
let k_start = pos * k_dim_per_token;
let k_end = k_start + k_dim_per_token;
if k_end <= k.len() {
apply_rope_norm(&mut k[k_start..k_end], num_kv_heads, head_dim, pos, rope_theta, rope_type);
}
}
}
fn resolve_model_dims(&self) -> (usize, usize, usize, usize, usize, usize, f32) {
let hidden_dim = self.metadata.hidden_size.unwrap_or(0);
let num_layers = self.metadata.num_layers.unwrap_or(0);
let num_heads = self.metadata.num_heads.unwrap_or(1);
let num_kv_heads = self.metadata.num_kv_heads.unwrap_or(num_heads);
let vocab_size = self.metadata.vocab_size.unwrap_or(0);
let intermediate_dim = self.metadata.intermediate_size.unwrap_or(hidden_dim * 4);
let eps = self.metadata.rms_norm_eps.unwrap_or(1e-6);
(hidden_dim, num_layers, num_heads, num_kv_heads, vocab_size, intermediate_dim, eps)
}
pub fn forward(&self, token_ids: &[u32]) -> Result<Vec<f32>> {
if token_ids.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Token sequence cannot be empty".to_string(),
});
}
if !self.metadata.is_transformer() {
return Err(RealizarError::FormatError {
reason: "Model is not a transformer (missing config)".to_string(),
});
}
let (hidden_dim, num_layers, num_heads, num_kv_heads, vocab_size, intermediate_dim, eps) =
self.resolve_model_dims();
let mut hidden = self.embed_tokens(token_ids, hidden_dim)?;
for layer_idx in 0..num_layers {
let attn_norm_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.input_layernorm.weight"),
&format!("layers.{layer_idx}.input_layernorm.weight"), &format!("transformer.h.{layer_idx}.ln_1.weight"),
&format!("layers.{layer_idx}.attention_norm.weight"),
&format!("blk.{layer_idx}.attn_norm.weight"), ])?;
let q_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.self_attn.q_proj.weight"),
&format!("layers.{layer_idx}.self_attn.q_proj.weight"), &format!("transformer.h.{layer_idx}.attn.q_proj.weight"),
&format!("layers.{layer_idx}.attention.wq.weight"),
&format!("blk.{layer_idx}.attn_q.weight"), ])?;
let k_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.self_attn.k_proj.weight"),
&format!("layers.{layer_idx}.self_attn.k_proj.weight"), &format!("transformer.h.{layer_idx}.attn.k_proj.weight"),
&format!("layers.{layer_idx}.attention.wk.weight"),
&format!("blk.{layer_idx}.attn_k.weight"), ])?;
let v_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.self_attn.v_proj.weight"),
&format!("layers.{layer_idx}.self_attn.v_proj.weight"), &format!("transformer.h.{layer_idx}.attn.v_proj.weight"),
&format!("layers.{layer_idx}.attention.wv.weight"),
&format!("blk.{layer_idx}.attn_v.weight"), ])?;
let o_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.self_attn.o_proj.weight"),
&format!("layers.{layer_idx}.self_attn.o_proj.weight"), &format!("transformer.h.{layer_idx}.attn.out_proj.weight"),
&format!("layers.{layer_idx}.attention.wo.weight"),
&format!("blk.{layer_idx}.attn_output.weight"), ])?;
let norm_weight = self.get_tensor_f32(&attn_norm_name)?;
let q_weight = self.get_tensor_f32(&q_name)?;
let k_weight = self.get_tensor_f32(&k_name)?;
let v_weight = self.get_tensor_f32(&v_name)?;
let o_weight = self.get_tensor_f32(&o_name)?;
let normed = rms_norm(&hidden, &norm_weight, eps);
let seq_len = token_ids.len();
let head_dim = hidden_dim / num_heads;
let mut q = matmul(&normed, &q_weight, seq_len, hidden_dim, hidden_dim);
let mut k = matmul(
&normed,
&k_weight,
seq_len,
hidden_dim,
num_kv_heads * head_dim,
);
let v = matmul(
&normed,
&v_weight,
seq_len,
hidden_dim,
num_kv_heads * head_dim,
);
let (rope_theta, rope_type) = self.resolve_rope_params();
Self::apply_rope_to_qk(
&mut q, &mut k, seq_len, num_heads, num_kv_heads, head_dim,
hidden_dim, rope_theta, rope_type,
);
let attn_out = simple_attention(&q, &k, &v, seq_len, num_heads, num_kv_heads, head_dim);
let attn_proj = matmul(&attn_out, &o_weight, seq_len, hidden_dim, hidden_dim);
for (h, &a) in hidden.iter_mut().zip(attn_proj.iter()) {
*h += a;
}
let ffn_norm_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.post_attention_layernorm.weight"),
&format!("layers.{layer_idx}.post_attention_layernorm.weight"), &format!("transformer.h.{layer_idx}.ln_2.weight"),
&format!("layers.{layer_idx}.ffn_norm.weight"),
&format!("blk.{layer_idx}.ffn_norm.weight"), ])?;
let gate_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.mlp.gate_proj.weight"),
&format!("layers.{layer_idx}.mlp.gate_proj.weight"), &format!("transformer.h.{layer_idx}.mlp.gate_proj.weight"),
&format!("layers.{layer_idx}.feed_forward.w1.weight"),
&format!("blk.{layer_idx}.ffn_gate.weight"), ])?;
let up_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.mlp.up_proj.weight"),
&format!("layers.{layer_idx}.mlp.up_proj.weight"), &format!("transformer.h.{layer_idx}.mlp.up_proj.weight"),
&format!("layers.{layer_idx}.feed_forward.w3.weight"),
&format!("blk.{layer_idx}.ffn_up.weight"), ])?;
let down_name = self.find_tensor_name(&[
&format!("model.layers.{layer_idx}.mlp.down_proj.weight"),
&format!("layers.{layer_idx}.mlp.down_proj.weight"), &format!("transformer.h.{layer_idx}.mlp.down_proj.weight"),
&format!("layers.{layer_idx}.feed_forward.w2.weight"),
&format!("blk.{layer_idx}.ffn_down.weight"), ])?;
let ffn_norm = self.get_tensor_f32(&ffn_norm_name)?;
let gate = self.get_tensor_f32(&gate_name)?;
let up = self.get_tensor_f32(&up_name)?;
let down = self.get_tensor_f32(&down_name)?;
let normed = rms_norm(&hidden, &ffn_norm, eps);
let gate_out = matmul(&normed, &gate, seq_len, hidden_dim, intermediate_dim);
let up_out = matmul(&normed, &up, seq_len, hidden_dim, intermediate_dim);
let mut ffn_hidden = Vec::with_capacity(seq_len * intermediate_dim);
for (g, u) in gate_out.iter().zip(up_out.iter()) {
let silu = g * (1.0 / (1.0 + (-g).exp()));
ffn_hidden.push(silu * u);
}
let ffn_out = matmul(&ffn_hidden, &down, seq_len, intermediate_dim, hidden_dim);
for (h, &f) in hidden.iter_mut().zip(ffn_out.iter()) {
*h += f;
}
}
let final_norm_name = self.find_tensor_name(&[
"model.norm.weight",
"norm.weight", "transformer.ln_f.weight",
"output_norm.weight", ])?;
let final_norm = self.get_tensor_f32(&final_norm_name)?;
let hidden = rms_norm(&hidden, &final_norm, eps);
let lm_head_name = self.find_tensor_name(&[
"lm_head.weight",
"output.weight",
"model.embed_tokens.weight", "embed_tokens.weight", "token_embd.weight", ])?;
let lm_head = self.get_tensor_f32(&lm_head_name)?;
let last_hidden = &hidden[hidden.len() - hidden_dim..];
let is_tied = lm_head_name == "token_embd.weight"
|| lm_head_name.ends_with("embed_tokens.weight");
let logits = Self::project_lm_head(last_hidden, &lm_head, hidden_dim, vocab_size, is_tied);
Ok(logits)
}
pub fn generate(
&self,
input_tokens: &[u32],
max_new_tokens: usize,
eos_token_id: Option<u32>,
) -> Result<Vec<u32>> {
if input_tokens.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Input tokens cannot be empty".to_string(),
});
}
let mut tokens = input_tokens.to_vec();
let vocab_size = self.metadata.vocab_size.unwrap_or(0);
for _ in 0..max_new_tokens {
let logits = self.forward(&tokens)?;
let next_token = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx as u32);
if let Some(eos) = eos_token_id {
if next_token == eos {
break;
}
}
if (next_token as usize) >= vocab_size && vocab_size > 0 {
break;
}
tokens.push(next_token);
}
Ok(tokens)
}
pub fn find_tensor_name(&self, candidates: &[&str]) -> Result<String> {
for &name in candidates {
if self.get_tensor(name).is_some() {
return Ok(name.to_string());
}
}
Err(RealizarError::FormatError {
reason: format!("No matching tensor found. Tried: {:?}", candidates),
})
}
pub fn load_tokenizer_from_sibling(
model_path: &Path,
) -> Option<(Vec<String>, Option<u32>, Option<u32>)> {
let tokenizer_path = find_sibling_file(model_path, "tokenizer.json")?;
let content = fs::read_to_string(&tokenizer_path).ok()?;
let json: serde_json::Value = serde_json::from_str(&content).ok()?;
let vocab_obj = json.get("model")?.get("vocab")?;
let vocab_map = vocab_obj.as_object()?;
let mut vocab_vec: Vec<(String, u32)> = vocab_map
.iter()
.filter_map(|(token, id)| Some((token.clone(), id.as_u64()? as u32)))
.collect();
vocab_vec.sort_by_key(|(_, id)| *id);
let mut bos_id = None;
let mut eos_id = None;
if let Some(added_tokens) = json.get("added_tokens").and_then(|v| v.as_array()) {
for token in added_tokens {
let content = token.get("content").and_then(|v| v.as_str());
let id = token
.get("id")
.and_then(serde_json::Value::as_u64)
.map(|v| v as u32);
if let (Some(content), Some(id)) = (content, id) {
if content == "<|endoftext|>" || content == "</s>" || content == "<eos>" {
eos_id = Some(id);
}
if content == "<s>" || content == "<bos>" {
bos_id = Some(id);
}
vocab_vec.push((content.to_string(), id));
}
}
}
vocab_vec.sort_by_key(|(_, id)| *id);
let vocab: Vec<String> = vocab_vec.into_iter().map(|(token, _)| token).collect();
Some((vocab, bos_id, eos_id))
}
pub fn decode_tokens(vocab: &[String], token_ids: &[u32]) -> String {
let mut result = String::new();
for &id in token_ids {
if let Some(token) = vocab.get(id as usize) {
let decoded = token
.replace("Ġ", " ")
.replace("Ċ", "\n")
.replace("ĉ", "\t");
result.push_str(&decoded);
} else {
result.push_str(&format!("[{}]", id));
}
}
result
}
}