car-inference 0.13.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
//! Embedding backend — loads Qwen3-Embedding GGUF, returns hidden states (not logits).
//!
//! Mirrors the quantized_qwen3 architecture but omits lm_head and returns
//! the last-token hidden state from the final norm layer, matching the
//! Qwen3-Embedding last-token pooling strategy.

use std::io::{Read, Seek};
use std::path::Path;
use std::sync::Arc;

use candle_core::quantized::gguf_file;
use candle_core::{DType, Device as CandleDevice, Tensor};

type Result<T> = candle_core::Result<T>;
use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
use candle_transformers::models::quantized_qwen3::{Gguf, RotaryEmbedding};

use crate::backend::candle::to_candle_device_pub;
use crate::{Device, InferenceError};
use candle_transformers::models::with_tracing::QMatMul;
use candle_transformers::quantized_nn::RmsNorm;
use candle_transformers::utils::repeat_kv;

// --- Layer internals (mirrored from quantized_qwen3, which keeps these private) ---
// TODO(upstream): candle-transformers quantized_qwen3 keeps MlpWeights, AttentionWeights,
// LayerWeights private. We duplicate ~150 lines because ModelWeights::forward() applies
// lm_head unconditionally and we need pre-lm_head hidden states.
// Track: if candle-transformers exposes a forward_hidden() method or makes these public,
// delete this duplication and use the upstream types directly.
// Pinned to candle-transformers 0.9.2 — test parity on upgrade.

#[derive(Debug, Clone)]
struct MlpWeights {
    gate_proj: QMatMul,
    up_proj: QMatMul,
    down_proj: QMatMul,
    act_fn: Activation,
}

impl MlpWeights {
    fn new<R: Read + Seek>(gg: &mut Gguf<R>, prefix: &str) -> Result<Self> {
        Ok(Self {
            gate_proj: gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?,
            up_proj: gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?,
            down_proj: gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?,
            act_fn: Activation::Silu,
        })
    }
}

impl Module for MlpWeights {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?;
        let up = self.up_proj.forward(x)?;
        self.down_proj.forward(&(gate * up)?)
    }
}

#[derive(Debug, Clone)]
struct AttentionWeights {
    q_proj: QMatMul,
    k_proj: QMatMul,
    v_proj: QMatMul,
    o_proj: QMatMul,
    q_norm: RmsNorm,
    k_norm: RmsNorm,
    num_heads: usize,
    num_kv_heads: usize,
    num_kv_groups: usize,
    head_dim: usize,
    rotary_emb: Arc<RotaryEmbedding>,
    kv_cache: ConcatKvCache,
}

impl AttentionWeights {
    fn new<R: Read + Seek>(
        gg: &mut Gguf<R>,
        num_heads: usize,
        num_kv_heads: usize,
        head_dim: usize,
        rms_norm_eps: f64,
        rotary_emb: Arc<RotaryEmbedding>,
        prefix: &str,
    ) -> Result<Self> {
        Ok(Self {
            q_proj: gg.qmatmul(&format!("{prefix}.attn_q.weight"))?,
            k_proj: gg.qmatmul(&format!("{prefix}.attn_k.weight"))?,
            v_proj: gg.qmatmul(&format!("{prefix}.attn_v.weight"))?,
            o_proj: gg.qmatmul(&format!("{prefix}.attn_output.weight"))?,
            q_norm: gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?,
            k_norm: gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?,
            num_heads,
            num_kv_heads,
            num_kv_groups: num_heads / num_kv_heads,
            head_dim,
            rotary_emb,
            kv_cache: ConcatKvCache::new(2),
        })
    }

    fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
        let (b, l, _) = x.dims3()?;

        let q = self.q_proj.forward(x)?;
        let k = self.k_proj.forward(x)?;
        let v = self.v_proj.forward(x)?;

        let q = q
            .reshape((b, l, self.num_heads, self.head_dim))?
            .transpose(1, 2)?;
        let k = k
            .reshape((b, l, self.num_kv_heads, self.head_dim))?
            .transpose(1, 2)?;
        let v = v
            .reshape((b, l, self.num_kv_heads, self.head_dim))?
            .transpose(1, 2)?;

        let q = self.q_norm.forward(&q.flatten(0, 2)?)?.reshape((
            b,
            self.num_heads,
            l,
            self.head_dim,
        ))?;
        let k = self.k_norm.forward(&k.flatten(0, 2)?)?.reshape((
            b,
            self.num_kv_heads,
            l,
            self.head_dim,
        ))?;

        let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
        let (k, v) = self.kv_cache.append(&k, &v)?;

        let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
        let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;

