1use crate::config::LLaDA2MoeConfig;
19use crate::mask::block_diffusion_attention_mask;
20use crate::sampling::sample_logits;
21use crate::tide::{TideOffloadStats, refresh_experts};
22
23#[derive(Debug, Clone, Default)]
25pub struct BlockDenoiseStepStats {
26 pub block: usize,
27 pub step: usize,
28 pub elapsed_ms: f64,
29 pub active_tokens: usize,
30 pub transferred_tokens: usize,
31 pub offload_stats: TideOffloadStats,
32}
33
34pub fn num_transfer_tokens_schedule(block_length: usize, steps: usize) -> Vec<usize> {
36 if steps == 0 {
37 return Vec::new();
38 }
39 let base = block_length / steps;
40 let remainder = block_length % steps;
41 let mut schedule = vec![base; steps];
42 for slot in schedule.iter_mut().take(remainder) {
43 *slot += 1;
44 }
45 schedule
46}
47
48#[derive(Debug, Clone)]
50pub struct GenerateConfig {
51 pub temperature: f32,
52 pub block_length: usize,
53 pub steps: usize,
54 pub gen_length: usize,
55 pub top_p: Option<f32>,
56 pub top_k: Option<usize>,
57 pub eos_early_stop: bool,
58 pub minimal_topk: usize,
59 pub threshold: f32,
61 pub eos_id: u32,
62 pub mask_id: u32,
63 pub do_sample: bool,
64 pub predictive_offload_enabled: bool,
65 pub jump_steps: usize,
66 pub collect_stats: bool,
67}
68
69impl GenerateConfig {
70 pub fn from_model(cfg: &LLaDA2MoeConfig) -> Self {
71 Self {
72 temperature: 0.0,
73 block_length: 32,
74 steps: 32,
75 gen_length: 2048,
76 top_p: None,
77 top_k: None,
78 eos_early_stop: false,
79 minimal_topk: 1,
80 threshold: 0.9,
81 eos_id: cfg.eos_token_id,
82 mask_id: cfg.mask_token_id,
83 do_sample: false,
84 predictive_offload_enabled: false,
85 jump_steps: 1,
86 collect_stats: false,
87 }
88 }
89}
90
91pub trait GenerateForward {
92 fn forward_window(
93 &mut self,
94 tokens: &[u32],
95 window_len: usize,
96 attn_mask: &[f32],
97 position_ids: &[f32],
98 refresh_experts: bool,
99 ) -> anyhow::Result<Vec<f32>>;
100}
101
102#[derive(Debug, Clone, Copy, Default)]
104pub struct DenoiseStepCtx {
105 pub num_block: usize,
106 pub prefill_blocks: usize,
107 pub denoise_step: usize,
108}
109
110pub trait BlockDenoiseSampler {
112 fn sample_block(
113 &mut self,
114 x: &[u32],
115 window_end: usize,
116 block_length: usize,
117 refresh_experts: bool,
118 gen_cfg: &GenerateConfig,
119 model_cfg: &LLaDA2MoeConfig,
120 step_ctx: DenoiseStepCtx,
121 ) -> anyhow::Result<(Vec<u32>, Vec<f32>)>;
122}
123
124impl<F: GenerateForward> BlockDenoiseSampler for F {
125 fn sample_block(
126 &mut self,
127 x: &[u32],
128 window_end: usize,
129 block_length: usize,
130 refresh_experts: bool,
131 gen_cfg: &GenerateConfig,
132 model_cfg: &LLaDA2MoeConfig,
133 _step_ctx: DenoiseStepCtx,
134 ) -> anyhow::Result<(Vec<u32>, Vec<f32>)> {
135 let mask = block_diffusion_attention_mask(1, window_end, block_length);
136 let position_ids: Vec<f32> = (0..window_end).map(|i| i as f32).collect();
137 let logits = self.forward_window(
138 &x[..window_end],
139 window_end,
140 &mask,
141 &position_ids,
142 refresh_experts,
143 )?;
144 let block_start = window_end.saturating_sub(block_length);
145 let vocab = model_cfg.vocab_size;
146 let mut x0 = vec![0u32; block_length];
147 let mut x0_p = vec![0f32; block_length];
148 for i in 0..block_length {
149 let pos = block_start + i;
150 if pos >= window_end {
151 x0[i] = gen_cfg.mask_id;
152 x0_p[i] = 0.0;
153 continue;
154 }
155 let base = pos * vocab;
156 let (tok, prob) = sample_logits(
157 &logits[base..base + vocab],
158 gen_cfg.temperature,
159 gen_cfg.top_k,
160 gen_cfg.top_p,
161 gen_cfg.do_sample,
162 );
163 x0[i] = tok;
164 x0_p[i] = prob;
165 }
166 Ok((x0, x0_p))
167 }
168}
169
170pub fn generate<S: BlockDenoiseSampler>(
172 sampler: &mut S,
173 cfg: &LLaDA2MoeConfig,
174 gen_cfg: &GenerateConfig,
175 prompt_ids: &[u32],
176) -> anyhow::Result<(Vec<u32>, Vec<BlockDenoiseStepStats>)> {
177 run_block_diffusion(sampler, cfg, gen_cfg, prompt_ids, |_| {
178 TideOffloadStats::default()
179 })
180}
181
182pub fn run_block_diffusion<S: BlockDenoiseSampler>(
184 sampler: &mut S,
185 cfg: &LLaDA2MoeConfig,
186 gen_cfg: &GenerateConfig,
187 prompt_ids: &[u32],
188 mut offload_stats: impl FnMut(&mut S) -> TideOffloadStats,
189) -> anyhow::Result<(Vec<u32>, Vec<BlockDenoiseStepStats>)> {
190 let steps = gen_cfg
191 .steps
192 .min(gen_cfg.gen_length / gen_cfg.minimal_topk.max(1));
193 let block_length = gen_cfg.block_length;
194 let prompt_length = prompt_ids.len();
195 let num_blocks = (prompt_length + gen_cfg.gen_length).div_ceil(block_length);
196 let total_length = num_blocks * block_length;
197 let prefill_blocks = prompt_length / block_length;
198
199 let mut x = vec![gen_cfg.mask_id; total_length];
200 x[..prompt_length].copy_from_slice(prompt_ids);
201
202 let transfer_schedule = num_transfer_tokens_schedule(block_length, steps);
203 let mut stats = Vec::new();
204
205 for num_block in prefill_blocks..num_blocks {
206 let window_end = (num_block + 1) * block_length;
207
208 for step in 0..steps {
209 let block_start = window_end.saturating_sub(block_length);
210 let active_tokens = x[block_start..window_end]
211 .iter()
212 .filter(|&&t| t == gen_cfg.mask_id)
213 .count();
214 if active_tokens == 0 {
215 break;
216 }
217
218 let refresh = refresh_experts(
219 gen_cfg.predictive_offload_enabled,
220 gen_cfg.jump_steps,
221 num_block,
222 prefill_blocks,
223 step,
224 );
225
226 let t0 = std::time::Instant::now();
227 let step_ctx = DenoiseStepCtx {
228 num_block,
229 prefill_blocks,
230 denoise_step: step,
231 };
232 let (x0, x0_p) = sampler.sample_block(
233 &x,
234 window_end,
235 block_length,
236 refresh,
237 gen_cfg,
238 cfg,
239 step_ctx,
240 )?;
241 let elapsed_ms = t0.elapsed().as_secs_f64() * 1000.0;
242
243 let num_to_transfer = transfer_schedule
244 .get(step)
245 .copied()
246 .unwrap_or(0)
247 .min(active_tokens);
248
249 let mut transfer = vec![false; block_length];
250 let mut high_conf = 0usize;
251 for i in 0..block_length {
252 if x[block_start + i] != gen_cfg.mask_id {
253 continue;
254 }
255 if x0_p[i] > gen_cfg.threshold {
256 transfer[i] = true;
257 high_conf += 1;
258 }
259 }
260 if high_conf < num_to_transfer {
261 let mut ranked: Vec<(f32, usize)> = (0..block_length)
262 .filter(|&i| x[block_start + i] == gen_cfg.mask_id)
263 .map(|i| (x0_p[i], i))
264 .collect();
265 ranked.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
266 for (_, i) in ranked.into_iter().take(num_to_transfer) {
267 transfer[i] = true;
268 }
269 }
270
271 let mut transferred = 0usize;
272 for (i, &sel) in transfer.iter().enumerate() {
273 if sel {
274 x[block_start + i] = x0[i];
275 transferred += 1;
276 }
277 }
278
279 if gen_cfg.collect_stats {
280 stats.push(BlockDenoiseStepStats {
281 block: num_block,
282 step,
283 elapsed_ms,
284 active_tokens,
285 transferred_tokens: transferred,
286 offload_stats: offload_stats(sampler),
287 });
288 }
289
290 if gen_cfg.eos_early_stop
291 && transfer
292 .iter()
293 .zip(x0.iter())
294 .any(|(&s, &t)| s && t == gen_cfg.eos_id)
295 {
296 if let Some(eos_pos) = x.iter().position(|&t| t == gen_cfg.eos_id) {
297 if x[prompt_length..eos_pos]
298 .iter()
299 .all(|&t| t != gen_cfg.mask_id)
300 {
301 return Ok((x[prompt_length..=eos_pos].to_vec(), stats));
302 }
303 }
304 }
305 }
306
307 if x[prompt_length..window_end].contains(&gen_cfg.eos_id) {
308 break;
309 }
310 }
311
312 let end = (prompt_length + gen_cfg.gen_length).min(x.len());
313 let slice = &x[prompt_length..end];
314 let eos_off = slice
315 .iter()
316 .position(|&t| t == gen_cfg.eos_id)
317 .map(|p| p + 1)
318 .unwrap_or(slice.len());
319 Ok((slice[..eos_off].to_vec(), stats))
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn transfer_schedule_matches_tide() {
328 assert_eq!(num_transfer_tokens_schedule(32, 32), vec![1; 32]);
329 assert_eq!(num_transfer_tokens_schedule(10, 3), vec![4, 3, 3]);
330 }
331
332 #[test]
333 fn from_model_threshold_matches_eval_dinfer() {
334 let cfg = crate::llada2::synth::tiny_cfg();
335 assert!((GenerateConfig::from_model(&cfg).threshold - 0.9).abs() < f32::EPSILON);
336 }
337}