#![allow(dead_code)]
use std::path::Path;
use burn::nn::{Embedding, EmbeddingConfig, RmsNorm, RmsNormConfig};
use burn::prelude::*;
use super::attention::{FullAttention, FullAttentionConfig};
use super::deltanet::{GatedDeltaNet, GatedDeltaNetConfig};
use super::ffn::{FeedForward, FeedForwardConfig};
use crate::error::SyaraError;
#[derive(Config, Debug)]
pub struct Qwen3Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub head_dim: usize,
pub linear_num_key_heads: usize,
pub linear_num_value_heads: usize,
pub linear_key_head_dim: usize,
pub linear_value_head_dim: usize,
#[config(default = 4)]
pub linear_conv_kernel_dim: usize,
#[config(default = 4)]
pub full_attention_interval: usize,
#[config(default = 4096)]
pub max_position_embeddings: usize,
#[config(default = 10_000_000.0)]
pub rope_theta: f32,
#[config(default = 0.25)]
pub partial_rotary_factor: f64,
#[config(default = 1e-6)]
pub rms_norm_eps: f64,
#[config(default = true)]
pub tie_word_embeddings: bool,
#[config(default = 248044)]
pub eos_token_id: usize,
}
impl Qwen3Config {
pub fn init<B: Backend>(&self, device: &B::Device) -> Qwen3TextModel<B> {
let embed_tokens = EmbeddingConfig::new(self.vocab_size, self.hidden_size).init(device);
let layers: Vec<TransformerBlock<B>> = (0..self.num_hidden_layers)
.map(|i| self.init_block(i, device))
.collect();
let final_norm = RmsNormConfig::new(self.hidden_size)
.with_epsilon(self.rms_norm_eps)
.init(device);
Qwen3TextModel {
embed_tokens,
layers,
final_norm,
tie_word_embeddings: self.tie_word_embeddings,
hidden_size: self.hidden_size,
vocab_size: self.vocab_size,
}
}
fn init_block<B: Backend>(&self, layer_idx: usize, device: &B::Device) -> TransformerBlock<B> {
let is_full_attn = (layer_idx + 1).is_multiple_of(self.full_attention_interval);
let hybrid = if is_full_attn {
HybridBlock::Full(
FullAttentionConfig {
d_model: self.hidden_size,
n_heads: self.num_attention_heads,
n_kv_heads: self.num_key_value_heads,
head_dim: self.head_dim,
max_seq_len: self.max_position_embeddings,
qk_norm: true,
partial_rotary_factor: self.partial_rotary_factor,
rope_theta: self.rope_theta,
rms_norm_eps: self.rms_norm_eps,
}
.init(device),
)
} else {
HybridBlock::Linear(
GatedDeltaNetConfig {
d_model: self.hidden_size,
num_heads: self.linear_num_key_heads,
key_head_dim: self.linear_key_head_dim,
value_head_dim: self.linear_value_head_dim,
conv_kernel_size: self.linear_conv_kernel_dim,
rms_norm_eps: self.rms_norm_eps,
}
.init(device),
)
};
let norm = |size| {
RmsNormConfig::new(size)
.with_epsilon(self.rms_norm_eps)
.init(device)
};
let ffn = FeedForwardConfig {
d_model: self.hidden_size,
d_intermediate: self.intermediate_size,
}
.init(device);
TransformerBlock {
input_layernorm: norm(self.hidden_size),
hybrid,
post_attention_layernorm: norm(self.hidden_size),
mlp: ffn,
}
}
}
#[derive(Module, Debug)]
#[allow(clippy::large_enum_variant)]
pub enum HybridBlock<B: Backend> {
Linear(GatedDeltaNet<B>),
Full(FullAttention<B>),
}
impl<B: Backend> HybridBlock<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
match self {
HybridBlock::Linear(deltanet) => deltanet.forward(x),
HybridBlock::Full(attn) => attn.forward(x),
}
}
}
#[derive(Module, Debug)]
pub struct TransformerBlock<B: Backend> {
pub(crate) input_layernorm: RmsNorm<B>,
pub(crate) hybrid: HybridBlock<B>,
pub(crate) post_attention_layernorm: RmsNorm<B>,
pub(crate) mlp: FeedForward<B>,
}
impl<B: Backend> TransformerBlock<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let residual = x.clone();
let h = self.input_layernorm.forward(x);
let h = self.hybrid.forward(h);
let h = residual + h;
let residual = h.clone();
let h = self.post_attention_layernorm.forward(h);
let h = self.mlp.forward(h);
residual + h
}
}
#[derive(Module, Debug)]
pub struct Qwen3TextModel<B: Backend> {
pub(crate) embed_tokens: Embedding<B>,
pub(crate) layers: Vec<TransformerBlock<B>>,
pub(crate) final_norm: RmsNorm<B>,
pub(crate) tie_word_embeddings: bool,
pub(crate) hidden_size: usize,
pub(crate) vocab_size: usize,
}
impl<B: Backend> Qwen3TextModel<B> {
pub fn forward(&self, input_ids: Tensor<B, 2, Int>) -> Tensor<B, 3> {
let mut h = self.embed_tokens.forward(input_ids);
for layer in &self.layers {
h = layer.forward(h);
}
let h = self.final_norm.forward(h);
let weight = self.embed_tokens.weight.val(); let weight = weight.unsqueeze_dim::<3>(0); let weight = weight.transpose(); h.matmul(weight) }
pub fn num_layers(&self) -> usize {
self.layers.len()
}
}
#[derive(serde::Deserialize)]
struct RawModelConfig {
text_config: RawTextConfig,
}
#[derive(serde::Deserialize)]
struct RawTextConfig {
vocab_size: usize,
hidden_size: usize,
intermediate_size: usize,
num_hidden_layers: usize,
num_attention_heads: usize,
num_key_value_heads: usize,
head_dim: usize,
linear_num_key_heads: usize,
linear_num_value_heads: usize,
linear_key_head_dim: usize,
linear_value_head_dim: usize,
#[serde(default = "default_conv_kernel")]
linear_conv_kernel_dim: usize,
#[serde(default = "default_full_attn_interval")]
full_attention_interval: usize,
#[serde(default = "default_max_pos")]
max_position_embeddings: usize,
#[serde(default = "default_rms_norm_eps")]
rms_norm_eps: f64,
#[serde(default = "default_tie")]
tie_word_embeddings: bool,
#[serde(default = "default_eos")]
eos_token_id: usize,
#[serde(default)]
rope_parameters: Option<RopeParameters>,
}
#[derive(serde::Deserialize)]
struct RopeParameters {
#[serde(default = "default_rope_theta")]
rope_theta: f32,
#[serde(default = "default_partial_rotary")]
partial_rotary_factor: f64,
}
fn default_conv_kernel() -> usize { 4 }
fn default_full_attn_interval() -> usize { 4 }
fn default_max_pos() -> usize { 4096 }
fn default_rms_norm_eps() -> f64 { 1e-6 }
fn default_tie() -> bool { true }
fn default_eos() -> usize { 248044 }
fn default_rope_theta() -> f32 { 10_000_000.0 }
fn default_partial_rotary() -> f64 { 0.25 }
pub fn load_qwen3_config(model_dir: &Path) -> Result<Qwen3Config, SyaraError> {
let config_path = model_dir.join("config.json");
let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
SyaraError::LlmError(format!("failed to read {}: {e}", config_path.display()))
})?;
let raw: RawModelConfig = serde_json::from_str(&config_str).map_err(|e| {
SyaraError::LlmError(format!("failed to parse config.json: {e}"))
})?;
let tc = raw.text_config;
let rope = tc.rope_parameters.unwrap_or(RopeParameters {
rope_theta: default_rope_theta(),
partial_rotary_factor: default_partial_rotary(),
});
Ok(Qwen3Config {
vocab_size: tc.vocab_size,
hidden_size: tc.hidden_size,
intermediate_size: tc.intermediate_size,
num_hidden_layers: tc.num_hidden_layers,
num_attention_heads: tc.num_attention_heads,
num_key_value_heads: tc.num_key_value_heads,
head_dim: tc.head_dim,
linear_num_key_heads: tc.linear_num_key_heads,
linear_num_value_heads: tc.linear_num_value_heads,
linear_key_head_dim: tc.linear_key_head_dim,
linear_value_head_dim: tc.linear_value_head_dim,
linear_conv_kernel_dim: tc.linear_conv_kernel_dim,
full_attention_interval: tc.full_attention_interval,
max_position_embeddings: tc.max_position_embeddings,
rope_theta: rope.rope_theta,
partial_rotary_factor: rope.partial_rotary_factor,
rms_norm_eps: tc.rms_norm_eps,
tie_word_embeddings: tc.tie_word_embeddings,
eos_token_id: tc.eos_token_id,
})
}
impl<B: Backend> super::ForwardModel<B> for Qwen3TextModel<B> {
fn forward(&self, input_ids: Tensor<B, 2, Int>) -> Tensor<B, 3> {
self.forward(input_ids)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
fn tiny_config() -> Qwen3Config {
Qwen3Config {
vocab_size: 256,
hidden_size: 64,
intermediate_size: 128,
num_hidden_layers: 2,
num_attention_heads: 4,
num_key_value_heads: 2,
head_dim: 16,
linear_num_key_heads: 4,
linear_num_value_heads: 4,
linear_key_head_dim: 16,
linear_value_head_dim: 16,
linear_conv_kernel_dim: 4,
full_attention_interval: 2,
max_position_embeddings: 128,
rope_theta: 10_000.0,
partial_rotary_factor: 0.25,
rms_norm_eps: 1e-6,
tie_word_embeddings: true,
eos_token_id: 0,
}
}
#[test]
fn forward_produces_logits() {
let device = Default::default();
let model = tiny_config().init::<B>(&device);
let input_ids = Tensor::<B, 2, Int>::zeros([1, 4], &device);
let logits = model.forward(input_ids);
assert_eq!(logits.dims(), [1, 4, 256]);
}
#[test]
fn hybrid_dispatch_correct() {
let cfg = tiny_config();
let device = Default::default();
let model = cfg.init::<B>(&device);
assert_eq!(model.num_layers(), 2);
assert!(
matches!(model.layers[0].hybrid, HybridBlock::Linear(_)),
"layer 0 should be DeltaNet"
);
assert!(
matches!(model.layers[1].hybrid, HybridBlock::Full(_)),
"layer 1 should be FullAttention"
);
}
#[test]
fn forward_single_token() {
let device = Default::default();
let model = tiny_config().init::<B>(&device);
let input_ids = Tensor::<B, 2, Int>::zeros([1, 1], &device);
let logits = model.forward(input_ids);
assert_eq!(logits.dims(), [1, 1, 256]);
}
#[test]
fn forward_batch() {
let device = Default::default();
let model = tiny_config().init::<B>(&device);
let input_ids = Tensor::<B, 2, Int>::zeros([3, 8], &device);
let logits = model.forward(input_ids);
assert_eq!(logits.dims(), [3, 8, 256]);
}
#[test]
fn tied_weights_produce_vocab_sized_logits() {
let device = Default::default();
let model = tiny_config().init::<B>(&device);
let input_ids = Tensor::<B, 2, Int>::from_data([[1, 2, 3]], &device);
let logits = model.forward(input_ids);
assert_eq!(logits.dims()[2], 256);
}
#[test]
fn load_config_from_fixture() {
let fixture_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/tiny-qwen");
let config = load_qwen3_config(&fixture_dir).unwrap();
assert_eq!(config.vocab_size, 256);
assert_eq!(config.hidden_size, 64);
assert_eq!(config.num_hidden_layers, 2);
assert_eq!(config.head_dim, 16);
assert_eq!(config.full_attention_interval, 2);
assert_eq!(config.eos_token_id, 0);
}
}