        let scale = 1.0 / (self.head_dim as f64).sqrt();
        let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
        if let Some(m) = attn_mask {
            let mask = if m.dtype() != scores.dtype() {
                m.to_dtype(scores.dtype())?
            } else {
                m.clone()
            };
            scores = scores.broadcast_add(&mask)?;
        }
        let probs = candle_nn::ops::softmax_last_dim(&scores)?;
        let ctx = probs.matmul(&v)?;
        let out = ctx
            .transpose(1, 2)?
            .reshape((b, l, self.num_heads * self.head_dim))?;
        self.o_proj.forward(&out)
    }

    fn clear_kv_cache(&mut self) {
        self.kv_cache.reset();
    }
}

#[derive(Debug, Clone)]
struct LayerWeights {
    self_attn: AttentionWeights,
    mlp: MlpWeights,
    ln1: RmsNorm,
    ln2: RmsNorm,
}

impl LayerWeights {
    fn new<R: Read + Seek>(
        gg: &mut Gguf<R>,
        num_heads: usize,
        num_kv_heads: usize,
        head_dim: usize,
        rms_norm_eps: f64,
        rotary: Arc<RotaryEmbedding>,
        layer_idx: usize,
    ) -> Result<Self> {
        let prefix = format!("blk.{layer_idx}");
        Ok(Self {
            ln1: gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?,
            ln2: gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?,
            self_attn: AttentionWeights::new(
                gg,
                num_heads,
                num_kv_heads,
                head_dim,
                rms_norm_eps,
                rotary,
                &prefix,
            )?,
            mlp: MlpWeights::new(gg, &prefix)?,
        })
    }

    fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
        let h = self
            .self_attn
            .forward(&self.ln1.forward(x)?, mask, offset)?;
        let x = (x + h)?;
        let h2 = self.ln2.forward(&x)?;
        let h2 = h2.apply(&self.mlp)?;
        x + h2
    }

    fn clear_kv_cache(&mut self) {
        self.self_attn.clear_kv_cache();
    }
}

// --- Embedding model (no lm_head) ---

/// Embedding model weights. Same as quantized_qwen3::ModelWeights but without lm_head.
/// Returns last-token hidden states for embedding.
#[derive(Debug, Clone)]
pub struct EmbeddingModelWeights {
    embed_tokens: Embedding,
    layers: Vec<LayerWeights>,
    norm: RmsNorm,
    device: CandleDevice,
    dtype: DType,
    hidden_size: usize,
}

impl EmbeddingModelWeights {
    /// Load from a GGUF file. Same format as generative Qwen3 models.
    pub fn from_gguf<R: Read + Seek>(
        ct: gguf_file::Content,
        reader: &mut R,
        device: &CandleDevice,
    ) -> Result<Self> {
        let mut gg = Gguf::new(ct, reader, device.clone());
        let md_get = |s: &str| match gg.metadata().get(s) {
            None => candle_core::bail!("cannot find {s} in metadata"),
            Some(v) => Ok(v),
        };

        let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
        let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
        let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
        let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
        let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
        let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
        let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
        let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;

        let dtype = match gg.metadata().get("general.dtype") {
            Some(v) => match v.to_u32() {
                Ok(0) => DType::F32,
                Ok(1) => DType::F16,
                _ => DType::F16,
            },
            None => DType::F16,
        };

        let embed_tensor = gg.tensor("token_embd.weight")?;
        let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);

        let rotary = Arc::new(RotaryEmbedding::new(
            dtype,
            head_dim,
            max_position_embeddings,
            rope_freq_base,
            device,
        )?);

        let mut layers = Vec::with_capacity(num_layers);
        for i in 0..num_layers {
            layers.push(LayerWeights::new(
                &mut gg,
                num_attention_heads,
                num_kv_heads,
                head_dim,
                rms_norm_eps,
                rotary.clone(),
                i,
            )?);
        }

        let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
        // Intentionally skip lm_head — we want hidden states, not logits

