Skip to main content

rlx_flux2/diamond/
params.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Diamond Maps guidance parameters for FLUX.2 sampling.
17
18/// HuggingFace repo for the reference FLUX.1-dev flow-map LoRA (weighted Diamond Maps).
19pub const FLOW_MAP_LORA_HF_REPO: &str = "gabeguofanclub/flux-1-dev-flowmap-lsd";
20
21/// Default weight file inside [`FLOW_MAP_LORA_HF_REPO`] (see reference `weighted_diamond_maps`).
22pub const FLOW_MAP_LORA_HF_WEIGHT: &str = concat!(
23    "01-12-26/runs/res_512_steps_50k_rank_64_lr_1e-4/checkpoint-43000/",
24    "pytorch_lora_weights.safetensors"
25);
26
27/// Which inference-time reward alignment path to use.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum DiamondMethod {
30    /// Multi-step GLASS posterior + value gradient (no flow-map weights).
31    #[default]
32    Glass,
33    /// Renoise + flow-map-style x0 lookahead (single-timestep denoiser; no dual-time LoRA).
34    Weighted,
35    /// Denoiser approximation V_t ≈ r(D_t(x_t)) (fast baseline).
36    Dps,
37}
38
39impl DiamondMethod {
40    pub fn parse(s: &str) -> Option<Self> {
41        match s.trim().to_lowercase().as_str() {
42            "glass" => Some(Self::Glass),
43            "weighted" | "weighted_diamond" => Some(Self::Weighted),
44            "dps" | "flow_map" | "fmtt" => Some(Self::Dps),
45            _ => None,
46        }
47    }
48}
49
50/// Inference-time reward alignment settings (no base-model retraining).
51#[derive(Debug, Clone)]
52pub struct DiamondGuidanceParams {
53    pub method: DiamondMethod,
54    /// Monte Carlo particles for value / gradient estimation.
55    pub mc_samples: usize,
56    /// Inner GLASS ODE steps per particle (`Glass` only).
57    pub inner_steps: usize,
58    /// Last N outer denoising steps that apply reward guidance.
59    pub guidance_steps: usize,
60    /// Multiplier on reward before softmax.
61    pub reward_scale: f32,
62    /// Max |b_t| for FLUX guidance coefficient.
63    pub max_guidance_b: f32,
64    /// SNR factor for weighted renoising time t′ (`Weighted` only).
65    pub snr_factor: f32,
66    /// Include Gaussian likelihood term in weighted gradient.
67    pub include_likelihood: bool,
68    /// Include score correction in weighted gradient.
69    pub include_score: bool,
70    /// Softmax logits use full weighting (likelihood + score + reward).
71    pub include_weights: bool,
72    /// Temperature on particle logits when `include_weights`.
73    pub weight_temperature: f32,
74    /// Scale combined guidance vector before Euler step.
75    pub gradient_norm_scale: f32,
76    /// Use dual-time flow-map x0 for weighted particles (`Weighted` only).
77    pub use_flow_map: bool,
78    /// Evaluate reward on VAE-decoded RGB when VAE is loaded.
79    pub decode_reward: bool,
80    /// RNG seed offset for particle noise.
81    pub seed: u64,
82}
83
84impl Default for DiamondGuidanceParams {
85    fn default() -> Self {
86        Self {
87            method: DiamondMethod::Glass,
88            mc_samples: 4,
89            inner_steps: 10,
90            guidance_steps: 5,
91            reward_scale: 1.0,
92            max_guidance_b: 20.0,
93            snr_factor: 5.0,
94            include_likelihood: true,
95            include_score: true,
96            include_weights: false,
97            weight_temperature: 1.0,
98            gradient_norm_scale: 1.0,
99            use_flow_map: true,
100            decode_reward: false,
101            seed: 0,
102        }
103    }
104}