const EPS: f32 = 1e-8;
#[inline]
pub fn clip_zero(x: f32) -> f32 {
x.max(EPS)
}
#[inline]
pub fn alpha(t: f32) -> f32 {
t
}
#[inline]
pub fn sigma(t: f32) -> f32 {
1.0 - t
}
#[inline]
pub fn alpha_dot(_t: f32) -> f32 {
1.0
}
#[inline]
pub fn sigma_dot(_t: f32) -> f32 {
-1.0
}
#[inline]
pub fn g(t: f32) -> f32 {
let a = alpha(t);
let s = sigma(t);
(s * s) / clip_zero(a * a)
}
#[inline]
pub fn g_inv(g_val: f32) -> f32 {
let g_val = g_val.max(0.0);
1.0 / (1.0 + g_val.sqrt())
}
#[inline]
pub fn denoiser_from_velocity(t: f32, x: f32, velocity: f32) -> f32 {
let a = alpha(t);
let s = sigma(t);
let ad = alpha_dot(t);
let sd = sigma_dot(t);
let denom = ad * s - a * sd;
(s * velocity - sd * x) / clip_zero(denom)
}
#[inline]
pub fn velocity_from_denoiser(t: f32, x: f32, denoised: f32) -> f32 {
let s = sigma(t);
let ad = alpha_dot(t);
let sd = sigma_dot(t);
(ad * denoised + sd * x) / clip_zero(s)
}
#[inline]
pub fn guidance_coefficient(t: f32) -> f32 {
let a = alpha(t);
let s = sigma(t);
s * s * alpha_dot(t) / clip_zero(a) - sigma_dot(t) * s
}
#[inline]
pub fn noise_sample(z: f32, t: f32, eps: f32) -> f32 {
alpha(t) * z + sigma(t) * eps
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn boundaries() {
assert!((alpha(0.0)).abs() < 1e-6);
assert!((alpha(1.0) - 1.0).abs() < 1e-6);
assert!((sigma(0.0) - 1.0).abs() < 1e-6);
assert!((sigma(1.0)).abs() < 1e-6);
}
#[test]
fn g_inv_roundtrip() {
for &t in &[0.1f32, 0.3, 0.5, 0.9] {
let gi = g_inv(g(t));
assert!((gi - t).abs() < 1e-5, "t={t} gi={gi}");
}
}
#[test]
fn denoiser_velocity_roundtrip() {
let t = 0.4;
let x = 0.7;
let u = 1.2;
let d = denoiser_from_velocity(t, x, u);
let u2 = velocity_from_denoiser(t, x, d);
assert!((u - u2).abs() < 1e-5);
}
}