Skip to main content

hot_loop/models/qwen3/
qwen3.rs

1use candle_transformers::models::with_tracing::QMatMul;
2use candle_transformers::{quantized_nn::RmsNorm};
3use candle_core::quantized::{gguf_file};
4use candle_core::{DType, Device, Result as CandleResult, Tensor};
5use candle_nn::{Embedding, Module};
6use std::io::{Read, Seek};
7use std::sync::Arc;
8use crate::Error;
9use crate::utils::kv_cache::KvCache;
10use tokenizers::Tokenizer;
11use super::ChatFormat;
12use crate::session::history::Role;
13use super::super::models_core::model::ModelWeights;
14use super::super::models_core::rotary_embedding::RotaryEmbedding;
15use super::super::models_core::mask::mask;
16use crate::utils::gguf::Gguf;
17use super::transformers::LayerWeights;
18
19#[derive(Clone)]
20pub struct Qwen3(Arc<Qwen3Inner>);
21
22impl Qwen3 {
23    pub fn load<M, T>(
24        model: &mut M,
25        tokenizer: T,
26        device: Device,
27    ) -> Result<Self, Error>
28    where
29        M: Read + Seek,
30        T: AsRef<[u8]>,
31    {
32        let model = Qwen3Inner::load(model, tokenizer, device)?;
33        Ok(Self(Arc::new(model)))
34    }
35}
36
37impl ModelWeights for Qwen3 {
38    fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut KvCache) -> CandleResult<Tensor> {
39        self.0.forward(input, offset, kv_cache)
40    }
41
42    fn layers_len(&self) -> usize {
43        self.0.layers_len()
44    }
45
46    fn tokenizer(&self) -> Arc<Tokenizer> {
47        self.0.tokenizer()
48    }
49
50    fn device(&self) -> &Device {
51        &self.0.device()
52    }
53
54    fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error> {
55        self.0.fmt_prompt(prompt, role)
56    }
57
58    fn assistant_start_template(&self) -> Vec<u32> {
59        self.0.assistant_start_template()
60    }
61
62    fn eos_token(&self) -> u32 {
63        self.0.eos_token()
64    }
65}
66
67struct Qwen3Inner {
68    embed_tokens: Embedding,
69    layers: Vec<LayerWeights>,
70    norm: RmsNorm,
71    lm_head: QMatMul,
72    device: Device,
73    dtype: DType,
74    chat_format: ChatFormat,
75    tokenizer: Arc<Tokenizer>,
76}
77
78impl Qwen3Inner {
79    fn load<M, T>(
80        model: &mut M,
81        tokenizer: T,
82        device: Device,
83    ) -> Result<Self, Error>
84    where
85        M: Read + Seek,
86        T: AsRef<[u8]>,
87    {
88        let ct = gguf_file::Content::read(model)?;
89        let tokenizer = Arc::new(Tokenizer::from_bytes(tokenizer)?);
90
91        let mut gg = Gguf::new("qwen3", &ct, model, &device);
92
93        let num_attention_heads = gg.get_with_prefix("attention.head_count")?.to_u32()? as usize;
94        let num_kv_heads = gg.get_with_prefix("attention.head_count_kv")?.to_u32()? as usize;
95        let head_dim = gg.get_with_prefix("attention.key_length")?.to_u32()? as usize;
96        let num_layers = gg.get_with_prefix("block_count")?.to_u32()? as usize;
97        let hidden_size = gg.get_with_prefix("embedding_length")?.to_u32()? as usize;
98        let max_position_embeddings = gg.get_with_prefix("context_length")?.to_u32()? as usize;
99        let rms_norm_eps = gg.get_with_prefix("attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
100        let rope_freq_base = gg.get_with_prefix("rope.freq_base")?.to_f32()? as f64;
101
102        let dtype = match gg.metadata().get("general.dtype") {
103            Some(v) => match v.to_u32() {
104                Ok(0) => DType::F32,
105                Ok(1) => DType::F16,
106                _ => DType::F16,
107            },
108            None => DType::F16,
109        };
110
111        let embed_tensor = gg.tensor("token_embd.weight")?;
112        let embed_tokens = Embedding::new(embed_tensor.dequantize(&device)?, hidden_size);
113
114        let rotary = Arc::new(RotaryEmbedding::new(
115            head_dim,
116            rope_freq_base,
117            max_position_embeddings,
118            dtype,
119            &device,
120        )?);
121
122        let mut layers = Vec::with_capacity(num_layers);
123        for i in 0..num_layers {
124            layers.push(LayerWeights::new(
125                &mut gg,
126                num_attention_heads,
127                num_kv_heads,
128                head_dim,
129                rms_norm_eps,
130                rotary.clone(),
131                i,
132            )?);
133        }
134
135        let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
136
137        let lm_head_tensor = match gg.tensor("output.weight") {
138            Ok(tensor) => tensor,
139            Err(_) => gg.tensor("token_embd.weight")?,
140        };
141
142        let lm_head = QMatMul::from_weights(lm_head_tensor.into())?;
143
144        let chat_format = ChatFormat::new(tokenizer.clone())?;
145
146        Ok(Self {
147            embed_tokens,
148            layers,
149            norm,
150            lm_head,
151            device,
152            dtype,
153            chat_format,
154            tokenizer
155        })
156    }
157
158    fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut KvCache) -> CandleResult<Tensor> {
159        let (b, l) = input.dims2()?;
160        let mut h = self.embed_tokens.forward(input)?;
161
162        let causal_mask = if l == 1 {
163            None
164        } else {
165            Some(mask(b, l, offset, None, self.dtype, &self.device)?)
166        };
167
168        for (layer, cache) in self.layers.iter().zip(kv_cache.iter_mut()) {
169            h = layer.forward(&h, causal_mask.as_ref(), offset, cache)?;
170        }
171
172        let h = self.norm.forward(&h)?;
173        let last_hidden = h.narrow(1, l - 1, 1)?;
174        self.lm_head.forward(&last_hidden)?.squeeze(1)
175    }
176
177    fn layers_len(&self) -> usize {
178        self.layers.len()
179    }
180
181    fn tokenizer(&self) -> Arc<Tokenizer> {
182        self.tokenizer.clone()
183    }
184
185    fn device(&self) -> &Device {
186        &self.device
187    }
188
189    fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error> {
190        self.chat_format.fmt_prompt(prompt, role)
191    }
192
193    fn assistant_start_template(&self) -> Vec<u32> {
194        self.chat_format.assistant_start_template()
195    }
196
197    fn eos_token(&self) -> u32 {
198        self.chat_format.eos_token()
199    }
200}