1use crate::builder::build_llada2_forward_graph;
19use crate::capabilities::{default_memory_budget_bytes, validate_device};
20use crate::compile_util::{compile_llada2_built, llada2_profile};
21use crate::config::LLaDA2MoeConfig;
22use crate::gate_op::ensure_group_limited_gate_registered;
23use crate::load::load_llada2_from_dir;
24use crate::mask::block_diffusion_attention_mask;
25use crate::moe_offload::{self, MoeOffloadState};
26use crate::moe_store::{
27 apply_moe_store_to_compiled, build_moe_expert_store, moe_host_bind_from_store,
28};
29use crate::sampling::sample_logits;
30use crate::tide::{
31 BlockDenoiseConfig, BlockDenoiseLoop, BlockDenoiseSampler, BlockDiffusionForward,
32 BlockForwardOutput, DenoiseStepCtx, GenerateConfig, run_block_diffusion,
33};
34use crate::weights::LLaDA2Weights;
35use anyhow::{Result, anyhow};
36use rlx_core::flow_util::built_from_graph;
37use rlx_runtime::{CompiledGraph, Device, MoeExpertStore, MoeResidencyStats};
38
39fn push_moe_residency(compiled: &mut CompiledGraph, layers: &[Vec<bool>]) {
40 let refs: Vec<&[bool]> = layers.iter().map(|m| m.as_slice()).collect();
41 compiled.set_moe_resident_experts_per_layer(&refs);
42}
43
44#[derive(Default)]
45pub struct LLaDA2RunnerBuilder {
46 inline: Option<(LLaDA2MoeConfig, LLaDA2Weights)>,
47 weights_path: Option<std::path::PathBuf>,
48 device: Option<Device>,
49 batch: usize,
50 max_seq: Option<usize>,
51 max_gpu_experts_per_layer: Option<usize>,
52 memory_budget_bytes: Option<usize>,
53 jump_steps: Option<usize>,
54 reserve_vram_gb: f64,
55 moe_collect_stats: bool,
56}
57
58impl LLaDA2RunnerBuilder {
59 pub fn inline_weights(mut self, cfg: LLaDA2MoeConfig, weights: LLaDA2Weights) -> Self {
60 self.inline = Some((cfg, weights));
61 self
62 }
63
64 pub fn weights_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
65 self.weights_path = Some(path.into());
66 self
67 }
68
69 pub fn device(mut self, device: Device) -> Self {
70 self.device = Some(device);
71 self
72 }
73
74 pub fn batch_seq(mut self, batch: usize, max_seq: usize) -> Self {
75 self.batch = batch.max(1);
76 self.max_seq = Some(max_seq.max(1));
77 self
78 }
79
80 pub fn enable_predictive_expert_offload(mut self, max_per_layer: usize) -> Self {
81 self.max_gpu_experts_per_layer = Some(max_per_layer);
82 self
83 }
84
85 pub fn jump_steps(mut self, n: usize) -> Self {
86 self.jump_steps = Some(n);
87 self
88 }
89
90 pub fn reserve_vram_gb(mut self, gb: f64) -> Self {
91 self.reserve_vram_gb = gb;
92 self
93 }
94
95 pub fn moe_collect_stats(mut self, on: bool) -> Self {
96 self.moe_collect_stats = on;
97 self
98 }
99
100 pub fn memory_budget_bytes(mut self, bytes: usize) -> Self {
101 self.memory_budget_bytes = Some(bytes);
102 self
103 }
104
105 pub fn build(self) -> Result<LLaDA2Runner> {
106 ensure_group_limited_gate_registered();
107
108 let (cfg, weights) = match self.inline {
109 Some(x) => x,
110 None => {
111 let path = self.weights_path.as_ref().ok_or_else(|| {
112 anyhow!("LLaDA2Runner: weights_path or inline_weights required")
113 })?;
114 load_llada2_from_dir(path.as_path())?
115 }
116 };
117
118 let device = self.device.unwrap_or(Device::Cpu);
119 validate_device(&cfg, device)?;
120 let batch = self.batch.max(1);
121 let seq = self.max_seq.unwrap_or(128).max(1);
122
123 let (graph, params) = build_llada2_forward_graph(&cfg, &weights, batch, seq)?;
124 let mut built = built_from_graph(graph, params)?;
125 built.profile = llada2_profile();
126 let mut compiled = compile_llada2_built(built, device)?;
127
128 let moe_store = if cfg.num_experts > 0 {
129 Some(build_moe_expert_store(&cfg, &weights)?)
130 } else {
131 None
132 };
133
134 let mem_budget = self
135 .memory_budget_bytes
136 .or_else(|| default_memory_budget_bytes(device));
137
138 let moe = moe_offload::build_moe_offload(
139 &cfg,
140 &weights,
141 device,
142 self.max_gpu_experts_per_layer,
143 mem_budget,
144 self.jump_steps,
145 self.reserve_vram_gb,
146 self.moe_collect_stats,
147 );
148
149 if let Some(mo) = &moe {
150 push_moe_residency(&mut compiled, &mo.per_layer_resident_masks());
151 compiled.enable_moe_topk_capture(cfg.num_experts);
152 if let Some(store) = &moe_store {
153 apply_moe_store_to_compiled(store, &mut compiled);
154 }
155 }
156
157 Ok(LLaDA2Runner {
158 cfg,
159 weights,
160 compiled,
161 device,
162 batch,
163 seq,
164 block_length: 32,
165 moe,
166 moe_store,
167 })
168 }
169}
170
171pub struct LLaDA2Runner {
172 pub cfg: LLaDA2MoeConfig,
173 pub weights: LLaDA2Weights,
174 compiled: CompiledGraph,
175 device: Device,
176 batch: usize,
177 seq: usize,
178 block_length: usize,
179 moe: Option<MoeOffloadState>,
180 moe_store: Option<MoeExpertStore>,
181}
182
183impl LLaDA2Runner {
184 pub fn builder() -> LLaDA2RunnerBuilder {
185 LLaDA2RunnerBuilder::default()
186 }
187
188 pub fn config(&self) -> &LLaDA2MoeConfig {
189 &self.cfg
190 }
191
192 pub fn device(&self) -> Device {
193 self.device
194 }
195
196 pub fn max_seq(&self) -> usize {
197 self.seq
198 }
199
200 pub fn predictive_offload_enabled(&self) -> bool {
201 self.moe.as_ref().is_some_and(|m| m.predictive_enabled)
202 }
203
204 pub fn jump_steps(&self) -> usize {
205 self.moe.as_ref().map(|m| m.jump_steps).unwrap_or(1)
206 }
207
208 pub fn predictive_offload_info(&self) -> Option<crate::tide::PredictiveOffloadInfo> {
209 self.moe.as_ref().map(|m| m.info.clone())
210 }
211
212 pub fn moe_offload(&self) -> Option<&MoeOffloadState> {
213 self.moe.as_ref()
214 }
215
216 pub fn moe_store(&self) -> Option<&MoeExpertStore> {
217 self.moe_store.as_ref()
218 }
219
220 pub fn sync_moe_residency(&self, compiled: &mut CompiledGraph) {
221 if let Some(mo) = &self.moe {
222 push_moe_residency(compiled, &mo.per_layer_resident_masks());
223 if let Some(store) = &self.moe_store {
224 apply_moe_store_to_compiled(store, compiled);
225 }
226 }
227 }
228
229 fn bind_moe_host_weights(&self) {
230 if self.moe.is_none() {
231 rlx_cpu::moe_residency::bind_host_weights(None);
232 return;
233 }
234 if let Some(store) = &self.moe_store {
235 rlx_cpu::moe_residency::bind_host_weights(Some(moe_host_bind_from_store(store)));
236 } else {
237 rlx_cpu::moe_residency::bind_host_weights(None);
238 }
239 }
240
241 fn refresh_moe_after_forward(&mut self, step_ctx: DenoiseStepCtx, want_refresh: bool) {
242 let Some(layers) = self.compiled.take_moe_topk_capture() else {
243 return;
244 };
245 let Some(mo) = self.moe.as_mut() else {
246 return;
247 };
248 let is_prefill = step_ctx.num_block == step_ctx.prefill_blocks;
249 if !want_refresh || !mo.should_refresh_forward(step_ctx.denoise_step, is_prefill) {
250 return;
251 }
252 let refreshed = if let Some(store) = self.moe_store.as_ref() {
253 mo.refresh_from_capture_with_store(store, &layers, step_ctx.denoise_step, is_prefill)
254 } else {
255 mo.refresh_from_capture(&layers, step_ctx.denoise_step, is_prefill)
256 };
257 if refreshed {
258 let masks = mo.per_layer_resident_masks();
259 push_moe_residency(&mut self.compiled, &masks);
260 if let Some(store) = &self.moe_store {
261 apply_moe_store_to_compiled(store, &mut self.compiled);
262 }
263 }
264 }
265
266 fn forward_window_padded(
267 &mut self,
268 tokens: &[u32],
269 window_len: usize,
270 attn_mask: &[f32],
271 position_ids: &[f32],
272 step_ctx: DenoiseStepCtx,
273 want_refresh: bool,
274 ) -> Result<Vec<f32>> {
275 let b = self.batch;
276 let s = self.seq;
277 let w = window_len.min(tokens.len()).min(s);
278 let mut ids = vec![0f32; b * s];
279 let mut pos = vec![0f32; b * s];
280 for i in 0..w {
281 ids[i] = tokens[i] as f32;
282 pos[i] = position_ids.get(i).copied().unwrap_or(i as f32);
283 }
284 let mut full_mask = vec![f32::NEG_INFINITY; b * s * s];
285 for r in 0..w {
286 for c in 0..w {
287 full_mask[r * s + c] = attn_mask[r * w + c];
288 }
289 }
290 let logits = self.forward_logits(&ids, &pos, &full_mask)?;
291 self.refresh_moe_after_forward(step_ctx, want_refresh);
292 Ok(logits)
293 }
294
295 pub fn forward_logits(
296 &mut self,
297 input_ids: &[f32],
298 position_ids: &[f32],
299 attn_mask: &[f32],
300 ) -> Result<Vec<f32>> {
301 let b = self.batch;
302 let s = self.seq;
303 if input_ids.len() != b * s {
304 return Err(anyhow!("input_ids len {} != {b}*{s}", input_ids.len()));
305 }
306 if attn_mask.len() != b * s * s {
307 return Err(anyhow!(
308 "attn_mask len {} != {b}*1*{s}*{s}",
309 attn_mask.len()
310 ));
311 }
312 self.bind_moe_host_weights();
313 let outs = self.compiled.run(&[
314 ("input_ids", input_ids),
315 ("position_ids", position_ids),
316 ("attn_mask", attn_mask),
317 ]);
318 Ok(outs.into_iter().next().unwrap_or_default())
319 }
320
321 pub fn block_denoise_loop(
322 &mut self,
323 cfg: BlockDenoiseConfig,
324 ) -> BlockDenoiseLoop<LLaDA2RunnerForward<'_>> {
325 self.block_length = cfg.block_length;
326 let model_cfg = self.cfg.clone();
327 BlockDenoiseLoop::new(cfg, model_cfg, LLaDA2RunnerForward { runner: self })
328 }
329
330 pub fn get_offload_stats(&mut self) -> crate::tide::TideOffloadStats {
331 let residency = self
332 .compiled
333 .take_moe_residency_stats()
334 .or_else(rlx_cpu::moe_residency::peek_stats);
335 let residency_ref = residency.as_ref();
336 self.offload_stats(residency_ref).unwrap_or_default()
337 }
338
339 pub fn reset_offload_step_stats(&mut self) {
340 if let Some(mo) = self.moe.as_mut() {
341 for pool in &mut mo.pools {
342 pool.reset_step_stats();
343 }
344 }
345 let _ = self.compiled.take_moe_residency_stats();
346 }
347
348 pub fn generate(
349 &mut self,
350 gen_cfg: &GenerateConfig,
351 prompt_ids: &[u32],
352 ) -> Result<(Vec<u32>, Vec<crate::tide::BlockDenoiseStepStats>)> {
353 let max_window = (prompt_ids.len() + gen_cfg.gen_length).div_ceil(gen_cfg.block_length)
354 * gen_cfg.block_length;
355 if max_window > self.seq {
356 return Err(anyhow!(
357 "generate needs max_seq >= {max_window} (set .batch_seq(batch, max_seq) on builder)"
358 ));
359 }
360 let cfg = self.cfg.clone();
361 let collect = gen_cfg.collect_stats;
362 run_block_diffusion(self, &cfg, gen_cfg, prompt_ids, |runner| {
363 let stats = runner.get_offload_stats();
364 if collect {
365 runner.reset_offload_step_stats();
366 }
367 stats
368 })
369 }
370
371 pub fn offload_stats(
372 &self,
373 residency: Option<&MoeResidencyStats>,
374 ) -> Option<crate::tide::TideOffloadStats> {
375 self.moe
376 .as_ref()
377 .map(|m| moe_offload::tide_stats(m, residency))
378 }
379}
380
381impl BlockDenoiseSampler for LLaDA2Runner {
382 fn sample_block(
383 &mut self,
384 x: &[u32],
385 window_end: usize,
386 block_length: usize,
387 refresh_experts: bool,
388 gen_cfg: &GenerateConfig,
389 model_cfg: &LLaDA2MoeConfig,
390 step_ctx: DenoiseStepCtx,
391 ) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
392 let mask = block_diffusion_attention_mask(1, window_end, block_length);
393 let position_ids: Vec<f32> = (0..window_end).map(|i| i as f32).collect();
394 let logits = self.forward_window_padded(
395 &x[..window_end],
396 window_end,
397 &mask,
398 &position_ids,
399 step_ctx,
400 refresh_experts,
401 )?;
402 let block_start = window_end.saturating_sub(block_length);
403 let vocab = model_cfg.vocab_size;
404 let mut x0 = vec![0u32; block_length];
405 let mut x0_p = vec![0f32; block_length];
406 for i in 0..block_length {
407 let pos = block_start + i;
408 if pos >= window_end {
409 x0[i] = gen_cfg.mask_id;
410 x0_p[i] = 0.0;
411 continue;
412 }
413 let base = pos * vocab;
414 let (tok, prob) = sample_logits(
415 &logits[base..base + vocab],
416 gen_cfg.temperature,
417 gen_cfg.top_k,
418 gen_cfg.top_p,
419 gen_cfg.do_sample,
420 );
421 x0[i] = tok;
422 x0_p[i] = prob;
423 }
424 Ok((x0, x0_p))
425 }
426}
427
428pub struct LLaDA2RunnerForward<'a> {
429 pub runner: &'a mut LLaDA2Runner,
430}
431
432impl BlockDiffusionForward for LLaDA2RunnerForward<'_> {
433 fn forward_block(
434 &mut self,
435 token_ids: &[u32],
436 seq_len: usize,
437 refresh_experts: bool,
438 ) -> Result<BlockForwardOutput, anyhow::Error> {
439 let b = self.runner.batch;
440 let s = self.runner.seq;
441 let block = self.runner.block_length;
442 let window = seq_len.min(token_ids.len()).min(s);
443 let block_start = window.saturating_sub(block);
444
445 let mut ids = vec![0f32; b * s];
446 let mut pos = vec![0f32; b * s];
447 for i in 0..window {
448 ids[i] = token_ids[i] as f32;
449 pos[i] = i as f32;
450 }
451
452 let mask = block_diffusion_attention_mask(b, window, block);
453 let position_ids: Vec<f32> = (0..window).map(|i| i as f32).collect();
454 let step_ctx = DenoiseStepCtx {
455 num_block: 0,
456 prefill_blocks: 0,
457 denoise_step: 0,
458 };
459 let logits = self.runner.forward_window_padded(
460 &token_ids[..window],
461 window,
462 &mask,
463 &position_ids,
464 step_ctx,
465 refresh_experts,
466 )?;
467
468 let vocab = self.runner.cfg.vocab_size;
469 let mut x0 = Vec::with_capacity(block);
470 let mut x0_p = Vec::with_capacity(block);
471 for i in 0..block {
472 let tok_pos = block_start + i;
473 if tok_pos >= window {
474 x0.push(self.runner.cfg.mask_token_id);
475 x0_p.push(0.0);
476 continue;
477 }
478 let base = tok_pos * vocab;
479 if base + vocab > logits.len() {
480 break;
481 }
482 let (tok, conf) = sample_logits(&logits[base..base + vocab], 0.0, None, None, false);
483 x0.push(tok);
484 x0_p.push(conf);
485 }
486 Ok(BlockForwardOutput { x0, x0_p })
487 }
488}