        Ok(Self {
            embed_tokens,
            layers,
            norm,
            device: device.clone(),
            dtype,
            hidden_size,
        })
    }

    /// Forward pass returning last-token hidden state (1024-dim for 0.6B).
    ///
    /// Uses causal attention mask (same as generative model) since
    /// Qwen3-Embedding is architecturally Qwen3ForCausalLM fine-tuned for embedding.
    /// Returns shape: (batch, hidden_size).
    /// Forward pass returning last-token hidden state.
    /// Only supports batch_size=1 (causal mask not broadcast for larger batches).
    pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
        let (b, l) = input.dims2()?;
        assert!(
            b == 1,
            "EmbeddingModelWeights only supports batch_size=1, got {b}"
        );
        let mut h = self.embed_tokens.forward(input)?;

        let causal_mask = if l == 1 {
            None
        } else {
            Some(self.causal_mask(b, l, offset)?)
        };

        for layer in &mut self.layers {
            h = layer.forward(&h, causal_mask.as_ref(), offset)?;
        }
        let h = self.norm.forward(&h)?;

        // Last-token pooling: extract hidden state at position l-1
        // narrow(1, l-1, 1) → (1, 1, hidden), squeeze(1) → (1, hidden), squeeze(0) → (hidden,)
        h.narrow(1, l - 1, 1)?.squeeze(1)?.squeeze(0)
    }

    pub fn clear_kv_cache(&mut self) {
        for layer in &mut self.layers {
            layer.clear_kv_cache();
        }
    }

    pub fn hidden_size(&self) -> usize {
        self.hidden_size
    }

    fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> Result<Tensor> {
        let minf = f32::NEG_INFINITY;
        let mask: Vec<_> = (0..tgt)
            .flat_map(|i| (0..(tgt + offset)).map(move |j| if j <= i + offset { 0. } else { minf }))
            .collect();
        Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
    }
}

// --- EmbeddingBackend wraps EmbeddingModelWeights + tokenizer ---

/// A loaded embedding model ready for inference.
pub struct EmbeddingBackend {
    pub model: EmbeddingModelWeights,
    pub tokenizer: tokenizers::Tokenizer,
    pub device: CandleDevice,
}

impl EmbeddingBackend {
    /// Load a Qwen3-Embedding GGUF model + tokenizer from a directory.
    pub fn load(model_dir: &Path, device: Device) -> std::result::Result<Self, InferenceError> {
        let candle_device = to_candle_device_pub(device)?;

        let model_path = model_dir.join("model.gguf");
        let mut file = std::fs::File::open(&model_path)
            .map_err(|e| InferenceError::InferenceFailed(format!("open embedding model: {e}")))?;
        let gguf = gguf_file::Content::read(&mut file)
            .map_err(|e| InferenceError::InferenceFailed(format!("read gguf: {e}")))?;

        let model = EmbeddingModelWeights::from_gguf(gguf, &mut file, &candle_device)
            .map_err(|e| InferenceError::InferenceFailed(format!("load embedding weights: {e}")))?;

        let tokenizer_path = model_dir.join("tokenizer.json");
        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
            .map_err(|e| InferenceError::TokenizationError(format!("load tokenizer: {e}")))?;

        Ok(Self {
            model,
            tokenizer,
            device: candle_device,
        })
    }

    /// Encode text to token IDs.
    pub fn encode(&self, text: &str) -> std::result::Result<Vec<u32>, InferenceError> {
        let encoding = self
            .tokenizer
            .encode(text, true)
            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
        Ok(encoding.get_ids().to_vec())
    }

    /// Embed a single text. Returns L2-normalized hidden state vector.
    pub fn embed_one(&mut self, text: &str) -> std::result::Result<Vec<f32>, InferenceError> {
        self.model.clear_kv_cache();

        let tokens = self.encode(text)?;
        if tokens.is_empty() {
            return Ok(vec![0.0; self.model.hidden_size()]);
        }

        let input = Tensor::new(&tokens[..], &self.device)
            .map_err(|e| InferenceError::InferenceFailed(format!("tensor: {e}")))?
            .unsqueeze(0)
            .map_err(|e| InferenceError::InferenceFailed(format!("unsqueeze: {e}")))?;

        let hidden = self
            .model
            .forward(&input, 0)
            .map_err(|e| InferenceError::InferenceFailed(format!("forward: {e}")))?;

        let embedding: Vec<f32> = hidden
            .to_dtype(candle_core::DType::F32)
            .map_err(|e| InferenceError::InferenceFailed(format!("dtype: {e}")))?
            .to_vec1()
            .map_err(|e| InferenceError::InferenceFailed(format!("to_vec: {e}")))?;

        Ok(car_ir::linalg::l2_normalize(&embedding))
    }

    /// Embed multiple texts. Processes sequentially (no batch padding needed).
    pub fn embed_batch(
        &mut self,
        texts: &[String],
    ) -> std::result::Result<Vec<Vec<f32>>, InferenceError> {
        texts.iter().map(|t| self.embed_one(t)).collect()
    }

    /// Embed with instruction prefix (for queries).
    /// Format: "Instruct: {instruction}\nQuery:{text}"
    pub fn embed_query(
        &mut self,
        text: &str,
        instruction: &str,
    ) -> std::result::Result<Vec<f32>, InferenceError> {
        let formatted = format!("Instruct: {instruction}\nQuery: {text}");
        self.embed_one(&formatted)
    }
}