use crate::autograd::Tensor;
use crate::demo::Qwen2Config;
use crate::nn::{GroupedQueryAttention, Linear, Module, RMSNorm, RotaryPositionEmbedding};
#[derive(Debug)]
pub struct Embedding {
weight: Tensor,
vocab_size: usize,
hidden_size: usize,
}
impl Embedding {
#[must_use]
pub fn new(vocab_size: usize, hidden_size: usize) -> Self {
let data: Vec<f32> = (0..vocab_size * hidden_size)
.map(|i| {
(i as f32 * 0.0001).sin() * 0.02
})
.collect();
Self {
weight: Tensor::from_vec(data, &[vocab_size, hidden_size]),
vocab_size,
hidden_size,
}
}
#[must_use]
pub fn placeholder(vocab_size: usize, hidden_size: usize) -> Self {
Self {
weight: Tensor::new(&[0.0], &[1]),
vocab_size,
hidden_size,
}
}
pub fn forward_into(&self, input_ids: &[u32], output: &mut [f32]) {
for (s, &token_id) in input_ids.iter().enumerate() {
let token_idx = token_id as usize;
if token_idx >= self.vocab_size {
eprintln!(
"Warning: token_id {token_id} >= vocab_size {} (N-09 OOB escape, zeros emitted)",
self.vocab_size
);
continue;
}
let src_offset = token_idx * self.hidden_size;
let dst_offset = s * self.hidden_size;
output[dst_offset..dst_offset + self.hidden_size]
.copy_from_slice(&self.weight.data()[src_offset..src_offset + self.hidden_size]);
}
}
#[provable_contracts_macros::contract("embedding-algebra-v1", equation = "embedding_lookup")]
#[must_use]
#[allow(unused_variables)] pub fn forward(&self, input_ids: &[u32]) -> Tensor {
contract_pre_embedding_lookup!(input_ids);
contract_pre_inference_determinism!();
let batch_size = 1;
let mut output = vec![0.0f32; batch_size * input_ids.len() * self.hidden_size];
self.forward_into(input_ids, &mut output);
let result = Tensor::new(&output, &[batch_size, input_ids.len(), self.hidden_size]);
contract_post_embedding_lookup!(result.data());
contract_post_inference_determinism!(result.data());
result
}
pub fn set_weight(&mut self, weight: Tensor) {
self.weight = weight;
}
#[must_use]
pub fn weight(&self) -> &Tensor {
contract_pre_q_projection!();
contract_pre_kv_projection!();
&self.weight
}
}
#[derive(Debug)]
#[allow(clippy::struct_field_names)] pub struct Qwen2MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
}
impl Qwen2MLP {
#[must_use]
pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
Self {
gate_proj: Linear::new(hidden_size, intermediate_size),
up_proj: Linear::new(hidden_size, intermediate_size),
down_proj: Linear::new(intermediate_size, hidden_size),
}
}
#[must_use]
pub fn placeholder(hidden_size: usize, intermediate_size: usize) -> Self {
Self {
gate_proj: Linear::placeholder(hidden_size, intermediate_size),
up_proj: Linear::placeholder(hidden_size, intermediate_size),
down_proj: Linear::placeholder(intermediate_size, hidden_size),
}
}
#[must_use]
pub fn forward(&self, x: &Tensor) -> Tensor {
contract_pre_swiglu!(x.data());
let gate = self.gate_proj.forward(x);
let gate_activated = silu(&gate);
let up = self.up_proj.forward(x);
let hidden = elementwise_mul(&gate_activated, &up);
let result = self.down_proj.forward(&hidden);
contract_post_swiglu!(result.data());
result
}
pub fn gate_proj_mut(&mut self) -> &mut Linear {
&mut self.gate_proj
}
pub fn up_proj_mut(&mut self) -> &mut Linear {
&mut self.up_proj
}
pub fn down_proj_mut(&mut self) -> &mut Linear {
&mut self.down_proj
}
}
#[derive(Debug)]
pub struct Qwen2DecoderLayer {
self_attn: GroupedQueryAttention,
mlp: Qwen2MLP,
input_layernorm: RMSNorm,
post_attention_layernorm: RMSNorm,
}
impl Qwen2DecoderLayer {
#[must_use]
pub fn new(config: &Qwen2Config) -> Self {
Self {
self_attn: GroupedQueryAttention::new(
config.hidden_size,
config.num_attention_heads,
config.num_kv_heads,
),
mlp: Qwen2MLP::new(config.hidden_size, config.intermediate_size),
input_layernorm: RMSNorm::new(&[config.hidden_size]),
post_attention_layernorm: RMSNorm::new(&[config.hidden_size]),
}
}
#[must_use]
pub fn placeholder(config: &Qwen2Config) -> Self {
Self {
self_attn: GroupedQueryAttention::placeholder(
config.hidden_size,
config.num_attention_heads,
config.num_kv_heads,
),
mlp: Qwen2MLP::placeholder(config.hidden_size, config.intermediate_size),
input_layernorm: RMSNorm::placeholder(&[config.hidden_size]),
post_attention_layernorm: RMSNorm::placeholder(&[config.hidden_size]),
}
}
#[must_use]
pub fn forward(
&self,
hidden_states: &Tensor,
_position_ids: &[usize],
_rope: &RotaryPositionEmbedding,
_attention_mask: Option<&Tensor>,
) -> Tensor {
contract_pre_residual!(hidden_states.data());
let residual = hidden_states.clone();
let hidden = self.input_layernorm.forward(hidden_states);
let (attn_output, _attn_weights) = self.self_attn.forward_self(&hidden, None);
let hidden = add_tensors(&residual, &attn_output);
let residual = hidden.clone();
let hidden = self.post_attention_layernorm.forward(&hidden);
let mlp_output = self.mlp.forward(&hidden);
add_tensors(&residual, &mlp_output)
}
pub fn self_attn_mut(&mut self) -> &mut GroupedQueryAttention {
&mut self.self_attn
}
pub fn mlp_mut(&mut self) -> &mut Qwen2MLP {
&mut self.mlp
}
pub fn input_layernorm_mut(&mut self) -> &mut RMSNorm {
&mut self.input_layernorm
}
pub fn post_attention_layernorm_mut(&mut self) -> &mut RMSNorm {
&mut self.post_attention_layernorm
}
}
#[derive(Debug)]
pub struct KVCache {
pub keys: Vec<Option<Tensor>>,
pub values: Vec<Option<Tensor>>,
pub cached_len: usize,
}
impl KVCache {
#[must_use]
pub fn new(num_layers: usize) -> Self {
Self {
keys: vec![None; num_layers],
values: vec![None; num_layers],
cached_len: 0,
}
}
pub fn clear(&mut self) {
for k in &mut self.keys {
*k = None;
}
for v in &mut self.values {
*v = None;
}
self.cached_len = 0;
}
}
#[derive(Debug)]
pub struct Qwen2Model {
embed_tokens: Embedding,
layers: Vec<Qwen2DecoderLayer>,
norm: RMSNorm,
lm_head: Linear,
#[allow(dead_code)]
rope: RotaryPositionEmbedding,
config: Qwen2Config,
kv_cache: Option<KVCache>,
training: bool,
}
#[cfg(test)]
#[path = "tests_embedding_contract.rs"]
mod tests_embedding_contract;
include!("constructors.rs");
include!("element-wise.rs");