car-inference 0.13.0

Local model inference for CAR — Candle backend with Qwen3 models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
//! Qwen3-MoE model with naive (non-fused) expert routing.
//!
//! Candle's `quantized_qwen3_moe::GGUFQWenMoE` uses `FusedMoeGGUF` which requires
//! a custom CUDA kernel. This module provides the same model architecture but with
//! a naive MoE forward pass using standard tensor ops — works on Metal, CPU, and CUDA.
//!
//! Architecture: identical to Qwen3 but with MoE MLP layers where each layer has
//! N experts (gate/up/down projections) with top-K routing.

use candle_core::quantized::gguf_file;
use candle_core::{DType, Device as CandleDevice, Result, Tensor, D};
use candle_nn::kv_cache::ConcatKvCache;
use candle_nn::{Embedding, Linear, Module};
use candle_transformers::models::quantized_qwen3::{Gguf, RotaryEmbedding};
use candle_transformers::models::with_tracing::QMatMul;
use candle_transformers::quantized_nn::RmsNorm;
use candle_transformers::utils::repeat_kv;
use std::sync::Arc;

#[allow(unused_imports)]
use candle_core::quantized::QTensor;

// --- MoE MLP ---

struct Mlp {
    gate: QMatMul,
    up: QMatMul,
    down: QMatMul,
}

impl Module for Mlp {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let gate = self.gate.forward(xs)?;
        let up = self.up.forward(xs)?;
        self.down.forward(&(candle_nn::ops::silu(&gate)? * up)?)
    }
}

/// Naive MoE layer — routes to top-K experts using standard tensor ops.
///
/// Expert weights are dequantized once (first forward call) and cached.
/// Per-token, only the selected top-K expert slices are used for matmuls.
struct NaiveMoe {
    gate: Linear,
    gate_experts: Arc<QTensor>,
    up_experts: Arc<QTensor>,
    down_experts: Arc<QTensor>,
    /// Cached dequantized expert weights — filled on first forward call.
    /// Stored as F16 to save memory (half the size of F32).
    cache: Option<(Tensor, Tensor, Tensor)>,
    num_experts: usize,
    num_experts_per_tok: usize,
    norm_topk_prob: bool,
}

impl NaiveMoe {
    /// Ensure dequantized weights are cached. Only runs once per layer.
    fn ensure_cache(&mut self, device: &CandleDevice) -> Result<()> {
        if self.cache.is_some() {
            return Ok(());
        }
        // Dequantize to F16 to save memory (vs F32)
        let gate_w = self.gate_experts.dequantize(device)?.to_dtype(DType::F16)?;
        let up_w = self.up_experts.dequantize(device)?.to_dtype(DType::F16)?;
        let down_w = self.down_experts.dequantize(device)?.to_dtype(DType::F16)?;
        self.cache = Some((gate_w, up_w, down_w));
        Ok(())
    }

    fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
        let (num_tokens, hidden_dim) = xs.dims2()?;
        let device = xs.device();

        // Route: softmax over gating logits, select top-K
        let router_logits = self.gate.forward(xs)?;
        let routing_weights =
            candle_nn::ops::softmax_last_dim(&router_logits.to_dtype(DType::F32)?)?;

        let topk_ids = routing_weights
            .arg_sort_last_dim(false)?
            .narrow(D::Minus1, 0, self.num_experts_per_tok)?
            .contiguous()?;

        let mut topk_weights = routing_weights.gather(&topk_ids, D::Minus1)?;
        if self.norm_topk_prob {
            topk_weights = topk_weights.broadcast_div(&topk_weights.sum_keepdim(D::Minus1)?)?;
        }

        let topk_ids_vec: Vec<u32> = topk_ids.flatten_all()?.to_vec1()?;
        let topk_weights_vec: Vec<f32> = topk_weights.flatten_all()?.to_vec1()?;

        // Use cached dequantized weights
        self.ensure_cache(device)?;
        let (gate_w, up_w, down_w) = self.cache.as_ref().unwrap();

