rlx_flux2/diamond/
flow_map.rs1use crate::runner::Flux2Runner;
19use anyhow::{Result, ensure};
20
21#[derive(Debug, Clone)]
23pub struct FlowMapPrediction {
24 pub noise_pred: Vec<f32>,
25 pub x0_hat: Vec<f32>,
27}
28
29pub fn forward_noise_dual(
31 runner: &Flux2Runner,
32 hidden_states: &[f32],
33 encoder_hidden_states: &[f32],
34 sigma: f32,
35 sigma_target: f32,
36 guidance: Option<&[f32]>,
37 img_ids: &[f32],
38 txt_ids: &[f32],
39) -> Result<Vec<f32>> {
40 let batch = runner.batch();
41 let timestep = vec![sigma; batch];
42 let target = vec![sigma_target; batch];
43 if runner.uses_compiled_denoiser() {
44 runner.forward_noise_dual_compiled(
45 hidden_states,
46 encoder_hidden_states,
47 ×tep,
48 &target,
49 guidance,
50 img_ids,
51 txt_ids,
52 )
53 } else {
54 runner.forward_noise_dual_native(
55 hidden_states,
56 encoder_hidden_states,
57 ×tep,
58 &target,
59 guidance,
60 img_ids,
61 txt_ids,
62 )
63 }
64}
65
66pub fn flow_map_predict(
68 runner: &Flux2Runner,
69 latents: &[f32],
70 sigma: f32,
71 sigma_next: f32,
72 encoder_hidden_states: &[f32],
73 guidance: Option<&[f32]>,
74 img_ids: &[f32],
75 txt_ids: &[f32],
76) -> Result<FlowMapPrediction> {
77 ensure!(latents.len() == runner.batch() * runner.img_seq() * runner.config().in_channels);
78 let noise_pred = forward_noise_dual(
79 runner,
80 latents,
81 encoder_hidden_states,
82 sigma,
83 sigma_next,
84 guidance,
85 img_ids,
86 txt_ids,
87 )?;
88 let v0 = forward_noise_dual(
89 runner,
90 latents,
91 encoder_hidden_states,
92 sigma,
93 0.0,
94 guidance,
95 img_ids,
96 txt_ids,
97 )?;
98 let x0_hat = latents
99 .iter()
100 .zip(v0.iter())
101 .map(|(&x, &v)| x - sigma * v)
102 .collect();
103 Ok(FlowMapPrediction { noise_pred, x0_hat })
104}