1use std::path::Path;
8use std::io::{Read, Seek};
9use std::sync::Arc;
10
11use candle_core::{DType, Device as CandleDevice, Tensor};
12use candle_core::quantized::gguf_file;
13
14type Result<T> = candle_core::Result<T>;
15use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
16use candle_transformers::models::quantized_qwen3::{Gguf, RotaryEmbedding};
17
18use candle_transformers::quantized_nn::RmsNorm;
19use candle_transformers::models::with_tracing::QMatMul;
20use candle_transformers::utils::repeat_kv;
21use crate::{Device, InferenceError};
22use crate::backend::candle::to_candle_device_pub;
23
24#[derive(Debug, Clone)]
33struct MlpWeights {
34 gate_proj: QMatMul,
35 up_proj: QMatMul,
36 down_proj: QMatMul,
37 act_fn: Activation,
38}
39
40impl MlpWeights {
41 fn new<R: Read + Seek>(gg: &mut Gguf<R>, prefix: &str) -> Result<Self> {
42 Ok(Self {
43 gate_proj: gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?,
44 up_proj: gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?,
45 down_proj: gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?,
46 act_fn: Activation::Silu,
47 })
48 }
49}
50
51impl Module for MlpWeights {
52 fn forward(&self, x: &Tensor) -> Result<Tensor> {
53 let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?;
54 let up = self.up_proj.forward(x)?;
55 self.down_proj.forward(&(gate * up)?)
56 }
57}
58
59#[derive(Debug, Clone)]
60struct AttentionWeights {
61 q_proj: QMatMul,
62 k_proj: QMatMul,
63 v_proj: QMatMul,
64 o_proj: QMatMul,
65 q_norm: RmsNorm,
66 k_norm: RmsNorm,
67 num_heads: usize,
68 num_kv_heads: usize,
69 num_kv_groups: usize,
70 head_dim: usize,
71 rotary_emb: Arc<RotaryEmbedding>,
72 kv_cache: ConcatKvCache,
73}
74
75impl AttentionWeights {
76 fn new<R: Read + Seek>(
77 gg: &mut Gguf<R>,
78 num_heads: usize,
79 num_kv_heads: usize,
80 head_dim: usize,
81 rms_norm_eps: f64,
82 rotary_emb: Arc<RotaryEmbedding>,
83 prefix: &str,
84 ) -> Result<Self> {
85 Ok(Self {
86 q_proj: gg.qmatmul(&format!("{prefix}.attn_q.weight"))?,
87 k_proj: gg.qmatmul(&format!("{prefix}.attn_k.weight"))?,
88 v_proj: gg.qmatmul(&format!("{prefix}.attn_v.weight"))?,
89 o_proj: gg.qmatmul(&format!("{prefix}.attn_output.weight"))?,
90 q_norm: gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?,
91 k_norm: gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?,
92 num_heads,
93 num_kv_heads,
94 num_kv_groups: num_heads / num_kv_heads,
95 head_dim,
96 rotary_emb,
97 kv_cache: ConcatKvCache::new(2),
98 })
99 }
100
101 fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
102 let (b, l, _) = x.dims3()?;
103
104 let q = self.q_proj.forward(x)?;
105 let k = self.k_proj.forward(x)?;
106 let v = self.v_proj.forward(x)?;
107
108 let q = q.reshape((b, l, self.num_heads, self.head_dim))?.transpose(1, 2)?;
109 let k = k.reshape((b, l, self.num_kv_heads, self.head_dim))?.transpose(1, 2)?;
110 let v = v.reshape((b, l, self.num_kv_heads, self.head_dim))?.transpose(1, 2)?;
111
112 let q = self.q_norm.forward(&q.flatten(0, 2)?)?.reshape((b, self.num_heads, l, self.head_dim))?;
113 let k = self.k_norm.forward(&k.flatten(0, 2)?)?.reshape((b, self.num_kv_heads, l, self.head_dim))?;
114
115 let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
116 let (k, v) = self.kv_cache.append(&k, &v)?;
117
118 let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
119 let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
120
121 let scale = 1.0 / (self.head_dim as f64).sqrt();
122 let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
123 if let Some(m) = attn_mask {
124 let mask = if m.dtype() != scores.dtype() { m.to_dtype(scores.dtype())? } else { m.clone() };
125 scores = scores.broadcast_add(&mask)?;
126 }
127 let probs = candle_nn::ops::softmax_last_dim(&scores)?;
128 let ctx = probs.matmul(&v)?;
129 let out = ctx.transpose(1, 2)?.reshape((b, l, self.num_heads * self.head_dim))?;
130 self.o_proj.forward(&out)
131 }
132
133 fn clear_kv_cache(&mut self) {
134 self.kv_cache.reset();
135 }
136}
137
138#[derive(Debug, Clone)]
139struct LayerWeights {
140 self_attn: AttentionWeights,
141 mlp: MlpWeights,
142 ln1: RmsNorm,
143 ln2: RmsNorm,
144}
145
146impl LayerWeights {
147 fn new<R: Read + Seek>(
148 gg: &mut Gguf<R>,
149 num_heads: usize,
150 num_kv_heads: usize,
151 head_dim: usize,
152 rms_norm_eps: f64,
153 rotary: Arc<RotaryEmbedding>,
154 layer_idx: usize,
155 ) -> Result<Self> {
156 let prefix = format!("blk.{layer_idx}");
157 Ok(Self {
158 ln1: gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?,
159 ln2: gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?,
160 self_attn: AttentionWeights::new(gg, num_heads, num_kv_heads, head_dim, rms_norm_eps, rotary, &prefix)?,
161 mlp: MlpWeights::new(gg, &prefix)?,
162 })
163 }
164
165 fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
166 let h = self.self_attn.forward(&self.ln1.forward(x)?, mask, offset)?;
167 let x = (x + h)?;
168 let h2 = self.ln2.forward(&x)?;
169 let h2 = h2.apply(&self.mlp)?;
170 x + h2
171 }
172
173 fn clear_kv_cache(&mut self) {
174 self.self_attn.clear_kv_cache();
175 }
176}
177
178#[derive(Debug, Clone)]
183pub struct EmbeddingModelWeights {
184 embed_tokens: Embedding,
185 layers: Vec<LayerWeights>,
186 norm: RmsNorm,
187 device: CandleDevice,
188 dtype: DType,
189 hidden_size: usize,
190}
191
192impl EmbeddingModelWeights {
193 pub fn from_gguf<R: Read + Seek>(
195 ct: gguf_file::Content,
196 reader: &mut R,
197 device: &CandleDevice,
198 ) -> Result<Self> {
199 let mut gg = Gguf::new(ct, reader, device.clone());
200 let md_get = |s: &str| match gg.metadata().get(s) {
201 None => candle_core::bail!("cannot find {s} in metadata"),
202 Some(v) => Ok(v),
203 };
204
205 let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
206 let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
207 let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
208 let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
209 let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
210 let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
211 let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
212 let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
213
214 let dtype = match gg.metadata().get("general.dtype") {
215 Some(v) => match v.to_u32() {
216 Ok(0) => DType::F32,
217 Ok(1) => DType::F16,
218 _ => DType::F16,
219 },
220 None => DType::F16,
221 };
222
223 let embed_tensor = gg.tensor("token_embd.weight")?;
224 let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
225
226 let rotary = Arc::new(RotaryEmbedding::new(
227 dtype, head_dim, max_position_embeddings, rope_freq_base, device,
228 )?);
229
230 let mut layers = Vec::with_capacity(num_layers);
231 for i in 0..num_layers {
232 layers.push(LayerWeights::new(
233 &mut gg, num_attention_heads, num_kv_heads, head_dim,
234 rms_norm_eps, rotary.clone(), i,
235 )?);
236 }
237
238 let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
239 Ok(Self { embed_tokens, layers, norm, device: device.clone(), dtype, hidden_size })
242 }
243
244 pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
252 let (b, l) = input.dims2()?;
253 assert!(b == 1, "EmbeddingModelWeights only supports batch_size=1, got {b}");
254 let mut h = self.embed_tokens.forward(input)?;
255
256 let causal_mask = if l == 1 {
257 None
258 } else {
259 Some(self.causal_mask(b, l, offset)?)
260 };
261
262 for layer in &mut self.layers {
263 h = layer.forward(&h, causal_mask.as_ref(), offset)?;
264 }
265 let h = self.norm.forward(&h)?;
266
267 h.narrow(1, l - 1, 1)?.squeeze(1)?.squeeze(0)
270 }
271
272 pub fn clear_kv_cache(&mut self) {
273 for layer in &mut self.layers {
274 layer.clear_kv_cache();
275 }
276 }
277
278 pub fn hidden_size(&self) -> usize {
279 self.hidden_size
280 }
281
282 fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> Result<Tensor> {
283 let minf = f32::NEG_INFINITY;
284 let mask: Vec<_> = (0..tgt)
285 .flat_map(|i| {
286 (0..(tgt + offset)).map(move |j| {
287 if j <= i + offset { 0. } else { minf }
288 })
289 })
290 .collect();
291 Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
292 }
293}
294
295pub struct EmbeddingBackend {
299 pub model: EmbeddingModelWeights,
300 pub tokenizer: tokenizers::Tokenizer,
301 pub device: CandleDevice,
302}
303
304impl EmbeddingBackend {
305 pub fn load(model_dir: &Path, device: Device) -> std::result::Result<Self, InferenceError> {
307 let candle_device = to_candle_device_pub(device)?;
308
309 let model_path = model_dir.join("model.gguf");
310 let mut file = std::fs::File::open(&model_path)
311 .map_err(|e| InferenceError::InferenceFailed(format!("open embedding model: {e}")))?;
312 let gguf = gguf_file::Content::read(&mut file)
313 .map_err(|e| InferenceError::InferenceFailed(format!("read gguf: {e}")))?;
314
315 let model = EmbeddingModelWeights::from_gguf(gguf, &mut file, &candle_device)
316 .map_err(|e| InferenceError::InferenceFailed(format!("load embedding weights: {e}")))?;
317
318 let tokenizer_path = model_dir.join("tokenizer.json");
319 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
320 .map_err(|e| InferenceError::TokenizationError(format!("load tokenizer: {e}")))?;
321
322 Ok(Self { model, tokenizer, device: candle_device })
323 }
324
325 pub fn encode(&self, text: &str) -> std::result::Result<Vec<u32>, InferenceError> {
327 let encoding = self.tokenizer
328 .encode(text, true)
329 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
330 Ok(encoding.get_ids().to_vec())
331 }
332
333 pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, InferenceError> {
335 self.model.clear_kv_cache();
336
337 let tokens = self.encode(text)?;
338 if tokens.is_empty() {
339 return Ok(vec![0.0; self.model.hidden_size()]);
340 }
341
342 let input = Tensor::new(&tokens[..], &self.device)
343 .map_err(|e| InferenceError::InferenceFailed(format!("tensor: {e}")))?
344 .unsqueeze(0)
345 .map_err(|e| InferenceError::InferenceFailed(format!("unsqueeze: {e}")))?;
346
347 let hidden = self.model.forward(&input, 0)
348 .map_err(|e| InferenceError::InferenceFailed(format!("forward: {e}")))?;
349
350 let embedding: Vec<f32> = hidden
351 .to_dtype(candle_core::DType::F32)
352 .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
353 .to_vec1()
354 .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
355
356 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
358 Ok(if norm > 0.0 {
359 embedding.iter().map(|x| x / norm).collect()
360 } else {
361 embedding
362 })
363 }
364
365 pub fn embed_batch(&mut self, texts: &[String]) -> std::result::Result<Vec<Vec<f32>>, InferenceError> {
367 texts.iter().map(|t| self.embed_one(t)).collect()
368 }
369
370 pub fn embed_query(&mut self, text: &str, instruction: &str) -> std::result::Result<Vec<f32>, InferenceError> {
373 let formatted = format!("Instruct: {instruction}\nQuery: {text}");
374 self.embed_one(&formatted)
375 }
376}