1use crate::config::LLaDA2MoeConfig;
19use crate::tide::stats::TideOffloadStats;
20use crate::tide::{DenoiseStepCtx, GenerateConfig, run_block_diffusion};
21
22#[derive(Debug, Clone)]
24pub struct BlockDenoiseConfig {
25 pub temperature: f32,
26 pub block_length: usize,
27 pub steps: usize,
28 pub gen_length: usize,
29 pub top_p: Option<f32>,
30 pub top_k: Option<usize>,
31 pub eos_early_stop: bool,
32 pub minimal_topk: usize,
33 pub threshold: f32,
34 pub mask_id: u32,
35 pub eos_id: u32,
36 pub do_sample: bool,
37 pub predictive_offload_enabled: bool,
38 pub jump_steps: usize,
39 pub collect_stats: bool,
40}
41
42impl Default for BlockDenoiseConfig {
43 fn default() -> Self {
44 Self {
45 temperature: 0.0,
46 block_length: 32,
47 steps: 32,
48 gen_length: 2048,
49 top_p: None,
50 top_k: None,
51 eos_early_stop: false,
52 minimal_topk: 1,
53 threshold: 0.9,
54 mask_id: 156895,
55 eos_id: 156892,
56 do_sample: false,
57 predictive_offload_enabled: false,
58 jump_steps: 1,
59 collect_stats: false,
60 }
61 }
62}
63
64impl BlockDenoiseConfig {
65 pub fn to_generate_config(&self) -> GenerateConfig {
66 GenerateConfig {
67 temperature: self.temperature,
68 block_length: self.block_length,
69 steps: self.steps,
70 gen_length: self.gen_length,
71 top_p: self.top_p,
72 top_k: self.top_k,
73 eos_early_stop: self.eos_early_stop,
74 minimal_topk: self.minimal_topk,
75 threshold: self.threshold,
76 eos_id: self.eos_id,
77 mask_id: self.mask_id,
78 do_sample: self.do_sample,
79 predictive_offload_enabled: self.predictive_offload_enabled,
80 jump_steps: self.jump_steps,
81 collect_stats: self.collect_stats,
82 }
83 }
84}
85
86pub use crate::tide::generate::BlockDenoiseStepStats;
88
89pub trait BlockDiffusionForward {
91 fn forward_block(
92 &mut self,
93 token_ids: &[u32],
94 seq_len: usize,
95 refresh_experts: bool,
96 ) -> Result<BlockForwardOutput, anyhow::Error>;
97}
98
99#[derive(Debug, Clone)]
100pub struct BlockForwardOutput {
101 pub x0: Vec<u32>,
102 pub x0_p: Vec<f32>,
103}
104
105struct BlockSamplerAdapter<'a, F: BlockDiffusionForward> {
106 forward: &'a mut F,
107 block_length: usize,
108}
109
110impl<F: BlockDiffusionForward> crate::tide::generate::BlockDenoiseSampler
111 for BlockSamplerAdapter<'_, F>
112{
113 fn sample_block(
114 &mut self,
115 x: &[u32],
116 window_end: usize,
117 block_length: usize,
118 refresh_experts: bool,
119 _gen_cfg: &GenerateConfig,
120 _model_cfg: &LLaDA2MoeConfig,
121 _step_ctx: DenoiseStepCtx,
122 ) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
123 let out = self.forward.forward_block(x, window_end, refresh_experts)?;
124 let mut x0 = vec![0u32; block_length];
125 let mut x0_p = vec![0f32; block_length];
126 for i in 0..block_length.min(out.x0.len()) {
127 x0[i] = out.x0[i];
128 x0_p[i] = out.x0_p.get(i).copied().unwrap_or(0.0);
129 }
130 let _ = self.block_length;
131 Ok((x0, x0_p))
132 }
133}
134
135pub struct BlockDenoiseLoop<F: BlockDiffusionForward> {
137 pub cfg: BlockDenoiseConfig,
138 pub model_cfg: LLaDA2MoeConfig,
139 pub forward: F,
140 pub offload_stats: Option<fn() -> TideOffloadStats>,
141}
142
143impl<F: BlockDiffusionForward> BlockDenoiseLoop<F> {
144 pub fn new(cfg: BlockDenoiseConfig, model_cfg: LLaDA2MoeConfig, forward: F) -> Self {
145 Self {
146 cfg,
147 model_cfg,
148 forward,
149 offload_stats: None,
150 }
151 }
152
153 pub fn with_offload_stats(mut self, f: fn() -> TideOffloadStats) -> Self {
154 self.offload_stats = Some(f);
155 self
156 }
157
158 pub fn generate(
160 &mut self,
161 prompt_ids: &[u32],
162 ) -> Result<(Vec<u32>, Vec<BlockDenoiseStepStats>), anyhow::Error> {
163 let gen_cfg = self.cfg.to_generate_config();
164 let mut adapter = BlockSamplerAdapter {
165 forward: &mut self.forward,
166 block_length: gen_cfg.block_length,
167 };
168 let stats_fn = self.offload_stats;
169 run_block_diffusion(
170 &mut adapter,
171 &self.model_cfg,
172 &gen_cfg,
173 prompt_ids,
174 move |_s| stats_fn.map(|f| f()).unwrap_or_default(),
175 )
176 }
177}