1use candle_transformers::models::with_tracing::QMatMul;
2use candle_transformers::{quantized_nn::RmsNorm};
3use candle_core::quantized::{gguf_file};
4use candle_core::{DType, Device, Result as CandleResult, Tensor};
5use candle_nn::{Embedding, Module};
6use std::io::{Read, Seek};
7use std::sync::Arc;
8use crate::Error;
9use crate::utils::kv_cache::KvCache;
10use tokenizers::Tokenizer;
11use super::ChatFormat;
12use crate::session::history::Role;
13use super::super::models_core::model::ModelWeights;
14use super::super::models_core::rotary_embedding::RotaryEmbedding;
15use super::super::models_core::mask::mask;
16use crate::utils::gguf::Gguf;
17use super::transformers::LayerWeights;
18
19#[derive(Clone)]
20pub struct Qwen3(Arc<Qwen3Inner>);
21
22impl Qwen3 {
23 pub fn load<M, T>(
24 model: &mut M,
25 tokenizer: T,
26 device: Device,
27 ) -> Result<Self, Error>
28 where
29 M: Read + Seek,
30 T: AsRef<[u8]>,
31 {
32 let model = Qwen3Inner::load(model, tokenizer, device)?;
33 Ok(Self(Arc::new(model)))
34 }
35}
36
37impl ModelWeights for Qwen3 {
38 fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut KvCache) -> CandleResult<Tensor> {
39 self.0.forward(input, offset, kv_cache)
40 }
41
42 fn layers_len(&self) -> usize {
43 self.0.layers_len()
44 }
45
46 fn tokenizer(&self) -> Arc<Tokenizer> {
47 self.0.tokenizer()
48 }
49
50 fn device(&self) -> &Device {
51 &self.0.device()
52 }
53
54 fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error> {
55 self.0.fmt_prompt(prompt, role)
56 }
57
58 fn assistant_start_template(&self) -> Vec<u32> {
59 self.0.assistant_start_template()
60 }
61
62 fn eos_token(&self) -> u32 {
63 self.0.eos_token()
64 }
65}
66
67struct Qwen3Inner {
68 embed_tokens: Embedding,
69 layers: Vec<LayerWeights>,
70 norm: RmsNorm,
71 lm_head: QMatMul,
72 device: Device,
73 dtype: DType,
74 chat_format: ChatFormat,
75 tokenizer: Arc<Tokenizer>,
76}
77
78impl Qwen3Inner {
79 fn load<M, T>(
80 model: &mut M,
81 tokenizer: T,
82 device: Device,
83 ) -> Result<Self, Error>
84 where
85 M: Read + Seek,
86 T: AsRef<[u8]>,
87 {
88 let ct = gguf_file::Content::read(model)?;
89 let tokenizer = Arc::new(Tokenizer::from_bytes(tokenizer)?);
90
91 let mut gg = Gguf::new("qwen3", &ct, model, &device);
92
93 let num_attention_heads = gg.get_with_prefix("attention.head_count")?.to_u32()? as usize;
94 let num_kv_heads = gg.get_with_prefix("attention.head_count_kv")?.to_u32()? as usize;
95 let head_dim = gg.get_with_prefix("attention.key_length")?.to_u32()? as usize;
96 let num_layers = gg.get_with_prefix("block_count")?.to_u32()? as usize;
97 let hidden_size = gg.get_with_prefix("embedding_length")?.to_u32()? as usize;
98 let max_position_embeddings = gg.get_with_prefix("context_length")?.to_u32()? as usize;
99 let rms_norm_eps = gg.get_with_prefix("attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
100 let rope_freq_base = gg.get_with_prefix("rope.freq_base")?.to_f32()? as f64;
101
102 let dtype = match gg.metadata().get("general.dtype") {
103 Some(v) => match v.to_u32() {
104 Ok(0) => DType::F32,
105 Ok(1) => DType::F16,
106 _ => DType::F16,
107 },
108 None => DType::F16,
109 };
110
111 let embed_tensor = gg.tensor("token_embd.weight")?;
112 let embed_tokens = Embedding::new(embed_tensor.dequantize(&device)?, hidden_size);
113
114 let rotary = Arc::new(RotaryEmbedding::new(
115 head_dim,
116 rope_freq_base,
117 max_position_embeddings,
118 dtype,
119 &device,
120 )?);
121
122 let mut layers = Vec::with_capacity(num_layers);
123 for i in 0..num_layers {
124 layers.push(LayerWeights::new(
125 &mut gg,
126 num_attention_heads,
127 num_kv_heads,
128 head_dim,
129 rms_norm_eps,
130 rotary.clone(),
131 i,
132 )?);
133 }
134
135 let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
136
137 let lm_head_tensor = match gg.tensor("output.weight") {
138 Ok(tensor) => tensor,
139 Err(_) => gg.tensor("token_embd.weight")?,
140 };
141
142 let lm_head = QMatMul::from_weights(lm_head_tensor.into())?;
143
144 let chat_format = ChatFormat::new(tokenizer.clone())?;
145
146 Ok(Self {
147 embed_tokens,
148 layers,
149 norm,
150 lm_head,
151 device,
152 dtype,
153 chat_format,
154 tokenizer
155 })
156 }
157
158 fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut KvCache) -> CandleResult<Tensor> {
159 let (b, l) = input.dims2()?;
160 let mut h = self.embed_tokens.forward(input)?;
161
162 let causal_mask = if l == 1 {
163 None
164 } else {
165 Some(mask(b, l, offset, None, self.dtype, &self.device)?)
166 };
167
168 for (layer, cache) in self.layers.iter().zip(kv_cache.iter_mut()) {
169 h = layer.forward(&h, causal_mask.as_ref(), offset, cache)?;
170 }
171
172 let h = self.norm.forward(&h)?;
173 let last_hidden = h.narrow(1, l - 1, 1)?;
174 self.lm_head.forward(&last_hidden)?.squeeze(1)
175 }
176
177 fn layers_len(&self) -> usize {
178 self.layers.len()
179 }
180
181 fn tokenizer(&self) -> Arc<Tokenizer> {
182 self.tokenizer.clone()
183 }
184
185 fn device(&self) -> &Device {
186 &self.device
187 }
188
189 fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error> {
190 self.chat_format.fmt_prompt(prompt, role)
191 }
192
193 fn assistant_start_template(&self) -> Vec<u32> {
194 self.chat_format.assistant_start_template()
195 }
196
197 fn eos_token(&self) -> u32 {
198 self.chat_format.eos_token()
199 }
200}