Skip to main content

hot_loop/models/qwen3/
qwen3.rs

1use candle_transformers::models::with_tracing::QMatMul;
2use candle_transformers::{quantized_nn::RmsNorm};
3use candle_core::quantized::{gguf_file};
4use candle_core::{DType, Device, Result as CandleResult, Tensor};
5use candle_nn::{Embedding, Module};
6use std::io::{Read, Seek};
7use std::sync::Arc;
8use crate::{ModelWeights, KvCache, Error, Role};
9use tokenizers::Tokenizer;
10use super::ChatTemplate;
11
12use super::{
13    transformers::{
14        LayerWeights,
15        Gguf,
16        RotaryEmbedding
17    }
18};
19
20#[derive(Clone)]
21pub struct Qwen3 {
22    embed_tokens: Embedding,
23    layers: Vec<LayerWeights>,
24    norm: RmsNorm,
25    lm_head: QMatMul,
26    device: Device,
27    dtype: DType,
28    chat_template: ChatTemplate
29}
30
31impl Qwen3 {
32    pub fn load<M, T>(
33        model: &mut M,
34        tokenizer: T,
35        device: &Device,
36    ) -> Result<Self, Error>
37    where
38        M: Read + Seek,
39        T: AsRef<[u8]>,
40    {
41        let ct = gguf_file::Content::read(model)?;
42
43        let tokenizer = Tokenizer::from_bytes(tokenizer)?;
44
45        let mut gg = Gguf::new(ct, model, device.clone());
46        let md_get = |s: &str| match gg.metadata().get(s) {
47            None => candle_core::bail!("cannot find {s} in metadata"),
48            Some(v) => Ok(v),
49        };
50
51        let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
52        let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
53        let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
54        let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
55        let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
56        let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
57        let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
58        let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
59
60        let dtype = match gg.metadata().get("general.dtype") {
61            Some(v) => match v.to_u32() {
62                Ok(0) => DType::F32,
63                Ok(1) => DType::F16,
64                _ => DType::F16,
65            },
66            None => DType::F16,
67        };
68
69        let embed_tensor = gg.tensor("token_embd.weight")?;
70        let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
71
72        let rotary = Arc::new(RotaryEmbedding::new(
73            dtype,
74            head_dim,
75            max_position_embeddings,
76            rope_freq_base,
77            device,
78        )?);
79
80        let mut layers = Vec::with_capacity(num_layers);
81        for i in 0..num_layers {
82            layers.push(LayerWeights::new(
83                &mut gg,
84                num_attention_heads,
85                num_kv_heads,
86                head_dim,
87                rms_norm_eps,
88                rotary.clone(),
89                i,
90            )?);
91        }
92
93        let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
94        // Load output projection tensor, falling back to tied embeddings like gemma3
95        let lm_head_tensor = match gg.tensor("output.weight") {
96            Ok(tensor) => tensor,
97            Err(_) => gg.tensor("token_embd.weight")?,
98        };
99        let lm_head = QMatMul::from_weights(lm_head_tensor.into())?;
100
101        let chat_template = ChatTemplate::new(tokenizer)?;
102
103        Ok(Self {
104            embed_tokens,
105            layers,
106            norm,
107            lm_head,
108            device: device.clone(),
109            dtype,
110            chat_template
111        })
112    }
113
114    fn causal_mask(
115        &self,
116        b: usize,
117        tgt: usize,
118        offset: usize,
119        sw: Option<usize>,
120    ) -> CandleResult<Tensor> {
121        let minf = f32::NEG_INFINITY;
122        let mask: Vec<_> = (0..tgt)
123            .flat_map(|i| {
124                (0..(tgt + offset)).map(move |j| {
125                    let past_ok = j <= i + offset;
126                    let sw_ok = match sw {
127                        Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
128                        None => true,
129                    };
130                    if past_ok && sw_ok {
131                        0.
132                    } else {
133                        minf
134                    }
135                })
136            })
137            .collect();
138        Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
139    }
140}
141
142impl ModelWeights for Qwen3 {
143    fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut Vec<KvCache>) -> CandleResult<Tensor> {
144        let (b, l) = input.dims2()?;
145        let mut h = self.embed_tokens.forward(input)?;
146        let causal_mask = if l == 1 {
147            None
148        } else {
149            Some(self.causal_mask(b, l, offset, None)?)
150        };
151        
152        for (layer, cache) in self.layers.iter().zip(kv_cache.iter_mut()) {
153            h = layer.forward(&h, causal_mask.as_ref(), offset, cache)?;
154        }
155
156        let h = self.norm.forward(&h)?;
157        let last_hidden = h.narrow(1, l - 1, 1)?;
158        self.lm_head.forward(&last_hidden)?.squeeze(1)
159    }
160
161    fn create_kv_cache(&self) -> Vec<KvCache> {
162        let mut kv_cache = Vec::with_capacity(self.layers.len());
163
164        for _ in 0..self.layers.len() {
165            kv_cache.push(KvCache::new(2));
166        }
167
168        kv_cache
169    }
170
171    fn tokenizer(&self) -> &Tokenizer {
172        &self.chat_template.tokenizer()
173    }
174
175    fn current_device(&self) -> &Device {
176        &self.device
177    }
178
179    fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error> {
180        self.chat_template.fmt_prompt(prompt, role)
181    }
182
183    fn assistant_start_template(&self) -> Vec<u32> {
184        self.chat_template.assistant_start_template()
185    }
186    
187    fn eos_token(&self) -> u32 {
188        self.chat_template.eos_token()
189    }
190}