use anyhow::Result;
use mold_core::GenerateRequest;
use mold_core::GenerateResponse;
use std::ops::{Deref, DerefMut};
use crate::progress::ProgressCallback;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum LoadStrategy {
#[default]
Eager,
Sequential,
}
pub trait InferenceEngine: Send + Sync {
fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse>;
fn model_name(&self) -> &str;
fn is_loaded(&self) -> bool;
fn load(&mut self) -> Result<()>;
fn unload(&mut self) {}
fn set_on_progress(&mut self, _callback: ProgressCallback) {}
fn clear_on_progress(&mut self) {}
fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
None
}
fn as_chain_renderer(&mut self) -> Option<&mut dyn crate::ltx2::ChainStageRenderer> {
None
}
}
pub(crate) struct OptionRestoreGuard<'a, T> {
slot: &'a mut Option<T>,
value: Option<T>,
}
impl<'a, T> OptionRestoreGuard<'a, T> {
pub(crate) fn take(slot: &'a mut Option<T>) -> Option<Self> {
let value = slot.take()?;
Some(Self {
slot,
value: Some(value),
})
}
}
impl<T> Deref for OptionRestoreGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.value
.as_ref()
.expect("option restore guard must hold a value")
}
}
impl<T> DerefMut for OptionRestoreGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.value
.as_mut()
.expect("option restore guard must hold a value")
}
}
impl<T> Drop for OptionRestoreGuard<'_, T> {
fn drop(&mut self) {
*self.slot = self.value.take();
}
}
pub(crate) fn gpu_dtype(device: &candle_core::Device) -> candle_core::DType {
crate::device::gpu_dtype(device)
}
pub(crate) fn rand_seed() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
}
pub(crate) const CFG_DISABLE_EPSILON: f64 = 1e-4;
pub(crate) fn cfg_active(guidance: f64) -> bool {
(guidance - 1.0).abs() > CFG_DISABLE_EPSILON
}
pub(crate) fn resolve_cfg_plus(req: &GenerateRequest) -> bool {
if let Some(explicit) = req.cfg_plus {
return explicit;
}
matches!(
std::env::var("MOLD_CFG_PLUS").ok().as_deref(),
Some("1") | Some("true") | Some("yes")
)
}
pub(crate) fn seeded_randn(
seed: u64,
shape: &[usize],
device: &candle_core::Device,
dtype: candle_core::DType,
) -> anyhow::Result<candle_core::Tensor> {
use rand::rngs::StdRng;
use rand::SeedableRng;
use rand_distr::{Distribution, StandardNormal};
let mut rng = StdRng::seed_from_u64(seed);
let elem_count: usize = shape.iter().product();
let noise: Vec<f32> = (0..elem_count)
.map(|_| StandardNormal.sample(&mut rng))
.collect();
let tensor = candle_core::Tensor::from_vec(noise, shape, &candle_core::Device::Cpu)?;
Ok(tensor.to_dtype(dtype)?.to_device(device)?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn seeded_randn_produces_correct_shape() {
let dev = candle_core::Device::Cpu;
let t = seeded_randn(42, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
assert_eq!(t.dims(), &[1, 4, 8, 8]);
}
#[test]
fn seeded_randn_respects_dtype() {
let dev = candle_core::Device::Cpu;
let t = seeded_randn(42, &[2, 2], &dev, candle_core::DType::BF16).unwrap();
assert_eq!(t.dtype(), candle_core::DType::BF16);
}
#[test]
fn seeded_randn_deterministic_same_seed() {
let dev = candle_core::Device::Cpu;
let a = seeded_randn(1337, &[1, 16, 8, 8], &dev, candle_core::DType::F32).unwrap();
let b = seeded_randn(1337, &[1, 16, 8, 8], &dev, candle_core::DType::F32).unwrap();
let diff = (a - b)
.unwrap()
.abs()
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
assert_eq!(diff, 0.0, "same seed must produce identical noise");
}
#[test]
fn seeded_randn_different_seeds_differ() {
let dev = candle_core::Device::Cpu;
let a = seeded_randn(42, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
let b = seeded_randn(43, &[1, 4, 8, 8], &dev, candle_core::DType::F32).unwrap();
let diff = (a - b)
.unwrap()
.abs()
.unwrap()
.sum_all()
.unwrap()
.to_scalar::<f32>()
.unwrap();
assert!(diff > 0.0, "different seeds must produce different noise");
}
#[test]
fn gpu_dtype_cpu_returns_f32() {
assert_eq!(
gpu_dtype(&candle_core::Device::Cpu),
candle_core::DType::F32
);
}
#[test]
fn option_restore_guard_restores_taken_value_on_drop() {
let mut slot = Some(String::from("loaded"));
{
let mut guard = OptionRestoreGuard::take(&mut slot).unwrap();
guard.push_str("-mutated");
}
assert_eq!(slot.as_deref(), Some("loaded-mutated"));
}
#[test]
fn test_cfg_disabled_at_guidance_1_0() {
assert!(!cfg_active(1.0), "guidance=1.0 must take the fast path");
}
#[test]
fn test_cfg_disabled_just_below_1_0() {
assert!(
!cfg_active(1.0 - 1e-5),
"guidance just under 1.0 must take the fast path"
);
assert!(
!cfg_active(1.0 + 1e-5),
"guidance just over 1.0 must take the fast path"
);
}
#[test]
fn test_cfg_enabled_at_guidance_1_5() {
assert!(cfg_active(1.5), "guidance=1.5 must run full CFG");
}
#[test]
fn test_cfg_enabled_at_guidance_7_5() {
assert!(cfg_active(7.5), "guidance=7.5 must run full CFG");
}
#[test]
fn test_cfg_enabled_just_outside_epsilon() {
assert!(
cfg_active(1.0 + 2.0 * CFG_DISABLE_EPSILON),
"guidance just past the epsilon must run full CFG"
);
}
}