ferrum_models/models/qwen3_moe/api.rs
1use super::*;
2
3impl<B: MoeLlmBackend + BackendPagedKv, K: KvDtypeKind> DecoderOnlyLLM for Qwen3MoeModel<B, K> {
4 fn config(&self) -> &LlmRuntimeConfig {
5 &self.runtime_cfg
6 }
7
8 fn prepare(&mut self, cache_id: &str, max_tokens: usize) {
9 // Eager scratch + KV cache grow + a 1-token forward warmup so
10 // the first real prefill / decode doesn't pay the cold-start
11 // ~25-MTLBuffer scratch alloc + ~96-MTLBuffer KV alloc + Metal
12 // pipeline-state first-bind costs (~265 ms total on Qwen3-MoE
13 // 30B-A3B / M1 Max). Mirrors what llama-bench's --warmup does
14 // (which runs a same-shape forward before the timer).
15 self.ensure_scratch(max_tokens);
16 self.ensure_kv(cache_id);
17
18 // Warmup forward through all 48 layers under a scratch cache_id
19 // so the real `cache_id` starts at pos_offset=0. Token 0 is
20 // valid for any tokenizer (BOS or pad).
21 const WARMUP_CACHE: &str = "__ferrum_warmup__";
22 let _ = self.prefill_internal(WARMUP_CACHE, &[0u32]);
23 // Drop the warmup KV cache slot — real cache_id is unaffected.
24 if let Some(mut caches) = self.kv_caches.remove(WARMUP_CACHE) {
25 let paged_cache = caches
26 .first()
27 .is_some_and(|cache| cache.block_table.is_some());
28 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
29 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
30 if let Some(c0) = caches.first() {
31 if !c0.paged_block_indices.is_empty() {
32 alloc.free(&c0.paged_block_indices);
33 }
34 }
35 for c in caches.iter_mut() {
36 c.paged_block_indices.clear();
37 }
38 }
39 if !paged_cache {
40 self.kv_free_pool.push(caches);
41 }
42 }
43 }
44
45 fn kv_capacity(&self) -> usize {
46 // Mirror the bound `ensure_kv` will use when allocating the cache.
47 let model_max = self.cfg.base.max_seq_len;
48 self.runtime_env.kv_capacity(model_max)
49 }
50
51 fn prefill(&mut self, cache_id: &str, tokens: &[u32]) -> Vec<f32> {
52 self.prefill_internal(cache_id, tokens)
53 }
54
55 fn decode(&mut self, cache_id: &str, token: u32, pos: u32) -> Vec<f32> {
56 self.decode_internal(cache_id, token, pos)
57 }
58
59 // decode_batch is gated to use the batched path only when it's a
60 // measurable win. The crossover depends on M:
61 //
62 // - At low M (≤ ~8) the per-item `decode_internal` loop wins
63 // because: (a) it stays at scratch offset 0 (no copy_slice
64 // overhead), (b) it preserves the cross-layer rms_norm fusion
65 // fast path (`weighted_sum_residual_norm_stacked`).
66 // - At high M (≥ ~12) the batched path wins because the dense
67 // GEMM batching (qkv_proj, o_proj, router, lm_head at m=M) and
68 // the prefill-batched MoE dispatch (one `gemm_quant_moe_id` for
69 // all tokens) amortise the ~48-dispatch lost-fusion penalty.
70 //
71 // Default ON in 0.7.2+. On CUDA with paged KV + vLLM MoE, the
72 // crossover is now M=4: 2026-05-28/29 Vast RTX 4090 random-256/128
73 // probes saw the old threshold=8 stay on sequential per-token decode
74 // (~89-122 tok/s), while threshold=4 measured 425.6 ± 36.6 tok/s.
75 // `FERRUM_MOE_BATCHED=0` forces the
76 // legacy loop; `FERRUM_MOE_BATCH_THRESHOLD` remains an escape hatch
77 // for future hardware/backends.
78 fn decode_batch(&mut self, batch: &[(String, u32, u32)]) -> Vec<Vec<f32>> {
79 let m = batch.len();
80 let opted_in = self.runtime_env.moe_batched_enabled;
81 let threshold = self.runtime_env.moe_batch_threshold;
82 if opted_in && m >= threshold {
83 self.decode_batch_internal(batch)
84 } else {
85 batch
86 .iter()
87 .map(|(cid, tok, p)| self.decode(cid, *tok, *p))
88 .collect()
89 }
90 }
91
92 fn unified_forward(
93 &mut self,
94 items: &[(String, Vec<u32>, usize, bool)],
95 ) -> std::result::Result<Vec<Option<Vec<f32>>>, FerrumError> {
96 if items.is_empty() {
97 return Ok(Vec::new());
98 }
99 if self.runtime_env.qwen_unified_trace {
100 let lens: Vec<usize> = items.iter().map(|it| it.1.len()).collect();
101 let positions: Vec<usize> = items.iter().map(|it| it.2).collect();
102 let finals: Vec<bool> = items.iter().map(|it| it.3).collect();
103 eprintln!(
104 "[qwen-unified] items={} lens={:?} positions={:?} finals={:?} use_vllm_paged_attn={}",
105 items.len(),
106 lens,
107 positions,
108 finals,
109 self.use_vllm_paged_attn
110 );
111 }
112 if !B::supports_varlen_qkv() {
113 return Err(FerrumError::unsupported(
114 "Qwen3MoeModel::unified_forward: backend lacks varlen QKV kernels. \
115 Engine will fall back to legacy paths.",
116 ));
117 }
118 // Pure-decode shortcut: every item is q_len=1 + is_final_chunk.
119 // For this shape, ferrum's legacy `forward_layer_batched_decode`
120 // path (with FERRUM_MOE_GRAPH=1 graph capture + decode-tuned
121 // moe_forward_stacked) is faster than our generic varlen +
122 // bucketed-MoE unified path. Returning Unsupported routes the
123 // engine to the legacy decode_batch path via LlmExecutor's
124 // fallback partition.
125 let all_decode = items.iter().all(|it| it.1.len() == 1 && it.3);
126 if all_decode {
127 return Err(FerrumError::unsupported(
128 "Qwen3MoeModel::unified_forward: pure-decode batch — \
129 routed to legacy decode_batch (faster for q_len=1)",
130 ));
131 }
132 if items.len() == 1 && items[0].1.len() > 1 {
133 return Err(FerrumError::unsupported(
134 "Qwen3MoeModel::unified_forward: single-seq prefill — \
135 routed to specialized prefill path",
136 ));
137 }
138 if !self.runtime_env.qwen_unified_prefill && items.iter().any(|it| it.1.len() > 1) {
139 return Err(FerrumError::unsupported(
140 "Qwen3MoeModel::unified_forward: prefill disabled by \
141 FERRUM_QWEN_UNIFIED_PREFILL=0",
142 ));
143 }
144 // Any prefill chunk (q_len > 1) OR non-final-chunk item:
145 // unified path wins by collapsing N serial prefills into one
146 // [M_total, hidden] forward.
147 if self.paged_pools.is_none() {
148 return Err(FerrumError::unsupported(
149 "Qwen3MoeModel::unified_forward: paged KV required \
150 (set FERRUM_METAL_PAGED_KV=1).",
151 ));
152 }
153 let m_total: usize = items.iter().map(|it| it.1.len()).sum();
154 if m_total > self.scratch.max_tokens {
155 return Err(FerrumError::unsupported(format!(
156 "Qwen3MoeModel::unified_forward: m_total={} > scratch.max_tokens={}",
157 m_total, self.scratch.max_tokens,
158 )));
159 }
160 Ok(self.unified_forward_internal(items))
161 }
162
163 fn release(&mut self, cache_id: &str) {
164 // Mirror LlamaFamilyModel::release — do NOT reset the captured
165 // graphs here. Graphs reference paged_pool addresses (model-
166 // level + stable) and paged_batch_* scratch addresses (also
167 // model-level + stable); the per-cache_id state (paged_block_
168 // indices) lives in `kv_caches` and never appears in graph
169 // node args. Wiping graphs on release would invalidate them
170 // mid-flight (a release between capture and the next replay
171 // → CUDA_ERROR_INVALID_VALUE on cuGraphLaunch).
172 let mut ctx = B::new_context();
173 B::sync(&mut ctx);
174 if let Some(mut caches) = self.kv_caches.remove(cache_id) {
175 let paged_cache = caches
176 .first()
177 .is_some_and(|cache| cache.block_table.is_some());
178 // Paged mode: return the cache_id's blocks to the shared
179 // allocator so other sequences can reuse them. Without this,
180 // every request consumes max_blocks_per_seq blocks
181 // permanently — pool exhausts after FERRUM_PAGED_MAX_SEQS
182 // requests and subsequent ensure_kv panics with
183 // "scratch residual missing" (the cascade panic from a
184 // failed ensure_kv path leaving scratch poisoned).
185 if let Some(alloc_arc) = self.paged_block_alloc.as_ref() {
186 let mut alloc = alloc_arc.lock().unwrap_or_else(|p| p.into_inner());
187 if let Some(c0) = caches.first() {
188 if !c0.paged_block_indices.is_empty() {
189 alloc.free(&c0.paged_block_indices);
190 }
191 }
192 for c in caches.iter_mut() {
193 c.paged_block_indices.clear();
194 }
195 }
196 // In paged mode the cache metadata (block_table/context_lens)
197 // is tiny compared with the shared K/V pools. Reusing that
198 // metadata on Metal GGUF MoE can leak stale per-request state
199 // across independent HTTP requests, producing empty completions
200 // or repeated `<think>` tokens after the first request. Drop it
201 // after returning physical blocks; the next ensure_kv allocates
202 // fresh metadata while reusing the shared pools.
203 if !paged_cache {
204 self.kv_free_pool.push(caches);
205 }
206 if paged_cache && self.runtime_env.paged_max_seqs <= 1 {
207 // Product Metal GGUF MoE serve currently uses one active
208 // paged sequence for correctness. In that mode there cannot
209 // be another live request sharing captured graph/model state,
210 // so reset graph/KV bookkeeping after each completed request
211 // to avoid stale paged state leaking into the next HTTP
212 // request. This keeps `ferrum serve` correct without asking
213 // users to manage env combinations.
214 self.reset();
215 }
216 }
217 }
218
219 fn reset(&mut self) {
220 let mut ctx = B::new_context();
221 B::sync(&mut ctx);
222 B::reset_all_graphs(&mut ctx);
223 self.batched_graph_keys_seen.clear();
224 self.batched_graph_warmup = 0;
225 self.batched_graph_failed = false;
226 B::sync(&mut ctx);
227 self.kv_caches.clear();
228 self.kv_free_pool.clear();
229 self.paged_pools = None;
230 self.paged_fa_pools = None;
231 self.paged_block_alloc = None;
232 self.paged_dims = None;
233 let initial_scratch_tokens = if B::supports_varlen_qkv() {
234 self.runtime_env.initial_scratch_tokens
235 } else {
236 1
237 };
238 self.scratch = Qwen3MoeScratch::alloc(&self.cfg, initial_scratch_tokens);
239 }
240}