Skip to main content

ferrum_models/models/qwen3_moe/
api.rs

1use super::*;
2
3impl<B: MoeLlmBackend + BackendPagedKv, K: KvDtypeKind> DecoderOnlyLLM for Qwen3MoeModel<B, K> {
4    fn config(&self) -> &LlmRuntimeConfig {
5        &self.runtime_cfg
6    }
7
8    fn cache_metrics_snapshot(&self) -> Option<serde_json::Value> {
9        Some(self.prefix_cache_snapshot_json())
10    }
11
12    fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
13        // Eager scratch + KV cache grow + a 1-token forward warmup so
14        // the first real prefill / decode doesn't pay the cold-start
15        // ~25-MTLBuffer scratch alloc + ~96-MTLBuffer KV alloc + Metal
16        // pipeline-state first-bind costs (~265 ms total on Qwen3-MoE
17        // 30B-A3B / M1 Max). Mirrors what llama-bench's --warmup does
18        // (which runs a same-shape forward before the timer).
19        self.ensure_scratch(max_tokens);
20        self.ensure_kv(cache_id);
21
22        // Warmup forward through all 48 layers under a scratch cache_id
23        // so the real `cache_id` starts at pos_offset=0. Token 0 is
24        // valid for any tokenizer (BOS or pad).
25        const WARMUP_CACHE: &str = "__ferrum_warmup__";
26        let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
27        // Drop the warmup KV cache slot — real cache_id is unaffected.
28        if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
29            let paged_cache = caches
30                .first()
31                .is_some_and(|cache| cache.block_table.is_some());
32            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
33                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
34                if let Some(c0) = caches.first() {
35                    if !c0.paged_block_indices.is_empty() {
36                        alloc.free(&c0.paged_block_indices);
37                    }
38                }
39                for c in caches.iter_mut() {
40                    c.paged_block_indices.clear();
41                }
42            }
43            if !paged_cache {
44                self.kv_free_pool.push(caches);
45            }
46        }
47    }
48
49    fn kv_capacity(&self) -> usize {
50        // Mirror the bound `ensure_kv` will use when allocating the cache.
51        let model_max = self.cfg.base.max_seq_len;
52        self.runtime_env.kv_capacity(model_max)
53    }
54
55    fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
56        self.prefill_internal(cache_id, tokens)
57    }
58
59    fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
60        self.decode_internal(cache_id, token, pos)
61    }
62
63    // decode_batch is gated to use the batched path only when it's a
64    // measurable win. The crossover depends on M:
65    //
66    //   - At low M (≤ ~8) the per-item `decode_internal` loop wins
67    //     because: (a) it stays at scratch offset 0 (no copy_slice
68    //     overhead), (b) it preserves the cross-layer rms_norm fusion
69    //     fast path (`weighted_sum_residual_norm_stacked`).
70    //   - At high M (≥ ~12) the batched path wins because the dense
71    //     GEMM batching (qkv_proj, o_proj, router, lm_head at m=M) and
72    //     the prefill-batched MoE dispatch (one `gemm_quant_moe_id` for
73    //     all tokens) amortise the ~48-dispatch lost-fusion penalty.
74    //
75    // Default ON in 0.7.2+. On CUDA with paged KV + vLLM MoE, the
76    // crossover is now M=4: 2026-05-28/29 Vast RTX 4090 random-256/128
77    // probes saw the old threshold=8 stay on sequential per-token decode
78    // (~89-122 tok/s), while threshold=4 measured 425.6 ± 36.6 tok/s.
79    // `FERRUM_MOE_BATCHED=0` forces the
80    // legacy loop; `FERRUM_MOE_BATCH_THRESHOLD` remains an escape hatch
81    // for future hardware/backends.
82    fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
83        self.decode_batch_with_full_logits(batch, false)
84    }
85
86    fn decode_batch_with_full_logits(
87        &mut self,
88        batch: &[(String, u32, u32)],
89        force_full_logits: bool,
90    ) -> Vec<Vec<f32>> {
91        let m = batch.len();
92        let opted_in = self.runtime_env.moe_batched_enabled;
93        let threshold = self.runtime_env.moe_batch_threshold;
94        if opted_in && m >= threshold {
95            self.decode_batch_internal_with_full_logits(batch, force_full_logits)
96        } else {
97            batch
98                .iter()
99                .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
100                .collect()
101        }
102    }
103
104    fn unified_forward(
105        &mut self,
106        items: &[(String, Vec<u32>, usize, bool)],
107    ) -> std::result::Result<Vec<Option<Vec<f32>>>, FerrumError> {
108        if items.is_empty() {
109            return Ok(Vec::new());
110        }
111        if self.runtime_env.qwen_unified_trace {
112            let lens: Vec<usize> = items.iter().map(|it| it.1.len()).collect();
113            let positions: Vec<usize> = items.iter().map(|it| it.2).collect();
114            let finals: Vec<bool> = items.iter().map(|it| it.3).collect();
115            eprintln!(
116                "[qwen-unified] items={} lens={:?} positions={:?} finals={:?} use_vllm_paged_attn={}",
117                items.len(),
118                lens,
119                positions,
120                finals,
121                self.use_vllm_paged_attn
122            );
123        }
124        if !self.supports_varlen_qkv {
125            return Err(FerrumError::unsupported(
126                "Qwen3MoeModel::unified_forward: backend lacks varlen QKV kernels. \
127                 Engine will fall back to legacy paths.",
128            ));
129        }
130        // Pure-decode shortcut: every item is q_len=1 + is_final_chunk.
131        // For this shape, ferrum's legacy `forward_layer_batched_decode`
132        // path (with FERRUM_MOE_GRAPH=1 graph capture + decode-tuned
133        // moe_forward_stacked) is faster than our generic varlen +
134        // bucketed-MoE unified path. Returning Unsupported routes the
135        // engine to the legacy decode_batch path via LlmExecutor's
136        // fallback partition.
137        let all_decode = items.iter().all(|it| it.1.len() == 1 && it.3);
138        if all_decode {
139            return Err(FerrumError::unsupported(
140                "Qwen3MoeModel::unified_forward: pure-decode batch — \
141                 routed to legacy decode_batch (faster for q_len=1)",
142            ));
143        }
144        if items.len() == 1 && items[0].1.len() > 1 {
145            return Err(FerrumError::unsupported(
146                "Qwen3MoeModel::unified_forward: single-seq prefill — \
147                 routed to specialized prefill path",
148            ));
149        }
150        if !self.runtime_env.qwen_unified_prefill && items.iter().any(|it| it.1.len() > 1) {
151            return Err(FerrumError::unsupported(
152                "Qwen3MoeModel::unified_forward: prefill disabled by \
153                 FERRUM_QWEN_UNIFIED_PREFILL=0",
154            ));
155        }
156        // Any prefill chunk (q_len > 1) OR non-final-chunk item:
157        // unified path wins by collapsing N serial prefills into one
158        // [M_total, hidden] forward.
159        if self.paged_pools.is_none() {
160            return Err(FerrumError::unsupported(
161                "Qwen3MoeModel::unified_forward: paged KV required \
162                 (set FERRUM_METAL_PAGED_KV=1).",
163            ));
164        }
165        let m_total: usize = items.iter().map(|it| it.1.len()).sum();
166        if m_total > self.scratch.max_tokens {
167            return Err(FerrumError::unsupported(format!(
168                "Qwen3MoeModel::unified_forward: m_total={} > scratch.max_tokens={}",
169                m_total, self.scratch.max_tokens,
170            )));
171        }
172        Ok(self.unified_forward_internal(items))
173    }
174
175    fn release(&mut self, cache_id: &str) {
176        // Mirror LlamaFamilyModel::release — do NOT reset the captured
177        // graphs here. Graphs reference paged_pool addresses (model-
178        // level + stable) and paged_batch_* scratch addresses (also
179        // model-level + stable); the per-cache_id state (paged_block_
180        // indices) lives in `kv_caches` and never appears in graph
181        // node args. Wiping graphs on release would invalidate them
182        // mid-flight (a release between capture and the next replay
183        // → CUDA_ERROR_INVALID_VALUE on cuGraphLaunch).
184        let mut ctx = B::new_context();
185        B::sync(&mut ctx);
186        if let Some(mut caches) = self.kv_caches.remove(cache_id) {
187            let paged_cache = caches
188                .first()
189                .is_some_and(|cache| cache.block_table.is_some());
190            // Paged mode: return the cache_id's blocks to the shared
191            // allocator so other sequences can reuse them. Without this,
192            // every request consumes max_blocks_per_seq blocks
193            // permanently — pool exhausts after FERRUM_PAGED_MAX_SEQS
194            // requests and subsequent ensure_kv panics with
195            // "scratch residual missing" (the cascade panic from a
196            // failed ensure_kv path leaving scratch poisoned).
197            if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
198                let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
199                if let Some(c0) = caches.first() {
200                    if !c0.paged_block_indices.is_empty() {
201                        alloc.free(&c0.paged_block_indices);
202                    }
203                }
204                for c in caches.iter_mut() {
205                    c.paged_block_indices.clear();
206                }
207            }
208            // In paged mode the cache metadata (block_table/context_lens)
209            // is tiny compared with the shared K/V pools. Reusing that
210            // metadata can leak stale per-request state across independent
211            // HTTP requests, producing empty completions or corrupted
212            // batched-decode output after the first request. Drop metadata
213            // after returning physical blocks; the next ensure_kv allocates
214            // fresh metadata.
215            if !paged_cache {
216                self.kv_free_pool.push(caches);
217            }
218            if paged_cache && self.kv_caches.is_empty() {
219                // Reset only when the model is idle. That prevents stale
220                // paged/unified/batched scratch from leaking into the next
221                // independent request while preserving active concurrent
222                // requests that still own model cache IDs.
223                self.reset();
224            }
225        }
226    }
227
228    fn reset(&mut self) {
229        let mut ctx = B::new_context();
230        B::sync(&mut ctx);
231        B::reset_all_graphs(&mut ctx);
232        self.batched_graph_keys_seen.clear();
233        self.batched_graph_warmup = 0;
234        self.batched_graph_failed = false;
235        B::sync(&mut ctx);
236        self.kv_caches.clear();
237        self.kv_free_pool.clear();
238        self.paged_pools = None;
239        self.paged_fa_pools = None;
240        self.paged_block_alloc = None;
241        self.paged_dims = None;
242        let initial_scratch_tokens = if self.supports_varlen_qkv {
243            self.runtime_env.initial_scratch_tokens
244        } else {
245            1
246        };
247        self.scratch = Qwen3MoeScratch::alloc(&self.cfg, initial_scratch_tokens);
248    }
249}