        // For each token, gather selected expert weights and batch compute
        let mut outputs = Vec::with_capacity(num_tokens);
        for t in 0..num_tokens {
            let token = xs.narrow(0, t, 1)?.contiguous()?.to_dtype(DType::F16)?;

            // Collect selected expert IDs and weights for this token
            let mut expert_ids = Vec::with_capacity(self.num_experts_per_tok);
            let mut weights = Vec::with_capacity(self.num_experts_per_tok);
            for k in 0..self.num_experts_per_tok {
                let idx = t * self.num_experts_per_tok + k;
                let eid = topk_ids_vec[idx] as usize;
                if eid < self.num_experts {
                    expert_ids.push(eid);
                    weights.push(topk_weights_vec[idx]);
                }
            }

            if expert_ids.is_empty() {
                outputs.push(Tensor::zeros((1, hidden_dim), DType::F32, device)?);
                continue;
            }

            // Gather selected expert gate/up weights into [K, intermediate, hidden]
            let gate_selected: Vec<Tensor> = expert_ids
                .iter()
                .map(|&eid| gate_w.narrow(0, eid, 1))
                .collect::<Result<Vec<_>>>()?;
            let gate_batch = Tensor::cat(&gate_selected, 0)?; // [K, intermediate, hidden]

            let up_selected: Vec<Tensor> = expert_ids
                .iter()
                .map(|&eid| up_w.narrow(0, eid, 1))
                .collect::<Result<Vec<_>>>()?;
            let up_batch = Tensor::cat(&up_selected, 0)?; // [K, intermediate, hidden]

            // Broadcast token across K experts: [1, hidden] → [K, 1, hidden]
            let k = expert_ids.len();
            let token_k = token.unsqueeze(0)?.expand((k, 1, hidden_dim))?; // [K, 1, hidden]

            // Batched matmul: [K, 1, hidden] @ [K, hidden, intermediate] = [K, 1, intermediate]
            let gate_t = gate_batch.transpose(1, 2)?; // [K, hidden, intermediate]
            let gate_out = token_k.matmul(&gate_t)?; // [K, 1, intermediate]

            let up_t = up_batch.transpose(1, 2)?;
            let up_out = token_k.matmul(&up_t)?; // [K, 1, intermediate]

            let activated = candle_nn::ops::silu(&gate_out)?.mul(&up_out)?; // [K, 1, intermediate]

            // Down projection: gather selected down experts
            let down_selected: Vec<Tensor> = expert_ids
                .iter()
                .map(|&eid| down_w.narrow(0, eid, 1))
                .collect::<Result<Vec<_>>>()?;
            let down_batch = Tensor::cat(&down_selected, 0)?; // [K, hidden, intermediate]
            let down_t = down_batch.transpose(1, 2)?; // [K, intermediate, hidden]
            let expert_outs = activated.matmul(&down_t)?; // [K, 1, hidden]

            // Weight and sum expert outputs
            let expert_outs_f32 = expert_outs.to_dtype(DType::F32)?;
            let weights_t = Tensor::from_slice(&weights, (k, 1, 1), device)?;
            let weighted = expert_outs_f32.broadcast_mul(&weights_t)?; // [K, 1, hidden]
            let combined = weighted.sum(0)?; // [1, hidden]

            outputs.push(combined);
        }

        Tensor::cat(&outputs, 0)
    }
}

enum MoeOrMlp {
    Moe(NaiveMoe),
    Mlp(Mlp),
}

impl MoeOrMlp {
    fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
        match self {
            Self::Mlp(m) => m.forward(xs),
            Self::Moe(m) => m.forward(xs),
        }
    }
}

// --- Attention (same as Qwen3) ---

struct Attention {
    q_proj: QMatMul,
    k_proj: QMatMul,
    v_proj: QMatMul,
    o_proj: QMatMul,
    q_norm: RmsNorm,
    k_norm: RmsNorm,
    num_heads: usize,
    num_kv_heads: usize,
    head_dim: usize,
    rotary: Arc<RotaryEmbedding>,
    kv_cache: ConcatKvCache,
}

impl Attention {
    fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
        let (b, l, _) = x.dims3()?;

        let q = self.q_proj.forward(x)?;
        let k = self.k_proj.forward(x)?;
        let v = self.v_proj.forward(x)?;

        let q = q
            .reshape((b, l, self.num_heads, self.head_dim))?
            .transpose(1, 2)?
            .contiguous()?;
        let k = k
            .reshape((b, l, self.num_kv_heads, self.head_dim))?
            .transpose(1, 2)?
            .contiguous()?;
        let v = v
            .reshape((b, l, self.num_kv_heads, self.head_dim))?
            .transpose(1, 2)?
            .contiguous()?;

