use crate::error::{NeuralError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SolverType {
Midpoint,
Heun,
}
impl Default for SolverType {
fn default() -> Self {
Self::Midpoint
}
}
#[derive(Debug, Clone)]
pub struct DpmSolverConfig {
pub n_steps: usize,
pub order: usize,
pub solver_type: SolverType,
pub thresholding: bool,
pub dynamic_thresholding_ratio: f64,
pub sigma_min: f64,
pub sigma_max: f64,
}
impl Default for DpmSolverConfig {
fn default() -> Self {
Self {
n_steps: 20,
order: 2,
solver_type: SolverType::Midpoint,
thresholding: false,
dynamic_thresholding_ratio: 0.995,
sigma_min: 0.002,
sigma_max: 80.0,
}
}
}
impl DpmSolverConfig {
pub fn fast_test() -> Self {
Self {
n_steps: 5,
order: 1,
..Default::default()
}
}
pub fn second_order(n_steps: usize) -> Self {
Self {
n_steps,
order: 2,
..Default::default()
}
}
pub fn validate(&self) -> Result<()> {
if self.n_steps == 0 {
return Err(NeuralError::InvalidArgument(
"DpmSolverConfig: n_steps must be > 0".to_string(),
));
}
if self.order == 0 || self.order > 3 {
return Err(NeuralError::InvalidArgument(format!(
"DpmSolverConfig: order must be 1, 2, or 3; got {}",
self.order
)));
}
if self.sigma_min <= 0.0 || self.sigma_max <= self.sigma_min {
return Err(NeuralError::InvalidArgument(format!(
"DpmSolverConfig: require 0 < sigma_min ({}) < sigma_max ({})",
self.sigma_min, self.sigma_max
)));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct DpmSchedule {
s: f64,
alpha_bar_0: f64,
}
impl Default for DpmSchedule {
fn default() -> Self {
Self::new()
}
}
impl DpmSchedule {
pub fn new() -> Self {
let s = 0.008_f64;
let alpha_bar_0 = Self::alpha_bar_raw(0.0, s);
Self { s, alpha_bar_0 }
}
fn alpha_bar_raw(t: f64, s: f64) -> f64 {
let arg = std::f64::consts::FRAC_PI_2 * (t + s) / (1.0 + s);
arg.cos().powi(2)
}
pub fn alpha_bar(&self, t: f64) -> f64 {
let t_clamped = t.clamp(0.0, 1.0 - 1e-6);
(Self::alpha_bar_raw(t_clamped, self.s) / self.alpha_bar_0).clamp(1e-12, 1.0)
}
pub fn alpha_t(&self, t: f64) -> f64 {
self.alpha_bar(t).sqrt()
}
pub fn sigma_t(&self, t: f64) -> f64 {
(1.0 - self.alpha_bar(t)).max(0.0).sqrt()
}
pub fn lambda_t(&self, t: f64) -> f64 {
let at = self.alpha_t(t).max(1e-12);
let st = self.sigma_t(t).max(1e-12);
(at / st).ln()
}
pub fn t_from_lambda(&self, lambda: f64) -> f64 {
let mut lo = 0.0_f64;
let mut hi = 1.0_f64 - 1e-7;
let lambda_lo = self.lambda_t(lo);
let lambda_hi = self.lambda_t(hi);
if lambda >= lambda_lo {
return lo;
}
if lambda <= lambda_hi {
return hi;
}
for _ in 0..64 {
let mid = 0.5 * (lo + hi);
let lm = self.lambda_t(mid);
if lm > lambda {
lo = mid;
} else {
hi = mid;
}
}
0.5 * (lo + hi)
}
}
#[derive(Debug, Clone)]
pub struct DpmSolverPlusPlus {
pub config: DpmSolverConfig,
pub schedule: DpmSchedule,
}
impl DpmSolverPlusPlus {
pub fn new(config: DpmSolverConfig) -> Result<Self> {
config.validate()?;
let schedule = DpmSchedule::new();
Ok(Self { config, schedule })
}
pub fn timestep_schedule(&self) -> Vec<f64> {
let n = self.config.n_steps;
let t_start = self.sigma_to_t(self.config.sigma_max);
let t_end = self.sigma_to_t(self.config.sigma_min);
let lambda_start = self.schedule.lambda_t(t_start);
let lambda_end = self.schedule.lambda_t(t_end);
(0..=n)
.map(|i| {
let frac = i as f64 / n as f64;
let lam = lambda_start + frac * (lambda_end - lambda_start);
self.schedule.t_from_lambda(lam)
})
.collect()
}
fn sigma_to_t(&self, sigma: f64) -> f64 {
let sigma_clamped = sigma.clamp(1e-6, 1.0 - 1e-6);
let mut lo = 0.0_f64;
let mut hi = 1.0_f64 - 1e-7;
for _ in 0..64 {
let mid = 0.5 * (lo + hi);
if self.schedule.sigma_t(mid) < sigma_clamped {
lo = mid;
} else {
hi = mid;
}
}
0.5 * (lo + hi)
}
pub fn first_order_update(
&self,
x_t: &[f64],
t_from: f64,
t_to: f64,
d0: &[f64],
) -> Vec<f64> {
let alpha_to = self.schedule.alpha_t(t_to);
let sigma_from = self.schedule.sigma_t(t_from).max(1e-12);
let sigma_to = self.schedule.sigma_t(t_to).max(1e-12);
let lambda_from = self.schedule.lambda_t(t_from);
let lambda_to = self.schedule.lambda_t(t_to);
let h = lambda_to - lambda_from;
let sigma_ratio = sigma_to / sigma_from;
let coeff_d0 = alpha_to * (1.0 - (-h).exp());
x_t.iter()
.zip(d0.iter())
.map(|(&x, &d)| sigma_ratio * x - coeff_d0 * d)
.collect()
}
pub fn second_order_update(
&self,
x_t: &[f64],
t_from: f64,
t_to: f64,
d0: &[f64],
d1: &[f64],
r: f64,
) -> Vec<f64> {
let alpha_to = self.schedule.alpha_t(t_to);
let sigma_from = self.schedule.sigma_t(t_from).max(1e-12);
let sigma_to = self.schedule.sigma_t(t_to).max(1e-12);
let lambda_from = self.schedule.lambda_t(t_from);
let lambda_to = self.schedule.lambda_t(t_to);
let h = lambda_to - lambda_from;
let sigma_ratio = sigma_to / sigma_from;
let eh = (-h).exp();
let coeff_d0 = alpha_to * (1.0 - eh);
let correction = 1.0 - eh - h * eh;
let r_safe = r.abs().max(1e-8);
let coeff_2nd = alpha_to * correction / r_safe;
x_t.iter()
.zip(d0.iter())
.zip(d1.iter())
.map(|((&x, &d_curr), &d_prev)| {
sigma_ratio * x - coeff_d0 * d_curr + coeff_2nd * (d_curr - d_prev)
})
.collect()
}
fn dynamic_threshold(&self, x: &[f64]) -> Vec<f64> {
if x.is_empty() {
return vec![];
}
let p = self.config.dynamic_thresholding_ratio;
let mut abs_vals: Vec<f64> = x.iter().map(|v| v.abs()).collect();
abs_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((abs_vals.len() as f64 * p).ceil() as usize)
.min(abs_vals.len() - 1)
.max(0);
let quantile = abs_vals[idx].max(1.0);
x.iter().map(|&v| (v / quantile).clamp(-1.0, 1.0)).collect()
}
pub fn sample<F>(&self, x_t: &[f64], model_fn: F) -> Result<Vec<f64>>
where
F: Fn(&[f64], f64) -> Vec<f64>,
{
let d = x_t.len();
if d == 0 {
return Err(NeuralError::InvalidArgument(
"DpmSolverPlusPlus sample: x_T must be non-empty".to_string(),
));
}
let timesteps = self.timestep_schedule();
if timesteps.len() < 2 {
return Err(NeuralError::InvalidArgument(
"DpmSolverPlusPlus sample: need at least 2 timesteps".to_string(),
));
}
let mut x = x_t.to_vec();
let mut d_prev: Option<(Vec<f64>, f64, f64)> = None;
let n = timesteps.len() - 1;
for i in 0..n {
let t_from = timesteps[i];
let t_to = timesteps[i + 1];
let d0 = model_fn(&x, t_from);
if d0.len() != d {
return Err(NeuralError::ShapeMismatch(format!(
"DpmSolverPlusPlus: model_fn returned {} values, expected {d}",
d0.len()
)));
}
let d0 = if self.config.thresholding {
self.dynamic_threshold(&d0)
} else {
d0
};
x = match (self.config.order, &d_prev) {
(1, _) | (_, None) => {
self.first_order_update(&x, t_from, t_to, &d0)
}
(2, Some((d1, t_prev, h_prev))) => {
let h_curr = (self.schedule.lambda_t(t_to)
- self.schedule.lambda_t(t_from))
.abs();
let r = h_prev / (h_prev + h_curr).max(1e-12);
let _ = t_prev; self.second_order_update(&x, t_from, t_to, &d0, d1, r)
}
_ => {
self.first_order_update(&x, t_from, t_to, &d0)
}
};
let h_curr = (self.schedule.lambda_t(t_to) - self.schedule.lambda_t(t_from)).abs();
d_prev = Some((d0, t_from, h_curr));
}
Ok(x)
}
pub fn sample_from_noise_pred<F>(&self, x_t: &[f64], noise_pred_fn: F) -> Result<Vec<f64>>
where
F: Fn(&[f64], f64) -> Vec<f64>,
{
let schedule_ref = &self.schedule;
self.sample(x_t, move |x, t| {
let eps = noise_pred_fn(x, t);
let at = schedule_ref.alpha_t(t).max(1e-12);
let st = schedule_ref.sigma_t(t).max(1e-12);
x.iter()
.zip(eps.iter())
.map(|(&xi, &ei)| (xi - st * ei) / at)
.collect()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let cfg = DpmSolverConfig::default();
assert_eq!(cfg.n_steps, 20);
assert_eq!(cfg.order, 2);
assert_eq!(cfg.solver_type, SolverType::Midpoint);
assert!(!cfg.thresholding);
assert!((cfg.dynamic_thresholding_ratio - 0.995).abs() < 1e-9);
assert!((cfg.sigma_min - 0.002).abs() < 1e-9);
assert!((cfg.sigma_max - 80.0).abs() < 1e-9);
}
#[test]
fn test_config_validation() {
let mut cfg = DpmSolverConfig::default();
assert!(cfg.validate().is_ok());
cfg.n_steps = 0;
assert!(cfg.validate().is_err());
cfg.n_steps = 10;
cfg.order = 0;
assert!(cfg.validate().is_err());
cfg.order = 4;
assert!(cfg.validate().is_err());
cfg.order = 2;
cfg.sigma_min = 0.0;
assert!(cfg.validate().is_err());
cfg.sigma_min = 100.0; assert!(cfg.validate().is_err());
}
#[test]
fn test_schedule_alpha_at_zero() {
let sched = DpmSchedule::new();
let at = sched.alpha_t(0.0);
assert!(
(at - 1.0).abs() < 1e-6,
"alpha(t=0) should be ≈ 1, got {at}"
);
}
#[test]
fn test_schedule_sigma_at_one() {
let sched = DpmSchedule::new();
let st = sched.sigma_t(1.0 - 1e-7);
assert!(st > 0.99, "sigma near t=1 should be close to 1, got {st}");
}
#[test]
fn test_lambda_monotone_decreasing() {
let sched = DpmSchedule::new();
let ts: Vec<f64> = (0..20).map(|i| i as f64 / 20.0).collect();
for window in ts.windows(2) {
let l0 = sched.lambda_t(window[0]);
let l1 = sched.lambda_t(window[1]);
assert!(
l0 > l1,
"λ should be monotonically decreasing: λ({})={l0} > λ({})={l1}",
window[0],
window[1]
);
}
}
#[test]
fn test_t_from_lambda_roundtrip() {
let sched = DpmSchedule::new();
for t_orig in [0.05, 0.2, 0.5, 0.8, 0.95] {
let lam = sched.lambda_t(t_orig);
let t_recovered = sched.t_from_lambda(lam);
assert!(
(t_recovered - t_orig).abs() < 1e-4,
"t_from_lambda roundtrip failed: orig={t_orig}, recovered={t_recovered}"
);
}
}
#[test]
fn test_timestep_schedule_length() {
let cfg = DpmSolverConfig::fast_test();
let solver = DpmSolverPlusPlus::new(cfg.clone()).expect("solver");
let ts = solver.timestep_schedule();
assert_eq!(
ts.len(),
cfg.n_steps + 1,
"timestep_schedule should return n_steps+1 values"
);
}
#[test]
fn test_timestep_schedule_endpoints() {
let cfg = DpmSolverConfig::fast_test();
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let ts = solver.timestep_schedule();
assert!(
ts[0] > *ts.last().expect("non-empty"),
"First timestep should have higher t (higher noise): {} vs {}",
ts[0],
ts.last().unwrap()
);
}
#[test]
fn test_timestep_schedule_monotone() {
let cfg = DpmSolverConfig::second_order(10);
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let ts = solver.timestep_schedule();
for window in ts.windows(2) {
assert!(
window[0] > window[1],
"timestep schedule must be strictly decreasing: {} > {}",
window[0],
window[1]
);
}
}
#[test]
fn test_first_order_update_shape() {
let cfg = DpmSolverConfig::fast_test();
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 8;
let x_t: Vec<f64> = (0..d).map(|i| i as f64 * 0.1 - 0.35).collect();
let d0: Vec<f64> = vec![0.0; d];
let result = solver.first_order_update(&x_t, 0.8, 0.6, &d0);
assert_eq!(result.len(), d);
for &v in &result {
assert!(v.is_finite(), "first_order_update output must be finite");
}
}
#[test]
fn test_second_order_update_shape() {
let cfg = DpmSolverConfig::second_order(10);
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 8;
let x_t: Vec<f64> = (0..d).map(|i| i as f64 * 0.1 - 0.35).collect();
let d0 = vec![0.1; d];
let d1 = vec![0.05; d];
let result = solver.second_order_update(&x_t, 0.8, 0.6, &d0, &d1, 0.5);
assert_eq!(result.len(), d);
for &v in &result {
assert!(v.is_finite(), "second_order_update output must be finite");
}
}
#[test]
fn test_first_order_update_zero_model() {
let cfg = DpmSolverConfig::fast_test();
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 4;
let x_t = vec![1.0; d];
let d0 = vec![0.0; d];
let t_from = 0.8;
let t_to = 0.6;
let result = solver.first_order_update(&x_t, t_from, t_to, &d0);
let sigma_from = solver.schedule.sigma_t(t_from);
let sigma_to = solver.schedule.sigma_t(t_to);
let expected = sigma_to / sigma_from;
for &v in &result {
assert!(
(v - expected).abs() < 1e-8,
"With D₀=0, result should be σ_to/σ_from · x; got {v}, expected {expected}"
);
}
}
#[test]
fn test_sample_with_identity_model_shape() {
let cfg = DpmSolverConfig::fast_test();
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 16;
let x_noise: Vec<f64> = (0..d).map(|i| (i as f64 - 8.0) * 0.5).collect();
let x0 = solver
.sample(&x_noise, |x, _t| vec![0.0; x.len()])
.expect("sample");
assert_eq!(x0.len(), d, "output shape should match input");
for &v in &x0 {
assert!(v.is_finite(), "sample output must be finite");
}
}
#[test]
fn test_sample_first_order_shape() {
let cfg = DpmSolverConfig::fast_test(); let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 8;
let x_noise = vec![1.0; d];
let x0 = solver.sample(&x_noise, |x, _t| x.to_vec()).expect("sample");
assert_eq!(x0.len(), d);
for &v in &x0 {
assert!(v.is_finite());
}
}
#[test]
fn test_sample_second_order_shape() {
let cfg = DpmSolverConfig::second_order(10);
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 8;
let x_noise = vec![0.5; d];
let x0 = solver.sample(&x_noise, |x, _t| x.to_vec()).expect("sample");
assert_eq!(x0.len(), d);
}
#[test]
fn test_sample_sigma_bounds_respected() {
let cfg = DpmSolverConfig {
n_steps: 5,
order: 1,
sigma_min: 0.002,
sigma_max: 80.0,
..Default::default()
};
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let ts = solver.timestep_schedule();
for &t in &ts {
let sigma = solver.schedule.sigma_t(t);
assert!(
sigma >= 0.0 && sigma <= 1.0,
"sigma out of [0,1]: {sigma} at t={t}"
);
}
}
#[test]
fn test_sample_from_noise_pred() {
let cfg = DpmSolverConfig::fast_test();
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 8;
let x_noise = vec![0.3; d];
let x0 = solver
.sample_from_noise_pred(&x_noise, |x, _t| vec![0.0; x.len()])
.expect("sample_from_noise_pred");
assert_eq!(x0.len(), d);
for &v in &x0 {
assert!(v.is_finite());
}
}
#[test]
fn test_thresholding() {
let cfg = DpmSolverConfig {
n_steps: 3,
order: 1,
thresholding: true,
..Default::default()
};
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let d = 8;
let x_large = vec![100.0; d];
let x0 = solver
.sample(&x_large, |x, _t| x.to_vec())
.expect("sample with thresholding");
assert_eq!(x0.len(), d);
for &v in &x0 {
assert!(v.is_finite());
}
}
#[test]
fn test_empty_input_error() {
let cfg = DpmSolverConfig::fast_test();
let solver = DpmSolverPlusPlus::new(cfg).expect("solver");
let result = solver.sample(&[], |x, _t| x.to_vec());
assert!(result.is_err(), "empty input should return error");
}
#[test]
fn test_solver_type_default() {
assert_eq!(SolverType::default(), SolverType::Midpoint);
}
}