Skip to main content

rlx_flux2/diamond/
flow_map.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Flow-map style denoiser calls: dual-time embedding for (t → t′) and x0 lookahead (t → 0).
17
18use crate::runner::Flux2Runner;
19use anyhow::{Result, ensure};
20
21/// Velocity prediction and flow-map x0 estimate at the current noise level.
22#[derive(Debug, Clone)]
23pub struct FlowMapPrediction {
24    pub noise_pred: Vec<f32>,
25    /// `x - σ v0` with v0 from dual-time forward at target σ=0.
26    pub x0_hat: Vec<f32>,
27}
28
29/// One dual-timestep denoiser evaluation (native or compiled).
30pub fn forward_noise_dual(
31    runner: &Flux2Runner,
32    hidden_states: &[f32],
33    encoder_hidden_states: &[f32],
34    sigma: f32,
35    sigma_target: f32,
36    guidance: Option<&[f32]>,
37    img_ids: &[f32],
38    txt_ids: &[f32],
39) -> Result<Vec<f32>> {
40    let batch = runner.batch();
41    let timestep = vec![sigma; batch];
42    let target = vec![sigma_target; batch];
43    if runner.uses_compiled_denoiser() {
44        runner.forward_noise_dual_compiled(
45            hidden_states,
46            encoder_hidden_states,
47            &timestep,
48            &target,
49            guidance,
50            img_ids,
51            txt_ids,
52        )
53    } else {
54        runner.forward_noise_dual_native(
55            hidden_states,
56            encoder_hidden_states,
57            &timestep,
58            &target,
59            guidance,
60            img_ids,
61            txt_ids,
62        )
63    }
64}
65
66/// Flow-map step: predict noise for scheduler (t → t_next) and x0_hat (t → 0).
67pub fn flow_map_predict(
68    runner: &Flux2Runner,
69    latents: &[f32],
70    sigma: f32,
71    sigma_next: f32,
72    encoder_hidden_states: &[f32],
73    guidance: Option<&[f32]>,
74    img_ids: &[f32],
75    txt_ids: &[f32],
76) -> Result<FlowMapPrediction> {
77    ensure!(latents.len() == runner.batch() * runner.img_seq() * runner.config().in_channels);
78    let noise_pred = forward_noise_dual(
79        runner,
80        latents,
81        encoder_hidden_states,
82        sigma,
83        sigma_next,
84        guidance,
85        img_ids,
86        txt_ids,
87    )?;
88    let v0 = forward_noise_dual(
89        runner,
90        latents,
91        encoder_hidden_states,
92        sigma,
93        0.0,
94        guidance,
95        img_ids,
96        txt_ids,
97    )?;
98    let x0_hat = latents
99        .iter()
100        .zip(v0.iter())
101        .map(|(&x, &v)| x - sigma * v)
102        .collect();
103    Ok(FlowMapPrediction { noise_pred, x0_hat })
104}