Skip to main content

rlx_llada2/llada2/
runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX — LLaDA2 MoE runner (block diffusion + TIDE offload, all backends).
17
18use crate::builder::build_llada2_forward_graph;
19use crate::capabilities::{default_memory_budget_bytes, validate_device};
20use crate::compile_util::{compile_llada2_built, llada2_profile};
21use crate::config::LLaDA2MoeConfig;
22use crate::gate_op::ensure_group_limited_gate_registered;
23use crate::load::load_llada2_from_dir;
24use crate::mask::block_diffusion_attention_mask;
25use crate::moe_offload::{self, MoeOffloadState};
26use crate::moe_store::{
27    apply_moe_store_to_compiled, build_moe_expert_store, moe_host_bind_from_store,
28};
29use crate::sampling::sample_logits;
30use crate::tide::{
31    BlockDenoiseConfig, BlockDenoiseLoop, BlockDenoiseSampler, BlockDiffusionForward,
32    BlockForwardOutput, DenoiseStepCtx, GenerateConfig, run_block_diffusion,
33};
34use crate::weights::LLaDA2Weights;
35use anyhow::{Result, anyhow};
36use rlx_core::flow_util::built_from_graph;
37use rlx_runtime::{CompiledGraph, Device, MoeExpertStore, MoeResidencyStats};
38
39fn push_moe_residency(compiled: &mut CompiledGraph, layers: &[Vec<bool>]) {
40    let refs: Vec<&[bool]> = layers.iter().map(|m| m.as_slice()).collect();
41    compiled.set_moe_resident_experts_per_layer(&refs);
42}
43
44#[derive(Default)]
45pub struct LLaDA2RunnerBuilder {
46    inline: Option<(LLaDA2MoeConfig, LLaDA2Weights)>,
47    weights_path: Option<std::path::PathBuf>,
48    device: Option<Device>,
49    batch: usize,
50    max_seq: Option<usize>,
51    max_gpu_experts_per_layer: Option<usize>,
52    memory_budget_bytes: Option<usize>,
53    jump_steps: Option<usize>,
54    reserve_vram_gb: f64,
55    moe_collect_stats: bool,
56}
57
58impl LLaDA2RunnerBuilder {
59    pub fn inline_weights(mut self, cfg: LLaDA2MoeConfig, weights: LLaDA2Weights) -> Self {
60        self.inline = Some((cfg, weights));
61        self
62    }
63
64    pub fn weights_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
65        self.weights_path = Some(path.into());
66        self
67    }
68
69    pub fn device(mut self, device: Device) -> Self {
70        self.device = Some(device);
71        self
72    }
73
74    pub fn batch_seq(mut self, batch: usize, max_seq: usize) -> Self {
75        self.batch = batch.max(1);
76        self.max_seq = Some(max_seq.max(1));
77        self
78    }
79
80    pub fn enable_predictive_expert_offload(mut self, max_per_layer: usize) -> Self {
81        self.max_gpu_experts_per_layer = Some(max_per_layer);
82        self
83    }
84
85    pub fn jump_steps(mut self, n: usize) -> Self {
86        self.jump_steps = Some(n);
87        self
88    }
89
90    pub fn reserve_vram_gb(mut self, gb: f64) -> Self {
91        self.reserve_vram_gb = gb;
92        self
93    }
94
95    pub fn moe_collect_stats(mut self, on: bool) -> Self {
96        self.moe_collect_stats = on;
97        self
98    }
99
100    pub fn memory_budget_bytes(mut self, bytes: usize) -> Self {
101        self.memory_budget_bytes = Some(bytes);
102        self
103    }
104
105    pub fn build(self) -> Result<LLaDA2Runner> {
106        ensure_group_limited_gate_registered();
107
108        let (cfg, weights) = match self.inline {
109            Some(x) => x,
110            None => {
111                let path = self.weights_path.as_ref().ok_or_else(|| {
112                    anyhow!("LLaDA2Runner: weights_path or inline_weights required")
113                })?;
114                load_llada2_from_dir(path.as_path())?
115            }
116        };
117
118        let device = self.device.unwrap_or(Device::Cpu);
119        validate_device(&cfg, device)?;
120        let batch = self.batch.max(1);
121        let seq = self.max_seq.unwrap_or(128).max(1);
122
123        let (graph, params) = build_llada2_forward_graph(&cfg, &weights, batch, seq)?;
124        let mut built = built_from_graph(graph, params)?;
125        built.profile = llada2_profile();
126        let mut compiled = compile_llada2_built(built, device)?;
127
128        let moe_store = if cfg.num_experts > 0 {
129            Some(build_moe_expert_store(&cfg, &weights)?)
130        } else {
131            None
132        };
133
134        let mem_budget = self
135            .memory_budget_bytes
136            .or_else(|| default_memory_budget_bytes(device));
137
138        let moe = moe_offload::build_moe_offload(
139            &cfg,
140            &weights,
141            device,
142            self.max_gpu_experts_per_layer,
143            mem_budget,
144            self.jump_steps,
145            self.reserve_vram_gb,
146            self.moe_collect_stats,
147        );
148
149        if let Some(mo) = &moe {
150            push_moe_residency(&mut compiled, &mo.per_layer_resident_masks());
151            compiled.enable_moe_topk_capture(cfg.num_experts);
152            if let Some(store) = &moe_store {
153                apply_moe_store_to_compiled(store, &mut compiled);
154            }
155        }
156
157        Ok(LLaDA2Runner {
158            cfg,
159            weights,
160            compiled,
161            device,
162            batch,
163            seq,
164            block_length: 32,
165            moe,
166            moe_store,
167        })
168    }
169}
170
171pub struct LLaDA2Runner {
172    pub cfg: LLaDA2MoeConfig,
173    pub weights: LLaDA2Weights,
174    compiled: CompiledGraph,
175    device: Device,
176    batch: usize,
177    seq: usize,
178    block_length: usize,
179    moe: Option<MoeOffloadState>,
180    moe_store: Option<MoeExpertStore>,
181}
182
183impl LLaDA2Runner {
184    pub fn builder() -> LLaDA2RunnerBuilder {
185        LLaDA2RunnerBuilder::default()
186    }
187
188    pub fn config(&self) -> &LLaDA2MoeConfig {
189        &self.cfg
190    }
191
192    pub fn device(&self) -> Device {
193        self.device
194    }
195
196    pub fn max_seq(&self) -> usize {
197        self.seq
198    }
199
200    pub fn predictive_offload_enabled(&self) -> bool {
201        self.moe.as_ref().is_some_and(|m| m.predictive_enabled)
202    }
203
204    pub fn jump_steps(&self) -> usize {
205        self.moe.as_ref().map(|m| m.jump_steps).unwrap_or(1)
206    }
207
208    pub fn predictive_offload_info(&self) -> Option<crate::tide::PredictiveOffloadInfo> {
209        self.moe.as_ref().map(|m| m.info.clone())
210    }
211
212    pub fn moe_offload(&self) -> Option<&MoeOffloadState> {
213        self.moe.as_ref()
214    }
215
216    pub fn moe_store(&self) -> Option<&MoeExpertStore> {
217        self.moe_store.as_ref()
218    }
219
220    pub fn sync_moe_residency(&self, compiled: &mut CompiledGraph) {
221        if let Some(mo) = &self.moe {
222            push_moe_residency(compiled, &mo.per_layer_resident_masks());
223            if let Some(store) = &self.moe_store {
224                apply_moe_store_to_compiled(store, compiled);
225            }
226        }
227    }
228
229    fn bind_moe_host_weights(&self) {
230        if self.moe.is_none() {
231            rlx_cpu::moe_residency::bind_host_weights(None);
232            return;
233        }
234        if let Some(store) = &self.moe_store {
235            rlx_cpu::moe_residency::bind_host_weights(Some(moe_host_bind_from_store(store)));
236        } else {
237            rlx_cpu::moe_residency::bind_host_weights(None);
238        }
239    }
240
241    fn refresh_moe_after_forward(&mut self, step_ctx: DenoiseStepCtx, want_refresh: bool) {
242        let Some(layers) = self.compiled.take_moe_topk_capture() else {
243            return;
244        };
245        let Some(mo) = self.moe.as_mut() else {
246            return;
247        };
248        let is_prefill = step_ctx.num_block == step_ctx.prefill_blocks;
249        if !want_refresh || !mo.should_refresh_forward(step_ctx.denoise_step, is_prefill) {
250            return;
251        }
252        let refreshed = if let Some(store) = self.moe_store.as_ref() {
253            mo.refresh_from_capture_with_store(store, &layers, step_ctx.denoise_step, is_prefill)
254        } else {
255            mo.refresh_from_capture(&layers, step_ctx.denoise_step, is_prefill)
256        };
257        if refreshed {
258            let masks = mo.per_layer_resident_masks();
259            push_moe_residency(&mut self.compiled, &masks);
260            if let Some(store) = &self.moe_store {
261                apply_moe_store_to_compiled(store, &mut self.compiled);
262            }
263        }
264    }
265
266    fn forward_window_padded(
267        &mut self,
268        tokens: &[u32],
269        window_len: usize,
270        attn_mask: &[f32],
271        position_ids: &[f32],
272        step_ctx: DenoiseStepCtx,
273        want_refresh: bool,
274    ) -> Result<Vec<f32>> {
275        let b = self.batch;
276        let s = self.seq;
277        let w = window_len.min(tokens.len()).min(s);
278        let mut ids = vec![0f32; b * s];
279        let mut pos = vec![0f32; b * s];
280        for i in 0..w {
281            ids[i] = tokens[i] as f32;
282            pos[i] = position_ids.get(i).copied().unwrap_or(i as f32);
283        }
284        let mut full_mask = vec![f32::NEG_INFINITY; b * s * s];
285        for r in 0..w {
286            for c in 0..w {
287                full_mask[r * s + c] = attn_mask[r * w + c];
288            }
289        }
290        let logits = self.forward_logits(&ids, &pos, &full_mask)?;
291        self.refresh_moe_after_forward(step_ctx, want_refresh);
292        Ok(logits)
293    }
294
295    pub fn forward_logits(
296        &mut self,
297        input_ids: &[f32],
298        position_ids: &[f32],
299        attn_mask: &[f32],
300    ) -> Result<Vec<f32>> {
301        let b = self.batch;
302        let s = self.seq;
303        if input_ids.len() != b * s {
304            return Err(anyhow!("input_ids len {} != {b}*{s}", input_ids.len()));
305        }
306        if attn_mask.len() != b * s * s {
307            return Err(anyhow!(
308                "attn_mask len {} != {b}*1*{s}*{s}",
309                attn_mask.len()
310            ));
311        }
312        self.bind_moe_host_weights();
313        let outs = self.compiled.run(&[
314            ("input_ids", input_ids),
315            ("position_ids", position_ids),
316            ("attn_mask", attn_mask),
317        ]);
318        Ok(outs.into_iter().next().unwrap_or_default())
319    }
320
321    pub fn block_denoise_loop(
322        &mut self,
323        cfg: BlockDenoiseConfig,
324    ) -> BlockDenoiseLoop<LLaDA2RunnerForward<'_>> {
325        self.block_length = cfg.block_length;
326        let model_cfg = self.cfg.clone();
327        BlockDenoiseLoop::new(cfg, model_cfg, LLaDA2RunnerForward { runner: self })
328    }
329
330    pub fn get_offload_stats(&mut self) -> crate::tide::TideOffloadStats {
331        let residency = self
332            .compiled
333            .take_moe_residency_stats()
334            .or_else(rlx_cpu::moe_residency::peek_stats);
335        let residency_ref = residency.as_ref();
336        self.offload_stats(residency_ref).unwrap_or_default()
337    }
338
339    pub fn reset_offload_step_stats(&mut self) {
340        if let Some(mo) = self.moe.as_mut() {
341            for pool in &mut mo.pools {
342                pool.reset_step_stats();
343            }
344        }
345        let _ = self.compiled.take_moe_residency_stats();
346    }
347
348    pub fn generate(
349        &mut self,
350        gen_cfg: &GenerateConfig,
351        prompt_ids: &[u32],
352    ) -> Result<(Vec<u32>, Vec<crate::tide::BlockDenoiseStepStats>)> {
353        let max_window = (prompt_ids.len() + gen_cfg.gen_length).div_ceil(gen_cfg.block_length)
354            * gen_cfg.block_length;
355        if max_window > self.seq {
356            return Err(anyhow!(
357                "generate needs max_seq >= {max_window} (set .batch_seq(batch, max_seq) on builder)"
358            ));
359        }
360        let cfg = self.cfg.clone();
361        let collect = gen_cfg.collect_stats;
362        run_block_diffusion(self, &cfg, gen_cfg, prompt_ids, |runner| {
363            let stats = runner.get_offload_stats();
364            if collect {
365                runner.reset_offload_step_stats();
366            }
367            stats
368        })
369    }
370
371    pub fn offload_stats(
372        &self,
373        residency: Option<&MoeResidencyStats>,
374    ) -> Option<crate::tide::TideOffloadStats> {
375        self.moe
376            .as_ref()
377            .map(|m| moe_offload::tide_stats(m, residency))
378    }
379}
380
381impl BlockDenoiseSampler for LLaDA2Runner {
382    fn sample_block(
383        &mut self,
384        x: &[u32],
385        window_end: usize,
386        block_length: usize,
387        refresh_experts: bool,
388        gen_cfg: &GenerateConfig,
389        model_cfg: &LLaDA2MoeConfig,
390        step_ctx: DenoiseStepCtx,
391    ) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
392        let mask = block_diffusion_attention_mask(1, window_end, block_length);
393        let position_ids: Vec<f32> = (0..window_end).map(|i| i as f32).collect();
394        let logits = self.forward_window_padded(
395            &x[..window_end],
396            window_end,
397            &mask,
398            &position_ids,
399            step_ctx,
400            refresh_experts,
401        )?;
402        let block_start = window_end.saturating_sub(block_length);
403        let vocab = model_cfg.vocab_size;
404        let mut x0 = vec![0u32; block_length];
405        let mut x0_p = vec![0f32; block_length];
406        for i in 0..block_length {
407            let pos = block_start + i;
408            if pos >= window_end {
409                x0[i] = gen_cfg.mask_id;
410                x0_p[i] = 0.0;
411                continue;
412            }
413            let base = pos * vocab;
414            let (tok, prob) = sample_logits(
415                &logits[base..base + vocab],
416                gen_cfg.temperature,
417                gen_cfg.top_k,
418                gen_cfg.top_p,
419                gen_cfg.do_sample,
420            );
421            x0[i] = tok;
422            x0_p[i] = prob;
423        }
424        Ok((x0, x0_p))
425    }
426}
427
428pub struct LLaDA2RunnerForward<'a> {
429    pub runner: &'a mut LLaDA2Runner,
430}
431
432impl BlockDiffusionForward for LLaDA2RunnerForward<'_> {
433    fn forward_block(
434        &mut self,
435        token_ids: &[u32],
436        seq_len: usize,
437        refresh_experts: bool,
438    ) -> Result<BlockForwardOutput, anyhow::Error> {
439        let b = self.runner.batch;
440        let s = self.runner.seq;
441        let block = self.runner.block_length;
442        let window = seq_len.min(token_ids.len()).min(s);
443        let block_start = window.saturating_sub(block);
444
445        let mut ids = vec![0f32; b * s];
446        let mut pos = vec![0f32; b * s];
447        for i in 0..window {
448            ids[i] = token_ids[i] as f32;
449            pos[i] = i as f32;
450        }
451
452        let mask = block_diffusion_attention_mask(b, window, block);
453        let position_ids: Vec<f32> = (0..window).map(|i| i as f32).collect();
454        let step_ctx = DenoiseStepCtx {
455            num_block: 0,
456            prefill_blocks: 0,
457            denoise_step: 0,
458        };
459        let logits = self.runner.forward_window_padded(
460            &token_ids[..window],
461            window,
462            &mask,
463            &position_ids,
464            step_ctx,
465            refresh_experts,
466        )?;
467
468        let vocab = self.runner.cfg.vocab_size;
469        let mut x0 = Vec::with_capacity(block);
470        let mut x0_p = Vec::with_capacity(block);
471        for i in 0..block {
472            let tok_pos = block_start + i;
473            if tok_pos >= window {
474                x0.push(self.runner.cfg.mask_token_id);
475                x0_p.push(0.0);
476                continue;
477            }
478            let base = tok_pos * vocab;
479            if base + vocab > logits.len() {
480                break;
481            }
482            let (tok, conf) = sample_logits(&logits[base..base + vocab], 0.0, None, None, false);
483            x0.push(tok);
484            x0_p.push(conf);
485        }
486        Ok(BlockForwardOutput { x0, x0_p })
487    }
488}