use crate::interpolant;
#[inline]
pub fn flux_sigma_to_paper_t(sigma: f32) -> f32 {
(1.0 - sigma).clamp(0.0, 1.0)
}
#[inline]
pub fn paper_t_to_flux_sigma(t: f32) -> f32 {
(1.0 - t).clamp(0.0, 1.0)
}
#[inline]
pub fn flux_x0_from_velocity(x: f32, sigma: f32, velocity: f32) -> f32 {
x - sigma * velocity
}
#[inline]
pub fn paper_denoiser_from_flux(x: f32, sigma: f32, velocity: f32) -> f32 {
flux_x0_from_velocity(x, sigma, velocity)
}
#[inline]
pub fn flux_guidance_coefficient(sigma: f32, max_abs: f32) -> f32 {
let t = sigma.clamp(1e-6, 1.0 - 1e-6);
let alpha = 1.0 - t;
(t / alpha).clamp(0.0, max_abs)
}
#[inline]
pub fn paper_guidance_coefficient_from_flux_sigma(sigma: f32) -> f32 {
interpolant::guidance_coefficient(flux_sigma_to_paper_t(sigma))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_time() {
for s in [0.0f32, 0.25, 0.5, 0.75, 1.0] {
assert!((paper_t_to_flux_sigma(flux_sigma_to_paper_t(s)) - s).abs() < 1e-6);
}
}
#[test]
fn guidance_matches_flux_formula() {
let sigma = 0.4f32;
let b_flux = flux_guidance_coefficient(sigma, 100.0);
let expected = sigma / (1.0 - sigma);
assert!((b_flux - expected).abs() < 1e-5);
}
}