Skip to main content

mold_inference/
engine.rs

1use anyhow::Result;
2use mold_core::GenerateRequest;
3use mold_core::GenerateResponse;
4use std::ops::{Deref, DerefMut};
5
6use crate::progress::ProgressCallback;
7
8/// Controls how model components are loaded during inference.
9#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
10pub enum LoadStrategy {
11    /// Load all components at once, keep hot (server mode).
12    #[default]
13    Eager,
14    /// Load-use-drop per component, minimizing peak memory (CLI one-shot mode).
15    Sequential,
16}
17
18/// Trait for inference backends.
19pub trait InferenceEngine: Send + Sync {
20    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse>;
21    fn model_name(&self) -> &str;
22    fn is_loaded(&self) -> bool;
23    /// Load model weights. Called automatically on first generate if not yet loaded.
24    fn load(&mut self) -> Result<()>;
25    /// Unload model weights to free GPU memory. The engine remains valid and
26    /// can be re-loaded by calling `load()` or generating again.
27    fn unload(&mut self) {}
28    /// Set a progress callback for receiving loading/inference status updates.
29    /// Default implementation is a no-op for engines that don't support progress.
30    fn set_on_progress(&mut self, _callback: ProgressCallback) {}
31    /// Clear any previously installed progress callback.
32    fn clear_on_progress(&mut self) {}
33    /// Return the model's resolved file paths, if available.
34    /// Used by the server for pre-load memory checks on unified-memory systems.
35    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
36        None
37    }
38
39    /// Returns a [`ChainStageRenderer`] view of this engine if the family
40    /// supports chained video generation. Default is `None` — only LTX-2
41    /// distilled overrides this in v1.
42    ///
43    /// Callers (the server chain route) invoke this once per stage to drive
44    /// [`crate::ltx2::Ltx2ChainOrchestrator::run`]; engines that don't support
45    /// chaining return `None` and the caller responds with 422.
46    fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
47        None
48    }
49}
50
51/// Restores an `Option<T>` slot even if the current scope unwinds.
52pub(crate) struct OptionRestoreGuard<'a, T> {
53    slot: &'a mut Option<T>,
54    value: Option<T>,
55}
56
57impl<'a, T> OptionRestoreGuard<'a, T> {
58    pub(crate) fn take(slot: &'a mut Option<T>) -> Option<Self> {
59        let value = slot.take()?;
60        Some(Self {
61            slot,
62            value: Some(value),
63        })
64    }
65}
66
67impl<T> Deref for OptionRestoreGuard<'_, T> {
68    type Target = T;
69
70    fn deref(&self) -> &Self::Target {
71        self.value
72            .as_ref()
73            .expect("option restore guard must hold a value")
74    }
75}
76
77impl<T> DerefMut for OptionRestoreGuard<'_, T> {
78    fn deref_mut(&mut self) -> &mut Self::Target {
79        self.value
80            .as_mut()
81            .expect("option restore guard must hold a value")
82    }
83}
84
85impl<T> Drop for OptionRestoreGuard<'_, T> {
86    fn drop(&mut self) {
87        *self.slot = self.value.take();
88    }
89}
90
91/// Select the optimal dtype for GPU inference.
92///
93/// Re-exported from `device::gpu_dtype` for backward compatibility.
94pub(crate) fn gpu_dtype(device: &candle_core::Device) -> candle_core::DType {
95    crate::device::gpu_dtype(device)
96}
97
98/// Generate a random seed from the current system time.
99pub(crate) fn rand_seed() -> u64 {
100    use std::time::{SystemTime, UNIX_EPOCH};
101    SystemTime::now()
102        .duration_since(UNIX_EPOCH)
103        .unwrap_or_default()
104        .as_nanos() as u64
105}
106
107/// Tolerance for treating a CFG (classifier-free guidance) scale as "1.0,
108/// disabled". When the active guidance is within this epsilon of 1.0 the
109/// uncond pass adds nothing — `cond + (cond - uncond) * 0 == cond` — so the
110/// pipeline can run a single conditional forward instead of batching
111/// `[uncond, cond]`. Used by LCM / Lightning / Turbo (guidance-distilled)
112/// workflows that ship with `cfg ≈ 1.0`.
113///
114/// Matches ComfyUI's short-circuit at `comfy/samplers.py:370`
115/// (`if math.isclose(cond_scale, 1.0): uncond_ = None`). The default
116/// `math.isclose` rel-tol is `1e-9` — ours is looser (`1e-4`) because the
117/// caller-visible knob is a user-typed `f64` like `1.0` or `1.0000`.
118pub(crate) const CFG_DISABLE_EPSILON: f64 = 1e-4;
119
120/// Returns `true` when classifier-free guidance is active for the given scale,
121/// i.e. when the unconditional forward pass meaningfully contributes to the
122/// final noise prediction. When `false`, callers should run a single
123/// conditional forward (saves ~2× denoise time).
124pub(crate) fn cfg_active(guidance: f64) -> bool {
125    (guidance - 1.0).abs() > CFG_DISABLE_EPSILON
126}
127
128/// Resolve the effective `cfg_plus` flag for a request.
129///
130/// Precedence: explicit request field > `MOLD_CFG_PLUS` env var > false.
131/// Mirrors `MOLD_OFFLOAD` / `MOLD_KEEP_TE_RAM`. Shared across SD3, SDXL,
132/// and SD1.5 so the wire-format and env contract stay identical.
133pub(crate) fn resolve_cfg_plus(req: &GenerateRequest) -> bool {
134    if let Some(explicit) = req.cfg_plus {
135        return explicit;
136    }
137    matches!(
138        std::env::var("MOLD_CFG_PLUS").ok().as_deref(),
139        Some("1") | Some("true") | Some("yes")
140    )
141}
142
143/// Generate deterministic noise on a device with a given seed.
144///
145/// This is the ONLY correct way to generate initial noise for denoising.
146/// All pipelines MUST use this instead of calling `device.set_seed()` +
147/// `Tensor::randn()` separately.
148///
149/// Noise is generated on CPU using a deterministic Rust RNG, then moved to
150/// the target device. This guarantees:
151/// 1. Same seed always produces identical noise (deterministic)
152/// 2. Same seed produces the same noise across CUDA, Metal, and CPU backends
153///    (cross-platform reproducibility)
154///
155/// GPU-native RNG (Metal's HybridTaus, CUDA's cuRAND) use different algorithms
156/// that produce different sequences from the same seed. CPU generation avoids this.
157pub(crate) fn seeded_randn(
158    seed: u64,
159    shape: &[usize],
160    device: &candle_core::Device,
161    dtype: candle_core::DType,
162) -> anyhow::Result<candle_core::Tensor> {
163    use rand::rngs::StdRng;
164    use rand::SeedableRng;
165    use rand_distr::{Distribution, StandardNormal};
166
167    // Generate noise on CPU with a deterministic RNG for cross-platform reproducibility.
168    let mut rng = StdRng::seed_from_u64(seed);
169    let elem_count: usize = shape.iter().product();
170    let noise: Vec<f32> = (0..elem_count)
171        .map(|_| StandardNormal.sample(&mut rng))
172        .collect();
173
174    let tensor = candle_core::Tensor::from_vec(noise, shape, &candle_core::Device::Cpu)?;
175    Ok(tensor.to_dtype(dtype)?.to_device(device)?)
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    #[test]
183    fn seeded_randn_produces_correct_shape() {
184        let dev = candle_core::Device::Cpu;
185        let t = seeded_randn(42, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
186        assert_eq!(t.dims(), &[1, 4, 8, 8]);
187    }
188
189    #[test]
190    fn seeded_randn_respects_dtype() {
191        let dev = candle_core::Device::Cpu;
192        let t = seeded_randn(42, &[2, 2], &dev, candle_core::DType::BF16).unwrap();
193        assert_eq!(t.dtype(), candle_core::DType::BF16);
194    }
195
196    #[test]
197    fn seeded_randn_deterministic_same_seed() {
198        let dev = candle_core::Device::Cpu;
199        let a = seeded_randn(1337, &[1, 16, 8, 8], &dev, candle_core::DType::F32).unwrap();
200        let b = seeded_randn(1337, &[1, 16, 8, 8], &dev, candle_core::DType::F32).unwrap();
201        let diff = (a - b)
202            .unwrap()
203            .abs()
204            .unwrap()
205            .sum_all()
206            .unwrap()
207            .to_scalar::<f32>()
208            .unwrap();
209        assert_eq!(diff, 0.0, "same seed must produce identical noise");
210    }
211
212    #[test]
213    fn seeded_randn_different_seeds_differ() {
214        let dev = candle_core::Device::Cpu;
215        let a = seeded_randn(42, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
216        let b = seeded_randn(43, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
217        let diff = (a - b)
218            .unwrap()
219            .abs()
220            .unwrap()
221            .sum_all()
222            .unwrap()
223            .to_scalar::<f32>()
224            .unwrap();
225        assert!(diff > 0.0, "different seeds must produce different noise");
226    }
227
228    #[test]
229    fn gpu_dtype_cpu_returns_f32() {
230        assert_eq!(
231            gpu_dtype(&candle_core::Device::Cpu),
232            candle_core::DType::F32
233        );
234    }
235
236    #[test]
237    fn option_restore_guard_restores_taken_value_on_drop() {
238        let mut slot = Some(String::from("loaded"));
239        {
240            let mut guard = OptionRestoreGuard::take(&mut slot).unwrap();
241            guard.push_str("-mutated");
242        }
243        assert_eq!(slot.as_deref(), Some("loaded-mutated"));
244    }
245
246    #[test]
247    fn test_cfg_disabled_at_guidance_1_0() {
248        assert!(!cfg_active(1.0), "guidance=1.0 must take the fast path");
249    }
250
251    #[test]
252    fn test_cfg_disabled_just_below_1_0() {
253        // LCM / Lightning workflows often expose 1.0 as a float that round-
254        // trips with tiny noise; anything within the epsilon is "disabled".
255        assert!(
256            !cfg_active(1.0 - 1e-5),
257            "guidance just under 1.0 must take the fast path"
258        );
259        assert!(
260            !cfg_active(1.0 + 1e-5),
261            "guidance just over 1.0 must take the fast path"
262        );
263    }
264
265    #[test]
266    fn test_cfg_enabled_at_guidance_1_5() {
267        assert!(cfg_active(1.5), "guidance=1.5 must run full CFG");
268    }
269
270    #[test]
271    fn test_cfg_enabled_at_guidance_7_5() {
272        assert!(cfg_active(7.5), "guidance=7.5 must run full CFG");
273    }
274
275    #[test]
276    fn test_cfg_enabled_just_outside_epsilon() {
277        // Sanity: the boundary itself is not strictly active, but anything
278        // visibly past it must engage CFG.
279        assert!(
280            cfg_active(1.0 + 2.0 * CFG_DISABLE_EPSILON),
281            "guidance just past the epsilon must run full CFG"
282        );
283    }
284}