Skip to main content

rlx_llada2/tide/
diffusion.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16// RLX — block diffusion driver (delegates to [`super::generate`] for TIDE parity).
17
18use crate::config::LLaDA2MoeConfig;
19use crate::tide::stats::TideOffloadStats;
20use crate::tide::{DenoiseStepCtx, GenerateConfig, run_block_diffusion};
21
22/// Block diffusion generation options (TIDE `generate` defaults).
23#[derive(Debug, Clone)]
24pub struct BlockDenoiseConfig {
25    pub temperature: f32,
26    pub block_length: usize,
27    pub steps: usize,
28    pub gen_length: usize,
29    pub top_p: Option<f32>,
30    pub top_k: Option<usize>,
31    pub eos_early_stop: bool,
32    pub minimal_topk: usize,
33    pub threshold: f32,
34    pub mask_id: u32,
35    pub eos_id: u32,
36    pub do_sample: bool,
37    pub predictive_offload_enabled: bool,
38    pub jump_steps: usize,
39    pub collect_stats: bool,
40}
41
42impl Default for BlockDenoiseConfig {
43    fn default() -> Self {
44        Self {
45            temperature: 0.0,
46            block_length: 32,
47            steps: 32,
48            gen_length: 2048,
49            top_p: None,
50            top_k: None,
51            eos_early_stop: false,
52            minimal_topk: 1,
53            threshold: 0.9,
54            mask_id: 156895,
55            eos_id: 156892,
56            do_sample: false,
57            predictive_offload_enabled: false,
58            jump_steps: 1,
59            collect_stats: false,
60        }
61    }
62}
63
64impl BlockDenoiseConfig {
65    pub fn to_generate_config(&self) -> GenerateConfig {
66        GenerateConfig {
67            temperature: self.temperature,
68            block_length: self.block_length,
69            steps: self.steps,
70            gen_length: self.gen_length,
71            top_p: self.top_p,
72            top_k: self.top_k,
73            eos_early_stop: self.eos_early_stop,
74            minimal_topk: self.minimal_topk,
75            threshold: self.threshold,
76            eos_id: self.eos_id,
77            mask_id: self.mask_id,
78            do_sample: self.do_sample,
79            predictive_offload_enabled: self.predictive_offload_enabled,
80            jump_steps: self.jump_steps,
81            collect_stats: self.collect_stats,
82        }
83    }
84}
85
86/// One denoise step record when `collect_stats` is enabled.
87pub use crate::tide::generate::BlockDenoiseStepStats;
88
89/// Forward callback: `refresh_experts` flag is set per TIDE policy before each call.
90pub trait BlockDiffusionForward {
91    fn forward_block(
92        &mut self,
93        token_ids: &[u32],
94        seq_len: usize,
95        refresh_experts: bool,
96    ) -> Result<BlockForwardOutput, anyhow::Error>;
97}
98
99#[derive(Debug, Clone)]
100pub struct BlockForwardOutput {
101    pub x0: Vec<u32>,
102    pub x0_p: Vec<f32>,
103}
104
105struct BlockSamplerAdapter<'a, F: BlockDiffusionForward> {
106    forward: &'a mut F,
107    block_length: usize,
108}
109
110impl<F: BlockDiffusionForward> crate::tide::generate::BlockDenoiseSampler
111    for BlockSamplerAdapter<'_, F>
112{
113    fn sample_block(
114        &mut self,
115        x: &[u32],
116        window_end: usize,
117        block_length: usize,
118        refresh_experts: bool,
119        _gen_cfg: &GenerateConfig,
120        _model_cfg: &LLaDA2MoeConfig,
121        _step_ctx: DenoiseStepCtx,
122    ) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
123        let out = self.forward.forward_block(x, window_end, refresh_experts)?;
124        let mut x0 = vec![0u32; block_length];
125        let mut x0_p = vec![0f32; block_length];
126        for i in 0..block_length.min(out.x0.len()) {
127            x0[i] = out.x0[i];
128            x0_p[i] = out.x0_p.get(i).copied().unwrap_or(0.0);
129        }
130        let _ = self.block_length;
131        Ok((x0, x0_p))
132    }
133}
134
135/// Driver for TIDE-style block masked diffusion (host-side token state).
136pub struct BlockDenoiseLoop<F: BlockDiffusionForward> {
137    pub cfg: BlockDenoiseConfig,
138    pub model_cfg: LLaDA2MoeConfig,
139    pub forward: F,
140    pub offload_stats: Option<fn() -> TideOffloadStats>,
141}
142
143impl<F: BlockDiffusionForward> BlockDenoiseLoop<F> {
144    pub fn new(cfg: BlockDenoiseConfig, model_cfg: LLaDA2MoeConfig, forward: F) -> Self {
145        Self {
146            cfg,
147            model_cfg,
148            forward,
149            offload_stats: None,
150        }
151    }
152
153    pub fn with_offload_stats(mut self, f: fn() -> TideOffloadStats) -> Self {
154        self.offload_stats = Some(f);
155        self
156    }
157
158    /// Run block diffusion from `prompt_ids`; returns generated suffix + optional stats.
159    pub fn generate(
160        &mut self,
161        prompt_ids: &[u32],
162    ) -> Result<(Vec<u32>, Vec<BlockDenoiseStepStats>), anyhow::Error> {
163        let gen_cfg = self.cfg.to_generate_config();
164        let mut adapter = BlockSamplerAdapter {
165            forward: &mut self.forward,
166            block_length: gen_cfg.block_length,
167        };
168        let stats_fn = self.offload_stats;
169        run_block_diffusion(
170            &mut adapter,
171            &self.model_cfg,
172            &gen_cfg,
173            prompt_ids,
174            move |_s| stats_fn.map(|f| f()).unwrap_or_default(),
175        )
176    }
177}