1use anyhow::Result;
2use mold_core::GenerateRequest;
3use mold_core::GenerateResponse;
4use std::ops::{Deref, DerefMut};
5
6use crate::progress::ProgressCallback;
7
8#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
10pub enum LoadStrategy {
11 #[default]
13 Eager,
14 Sequential,
16}
17
18pub trait InferenceEngine: Send + Sync {
20 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse>;
21 fn model_name(&self) -> &str;
22 fn is_loaded(&self) -> bool;
23 fn load(&mut self) -> Result<()>;
25 fn unload(&mut self) {}
28 fn set_on_progress(&mut self, _callback: ProgressCallback) {}
31 fn clear_on_progress(&mut self) {}
33 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
36 None
37 }
38
39 fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
47 None
48 }
49}
50
51pub(crate) struct OptionRestoreGuard<'a, T> {
53 slot: &'a mut Option<T>,
54 value: Option<T>,
55}
56
57impl<'a, T> OptionRestoreGuard<'a, T> {
58 pub(crate) fn take(slot: &'a mut Option<T>) -> Option<Self> {
59 let value = slot.take()?;
60 Some(Self {
61 slot,
62 value: Some(value),
63 })
64 }
65}
66
67impl<T> Deref for OptionRestoreGuard<'_, T> {
68 type Target = T;
69
70 fn deref(&self) -> &Self::Target {
71 self.value
72 .as_ref()
73 .expect("option restore guard must hold a value")
74 }
75}
76
77impl<T> DerefMut for OptionRestoreGuard<'_, T> {
78 fn deref_mut(&mut self) -> &mut Self::Target {
79 self.value
80 .as_mut()
81 .expect("option restore guard must hold a value")
82 }
83}
84
85impl<T> Drop for OptionRestoreGuard<'_, T> {
86 fn drop(&mut self) {
87 *self.slot = self.value.take();
88 }
89}
90
91pub(crate) fn gpu_dtype(device: &candle_core::Device) -> candle_core::DType {
95 crate::device::gpu_dtype(device)
96}
97
98pub(crate) fn rand_seed() -> u64 {
100 use std::time::{SystemTime, UNIX_EPOCH};
101 SystemTime::now()
102 .duration_since(UNIX_EPOCH)
103 .unwrap_or_default()
104 .as_nanos() as u64
105}
106
107pub(crate) const CFG_DISABLE_EPSILON: f64 = 1e-4;
119
120pub(crate) fn cfg_active(guidance: f64) -> bool {
125 (guidance - 1.0).abs() > CFG_DISABLE_EPSILON
126}
127
128pub(crate) fn resolve_cfg_plus(req: &GenerateRequest) -> bool {
134 if let Some(explicit) = req.cfg_plus {
135 return explicit;
136 }
137 matches!(
138 std::env::var("MOLD_CFG_PLUS").ok().as_deref(),
139 Some("1") | Some("true") | Some("yes")
140 )
141}
142
143pub(crate) fn seeded_randn(
158 seed: u64,
159 shape: &[usize],
160 device: &candle_core::Device,
161 dtype: candle_core::DType,
162) -> anyhow::Result<candle_core::Tensor> {
163 use rand::rngs::StdRng;
164 use rand::SeedableRng;
165 use rand_distr::{Distribution, StandardNormal};
166
167 let mut rng = StdRng::seed_from_u64(seed);
169 let elem_count: usize = shape.iter().product();
170 let noise: Vec<f32> = (0..elem_count)
171 .map(|_| StandardNormal.sample(&mut rng))
172 .collect();
173
174 let tensor = candle_core::Tensor::from_vec(noise, shape, &candle_core::Device::Cpu)?;
175 Ok(tensor.to_dtype(dtype)?.to_device(device)?)
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181
182 #[test]
183 fn seeded_randn_produces_correct_shape() {
184 let dev = candle_core::Device::Cpu;
185 let t = seeded_randn(42, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
186 assert_eq!(t.dims(), &[1, 4, 8, 8]);
187 }
188
189 #[test]
190 fn seeded_randn_respects_dtype() {
191 let dev = candle_core::Device::Cpu;
192 let t = seeded_randn(42, &[2, 2], &dev, candle_core::DType::BF16).unwrap();
193 assert_eq!(t.dtype(), candle_core::DType::BF16);
194 }
195
196 #[test]
197 fn seeded_randn_deterministic_same_seed() {
198 let dev = candle_core::Device::Cpu;
199 let a = seeded_randn(1337, &[1, 16, 8, 8], &dev, candle_core::DType::F32).unwrap();
200 let b = seeded_randn(1337, &[1, 16, 8, 8], &dev, candle_core::DType::F32).unwrap();
201 let diff = (a - b)
202 .unwrap()
203 .abs()
204 .unwrap()
205 .sum_all()
206 .unwrap()
207 .to_scalar::<f32>()
208 .unwrap();
209 assert_eq!(diff, 0.0, "same seed must produce identical noise");
210 }
211
212 #[test]
213 fn seeded_randn_different_seeds_differ() {
214 let dev = candle_core::Device::Cpu;
215 let a = seeded_randn(42, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
216 let b = seeded_randn(43, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
217 let diff = (a - b)
218 .unwrap()
219 .abs()
220 .unwrap()
221 .sum_all()
222 .unwrap()
223 .to_scalar::<f32>()
224 .unwrap();
225 assert!(diff > 0.0, "different seeds must produce different noise");
226 }
227
228 #[test]
229 fn gpu_dtype_cpu_returns_f32() {
230 assert_eq!(
231 gpu_dtype(&candle_core::Device::Cpu),
232 candle_core::DType::F32
233 );
234 }
235
236 #[test]
237 fn option_restore_guard_restores_taken_value_on_drop() {
238 let mut slot = Some(String::from("loaded"));
239 {
240 let mut guard = OptionRestoreGuard::take(&mut slot).unwrap();
241 guard.push_str("-mutated");
242 }
243 assert_eq!(slot.as_deref(), Some("loaded-mutated"));
244 }
245
246 #[test]
247 fn test_cfg_disabled_at_guidance_1_0() {
248 assert!(!cfg_active(1.0), "guidance=1.0 must take the fast path");
249 }
250
251 #[test]
252 fn test_cfg_disabled_just_below_1_0() {
253 assert!(
256 !cfg_active(1.0 - 1e-5),
257 "guidance just under 1.0 must take the fast path"
258 );
259 assert!(
260 !cfg_active(1.0 + 1e-5),
261 "guidance just over 1.0 must take the fast path"
262 );
263 }
264
265 #[test]
266 fn test_cfg_enabled_at_guidance_1_5() {
267 assert!(cfg_active(1.5), "guidance=1.5 must run full CFG");
268 }
269
270 #[test]
271 fn test_cfg_enabled_at_guidance_7_5() {
272 assert!(cfg_active(7.5), "guidance=7.5 must run full CFG");
273 }
274
275 #[test]
276 fn test_cfg_enabled_just_outside_epsilon() {
277 assert!(
280 cfg_active(1.0 + 2.0 * CFG_DISABLE_EPSILON),
281 "guidance just past the epsilon must run full CFG"
282 );
283 }
284}