Skip to main content

car_inference/backend/
embedding.rs

1//! Embedding backend — loads Qwen3-Embedding GGUF, returns hidden states (not logits).
2//!
3//! Mirrors the quantized_qwen3 architecture but omits lm_head and returns
4//! the last-token hidden state from the final norm layer, matching the
5//! Qwen3-Embedding last-token pooling strategy.
6
7use std::path::Path;
8use std::io::{Read, Seek};
9use std::sync::Arc;
10
11use candle_core::{DType, Device as CandleDevice, Tensor};
12use candle_core::quantized::gguf_file;
13
14type Result<T> = candle_core::Result<T>;
15use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
16use candle_transformers::models::quantized_qwen3::{Gguf, RotaryEmbedding};
17
18use candle_transformers::quantized_nn::RmsNorm;
19use candle_transformers::models::with_tracing::QMatMul;
20use candle_transformers::utils::repeat_kv;
21use crate::{Device, InferenceError};
22use crate::backend::candle::to_candle_device_pub;
23
24// --- Layer internals (mirrored from quantized_qwen3, which keeps these private) ---
25// TODO(upstream): candle-transformers quantized_qwen3 keeps MlpWeights, AttentionWeights,
26// LayerWeights private. We duplicate ~150 lines because ModelWeights::forward() applies
27// lm_head unconditionally and we need pre-lm_head hidden states.
28// Track: if candle-transformers exposes a forward_hidden() method or makes these public,
29// delete this duplication and use the upstream types directly.
30// Pinned to candle-transformers 0.9.2 — test parity on upgrade.
31
32#[derive(Debug, Clone)]
33struct MlpWeights {
34    gate_proj: QMatMul,
35    up_proj: QMatMul,
36    down_proj: QMatMul,
37    act_fn: Activation,
38}
39
40impl MlpWeights {
41    fn new<R: Read + Seek>(gg: &mut Gguf<R>, prefix: &str) -> Result<Self> {
42        Ok(Self {
43            gate_proj: gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?,
44            up_proj: gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?,
45            down_proj: gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?,
46            act_fn: Activation::Silu,
47        })
48    }
49}
50
51impl Module for MlpWeights {
52    fn forward(&self, x: &Tensor) -> Result<Tensor> {
53        let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?;
54        let up = self.up_proj.forward(x)?;
55        self.down_proj.forward(&(gate * up)?)
56    }
57}
58
59#[derive(Debug, Clone)]
60struct AttentionWeights {
61    q_proj: QMatMul,
62    k_proj: QMatMul,
63    v_proj: QMatMul,
64    o_proj: QMatMul,
65    q_norm: RmsNorm,
66    k_norm: RmsNorm,
67    num_heads: usize,
68    num_kv_heads: usize,
69    num_kv_groups: usize,
70    head_dim: usize,
71    rotary_emb: Arc<RotaryEmbedding>,
72    kv_cache: ConcatKvCache,
73}
74
75impl AttentionWeights {
76    fn new<R: Read + Seek>(
77        gg: &mut Gguf<R>,
78        num_heads: usize,
79        num_kv_heads: usize,
80        head_dim: usize,
81        rms_norm_eps: f64,
82        rotary_emb: Arc<RotaryEmbedding>,
83        prefix: &str,
84    ) -> Result<Self> {
85        Ok(Self {
86            q_proj: gg.qmatmul(&format!("{prefix}.attn_q.weight"))?,
87            k_proj: gg.qmatmul(&format!("{prefix}.attn_k.weight"))?,
88            v_proj: gg.qmatmul(&format!("{prefix}.attn_v.weight"))?,
89            o_proj: gg.qmatmul(&format!("{prefix}.attn_output.weight"))?,
90            q_norm: gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?,
91            k_norm: gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?,
92            num_heads,
93            num_kv_heads,
94            num_kv_groups: num_heads / num_kv_heads,
95            head_dim,
96            rotary_emb,
97            kv_cache: ConcatKvCache::new(2),
98        })
99    }
100
101    fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
102        let (b, l, _) = x.dims3()?;
103
104        let q = self.q_proj.forward(x)?;
105        let k = self.k_proj.forward(x)?;
106        let v = self.v_proj.forward(x)?;
107
108        let q = q.reshape((b, l, self.num_heads, self.head_dim))?.transpose(1, 2)?;
109        let k = k.reshape((b, l, self.num_kv_heads, self.head_dim))?.transpose(1, 2)?;
110        let v = v.reshape((b, l, self.num_kv_heads, self.head_dim))?.transpose(1, 2)?;
111
112        let q = self.q_norm.forward(&q.flatten(0, 2)?)?.reshape((b, self.num_heads, l, self.head_dim))?;
113        let k = self.k_norm.forward(&k.flatten(0, 2)?)?.reshape((b, self.num_kv_heads, l, self.head_dim))?;
114
115        let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
116        let (k, v) = self.kv_cache.append(&k, &v)?;
117
118        let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
119        let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
120
121        let scale = 1.0 / (self.head_dim as f64).sqrt();
122        let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
123        if let Some(m) = attn_mask {
124            let mask = if m.dtype() != scores.dtype() { m.to_dtype(scores.dtype())? } else { m.clone() };
125            scores = scores.broadcast_add(&mask)?;
126        }
127        let probs = candle_nn::ops::softmax_last_dim(&scores)?;
128        let ctx = probs.matmul(&v)?;
129        let out = ctx.transpose(1, 2)?.reshape((b, l, self.num_heads * self.head_dim))?;
130        self.o_proj.forward(&out)
131    }
132
133    fn clear_kv_cache(&mut self) {
134        self.kv_cache.reset();
135    }
136}
137
138#[derive(Debug, Clone)]
139struct LayerWeights {
140    self_attn: AttentionWeights,
141    mlp: MlpWeights,
142    ln1: RmsNorm,
143    ln2: RmsNorm,
144}
145
146impl LayerWeights {
147    fn new<R: Read + Seek>(
148        gg: &mut Gguf<R>,
149        num_heads: usize,
150        num_kv_heads: usize,
151        head_dim: usize,
152        rms_norm_eps: f64,
153        rotary: Arc<RotaryEmbedding>,
154        layer_idx: usize,
155    ) -> Result<Self> {
156        let prefix = format!("blk.{layer_idx}");
157        Ok(Self {
158            ln1: gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?,
159            ln2: gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?,
160            self_attn: AttentionWeights::new(gg, num_heads, num_kv_heads, head_dim, rms_norm_eps, rotary, &prefix)?,
161            mlp: MlpWeights::new(gg, &prefix)?,
162        })
163    }
164
165    fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
166        let h = self.self_attn.forward(&self.ln1.forward(x)?, mask, offset)?;
167        let x = (x + h)?;
168        let h2 = self.ln2.forward(&x)?;
169        let h2 = h2.apply(&self.mlp)?;
170        x + h2
171    }
172
173    fn clear_kv_cache(&mut self) {
174        self.self_attn.clear_kv_cache();
175    }
176}
177
178// --- Embedding model (no lm_head) ---
179
180/// Embedding model weights. Same as quantized_qwen3::ModelWeights but without lm_head.
181/// Returns last-token hidden states for embedding.
182#[derive(Debug, Clone)]
183pub struct EmbeddingModelWeights {
184    embed_tokens: Embedding,
185    layers: Vec<LayerWeights>,
186    norm: RmsNorm,
187    device: CandleDevice,
188    dtype: DType,
189    hidden_size: usize,
190}
191
192impl EmbeddingModelWeights {
193    /// Load from a GGUF file. Same format as generative Qwen3 models.
194    pub fn from_gguf<R: Read + Seek>(
195        ct: gguf_file::Content,
196        reader: &mut R,
197        device: &CandleDevice,
198    ) -> Result<Self> {
199        let mut gg = Gguf::new(ct, reader, device.clone());
200        let md_get = |s: &str| match gg.metadata().get(s) {
201            None => candle_core::bail!("cannot find {s} in metadata"),
202            Some(v) => Ok(v),
203        };
204
205        let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
206        let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
207        let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
208        let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
209        let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
210        let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
211        let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
212        let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
213
214        let dtype = match gg.metadata().get("general.dtype") {
215            Some(v) => match v.to_u32() {
216                Ok(0) => DType::F32,
217                Ok(1) => DType::F16,
218                _ => DType::F16,
219            },
220            None => DType::F16,
221        };
222
223        let embed_tensor = gg.tensor("token_embd.weight")?;
224        let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
225
226        let rotary = Arc::new(RotaryEmbedding::new(
227            dtype, head_dim, max_position_embeddings, rope_freq_base, device,
228        )?);
229
230        let mut layers = Vec::with_capacity(num_layers);
231        for i in 0..num_layers {
232            layers.push(LayerWeights::new(
233                &mut gg, num_attention_heads, num_kv_heads, head_dim,
234                rms_norm_eps, rotary.clone(), i,
235            )?);
236        }
237
238        let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
239        // Intentionally skip lm_head — we want hidden states, not logits
240
241        Ok(Self { embed_tokens, layers, norm, device: device.clone(), dtype, hidden_size })
242    }
243
244    /// Forward pass returning last-token hidden state (1024-dim for 0.6B).
245    ///
246    /// Uses causal attention mask (same as generative model) since
247    /// Qwen3-Embedding is architecturally Qwen3ForCausalLM fine-tuned for embedding.
248    /// Returns shape: (batch, hidden_size).
249    /// Forward pass returning last-token hidden state.
250    /// Only supports batch_size=1 (causal mask not broadcast for larger batches).
251    pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
252        let (b, l) = input.dims2()?;
253        assert!(b == 1, "EmbeddingModelWeights only supports batch_size=1, got {b}");
254        let mut h = self.embed_tokens.forward(input)?;
255
256        let causal_mask = if l == 1 {
257            None
258        } else {
259            Some(self.causal_mask(b, l, offset)?)
260        };
261
262        for layer in &mut self.layers {
263            h = layer.forward(&h, causal_mask.as_ref(), offset)?;
264        }
265        let h = self.norm.forward(&h)?;
266
267        // Last-token pooling: extract hidden state at position l-1
268        // narrow(1, l-1, 1) → (1, 1, hidden), squeeze(1) → (1, hidden), squeeze(0) → (hidden,)
269        h.narrow(1, l - 1, 1)?.squeeze(1)?.squeeze(0)
270    }
271
272    pub fn clear_kv_cache(&mut self) {
273        for layer in &mut self.layers {
274            layer.clear_kv_cache();
275        }
276    }
277
278    pub fn hidden_size(&self) -> usize {
279        self.hidden_size
280    }
281
282    fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> Result<Tensor> {
283        let minf = f32::NEG_INFINITY;
284        let mask: Vec<_> = (0..tgt)
285            .flat_map(|i| {
286                (0..(tgt + offset)).map(move |j| {
287                    if j <= i + offset { 0. } else { minf }
288                })
289            })
290            .collect();
291        Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
292    }
293}
294
295// --- EmbeddingBackend wraps EmbeddingModelWeights + tokenizer ---
296
297/// A loaded embedding model ready for inference.
298pub struct EmbeddingBackend {
299    pub model: EmbeddingModelWeights,
300    pub tokenizer: tokenizers::Tokenizer,
301    pub device: CandleDevice,
302}
303
304impl EmbeddingBackend {
305    /// Load a Qwen3-Embedding GGUF model + tokenizer from a directory.
306    pub fn load(model_dir: &Path, device: Device) -> std::result::Result<Self, InferenceError> {
307        let candle_device = to_candle_device_pub(device)?;
308
309        let model_path = model_dir.join("model.gguf");
310        let mut file = std::fs::File::open(&model_path)
311            .map_err(|e| InferenceError::InferenceFailed(format!("open embedding model: {e}")))?;
312        let gguf = gguf_file::Content::read(&mut file)
313            .map_err(|e| InferenceError::InferenceFailed(format!("read gguf: {e}")))?;
314
315        let model = EmbeddingModelWeights::from_gguf(gguf, &mut file, &candle_device)
316            .map_err(|e| InferenceError::InferenceFailed(format!("load embedding weights: {e}")))?;
317
318        let tokenizer_path = model_dir.join("tokenizer.json");
319        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
320            .map_err(|e| InferenceError::TokenizationError(format!("load tokenizer: {e}")))?;
321
322        Ok(Self { model, tokenizer, device: candle_device })
323    }
324
325    /// Encode text to token IDs.
326    pub fn encode(&self, text: &str) -> std::result::Result<Vec<u32>, InferenceError> {
327        let encoding = self.tokenizer
328            .encode(text, true)
329            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
330        Ok(encoding.get_ids().to_vec())
331    }
332
333    /// Embed a single text. Returns L2-normalized hidden state vector.
334    pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, InferenceError> {
335        self.model.clear_kv_cache();
336
337        let tokens = self.encode(text)?;
338        if tokens.is_empty() {
339            return Ok(vec![0.0; self.model.hidden_size()]);
340        }
341
342        let input = Tensor::new(&tokens[..], &self.device)
343            .map_err(|e| InferenceError::InferenceFailed(format!("tensor: {e}")))?
344            .unsqueeze(0)
345            .map_err(|e| InferenceError::InferenceFailed(format!("unsqueeze: {e}")))?;
346
347        let hidden = self.model.forward(&input, 0)
348            .map_err(|e| InferenceError::InferenceFailed(format!("forward: {e}")))?;
349
350        let embedding: Vec<f32> = hidden
351            .to_dtype(candle_core::DType::F32)
352            .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
353            .to_vec1()
354            .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;
355
356        // L2 normalize
357        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
358        Ok(if norm > 0.0 {
359            embedding.iter().map(|x| x / norm).collect()
360        } else {
361            embedding
362        })
363    }
364
365    /// Embed multiple texts. Processes sequentially (no batch padding needed).
366    pub fn embed_batch(&mut self, texts: &[String]) -> std::result::Result<Vec<Vec<f32>>, InferenceError> {
367        texts.iter().map(|t| self.embed_one(t)).collect()
368    }
369
370    /// Embed with instruction prefix (for queries).
371    /// Format: "Instruct: {instruction}\nQuery:{text}"
372    pub fn embed_query(&mut self, text: &str, instruction: &str) -> std::result::Result<Vec<f32>, InferenceError> {
373        let formatted = format!("Instruct: {instruction}\nQuery: {text}");
374        self.embed_one(&formatted)
375    }
376}