use anyhow::Result;
use candle_core::{DType, IndexOp, Tensor};
use std::time::Instant;
use super::transformer::SD3Transformer;
use crate::engine::cfg_active;
use crate::img_utils;
use crate::progress::{ProgressEvent, ProgressReporter};
pub struct SkipLayerGuidanceConfig {
pub scale: f64,
pub start: f64,
pub end: f64,
pub layers: Vec<usize>,
}
fn debug_tensor_stats(name: &str, tensor: &Tensor) {
if std::env::var_os("MOLD_SD3_DEBUG").is_none() {
return;
}
let stats = || -> Result<(f32, f32)> {
let min = tensor.min_all()?.to_dtype(DType::F32)?.to_scalar::<f32>()?;
let max = tensor.max_all()?.to_dtype(DType::F32)?.to_scalar::<f32>()?;
Ok((min, max))
};
match stats() {
Ok((min, max)) => eprintln!("[sd3-debug] {name}: min={min:.4} max={max:.4}"),
Err(err) => eprintln!("[sd3-debug] {name}: <failed: {err}>"),
}
}
#[allow(clippy::too_many_arguments)]
pub fn euler_sample(
mmdit: &SD3Transformer,
y: &Tensor,
context: &Tensor,
num_inference_steps: usize,
cfg_scale: f64,
cfg_plus: bool,
time_shift: f64,
height: usize,
width: usize,
slg_config: Option<&SkipLayerGuidanceConfig>,
is_quantized: bool,
seed: u64,
progress: &ProgressReporter,
initial_latents: Option<&Tensor>,
sigmas_override: Option<Vec<f64>>,
inpaint_ctx: Option<&img_utils::InpaintContext>,
) -> Result<Tensor> {
let noise_dtype = if is_quantized { DType::F32 } else { DType::F16 };
let latent_h = height / 16 * 2;
let latent_w = width / 16 * 2;
let mut x = if let Some(latents) = initial_latents {
latents.clone()
} else {
crate::engine::seeded_randn(seed, &[1, 16, latent_h, latent_w], y.device(), noise_dtype)?
};
let sigmas = sigmas_override.unwrap_or_else(|| {
(0..=num_inference_steps)
.map(|s| s as f64 / num_inference_steps as f64)
.rev()
.map(|t| time_snr_shift(time_shift, t))
.collect()
});
let total_steps = sigmas.len().saturating_sub(1);
let use_cfg = cfg_active(cfg_scale);
let (y_cond_only, context_cond_only) = if use_cfg {
(None, None)
} else {
(Some(y.i(..1)?), Some(context.i(..1)?))
};
let cfg_plus_active = cfg_plus && use_cfg;
if cfg_plus && !use_cfg {
tracing::warn!(
cfg_scale,
"cfg_plus requested but cfg_scale ≈ 1.0 — falling back to standard step (no uncond available)"
);
}
for (step, window) in sigmas.windows(2).enumerate() {
let step_start = Instant::now();
let (s_curr, s_prev) = match window {
[a, b] => (a, b),
_ => continue,
};
let timestep = (*s_curr) * 1000.0;
let (mut guidance, noise_pred_full) = if use_cfg {
let noise_pred = mmdit.forward(
&Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
None,
)?;
if step == 0 {
debug_tensor_stats("noise_pred", &noise_pred);
}
let g = apply_cfg(cfg_scale, &noise_pred)?;
if step == 0 {
debug_tensor_stats("guidance", &g);
}
(g, noise_pred)
} else {
let noise_pred = mmdit.forward(
&x,
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
y_cond_only.as_ref().expect("cfg-disabled cond slice"),
context_cond_only.as_ref().expect("cfg-disabled cond slice"),
None,
)?;
if step == 0 {
debug_tensor_stats("noise_pred (cfg=1)", &noise_pred);
}
(noise_pred.clone(), noise_pred)
};
if let Some(slg_config) = slg_config {
if (total_steps as f64) * slg_config.start < (step as f64)
&& (step as f64) < (total_steps as f64) * slg_config.end
{
let slg_noise_pred = mmdit.forward(
&x,
&Tensor::full(timestep as f32, (1,), x.device())?.contiguous()?,
&y.i(..1)?,
&context.i(..1)?,
Some(&slg_config.layers),
)?;
guidance = (guidance
+ (slg_config.scale * (noise_pred_full.i(..1)? - slg_noise_pred.i(..1))?)?)?;
}
}
x = if cfg_plus_active {
cfg_plus_step(&x, &guidance, &noise_pred_full, *s_curr, *s_prev)?
} else {
(x + (&guidance * (*s_prev - *s_curr))?)?
};
if let Some(ctx) = inpaint_ctx {
x = crate::img2img::apply_flow_match_inpaint(&x, ctx, *s_prev)?;
}
if step + 1 == total_steps {
debug_tensor_stats("latents_final", &x);
}
progress.emit(ProgressEvent::DenoiseStep {
step: step + 1,
total: total_steps,
elapsed: step_start.elapsed(),
});
}
Ok(x)
}
pub fn time_snr_shift(alpha: f64, t: f64) -> f64 {
alpha * t / (1.0 + (alpha - 1.0) * t)
}
fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result<Tensor> {
Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)?
- ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?)
}
fn cfg_plus_step(
x: &Tensor,
guidance: &Tensor,
noise_pred_full: &Tensor,
s_curr: f64,
s_prev: f64,
) -> Result<Tensor> {
let v_uncond = noise_pred_full.narrow(0, 1, 1)?;
let x0_estimate = (x - (guidance * s_curr)?)?;
Ok((x0_estimate + (v_uncond * s_prev)?)?)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_time_snr_shift_alpha_1() {
for i in 0..=100 {
let t = i as f64 / 100.0;
let shifted = time_snr_shift(1.0, t);
assert!(
(shifted - t).abs() < 1e-12,
"alpha=1 should be identity: time_snr_shift(1.0, {t}) = {shifted}, expected {t}"
);
}
}
#[test]
fn test_time_snr_shift_boundaries() {
for alpha in [0.5, 1.0, 3.0, 10.0, 100.0] {
let at_zero = time_snr_shift(alpha, 0.0);
let at_one = time_snr_shift(alpha, 1.0);
assert!(
at_zero.abs() < 1e-12,
"time_snr_shift({alpha}, 0.0) = {at_zero}, expected 0.0"
);
assert!(
(at_one - 1.0).abs() < 1e-12,
"time_snr_shift({alpha}, 1.0) = {at_one}, expected 1.0"
);
}
}
#[test]
fn test_time_snr_shift_midpoint() {
let result = time_snr_shift(3.0, 0.5);
assert!(
(result - 0.75).abs() < 1e-12,
"time_snr_shift(3.0, 0.5) = {result}, expected 0.75"
);
}
#[test]
fn test_time_snr_shift_monotonic() {
let alpha = 3.0;
let mut prev = time_snr_shift(alpha, 0.0);
for i in 1..=100 {
let t = i as f64 / 100.0;
let curr = time_snr_shift(alpha, t);
assert!(
curr >= prev - 1e-12,
"non-monotonic at t={t}: {curr} < {prev}"
);
prev = curr;
}
}
#[test]
fn test_apply_cfg_scale_1() {
let dev = Device::Cpu;
let cond = Tensor::new(&[[1.0f32, 2.0, 3.0]], &dev).unwrap();
let uncond = Tensor::new(&[[10.0f32, 20.0, 30.0]], &dev).unwrap();
let noise_pred = Tensor::cat(&[&cond, &uncond], 0).unwrap();
let result = apply_cfg(1.0, &noise_pred).unwrap();
let result_vec: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
let cond_vec: Vec<f32> = cond.flatten_all().unwrap().to_vec1().unwrap();
for (r, c) in result_vec.iter().zip(cond_vec.iter()) {
assert!(
(r - c).abs() < 1e-6,
"cfg=1 should return cond: got {r}, expected {c}"
);
}
}
#[test]
fn test_apply_cfg_scale_7_5() {
let dev = Device::Cpu;
let cond = Tensor::new(&[[2.0f32, 4.0]], &dev).unwrap();
let uncond = Tensor::new(&[[1.0f32, 1.0]], &dev).unwrap();
let noise_pred = Tensor::cat(&[&cond, &uncond], 0).unwrap();
let result = apply_cfg(7.5, &noise_pred).unwrap();
let result_vec: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
let expected = [8.5f32, 23.5];
for (r, e) in result_vec.iter().zip(expected.iter()) {
assert!(
(r - e).abs() < 1e-4,
"cfg=7.5 mismatch: got {r}, expected {e}"
);
}
}
#[test]
fn test_sigma_schedule_endpoints() {
let num_steps = 28;
let alpha = 3.0;
let sigmas: Vec<f64> = (0..=num_steps)
.map(|s| s as f64 / num_steps as f64)
.rev()
.map(|t| time_snr_shift(alpha, t))
.collect();
assert_eq!(
sigmas.len(),
num_steps + 1,
"schedule length should be steps+1"
);
assert!(
(sigmas[0] - 1.0).abs() < 1e-12,
"first sigma should be 1.0, got {}",
sigmas[0]
);
assert!(
sigmas[sigmas.len() - 1].abs() < 1e-12,
"last sigma should be 0.0, got {}",
sigmas[sigmas.len() - 1]
);
}
#[test]
fn test_cfg_disabled_at_guidance_1_0() {
assert!(!cfg_active(1.0));
}
#[test]
fn test_cfg_disabled_just_below_1_0() {
assert!(!cfg_active(1.0 - 1e-5));
}
#[test]
fn test_cfg_enabled_at_guidance_1_5() {
assert!(cfg_active(1.5));
}
#[test]
fn test_cfg_enabled_at_guidance_7_5() {
assert!(cfg_active(7.5));
}
fn toy_noise_pair(cond: &[f32], uncond: &[f32], cfg_scale: f64) -> (Tensor, Tensor) {
assert_eq!(cond.len(), uncond.len(), "cond/uncond shapes must match");
let n = cond.len();
let dev = Device::Cpu;
let cond_t = Tensor::from_slice(cond, (1, n), &dev).unwrap();
let uncond_t = Tensor::from_slice(uncond, (1, n), &dev).unwrap();
let noise_pred = Tensor::cat(&[&cond_t, &uncond_t], 0).unwrap();
let guidance = apply_cfg(cfg_scale, &noise_pred).unwrap();
(noise_pred, guidance)
}
#[test]
fn cfg_plus_step_matches_manifold_formula() {
let dev = Device::Cpu;
let x = Tensor::new(&[[10.0f32, 20.0, 30.0]], &dev).unwrap();
let (noise_pred, guidance) = toy_noise_pair(&[2.0, 4.0, 6.0], &[1.0, 1.0, 1.0], 7.5);
let s_curr = 0.8;
let s_prev = 0.6;
let result = cfg_plus_step(&x, &guidance, &noise_pred, s_curr, s_prev).unwrap();
let result_vec: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
let expected = [
10.0 - 0.8 * 8.5 + 0.6 * 1.0,
20.0 - 0.8 * 23.5 + 0.6 * 1.0,
30.0 - 0.8 * 38.5 + 0.6 * 1.0,
];
for (got, exp) in result_vec.iter().zip(expected.iter()) {
assert!(
(got - exp).abs() < 1e-4,
"cfg++ step mismatch: got {got}, expected {exp}"
);
}
}
#[test]
fn cfg_plus_step_collapses_to_standard_when_cond_eq_uncond() {
let dev = Device::Cpu;
let x = Tensor::new(&[[5.0f32, 7.0]], &dev).unwrap();
let (noise_pred, guidance) = toy_noise_pair(&[3.0, 4.0], &[3.0, 4.0], 7.5);
let s_curr = 0.5;
let s_prev = 0.25;
let cfg_pp = cfg_plus_step(&x, &guidance, &noise_pred, s_curr, s_prev).unwrap();
let standard = (&x + (&guidance * (s_prev - s_curr)).unwrap()).unwrap();
let cfg_pp_vec: Vec<f32> = cfg_pp.flatten_all().unwrap().to_vec1().unwrap();
let std_vec: Vec<f32> = standard.flatten_all().unwrap().to_vec1().unwrap();
for (a, b) in cfg_pp_vec.iter().zip(std_vec.iter()) {
assert!(
(a - b).abs() < 1e-5,
"cfg++ ≠ standard when v_cond=v_uncond: got {a}, expected {b}"
);
}
}
#[test]
fn cfg_plus_step_diverges_from_standard_under_high_cfg() {
let dev = Device::Cpu;
let x = Tensor::new(&[[0.0f32, 0.0, 0.0]], &dev).unwrap();
let (noise_pred, guidance) = toy_noise_pair(&[2.0, 4.0, 6.0], &[1.0, 1.0, 1.0], 7.5);
let s_curr = 0.9;
let s_prev = 0.7;
let cfg_pp = cfg_plus_step(&x, &guidance, &noise_pred, s_curr, s_prev).unwrap();
let standard = (&x + (&guidance * (s_prev - s_curr)).unwrap()).unwrap();
let cfg_pp_vec: Vec<f32> = cfg_pp.flatten_all().unwrap().to_vec1().unwrap();
let std_vec: Vec<f32> = standard.flatten_all().unwrap().to_vec1().unwrap();
let max_diff = cfg_pp_vec
.iter()
.zip(std_vec.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff > 0.1,
"cfg++ should differ from standard at cfg=7.5 with v_cond≠v_uncond, max_diff={max_diff}"
);
}
}