use crate::interpolant;
pub fn guided_velocity(base_velocity: f32, t_paper: f32, grad_v: f32) -> f32 {
base_velocity + interpolant::guidance_coefficient(t_paper) * grad_v
}
pub fn euler_step(x: f32, dt: f32, velocity: f32) -> f32 {
x + dt * velocity
}
pub fn euler_step_vec(x: &mut [f32], dt: f32, velocity: &[f32]) {
assert_eq!(x.len(), velocity.len());
for (xi, &v) in x.iter_mut().zip(velocity.iter()) {
*xi += dt * v;
}
}
pub fn flux_guided_euler_step(
latents: &mut [f32],
base_velocity: &[f32],
grad_v: &[f32],
sigma: f32,
sigma_next: f32,
max_b: f32,
) {
assert_eq!(latents.len(), base_velocity.len());
assert_eq!(latents.len(), grad_v.len());
let t_paper = crate::flux_time::flux_sigma_to_paper_t(sigma);
let b = crate::flux_time::flux_guidance_coefficient(sigma, max_b);
let dt = sigma_next - sigma;
let _ = t_paper;
for ((l, &u), &g) in latents
.iter_mut()
.zip(base_velocity.iter())
.zip(grad_v.iter())
{
*l += dt * (u + b * g);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn euler_increments() {
assert!((euler_step(1.0, 0.1, 2.0) - 1.2).abs() < 1e-6);
}
}