        let q = self.q_norm.forward(&q)?;
        let k = self.k_norm.forward(&k)?;

        let (q, k) = self.rotary.apply(&q, &k, offset)?;

        let (k, v) = self.kv_cache.append(&k, &v)?;

        let k = repeat_kv(k, self.num_heads / self.num_kv_heads)?;
        let v = repeat_kv(v, self.num_heads / self.num_kv_heads)?;

        let scale = 1.0 / (self.head_dim as f64).sqrt();
        let attn = (q.matmul(&k.t()?)? * scale)?;
        let attn = match mask {
            Some(m) => attn.broadcast_add(m)?,
            None => attn,
        };
        let attn = candle_nn::ops::softmax_last_dim(&attn)?;
        let out = attn.matmul(&v)?;

        let out = out.transpose(1, 2)?.reshape((b, l, ()))?;
        self.o_proj.forward(&out)
    }
}

// --- Layer ---

struct Layer {
    attn: Attention,
    attn_norm: RmsNorm,
    mlp: MoeOrMlp,
    ffn_norm: RmsNorm,
}

// --- Full Model ---

pub struct Qwen3MoeModel {
    embeddings: Embedding,
    layers: Vec<Layer>,
    norm: RmsNorm,
    output: QMatMul,
    dtype: DType,
    device: CandleDevice,
}

impl Qwen3MoeModel {
    pub fn from_gguf<R: std::io::Seek + std::io::Read>(
        ct: gguf_file::Content,
        reader: &mut R,
        device: &CandleDevice,
    ) -> Result<Self> {
        let dtype = DType::F32;
        let mut gg = Gguf::new(ct, reader, device.clone());

        let arch = "qwen3moe";
        let metadata = gg.metadata().clone();
        let md_get = |key: &str| -> Result<gguf_file::Value> {
            metadata
                .get(key)
                .cloned()
                .ok_or_else(|| candle_core::Error::Msg(format!("cannot find {key} in metadata")))
        };

        let head_count = md_get(&format!("{arch}.attention.head_count"))?.to_u32()? as usize;
        let head_count_kv = md_get(&format!("{arch}.attention.head_count_kv"))?.to_u32()? as usize;
        let head_dim = md_get(&format!("{arch}.attention.key_length"))?.to_u32()? as usize;
        let block_count = md_get(&format!("{arch}.block_count"))?.to_u32()? as usize;
        let embedding_length = md_get(&format!("{arch}.embedding_length"))?.to_u32()? as usize;
        let rms_norm_eps =
            md_get(&format!("{arch}.attention.layer_norm_rms_epsilon"))?.to_f32()? as f64;
        let context_length = md_get(&format!("{arch}.context_length"))?.to_u32()? as usize;
        let rope_freq_base = md_get(&format!("{arch}.rope.freq_base"))?.to_f32()? as f64;
        let num_experts = md_get(&format!("{arch}.expert_count"))?.to_u32()? as usize;
        let num_experts_per_tok = md_get(&format!("{arch}.expert_used_count"))?.to_u32()? as usize;

        // Token embeddings
        let tok_embd = gg.tensor("token_embd.weight")?.dequantize(device)?;
        let embeddings = Embedding::new(tok_embd, embedding_length);

        // Output head
        let output = match gg.qmatmul("output.weight") {
            Ok(v) => v,
            _ => gg.qmatmul("token_embd.weight")?,
        };

        let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;

        let rotary = Arc::new(RotaryEmbedding::new(
            dtype,
            head_dim,
            context_length,
            rope_freq_base,
            device,
        )?);

        let mut layers = Vec::with_capacity(block_count);
        for i in 0..block_count {
            let pfx = format!("blk.{i}");

            // Attention
            let q_proj = gg.qmatmul(&format!("{pfx}.attn_q.weight"))?;
            let k_proj = gg.qmatmul(&format!("{pfx}.attn_k.weight"))?;
            let v_proj = gg.qmatmul(&format!("{pfx}.attn_v.weight"))?;
            let o_proj = gg.qmatmul(&format!("{pfx}.attn_output.weight"))?;
            let q_norm = gg.rms_norm(&format!("{pfx}.attn_q_norm.weight"), rms_norm_eps)?;
            let k_norm = gg.rms_norm(&format!("{pfx}.attn_k_norm.weight"), rms_norm_eps)?;

            let attn = Attention {
                q_proj,
                k_proj,
                v_proj,
                o_proj,
                q_norm,
                k_norm,
                num_heads: head_count,
                num_kv_heads: head_count_kv,
                head_dim,
                rotary: rotary.clone(),
                kv_cache: ConcatKvCache::new(2),
            };

            let attn_norm = gg.rms_norm(&format!("{pfx}.attn_norm.weight"), rms_norm_eps)?;
            let ffn_norm = gg.rms_norm(&format!("{pfx}.ffn_norm.weight"), rms_norm_eps)?;

            // MLP: MoE or dense
            let mlp = if num_experts > 0 {
                // Gate router
                let gate_ws = gg
                    .tensor(&format!("{pfx}.ffn_gate_inp.weight"))?
                    .dequantize(device)?
                    .to_dtype(DType::F32)?;
                let gate = Linear::new(gate_ws, None);

                // Packed expert weights as QTensor (stacked: [num_experts, size, hidden])
                let gate_experts = Arc::new(gg.tensor(&format!("{pfx}.ffn_gate_exps.weight"))?);
                let up_experts = Arc::new(gg.tensor(&format!("{pfx}.ffn_up_exps.weight"))?);
                let down_experts = Arc::new(gg.tensor(&format!("{pfx}.ffn_down_exps.weight"))?);

                MoeOrMlp::Moe(NaiveMoe {
                    gate,
                    gate_experts,
                    up_experts,
                    down_experts,
                    cache: None,
                    num_experts,
                    num_experts_per_tok,
                    norm_topk_prob: true,
                })
            } else {
                let gate = gg.qmatmul(&format!("{pfx}.ffn_gate.weight"))?;
                let up = gg.qmatmul(&format!("{pfx}.ffn_up.weight"))?;
                let down = gg.qmatmul(&format!("{pfx}.ffn_down.weight"))?;
                MoeOrMlp::Mlp(Mlp { gate, up, down })
            };

            layers.push(Layer {
                attn,
                attn_norm,
                mlp,
                ffn_norm,
            });
        }

        Ok(Self {
            embeddings,
            layers,
            norm,
            output,
            dtype,
            device: device.clone(),
        })
    }

