Skip to main content

rlx_llada2/tide/
generate.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX — block diffusion `generate()` (TIDE `LLaDA2MoeModelLM.generate`).
17
18use crate::config::LLaDA2MoeConfig;
19use crate::mask::block_diffusion_attention_mask;
20use crate::sampling::sample_logits;
21use crate::tide::{TideOffloadStats, refresh_experts};
22
23/// One denoise step record when `collect_stats` is enabled.
24#[derive(Debug, Clone, Default)]
25pub struct BlockDenoiseStepStats {
26    pub block: usize,
27    pub step: usize,
28    pub elapsed_ms: f64,
29    pub active_tokens: usize,
30    pub transferred_tokens: usize,
31    pub offload_stats: TideOffloadStats,
32}
33
34/// Schedule of mask tokens to unmask per denoise step (TIDE `_get_num_transfer_tokens`).
35pub fn num_transfer_tokens_schedule(block_length: usize, steps: usize) -> Vec<usize> {
36    if steps == 0 {
37        return Vec::new();
38    }
39    let base = block_length / steps;
40    let remainder = block_length % steps;
41    let mut schedule = vec![base; steps];
42    for slot in schedule.iter_mut().take(remainder) {
43        *slot += 1;
44    }
45    schedule
46}
47
48/// Generation options matching PyTorch `LLaDA2MoeModelLM.generate`.
49#[derive(Debug, Clone)]
50pub struct GenerateConfig {
51    pub temperature: f32,
52    pub block_length: usize,
53    pub steps: usize,
54    pub gen_length: usize,
55    pub top_p: Option<f32>,
56    pub top_k: Option<usize>,
57    pub eos_early_stop: bool,
58    pub minimal_topk: usize,
59    /// Confidence threshold for unmasking (`eval_dinfer.py` default 0.9).
60    pub threshold: f32,
61    pub eos_id: u32,
62    pub mask_id: u32,
63    pub do_sample: bool,
64    pub predictive_offload_enabled: bool,
65    pub jump_steps: usize,
66    pub collect_stats: bool,
67}
68
69impl GenerateConfig {
70    pub fn from_model(cfg: &LLaDA2MoeConfig) -> Self {
71        Self {
72            temperature: 0.0,
73            block_length: 32,
74            steps: 32,
75            gen_length: 2048,
76            top_p: None,
77            top_k: None,
78            eos_early_stop: false,
79            minimal_topk: 1,
80            threshold: 0.9,
81            eos_id: cfg.eos_token_id,
82            mask_id: cfg.mask_token_id,
83            do_sample: false,
84            predictive_offload_enabled: false,
85            jump_steps: 1,
86            collect_stats: false,
87        }
88    }
89}
90
91pub trait GenerateForward {
92    fn forward_window(
93        &mut self,
94        tokens: &[u32],
95        window_len: usize,
96        attn_mask: &[f32],
97        position_ids: &[f32],
98        refresh_experts: bool,
99    ) -> anyhow::Result<Vec<f32>>;
100}
101
102/// Per-denoise-step context for MoE expert refresh (TIDE `generate` loop).
103#[derive(Debug, Clone, Copy, Default)]
104pub struct DenoiseStepCtx {
105    pub num_block: usize,
106    pub prefill_blocks: usize,
107    pub denoise_step: usize,
108}
109
110/// Sample `(x0, x0_p)` for the trailing block of the active window.
111pub trait BlockDenoiseSampler {
112    fn sample_block(
113        &mut self,
114        x: &[u32],
115        window_end: usize,
116        block_length: usize,
117        refresh_experts: bool,
118        gen_cfg: &GenerateConfig,
119        model_cfg: &LLaDA2MoeConfig,
120        step_ctx: DenoiseStepCtx,
121    ) -> anyhow::Result<(Vec<u32>, Vec<f32>)>;
122}
123
124impl<F: GenerateForward> BlockDenoiseSampler for F {
125    fn sample_block(
126        &mut self,
127        x: &[u32],
128        window_end: usize,
129        block_length: usize,
130        refresh_experts: bool,
131        gen_cfg: &GenerateConfig,
132        model_cfg: &LLaDA2MoeConfig,
133        _step_ctx: DenoiseStepCtx,
134    ) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
135        let mask = block_diffusion_attention_mask(1, window_end, block_length);
136        let position_ids: Vec<f32> = (0..window_end).map(|i| i as f32).collect();
137        let logits = self.forward_window(
138            &x[..window_end],
139            window_end,
140            &mask,
141            &position_ids,
142            refresh_experts,
143        )?;
144        let block_start = window_end.saturating_sub(block_length);
145        let vocab = model_cfg.vocab_size;
146        let mut x0 = vec![0u32; block_length];
147        let mut x0_p = vec![0f32; block_length];
148        for i in 0..block_length {
149            let pos = block_start + i;
150            if pos >= window_end {
151                x0[i] = gen_cfg.mask_id;
152                x0_p[i] = 0.0;
153                continue;
154            }
155            let base = pos * vocab;
156            let (tok, prob) = sample_logits(
157                &logits[base..base + vocab],
158                gen_cfg.temperature,
159                gen_cfg.top_k,
160                gen_cfg.top_p,
161                gen_cfg.do_sample,
162            );
163            x0[i] = tok;
164            x0_p[i] = prob;
165        }
166        Ok((x0, x0_p))
167    }
168}
169
170/// Run TIDE block diffusion; returns generated suffix (after prompt) + stats.
171pub fn generate<S: BlockDenoiseSampler>(
172    sampler: &mut S,
173    cfg: &LLaDA2MoeConfig,
174    gen_cfg: &GenerateConfig,
175    prompt_ids: &[u32],
176) -> anyhow::Result<(Vec<u32>, Vec<BlockDenoiseStepStats>)> {
177    run_block_diffusion(sampler, cfg, gen_cfg, prompt_ids, |_| {
178        TideOffloadStats::default()
179    })
180}
181
182/// Run TIDE block diffusion; returns generated suffix (after prompt) + stats.
183pub fn run_block_diffusion<S: BlockDenoiseSampler>(
184    sampler: &mut S,
185    cfg: &LLaDA2MoeConfig,
186    gen_cfg: &GenerateConfig,
187    prompt_ids: &[u32],
188    mut offload_stats: impl FnMut(&mut S) -> TideOffloadStats,
189) -> anyhow::Result<(Vec<u32>, Vec<BlockDenoiseStepStats>)> {
190    let steps = gen_cfg
191        .steps
192        .min(gen_cfg.gen_length / gen_cfg.minimal_topk.max(1));
193    let block_length = gen_cfg.block_length;
194    let prompt_length = prompt_ids.len();
195    let num_blocks = (prompt_length + gen_cfg.gen_length).div_ceil(block_length);
196    let total_length = num_blocks * block_length;
197    let prefill_blocks = prompt_length / block_length;
198
199    let mut x = vec![gen_cfg.mask_id; total_length];
200    x[..prompt_length].copy_from_slice(prompt_ids);
201
202    let transfer_schedule = num_transfer_tokens_schedule(block_length, steps);
203    let mut stats = Vec::new();
204
205    for num_block in prefill_blocks..num_blocks {
206        let window_end = (num_block + 1) * block_length;
207
208        for step in 0..steps {
209            let block_start = window_end.saturating_sub(block_length);
210            let active_tokens = x[block_start..window_end]
211                .iter()
212                .filter(|&&t| t == gen_cfg.mask_id)
213                .count();
214            if active_tokens == 0 {
215                break;
216            }
217
218            let refresh = refresh_experts(
219                gen_cfg.predictive_offload_enabled,
220                gen_cfg.jump_steps,
221                num_block,
222                prefill_blocks,
223                step,
224            );
225
226            let t0 = std::time::Instant::now();
227            let step_ctx = DenoiseStepCtx {
228                num_block,
229                prefill_blocks,
230                denoise_step: step,
231            };
232            let (x0, x0_p) = sampler.sample_block(
233                &x,
234                window_end,
235                block_length,
236                refresh,
237                gen_cfg,
238                cfg,
239                step_ctx,
240            )?;
241            let elapsed_ms = t0.elapsed().as_secs_f64() * 1000.0;
242
243            let num_to_transfer = transfer_schedule
244                .get(step)
245                .copied()
246                .unwrap_or(0)
247                .min(active_tokens);
248
249            let mut transfer = vec![false; block_length];
250            let mut high_conf = 0usize;
251            for i in 0..block_length {
252                if x[block_start + i] != gen_cfg.mask_id {
253                    continue;
254                }
255                if x0_p[i] > gen_cfg.threshold {
256                    transfer[i] = true;
257                    high_conf += 1;
258                }
259            }
260            if high_conf < num_to_transfer {
261                let mut ranked: Vec<(f32, usize)> = (0..block_length)
262                    .filter(|&i| x[block_start + i] == gen_cfg.mask_id)
263                    .map(|i| (x0_p[i], i))
264                    .collect();
265                ranked.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
266                for (_, i) in ranked.into_iter().take(num_to_transfer) {
267                    transfer[i] = true;
268                }
269            }
270
271            let mut transferred = 0usize;
272            for (i, &sel) in transfer.iter().enumerate() {
273                if sel {
274                    x[block_start + i] = x0[i];
275                    transferred += 1;
276                }
277            }
278
279            if gen_cfg.collect_stats {
280                stats.push(BlockDenoiseStepStats {
281                    block: num_block,
282                    step,
283                    elapsed_ms,
284                    active_tokens,
285                    transferred_tokens: transferred,
286                    offload_stats: offload_stats(sampler),
287                });
288            }
289
290            if gen_cfg.eos_early_stop
291                && transfer
292                    .iter()
293                    .zip(x0.iter())
294                    .any(|(&s, &t)| s && t == gen_cfg.eos_id)
295            {
296                if let Some(eos_pos) = x.iter().position(|&t| t == gen_cfg.eos_id) {
297                    if x[prompt_length..eos_pos]
298                        .iter()
299                        .all(|&t| t != gen_cfg.mask_id)
300                    {
301                        return Ok((x[prompt_length..=eos_pos].to_vec(), stats));
302                    }
303                }
304            }
305        }
306
307        if x[prompt_length..window_end].contains(&gen_cfg.eos_id) {
308            break;
309        }
310    }
311
312    let end = (prompt_length + gen_cfg.gen_length).min(x.len());
313    let slice = &x[prompt_length..end];
314    let eos_off = slice
315        .iter()
316        .position(|&t| t == gen_cfg.eos_id)
317        .map(|p| p + 1)
318        .unwrap_or(slice.len());
319    Ok((slice[..eos_off].to_vec(), stats))
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn transfer_schedule_matches_tide() {
328        assert_eq!(num_transfer_tokens_schedule(32, 32), vec![1; 32]);
329        assert_eq!(num_transfer_tokens_schedule(10, 3), vec![4, 3, 3]);
330    }
331
332    #[test]
333    fn from_model_threshold_matches_eval_dinfer() {
334        let cfg = crate::llada2::synth::tiny_cfg();
335        assert!((GenerateConfig::from_model(&cfg).threshold - 0.9).abs() < f32::EPSILON);
336    }
337}