hot_loop/models/qwen3/
qwen3.rs1use candle_transformers::models::with_tracing::QMatMul;
2use candle_transformers::{quantized_nn::RmsNorm};
3use candle_core::quantized::{gguf_file};
4use candle_core::{DType, Device, Result as CandleResult, Tensor};
5use candle_nn::{Embedding, Module};
6use std::io::{Read, Seek};
7use std::sync::Arc;
8use crate::{ModelWeights, KvCache, Error, Role};
9use tokenizers::Tokenizer;
10use super::ChatTemplate;
11
12use super::{
13 transformers::{
14 LayerWeights,
15 Gguf,
16 RotaryEmbedding
17 }
18};
19
20#[derive(Clone)]
21pub struct Qwen3 {
22 embed_tokens: Embedding,
23 layers: Vec<LayerWeights>,
24 norm: RmsNorm,
25 lm_head: QMatMul,
26 device: Device,
27 dtype: DType,
28 chat_template: ChatTemplate
29}
30
31impl Qwen3 {
32 pub fn load<M, T>(
33 model: &mut M,
34 tokenizer: T,
35 device: &Device,
36 ) -> Result<Self, Error>
37 where
38 M: Read + Seek,
39 T: AsRef<[u8]>,
40 {
41 let ct = gguf_file::Content::read(model)?;
42
43 let tokenizer = Tokenizer::from_bytes(tokenizer)?;
44
45 let mut gg = Gguf::new(ct, model, device.clone());
46 let md_get = |s: &str| match gg.metadata().get(s) {
47 None => candle_core::bail!("cannot find {s} in metadata"),
48 Some(v) => Ok(v),
49 };
50
51 let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
52 let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
53 let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
54 let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
55 let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
56 let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
57 let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
58 let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
59
60 let dtype = match gg.metadata().get("general.dtype") {
61 Some(v) => match v.to_u32() {
62 Ok(0) => DType::F32,
63 Ok(1) => DType::F16,
64 _ => DType::F16,
65 },
66 None => DType::F16,
67 };
68
69 let embed_tensor = gg.tensor("token_embd.weight")?;
70 let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
71
72 let rotary = Arc::new(RotaryEmbedding::new(
73 dtype,
74 head_dim,
75 max_position_embeddings,
76 rope_freq_base,
77 device,
78 )?);
79
80 let mut layers = Vec::with_capacity(num_layers);
81 for i in 0..num_layers {
82 layers.push(LayerWeights::new(
83 &mut gg,
84 num_attention_heads,
85 num_kv_heads,
86 head_dim,
87 rms_norm_eps,
88 rotary.clone(),
89 i,
90 )?);
91 }
92
93 let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
94 let lm_head_tensor = match gg.tensor("output.weight") {
96 Ok(tensor) => tensor,
97 Err(_) => gg.tensor("token_embd.weight")?,
98 };
99 let lm_head = QMatMul::from_weights(lm_head_tensor.into())?;
100
101 let chat_template = ChatTemplate::new(tokenizer)?;
102
103 Ok(Self {
104 embed_tokens,
105 layers,
106 norm,
107 lm_head,
108 device: device.clone(),
109 dtype,
110 chat_template
111 })
112 }
113
114 fn causal_mask(
115 &self,
116 b: usize,
117 tgt: usize,
118 offset: usize,
119 sw: Option<usize>,
120 ) -> CandleResult<Tensor> {
121 let minf = f32::NEG_INFINITY;
122 let mask: Vec<_> = (0..tgt)
123 .flat_map(|i| {
124 (0..(tgt + offset)).map(move |j| {
125 let past_ok = j <= i + offset;
126 let sw_ok = match sw {
127 Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
128 None => true,
129 };
130 if past_ok && sw_ok {
131 0.
132 } else {
133 minf
134 }
135 })
136 })
137 .collect();
138 Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
139 }
140}
141
142impl ModelWeights for Qwen3 {
143 fn forward(&self, input: &Tensor, offset: usize, kv_cache: &mut Vec<KvCache>) -> CandleResult<Tensor> {
144 let (b, l) = input.dims2()?;
145 let mut h = self.embed_tokens.forward(input)?;
146 let causal_mask = if l == 1 {
147 None
148 } else {
149 Some(self.causal_mask(b, l, offset, None)?)
150 };
151
152 for (layer, cache) in self.layers.iter().zip(kv_cache.iter_mut()) {
153 h = layer.forward(&h, causal_mask.as_ref(), offset, cache)?;
154 }
155
156 let h = self.norm.forward(&h)?;
157 let last_hidden = h.narrow(1, l - 1, 1)?;
158 self.lm_head.forward(&last_hidden)?.squeeze(1)
159 }
160
161 fn create_kv_cache(&self) -> Vec<KvCache> {
162 let mut kv_cache = Vec::with_capacity(self.layers.len());
163
164 for _ in 0..self.layers.len() {
165 kv_cache.push(KvCache::new(2));
166 }
167
168 kv_cache
169 }
170
171 fn tokenizer(&self) -> &Tokenizer {
172 &self.chat_template.tokenizer()
173 }
174
175 fn current_device(&self) -> &Device {
176 &self.device
177 }
178
179 fn fmt_prompt(&self, prompt: &str, role: Role) -> Result<Vec<u32>, Error> {
180 self.chat_template.fmt_prompt(prompt, role)
181 }
182
183 fn assistant_start_template(&self) -> Vec<u32> {
184 self.chat_template.assistant_start_template()
185 }
186
187 fn eos_token(&self) -> u32 {
188 self.chat_template.eos_token()
189 }
190}