    pub fn forward(&mut self, x: &Tensor, offset: usize) -> Result<Tensor> {
        let mut xs = self.embeddings.forward(x)?;
        let (b, l) = x.dims2()?;

        let mask = if l == 1 {
            None
        } else {
            Some(self.causal_mask(b, l, offset)?)
        };

        for layer in self.layers.iter_mut() {
            let residual = xs.clone();
            let x = layer.attn_norm.forward(&xs)?;
            let x = layer.attn.forward(&x, mask.as_ref(), offset)?;
            let x = (x + residual)?;

            let residual = x.clone();
            let ffn_in = layer.ffn_norm.forward(&x)?;
            // Reshape for MoE: [batch, seq, hidden] → [batch*seq, hidden] → back
            let (fb, fl, fh) = ffn_in.dims3()?;
            let ffn_flat = ffn_in.reshape((fb * fl, fh))?;
            let ffn_out = layer.mlp.forward(&ffn_flat)?;
            let ffn_out = ffn_out.reshape((fb, fl, fh))?;
            xs = (ffn_out + residual)?;
        }

        let xs = xs.narrow(1, l - 1, 1)?;
        let xs = self.norm.forward(&xs)?;
        self.output.forward(&xs)?.to_dtype(DType::F32)?.squeeze(1)
    }

    pub fn clear_kv_cache(&mut self) {
        // ConcatKvCache doesn't expose clear, so we rebuild
        for layer in &mut self.layers {
            layer.attn.kv_cache = ConcatKvCache::new(2);
        }
    }

    fn causal_mask(&self, b: usize, tgt: usize, offset: usize) -> Result<Tensor> {
        let minf = f32::NEG_INFINITY;
        let mask: Vec<f32> = (0..tgt)
            .flat_map(|i| {
                (0..(tgt + offset)).map(move |j| if j <= i + offset { 0.0 } else { minf })
            })
            .collect();
        Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
    }
}