Skip to main content

mold_inference/sd3/
pipeline.rs

1use anyhow::{bail, Result};
2use candle_core::{DType, Device, IndexOp, Tensor};
3use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
4use candle_transformers::quantized_var_builder;
5use mold_core::{GenerateRequest, GenerateResponse, ImageData, LoraWeight, ModelPaths};
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10use tokenizers::Tokenizer;
11
12use crate::cache::{
13    cfg_prompt_cache_key, clear_cache, get_or_insert_cached_tensor_pair,
14    restore_cached_tensor_pair, CachedTensorPair, CfgPromptCacheKey, LruCache,
15    DEFAULT_PROMPT_CACHE_CAPACITY,
16};
17use crate::device::{
18    check_memory_budget, fmt_gb, free_vram_bytes, memory_status_string, preflight_memory_check,
19    usable_free_vram_bytes,
20};
21use crate::encoders;
22use crate::engine::{
23    rand_seed, resolve_cfg_plus, InferenceEngine, LoadStrategy, OptionRestoreGuard,
24};
25use crate::engine_base::EngineBase;
26use crate::image::{build_output_metadata, encode_image};
27use crate::img_utils;
28use crate::progress::{ProgressCallback, ProgressReporter};
29
30use super::lora as sd3_lora;
31use super::quantized_mmdit::QuantizedMMDiT;
32use super::sampling::{self, SkipLayerGuidanceConfig};
33use super::transformer::SD3Transformer;
34use super::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
35
36/// Smallest LoRA scale the engine treats as non-zero. Sliders pinned to
37/// 0.0 are dropped from the effective stack — forcing a transformer
38/// rebuild for a no-op patch is pure overhead. The threshold matches
39/// the precision of an f64 scrubbed by a UI slider.
40const ZERO_SCALE_EPS: f64 = 1e-8;
41
42/// Collapse a `GenerateRequest`'s `lora` (legacy single) and `loras`
43/// (plural) fields into an ordered, zero-pruned working list. Plural
44/// wins when both are set — the singular form is for backward compat.
45pub(crate) fn effective_loras(req: &GenerateRequest) -> Vec<LoraWeight> {
46    let raw: Vec<LoraWeight> = if let Some(plural) = &req.loras {
47        if !plural.is_empty() {
48            plural.clone()
49        } else {
50            req.lora.iter().cloned().collect()
51        }
52    } else {
53        req.lora.iter().cloned().collect()
54    };
55    raw.into_iter()
56        .filter(|w| {
57            let keep = w.scale.abs() > ZERO_SCALE_EPS;
58            if !keep {
59                tracing::debug!(
60                    path = w.path.as_str(),
61                    scale = w.scale,
62                    "dropping zero-scale LoRA from SD3 effective stack"
63                );
64            }
65            keep
66        })
67        .collect()
68}
69
70#[derive(Debug, PartialEq, Eq)]
71enum SD3OffloadDecision {
72    Disabled,
73    Selected,
74    Unsupported(&'static str),
75}
76
77fn sd3_offload_decision(
78    forced_offload: bool,
79    is_quantized: bool,
80    has_lora: bool,
81) -> SD3OffloadDecision {
82    if !forced_offload {
83        return SD3OffloadDecision::Disabled;
84    }
85    if is_quantized {
86        return SD3OffloadDecision::Unsupported(
87            "SD3 block-level offload is only planned for BF16/FP transformers; \
88             GGUF variants already use quantized transformer paths",
89        );
90    }
91    if has_lora {
92        return SD3OffloadDecision::Unsupported(
93            "SD3 block-level offload with LoRA is not wired yet; \
94             LoRA merge/cache semantics need a dedicated offload design",
95        );
96    }
97    SD3OffloadDecision::Selected
98}
99
100/// Build a LoRA-aware SD3 `VarBuilder` from BF16 safetensors. Caller
101/// is responsible for short-circuiting to the non-LoRA mmap path when
102/// `loras` is empty — this function errors on an empty stack so a stray
103/// caller doesn't silently get the wrong builder shape.
104fn sd3_lora_var_builder<'a>(
105    transformer_path: &Path,
106    loras: &[LoraWeight],
107    dtype: DType,
108    device: &Device,
109    progress: &ProgressReporter,
110    delta_cache: Option<Arc<Mutex<sd3_lora::LoraDeltaCache>>>,
111) -> Result<candle_nn::VarBuilder<'a>> {
112    let adapters: Vec<Arc<sd3_lora::LoraAdapter>> = loras
113        .iter()
114        .map(|w| {
115            progress.info("Loading SD3 LoRA adapter");
116            let adapter = sd3_lora::get_or_load_adapter(Path::new(&w.path))?;
117            progress.info(&format!(
118                "SD3 LoRA: {} layers, rank {}, scale {:.2}",
119                adapter.layers.len(),
120                adapter.rank,
121                w.scale,
122            ));
123            anyhow::Ok(adapter)
124        })
125        .collect::<Result<_>>()?;
126
127    let specs: Vec<sd3_lora::LoraSpec<'_>> = adapters
128        .iter()
129        .zip(loras.iter())
130        .map(|(adapter, w)| sd3_lora::LoraSpec {
131            adapter: adapter.as_ref(),
132            scale: w.scale,
133            path_hash: sd3_lora::lora_path_hash(&w.path),
134        })
135        .collect();
136
137    sd3_lora::lora_var_builder(
138        transformer_path,
139        &specs,
140        dtype,
141        device,
142        progress,
143        delta_cache,
144    )
145}
146
147/// GGUF counterpart to `sd3_lora_var_builder` — selectively dequantizes
148/// patched tensors, merges, and re-quantizes back to the original GGML
149/// dtype on the target device.
150fn sd3_gguf_lora_var_builder(
151    transformer_path: &Path,
152    loras: &[LoraWeight],
153    device: &Device,
154    progress: &ProgressReporter,
155    delta_cache: Option<Arc<Mutex<sd3_lora::LoraDeltaCache>>>,
156) -> Result<quantized_var_builder::VarBuilder> {
157    let adapters: Vec<Arc<sd3_lora::LoraAdapter>> = loras
158        .iter()
159        .map(|w| {
160            progress.info("Loading SD3 LoRA adapter");
161            let adapter = sd3_lora::get_or_load_adapter(Path::new(&w.path))?;
162            progress.info(&format!(
163                "SD3 LoRA: {} layers, rank {}, scale {:.2}",
164                adapter.layers.len(),
165                adapter.rank,
166                w.scale,
167            ));
168            anyhow::Ok(adapter)
169        })
170        .collect::<Result<_>>()?;
171
172    let specs: Vec<sd3_lora::LoraSpec<'_>> = adapters
173        .iter()
174        .zip(loras.iter())
175        .map(|(adapter, w)| sd3_lora::LoraSpec {
176            adapter: adapter.as_ref(),
177            scale: w.scale,
178            path_hash: sd3_lora::lora_path_hash(&w.path),
179        })
180        .collect();
181
182    sd3_lora::gguf_lora_var_builder(transformer_path, &specs, device, progress, delta_cache)
183}
184
185/// Loaded SD3 model components, ready for inference.
186struct LoadedSD3 {
187    /// None after being dropped for VAE decode VRAM; reloaded on next generate.
188    transformer: Option<SD3Transformer>,
189    triple_encoder: encoders::sd3_clip::SD3TripleEncoder,
190    vae_vb_path: std::path::PathBuf,
191    device: Device,
192    dtype: DType,
193    _is_quantized: bool,
194    is_turbo: bool,
195    is_medium: bool,
196}
197
198/// SD3.5 inference engine backed by candle.
199///
200/// Supports SD3.5 Large (8.1B, depth=38), SD3.5 Large Turbo (8.1B, 4 steps),
201/// and SD3.5 Medium (2.5B, depth=24, SLG support).
202/// Both BF16 safetensors and GGUF quantized transformers are supported.
203pub struct SD3Engine {
204    base: EngineBase<LoadedSD3>,
205    is_turbo: bool,
206    is_medium: bool,
207    t5_variant: Option<String>,
208    offload: bool,
209    prompt_cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensorPair>>,
210    pending_placement: Option<mold_core::types::DevicePlacement>,
211    /// CPU-resident cache of pre-computed LoRA deltas, shared across
212    /// transformer rebuilds. Saves the `B @ A * scale` matmul when the
213    /// same LoRA stack reappears on a later generate.
214    lora_delta_cache: Arc<Mutex<sd3_lora::LoraDeltaCache>>,
215    shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
216}
217
218impl SD3Engine {
219    /// Create a new SD3Engine. Does not load models until `load()` is called.
220    #[allow(clippy::too_many_arguments)]
221    pub fn new(
222        model_name: String,
223        paths: ModelPaths,
224        is_turbo: bool,
225        is_medium: bool,
226        t5_variant: Option<String>,
227        load_strategy: LoadStrategy,
228        gpu_ordinal: usize,
229        offload: bool,
230        shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
231    ) -> Self {
232        Self {
233            base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
234            is_turbo,
235            is_medium,
236            t5_variant,
237            offload,
238            prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
239            pending_placement: None,
240            lora_delta_cache: Arc::new(Mutex::new(sd3_lora::LoraDeltaCache::new())),
241            shared_pool,
242        }
243    }
244
245    fn load_text_tokenizers(
246        &self,
247        clip_l_tokenizer: &Path,
248        clip_g_tokenizer: &Path,
249        t5_tokenizer: &Path,
250    ) -> Result<(Arc<Tokenizer>, Arc<Tokenizer>, Arc<Tokenizer>)> {
251        if let Some(shared_pool) = &self.shared_pool {
252            let mut pool = shared_pool.lock().unwrap();
253            return Ok((
254                pool.load_tokenizer(clip_l_tokenizer)?,
255                pool.load_tokenizer(clip_g_tokenizer)?,
256                pool.load_tokenizer(t5_tokenizer)?,
257            ));
258        }
259
260        let load = |path: &Path, label: &str| {
261            Tokenizer::from_file(path)
262                .map(Arc::new)
263                .map_err(|e| anyhow::anyhow!("failed to load {label} tokenizer: {e}"))
264        };
265        Ok((
266            load(clip_l_tokenizer, "CLIP-L")?,
267            load(clip_g_tokenizer, "CLIP-G")?,
268            load(t5_tokenizer, "T5")?,
269        ))
270    }
271
272    #[cfg(test)]
273    fn load_vae_cpu_tensors(
274        &self,
275        vae_path: &Path,
276    ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
277        Self::load_vae_cpu_tensors_from_pool(self.shared_pool.as_ref(), vae_path)
278    }
279
280    fn load_vae_cpu_tensors_from_pool(
281        shared_pool: Option<&Arc<Mutex<crate::shared_pool::SharedPool>>>,
282        vae_path: &Path,
283    ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
284        let Some(shared_pool) = shared_pool else {
285            return Ok(None);
286        };
287        shared_pool
288            .lock()
289            .unwrap()
290            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
291    }
292
293    fn load_transformer_cpu_tensors(&self) -> Result<Arc<HashMap<String, Tensor>>> {
294        if let Some(shared_pool) = &self.shared_pool {
295            if let Some(tensors) = shared_pool
296                .lock()
297                .unwrap()
298                .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.transformer))?
299            {
300                return Ok(tensors);
301            }
302        }
303        Ok(Arc::new(crate::encoders::park::load_tensors_to_cpu(
304            std::slice::from_ref(&self.base.paths.transformer),
305        )?))
306    }
307
308    fn load_vae_var_builder<'a>(
309        &self,
310        vae_path: &Path,
311        dtype: DType,
312        device: &Device,
313        component: &str,
314        progress: &ProgressReporter,
315    ) -> Result<candle_nn::VarBuilder<'a>> {
316        Self::load_vae_var_builder_from_pool(
317            self.shared_pool.as_ref(),
318            vae_path,
319            dtype,
320            device,
321            component,
322            progress,
323        )
324    }
325
326    fn load_vae_var_builder_from_pool<'a>(
327        shared_pool: Option<&Arc<Mutex<crate::shared_pool::SharedPool>>>,
328        vae_path: &Path,
329        dtype: DType,
330        device: &Device,
331        component: &str,
332        progress: &ProgressReporter,
333    ) -> Result<candle_nn::VarBuilder<'a>> {
334        if let Some(tensors) = Self::load_vae_cpu_tensors_from_pool(shared_pool, vae_path)? {
335            return Ok(crate::encoders::park::varbuilder_from_parked(
336                tensors.as_ref(),
337                dtype,
338                device,
339            ));
340        }
341
342        crate::weight_loader::load_safetensors_with_progress(
343            std::slice::from_ref(&vae_path),
344            dtype,
345            device,
346            component,
347            progress,
348        )
349    }
350
351    #[allow(clippy::too_many_arguments)]
352    fn encode_conditioning(
353        progress: &ProgressReporter,
354        prompt_cache: &Mutex<LruCache<CfgPromptCacheKey, CachedTensorPair>>,
355        triple_encoder: &mut encoders::sd3_clip::SD3TripleEncoder,
356        prompt: &str,
357        negative_prompt: &str,
358        guidance: f64,
359        device: &Device,
360        dtype: DType,
361        is_quantized: bool,
362    ) -> Result<(candle_core::Tensor, candle_core::Tensor)> {
363        // SD3 always concatenates `(cond, uncond)` for CFG, so the cache key
364        // must include the negative prompt and guidance — keying only on the
365        // positive prompt returned a stale `(cond, uncond_old)` pair when the
366        // user changed only the negative.
367        let cache_key = cfg_prompt_cache_key(prompt, negative_prompt, guidance);
368        let ((context, y), cache_hit) = get_or_insert_cached_tensor_pair(
369            prompt_cache,
370            cache_key,
371            device,
372            if is_quantized { DType::F32 } else { dtype },
373            || {
374                progress.stage_start("Encoding prompt (SD3 triple)");
375                let encode_start = Instant::now();
376                let (context_cond, y_cond) = triple_encoder.encode(prompt, device, dtype)?;
377                let (context_uncond, y_uncond) =
378                    triple_encoder.encode(negative_prompt, device, dtype)?;
379                progress.stage_done("Encoding prompt (SD3 triple)", encode_start.elapsed());
380
381                let pair = if is_quantized {
382                    (
383                        candle_core::Tensor::cat(&[&context_cond, &context_uncond], 0)?
384                            .to_dtype(DType::F32)?,
385                        candle_core::Tensor::cat(&[&y_cond, &y_uncond], 0)?.to_dtype(DType::F32)?,
386                    )
387                } else {
388                    (
389                        candle_core::Tensor::cat(&[&context_cond, &context_uncond], 0)?,
390                        candle_core::Tensor::cat(&[&y_cond, &y_uncond], 0)?,
391                    )
392                };
393                Ok(pair)
394            },
395        )?;
396        if cache_hit {
397            progress.cache_hit("prompt conditioning");
398            return Ok((context, y));
399        }
400        Ok((context, y))
401    }
402
403    fn img2img_source_normalize_range() -> img_utils::NormalizeRange {
404        img_utils::NormalizeRange::MinusOneToOne
405    }
406
407    fn uses_sequential_generate_path(&self) -> bool {
408        self.base.load_strategy == LoadStrategy::Sequential || self.offload
409    }
410
411    /// Detect if the transformer is quantized (GGUF).
412    fn detect_is_quantized(&self) -> bool {
413        self.base
414            .paths
415            .transformer
416            .extension()
417            .and_then(|e| e.to_str())
418            .map(|e| e.eq_ignore_ascii_case("gguf"))
419            .unwrap_or(false)
420    }
421
422    /// Get the MMDiT config for this model variant.
423    fn mmdit_config(&self) -> MMDiTConfig {
424        if self.is_medium {
425            MMDiTConfig::sd3_5_medium()
426        } else {
427            MMDiTConfig::sd3_5_large()
428        }
429    }
430
431    /// Validate that all required paths exist.
432    fn validate_paths(
433        &self,
434    ) -> Result<(
435        std::path::PathBuf, // clip_l_path
436        std::path::PathBuf, // clip_l_tokenizer
437        std::path::PathBuf, // clip_g_path
438        std::path::PathBuf, // clip_g_tokenizer
439        std::path::PathBuf, // t5_encoder_path
440        std::path::PathBuf, // t5_tokenizer_path
441    )> {
442        let clip_l_path = self
443            .base
444            .paths
445            .clip_encoder
446            .as_ref()
447            .ok_or_else(|| anyhow::anyhow!("CLIP-L encoder path required for SD3 models"))?
448            .clone();
449        let clip_l_tokenizer = self
450            .base
451            .paths
452            .clip_tokenizer
453            .as_ref()
454            .ok_or_else(|| anyhow::anyhow!("CLIP-L tokenizer path required for SD3 models"))?
455            .clone();
456        let clip_g_path = self
457            .base
458            .paths
459            .clip_encoder_2
460            .as_ref()
461            .ok_or_else(|| anyhow::anyhow!("CLIP-G encoder path required for SD3 models"))?
462            .clone();
463        let clip_g_tokenizer = self
464            .base
465            .paths
466            .clip_tokenizer_2
467            .as_ref()
468            .ok_or_else(|| anyhow::anyhow!("CLIP-G tokenizer path required for SD3 models"))?
469            .clone();
470        let t5_encoder_path = self
471            .base
472            .paths
473            .t5_encoder
474            .as_ref()
475            .ok_or_else(|| anyhow::anyhow!("T5 encoder path required for SD3 models"))?
476            .clone();
477        let t5_tokenizer_path = self
478            .base
479            .paths
480            .t5_tokenizer
481            .as_ref()
482            .ok_or_else(|| anyhow::anyhow!("T5 tokenizer path required for SD3 models"))?
483            .clone();
484
485        for (label, path) in [
486            ("transformer", &self.base.paths.transformer),
487            ("vae", &self.base.paths.vae),
488            ("clip_encoder (CLIP-L)", &clip_l_path),
489            ("clip_tokenizer (CLIP-L)", &clip_l_tokenizer),
490            ("clip_encoder_2 (CLIP-G)", &clip_g_path),
491            ("clip_tokenizer_2 (CLIP-G)", &clip_g_tokenizer),
492            ("t5_encoder", &t5_encoder_path),
493            ("t5_tokenizer", &t5_tokenizer_path),
494        ] {
495            if !path.exists() {
496                bail!("{label} file not found: {}", path.display());
497            }
498        }
499
500        Ok((
501            clip_l_path,
502            clip_l_tokenizer,
503            clip_g_path,
504            clip_g_tokenizer,
505            t5_encoder_path,
506            t5_tokenizer_path,
507        ))
508    }
509
510    /// Load all model components into GPU memory (Eager mode).
511    ///
512    /// On error, `self.base.loaded` remains `None` — all components are assembled into
513    /// local variables and only stored in `self.base.loaded` on success, so partial loads
514    /// cannot leave the engine in an inconsistent state.
515    pub fn load(&mut self) -> Result<()> {
516        if self.base.loaded.is_some() {
517            return Ok(());
518        }
519
520        // Sequential mode defers loading to generate_sequential()
521        if self.base.load_strategy == LoadStrategy::Sequential {
522            return Ok(());
523        }
524
525        tracing::info!(model = %self.base.model_name, "loading SD3 model components...");
526
527        let (
528            clip_l_path,
529            clip_l_tokenizer,
530            clip_g_path,
531            clip_g_tokenizer,
532            t5_encoder_path,
533            t5_tokenizer_path,
534        ) = self.validate_paths()?;
535
536        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
537        let gpu_dtype = if crate::device::is_gpu(&device) {
538            DType::F16
539        } else {
540            DType::F32
541        };
542
543        let is_quantized = self.detect_is_quantized();
544        let mmdit_config = self.mmdit_config();
545
546        // --- Load MMDiT transformer on GPU first ---
547        let xformer_label = if is_quantized {
548            "Loading SD3 MMDiT transformer (GPU, quantized)"
549        } else {
550            "Loading SD3 MMDiT transformer (GPU, FP16)"
551        };
552        self.base.progress.stage_start(xformer_label);
553        let xformer_stage = Instant::now();
554
555        let transformer = if is_quantized {
556            // GGUF files from city96 use unprefixed tensor names (no "model.diffusion_model.")
557            let vb = quantized_var_builder::VarBuilder::from_gguf(
558                &self.base.paths.transformer,
559                &device,
560            )?;
561            SD3Transformer::Quantized(QuantizedMMDiT::new(&mmdit_config, vb)?)
562        } else {
563            // BF16 safetensors from stabilityai use "model.diffusion_model." prefix
564            let vb = crate::weight_loader::load_safetensors_with_progress(
565                std::slice::from_ref(&self.base.paths.transformer),
566                gpu_dtype,
567                &device,
568                "SD3 transformer",
569                &self.base.progress,
570            )?;
571            SD3Transformer::BF16(MMDiT::new(
572                &mmdit_config,
573                false,
574                vb.pp("model.diffusion_model"),
575            )?)
576        };
577        self.base
578            .progress
579            .stage_done(xformer_label, xformer_stage.elapsed());
580
581        // --- Decide encoder placement based on remaining VRAM ---
582        // Log the raw driver reading; pass the reserve-adjusted budget to
583        // variant resolution below.
584        let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
585        let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
586        if free_raw > 0 {
587            self.base.progress.info(&format!(
588                "Free VRAM after transformer: {}",
589                fmt_gb(free_raw)
590            ));
591        }
592
593        // --- Load triple encoder (CLIP-L + CLIP-G + T5) ---
594        // For T5, use variant resolution logic
595        self.base.progress.stage_start("Selecting T5 encoder");
596        let t5_resolve_start = Instant::now();
597        let t5_preference = self.t5_variant.as_deref();
598        let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
599            crate::encoders::variant_resolution::resolve_t5_variant(
600                &self.base.progress,
601                t5_preference,
602                &device,
603                free,
604                &t5_encoder_path,
605            )?;
606        self.base
607            .progress
608            .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
609
610        // Tier 1: honor `placement.text_encoders` — all three encoders share the knob.
611        let tier1 = self
612            .pending_placement
613            .as_ref()
614            .map(|p| p.text_encoders)
615            .unwrap_or_default();
616        let auto_encoder_device = if t5_on_gpu {
617            device.clone()
618        } else {
619            Device::Cpu
620        };
621        let encoder_device_owned =
622            crate::device::resolve_device(Some(tier1), || Ok(auto_encoder_device.clone()))?;
623        let encoder_device = &encoder_device_owned;
624        let t5_on_gpu = !encoder_device.is_cpu();
625        let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
626        let encoder_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
627
628        let encoder_label = format!("Loading SD3 triple encoder ({t5_device_label})");
629        self.base.progress.stage_start(&encoder_label);
630        let encoder_stage = Instant::now();
631        let (clip_l_tokenizer_handle, clip_g_tokenizer_handle, t5_tokenizer_handle) =
632            self.load_text_tokenizers(&clip_l_tokenizer, &clip_g_tokenizer, &t5_tokenizer_path)?;
633
634        let triple_encoder = encoders::sd3_clip::SD3TripleEncoder::load_with_tokenizers(
635            &clip_l_path,
636            &clip_l_tokenizer,
637            Some(clip_l_tokenizer_handle),
638            &clip_g_path,
639            &clip_g_tokenizer,
640            Some(clip_g_tokenizer_handle),
641            &resolved_t5_path,
642            &t5_tokenizer_path,
643            Some(t5_tokenizer_handle),
644            encoder_device,
645            encoder_dtype,
646            &self.base.progress,
647        )?;
648
649        self.base
650            .progress
651            .stage_done(&encoder_label, encoder_stage.elapsed());
652
653        self.base.loaded = Some(LoadedSD3 {
654            transformer: Some(transformer),
655            triple_encoder,
656            vae_vb_path: self.base.paths.vae.clone(),
657            device,
658            dtype: gpu_dtype,
659            _is_quantized: is_quantized,
660            is_turbo: self.is_turbo,
661            is_medium: self.is_medium,
662        });
663
664        tracing::info!(model = %self.base.model_name, "all SD3 model components loaded successfully");
665        Ok(())
666    }
667
668    /// Get SLG config if applicable (Medium only).
669    fn slg_config(&self) -> Option<SkipLayerGuidanceConfig> {
670        if self.is_medium {
671            Some(SkipLayerGuidanceConfig {
672                scale: 2.5,
673                start: 0.01,
674                end: 0.2,
675                layers: vec![7, 8, 9],
676            })
677        } else {
678            None
679        }
680    }
681
682    /// Generate an image using sequential loading strategy.
683    fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
684        let is_quantized = self.detect_is_quantized();
685        let active_loras = effective_loras(req);
686        match sd3_offload_decision(self.offload, is_quantized, !active_loras.is_empty()) {
687            SD3OffloadDecision::Disabled => {}
688            SD3OffloadDecision::Unsupported(reason) => bail!("{reason}"),
689            SD3OffloadDecision::Selected => {}
690        }
691
692        let (
693            clip_l_path,
694            clip_l_tokenizer,
695            clip_g_path,
696            clip_g_tokenizer,
697            t5_encoder_path,
698            t5_tokenizer_path,
699        ) = self.validate_paths()?;
700
701        if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
702            self.base.progress.info(&warning);
703        }
704
705        let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
706        let gpu_dtype = if crate::device::is_gpu(&device) {
707            DType::F16
708        } else {
709            DType::F32
710        };
711
712        let start = Instant::now();
713        let seed = req.seed.unwrap_or_else(rand_seed);
714
715        let width = req.width as usize;
716        let height = req.height as usize;
717
718        tracing::info!(
719            prompt = %req.prompt,
720            seed, width, height,
721            steps = req.steps,
722            guidance = req.guidance,
723            "starting sequential SD3 generation"
724        );
725
726        self.base
727            .progress
728            .info("Using sequential loading (load-use-drop) to minimize peak memory");
729
730        // --- Phase 1: Encode prompt (check cache first to skip encoder load) ---
731        let neg = req.negative_prompt.as_deref().unwrap_or("");
732        let cache_key = cfg_prompt_cache_key(&req.prompt, neg, req.guidance);
733        let (context, y) = if let Some((context, y)) =
734            restore_cached_tensor_pair(&self.prompt_cache, &cache_key, &device, gpu_dtype)?
735        {
736            self.base.progress.cache_hit("prompt conditioning");
737            (context, y)
738        } else {
739            // Reserve-adjusted reading drives the T5 variant selection.
740            let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
741            self.base.progress.stage_start("Selecting T5 encoder");
742            let t5_resolve_start = Instant::now();
743            let t5_preference = self.t5_variant.as_deref();
744            let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
745                crate::encoders::variant_resolution::resolve_t5_variant(
746                    &self.base.progress,
747                    t5_preference,
748                    &device,
749                    free,
750                    &t5_encoder_path,
751                )?;
752            self.base
753                .progress
754                .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
755
756            let tier1 = self
757                .pending_placement
758                .as_ref()
759                .map(|p| p.text_encoders)
760                .unwrap_or_default();
761            let auto_encoder_device = if t5_on_gpu {
762                device.clone()
763            } else {
764                Device::Cpu
765            };
766            let encoder_device_owned =
767                crate::device::resolve_device(Some(tier1), || Ok(auto_encoder_device.clone()))?;
768            let encoder_device = &encoder_device_owned;
769            let t5_on_gpu = !encoder_device.is_cpu();
770            let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
771            let encoder_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
772
773            let t5_size = std::fs::metadata(&resolved_t5_path)
774                .map(|m| m.len())
775                .unwrap_or(0);
776            let te_activation_budget = crate::device::activation_bytes(
777                req.width,
778                req.height,
779                1,
780                crate::device::dtype_bytes(encoder_dtype),
781                crate::device::ActivationFamily::SmallTransformer,
782            );
783            preflight_memory_check("SD3 triple encoder", t5_size, te_activation_budget)?;
784            if let Some(status) = memory_status_string() {
785                self.base.progress.info(&status);
786            }
787
788            let encoder_label = format!("Loading SD3 triple encoder ({t5_device_label})");
789            self.base.progress.stage_start(&encoder_label);
790            let encoder_stage = Instant::now();
791            let (clip_l_tokenizer_handle, clip_g_tokenizer_handle, t5_tokenizer_handle) = self
792                .load_text_tokenizers(&clip_l_tokenizer, &clip_g_tokenizer, &t5_tokenizer_path)?;
793            let mut triple_encoder = encoders::sd3_clip::SD3TripleEncoder::load_with_tokenizers(
794                &clip_l_path,
795                &clip_l_tokenizer,
796                Some(clip_l_tokenizer_handle),
797                &clip_g_path,
798                &clip_g_tokenizer,
799                Some(clip_g_tokenizer_handle),
800                &resolved_t5_path,
801                &t5_tokenizer_path,
802                Some(t5_tokenizer_handle),
803                encoder_device,
804                encoder_dtype,
805                &self.base.progress,
806            )?;
807            self.base
808                .progress
809                .stage_done(&encoder_label, encoder_stage.elapsed());
810
811            let (context, y) = Self::encode_conditioning(
812                &self.base.progress,
813                &self.prompt_cache,
814                &mut triple_encoder,
815                &req.prompt,
816                neg,
817                req.guidance,
818                &device,
819                gpu_dtype,
820                is_quantized,
821            )?;
822
823            drop(triple_encoder);
824            self.base.progress.info("Freed SD3 triple encoder");
825
826            (context, y)
827        };
828
829        // --- Phase 2: img2img — encode source image if provided ---
830        let noise_dtype = if is_quantized { DType::F32 } else { gpu_dtype };
831        let latent_h = height / 16 * 2;
832        let latent_w = width / 16 * 2;
833        let time_shift = 3.0;
834
835        // Build sigma schedule
836        let num_steps = req.steps as usize;
837        let mut sigmas: Vec<f64> = (0..=num_steps)
838            .map(|s| s as f64 / num_steps as f64)
839            .rev()
840            .map(|t| sampling::time_snr_shift(time_shift, t))
841            .collect();
842
843        if req.source_image.is_some() {
844            let (trimmed, start_index) =
845                crate::img2img::trim_schedule_tail(&sigmas, req.steps as usize, req.strength);
846            sigmas = trimmed;
847            tracing::info!(
848                strength = req.strength,
849                start_index,
850                start_sigma = sigmas[0],
851                schedule = ?sigmas,
852                remaining_steps = sigmas.len().saturating_sub(1),
853                "img2img: truncated schedule from strength"
854            );
855        }
856
857        let (initial_latents, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
858            let start_t = sigmas[0];
859
860            // Load VAE early for source image encoding
861            self.base.progress.stage_start("Loading VAE for encoding");
862            let vae_stage = Instant::now();
863            let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
864            let vae_vb = self.load_vae_var_builder(
865                &self.base.paths.vae,
866                vae_dtype,
867                &device,
868                "VAE",
869                &self.base.progress,
870            )?;
871            let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
872            let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
873            self.base
874                .progress
875                .stage_done("Loading VAE for encoding", vae_stage.elapsed());
876
877            self.base
878                .progress
879                .stage_start("Encoding source image (VAE)");
880            let encode_start = Instant::now();
881            let source_tensor = img_utils::decode_source_image(
882                source_bytes,
883                req.width,
884                req.height,
885                Self::img2img_source_normalize_range(),
886                &device,
887                vae_dtype,
888            )?;
889            let dist = autoencoder.encode(&source_tensor)?;
890            // SD3 VAE encode scaling: reverse of decode's x / 1.5305 + 0.0609.
891            // Use the posterior mean so img2img remains deterministic.
892            let encoded = ((dist.mode()? - 0.0609)? * 1.5305)?;
893            self.base
894                .progress
895                .stage_done("Encoding source image (VAE)", encode_start.elapsed());
896
897            // Drop VAE to free VRAM for transformer (will reload for decode)
898            drop(autoencoder);
899            device.synchronize()?;
900            self.base
901                .progress
902                .info("Freed VAE encoder to make room for transformer");
903
904            let encoded = encoded.to_dtype(noise_dtype)?;
905            let prepared = crate::img2img::prepare_flow_match_img2img(
906                &encoded,
907                seed,
908                &[1, 16, latent_h, latent_w],
909                start_t,
910                req.mask_image.as_deref(),
911                latent_h,
912                latent_w,
913                &device,
914                noise_dtype,
915            )?;
916            (Some(prepared.initial_latents), prepared.inpaint_ctx)
917        } else {
918            (None, None)
919        };
920
921        // --- Phase 3: Load transformer + denoise ---
922        let mmdit_config = self.mmdit_config();
923
924        let xformer_size = if self.offload && !is_quantized && active_loras.is_empty() {
925            0
926        } else {
927            std::fs::metadata(&self.base.paths.transformer)
928                .map(|m| m.len())
929                .unwrap_or(0)
930        };
931        // SD3 runs CFG by default → batch=2 if guidance > 1, else batch=1.
932        let xformer_batch = if req.guidance > 1.0 { 2 } else { 1 };
933        let xformer_activation_budget = crate::device::activation_bytes(
934            req.width,
935            req.height,
936            xformer_batch,
937            crate::device::dtype_bytes(gpu_dtype),
938            crate::device::ActivationFamily::Sd3Mmdit,
939        );
940        preflight_memory_check(
941            "SD3 MMDiT transformer",
942            xformer_size,
943            xformer_activation_budget,
944        )?;
945        if let Some(status) = memory_status_string() {
946            self.base.progress.info(&status);
947        }
948
949        let active_loras = effective_loras(req);
950        let lora_delta_cache = self.lora_delta_cache.clone();
951        let xformer_label = match (is_quantized, active_loras.is_empty(), self.offload) {
952            (true, true, _) => "Loading SD3 MMDiT transformer (GPU, quantized)",
953            (true, false, _) => "Loading SD3 MMDiT transformer (GPU, quantized, with LoRA)",
954            (false, true, true) => "Loading SD3 MMDiT transformer (offload, FP16)",
955            (false, true, false) => "Loading SD3 MMDiT transformer (GPU, FP16)",
956            (false, false, _) => "Loading SD3 MMDiT transformer (GPU, FP16, with LoRA)",
957        };
958        self.base.progress.stage_start(xformer_label);
959        let xformer_stage = Instant::now();
960
961        let transformer = if is_quantized {
962            let vb = if active_loras.is_empty() {
963                quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, &device)?
964            } else {
965                sd3_gguf_lora_var_builder(
966                    &self.base.paths.transformer,
967                    &active_loras,
968                    &device,
969                    &self.base.progress,
970                    Some(lora_delta_cache.clone()),
971                )?
972            };
973            SD3Transformer::Quantized(QuantizedMMDiT::new(&mmdit_config, vb)?)
974        } else if active_loras.is_empty() && self.offload {
975            let tensors = self.load_transformer_cpu_tensors()?;
976            SD3Transformer::Offloaded(Box::new(super::offload::OffloadedMMDiT::new(
977                &mmdit_config,
978                tensors,
979                gpu_dtype,
980                &device,
981            )?))
982        } else if active_loras.is_empty() {
983            // BF16 safetensors from stabilityai use "model.diffusion_model." prefix
984            let vb = crate::weight_loader::load_safetensors_with_progress(
985                std::slice::from_ref(&self.base.paths.transformer),
986                gpu_dtype,
987                &device,
988                "SD3 transformer",
989                &self.base.progress,
990            )?;
991            SD3Transformer::BF16(MMDiT::new(
992                &mmdit_config,
993                false,
994                vb.pp("model.diffusion_model"),
995            )?)
996        } else {
997            // LoRA path: the `LoraBackend` strips the on-disk prefix
998            // automatically, so the builder is positioned at the prefix
999            // root — feed it straight to `MMDiT::new`.
1000            let vb = sd3_lora_var_builder(
1001                &self.base.paths.transformer,
1002                &active_loras,
1003                gpu_dtype,
1004                &device,
1005                &self.base.progress,
1006                Some(lora_delta_cache.clone()),
1007            )?;
1008            SD3Transformer::BF16(MMDiT::new(&mmdit_config, false, vb)?)
1009        };
1010        self.base
1011            .progress
1012            .stage_done(xformer_label, xformer_stage.elapsed());
1013
1014        // Denoise
1015        let slg_config = self.slg_config();
1016        let actual_steps = sigmas.len().saturating_sub(1);
1017        let denoise_label = format!("Denoising ({actual_steps} steps)");
1018        self.base.progress.stage_start(&denoise_label);
1019        let denoise_start = Instant::now();
1020
1021        let x = sampling::euler_sample(
1022            &transformer,
1023            &y,
1024            &context,
1025            num_steps,
1026            req.guidance,
1027            resolve_cfg_plus(req),
1028            time_shift,
1029            height,
1030            width,
1031            slg_config.as_ref(),
1032            is_quantized,
1033            seed,
1034            &self.base.progress,
1035            initial_latents.as_ref(),
1036            Some(sigmas),
1037            inpaint_ctx.as_ref(),
1038        )?;
1039
1040        self.base
1041            .progress
1042            .stage_done(&denoise_label, denoise_start.elapsed());
1043
1044        // Drop transformer to free memory for VAE
1045        drop(transformer);
1046        drop(context);
1047        drop(y);
1048        drop(inpaint_ctx);
1049        device.synchronize()?;
1050        self.base.progress.info("Freed SD3 MMDiT transformer");
1051
1052        // --- Phase 4: VAE decode ---
1053        self.base.progress.stage_start("Loading VAE (GPU)");
1054        let vae_stage = Instant::now();
1055        let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
1056        let vae_vb = self.load_vae_var_builder(
1057            &self.base.paths.vae,
1058            vae_dtype,
1059            &device,
1060            "VAE",
1061            &self.base.progress,
1062        )?;
1063        let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
1064        let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
1065        self.base
1066            .progress
1067            .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
1068
1069        self.base.progress.stage_start("VAE decode");
1070        let vae_decode_start = Instant::now();
1071
1072        // SD3 VAE scaling: x / 1.5305 + 0.0609
1073        // Cast to VAE dtype (quantized path outputs F32, VAE is F16/BF16/F32
1074        // depending on MOLD_VAE_DTYPE).
1075        let x = ((x / 1.5305)? + 0.0609)?.to_dtype(vae_dtype)?;
1076        let device_for_sync = device.clone();
1077        let img = crate::vae_tiling::decode_with_oom_fallback(
1078            &x,
1079            |t| autoencoder.decode(t).map_err(Into::into),
1080            || {
1081                if let Err(e) = device_for_sync.synchronize() {
1082                    tracing::warn!(
1083                        "SD3 (sequential) device.synchronize() after VAE OOM failed: {e}"
1084                    );
1085                }
1086            },
1087        )?;
1088
1089        let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1090        let img = img.i(0)?;
1091
1092        self.base
1093            .progress
1094            .stage_done("VAE decode", vae_decode_start.elapsed());
1095
1096        let output_metadata = build_output_metadata(req, seed, None);
1097        let image_bytes = encode_image(
1098            &img,
1099            req.resolved_output_format(),
1100            req.width,
1101            req.height,
1102            output_metadata.as_ref(),
1103        )?;
1104
1105        let generation_time_ms = start.elapsed().as_millis() as u64;
1106        tracing::info!(
1107            generation_time_ms,
1108            seed,
1109            "sequential SD3 generation complete"
1110        );
1111
1112        Ok(GenerateResponse {
1113            images: vec![ImageData {
1114                data: image_bytes,
1115                format: req.resolved_output_format(),
1116                width: req.width,
1117                height: req.height,
1118                index: 0,
1119            }],
1120            generation_time_ms,
1121            model: req.model.clone(),
1122            seed_used: seed,
1123            video: None,
1124            gpu: None,
1125        })
1126    }
1127}
1128
1129impl SD3Engine {
1130    fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1131        if req.scheduler.is_some() {
1132            tracing::warn!("scheduler selection not supported for SD3 (flow-matching), ignoring");
1133        }
1134
1135        // Sequential mode: load-use-drop each component
1136        if self.uses_sequential_generate_path() {
1137            return self.generate_sequential(req);
1138        }
1139
1140        // Eager mode: use pre-loaded components
1141        let progress = &self.base.progress;
1142        let prompt_cache = &self.prompt_cache;
1143        let mmdit_config = self.mmdit_config();
1144        let transformer_path = self.base.paths.transformer.clone();
1145        let active_loras = effective_loras(req);
1146        let lora_delta_cache = self.lora_delta_cache.clone();
1147        let shared_pool = self.shared_pool.clone();
1148
1149        let mut loaded = OptionRestoreGuard::take(&mut self.base.loaded)
1150            .ok_or_else(|| anyhow::anyhow!("model not loaded -- call load() first"))?;
1151        let loaded_dtype = loaded.dtype;
1152        let loaded_device = loaded.device.clone();
1153        let is_quantized = loaded._is_quantized;
1154
1155        // LoRA: force a transformer rebuild via the LoRA path. The
1156        // eager-loaded transformer is the unpatched base; even if the
1157        // same stack was applied on a previous generate, the in-memory
1158        // tensors no longer carry that delta after the post-decode
1159        // drop. The `LoraDeltaCache` keeps the matmul work cheap on
1160        // every rebuild.
1161        if !active_loras.is_empty() && loaded.transformer.is_some() {
1162            loaded.transformer = None;
1163            loaded_device.synchronize()?;
1164            progress.info("SD3 LoRA: dropping base transformer for LoRA merge");
1165        }
1166
1167        let start = Instant::now();
1168        let seed = req.seed.unwrap_or_else(rand_seed);
1169
1170        let width = req.width as usize;
1171        let height = req.height as usize;
1172
1173        tracing::info!(
1174            prompt = %req.prompt,
1175            seed, width, height,
1176            steps = req.steps,
1177            guidance = req.guidance,
1178            turbo = loaded.is_turbo,
1179            medium = loaded.is_medium,
1180            "starting SD3 generation"
1181        );
1182
1183        (|| -> Result<GenerateResponse> {
1184            if !loaded.triple_encoder.is_loaded() {
1185                let label = if loaded.triple_encoder.is_parked() {
1186                    "Unparking SD3 triple encoder (CPU→GPU)"
1187                } else {
1188                    "Reloading SD3 triple encoder"
1189                };
1190                progress.stage_start(label);
1191                let reload_start = Instant::now();
1192                if loaded.triple_encoder.is_parked() {
1193                    loaded
1194                        .triple_encoder
1195                        .unpark_to_gpu(loaded_dtype, progress)?;
1196                } else {
1197                    loaded.triple_encoder.reload(loaded_dtype, progress)?;
1198                }
1199                progress.stage_done(label, reload_start.elapsed());
1200            }
1201
1202            let neg = req.negative_prompt.as_deref().unwrap_or("");
1203            let (context, y) = Self::encode_conditioning(
1204                progress,
1205                prompt_cache,
1206                &mut loaded.triple_encoder,
1207                &req.prompt,
1208                neg,
1209                req.guidance,
1210                &loaded_device,
1211                loaded_dtype,
1212                is_quantized,
1213            )?;
1214
1215            if loaded.triple_encoder.on_gpu {
1216                // Park mode keeps the FP16 encoders alive on host RAM (~9 GB
1217                // T5 + ~1.6 GB CLIPs). Disabled on Metal (unified memory).
1218                let park_mode = crate::device::keep_te_in_ram() && !loaded_device.is_metal();
1219                if park_mode {
1220                    loaded.triple_encoder.park_to_cpu()?;
1221                    tracing::info!("SD3 triple encoder parked to CPU host RAM");
1222                } else {
1223                    loaded.triple_encoder.drop_weights();
1224                    tracing::info!(
1225                        "SD3 triple encoder dropped from GPU to free VRAM for denoising"
1226                    );
1227                }
1228            }
1229
1230            // --- img2img: build schedule and encode source image ---
1231            let noise_dtype = if is_quantized {
1232                DType::F32
1233            } else {
1234                loaded_dtype
1235            };
1236            let latent_h = height / 16 * 2;
1237            let latent_w = width / 16 * 2;
1238            let time_shift = 3.0;
1239            let num_steps = req.steps as usize;
1240
1241            let mut sigmas: Vec<f64> = (0..=num_steps)
1242                .map(|s| s as f64 / num_steps as f64)
1243                .rev()
1244                .map(|t| sampling::time_snr_shift(time_shift, t))
1245                .collect();
1246
1247            if req.source_image.is_some() {
1248                let (trimmed, start_index) =
1249                    crate::img2img::trim_schedule_tail(&sigmas, req.steps as usize, req.strength);
1250                sigmas = trimmed;
1251                tracing::info!(
1252                    strength = req.strength,
1253                    start_index,
1254                    start_sigma = sigmas[0],
1255                    schedule = ?sigmas,
1256                    remaining_steps = sigmas.len().saturating_sub(1),
1257                    "img2img: truncated schedule from strength"
1258                );
1259            }
1260
1261            let (initial_latents, inpaint_ctx, early_vae) =
1262                if let Some(ref source_bytes) = req.source_image {
1263                    let start_t = sigmas[0];
1264
1265                    // Drop transformer to make room for VAE encoding
1266                    loaded.transformer = None;
1267                    loaded.device.synchronize()?;
1268
1269                    progress.stage_start("Loading VAE for encoding");
1270                    let vae_stage = Instant::now();
1271                    let vae_dtype = crate::device::resolve_vae_dtype(loaded_dtype);
1272                    let vae_vb = Self::load_vae_var_builder_from_pool(
1273                        shared_pool.as_ref(),
1274                        &loaded.vae_vb_path,
1275                        vae_dtype,
1276                        &loaded.device,
1277                        "VAE",
1278                        progress,
1279                    )?;
1280                    let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
1281                    let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
1282                    progress.stage_done("Loading VAE for encoding", vae_stage.elapsed());
1283
1284                    progress.stage_start("Encoding source image (VAE)");
1285                    let encode_start = Instant::now();
1286                    let source_tensor = img_utils::decode_source_image(
1287                        source_bytes,
1288                        req.width,
1289                        req.height,
1290                        Self::img2img_source_normalize_range(),
1291                        &loaded_device,
1292                        vae_dtype,
1293                    )?;
1294                    let dist = autoencoder.encode(&source_tensor)?;
1295                    // SD3 VAE encode scaling: reverse of decode's x / 1.5305 + 0.0609.
1296                    // Use the posterior mean so img2img remains deterministic.
1297                    let encoded = ((dist.mode()? - 0.0609)? * 1.5305)?;
1298                    progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
1299
1300                    // Drop VAE to free VRAM for transformer reload
1301                    drop(autoencoder);
1302                    loaded.device.synchronize()?;
1303
1304                    let encoded = encoded.to_dtype(noise_dtype)?;
1305                    let prepared = crate::img2img::prepare_flow_match_img2img(
1306                        &encoded,
1307                        seed,
1308                        &[1, 16, latent_h, latent_w],
1309                        start_t,
1310                        req.mask_image.as_deref(),
1311                        latent_h,
1312                        latent_w,
1313                        &loaded_device,
1314                        noise_dtype,
1315                    )?;
1316                    (
1317                        Some(prepared.initial_latents),
1318                        prepared.inpaint_ctx,
1319                        None::<()>,
1320                    )
1321                } else {
1322                    (None, None, None)
1323                };
1324
1325            // Reload transformer if needed (dropped for img2img VAE encoding, or prior VAE decode)
1326            if loaded.transformer.is_none() {
1327                let reload_label = if active_loras.is_empty() {
1328                    "Reloading SD3 transformer"
1329                } else {
1330                    "Reloading SD3 transformer (with LoRA)"
1331                };
1332                progress.stage_start(reload_label);
1333                let reload_start = Instant::now();
1334                let transformer = if is_quantized {
1335                    let vb = if active_loras.is_empty() {
1336                        quantized_var_builder::VarBuilder::from_gguf(
1337                            &transformer_path,
1338                            &loaded_device,
1339                        )?
1340                    } else {
1341                        sd3_gguf_lora_var_builder(
1342                            &transformer_path,
1343                            &active_loras,
1344                            &loaded_device,
1345                            progress,
1346                            Some(lora_delta_cache.clone()),
1347                        )?
1348                    };
1349                    SD3Transformer::Quantized(QuantizedMMDiT::new(&mmdit_config, vb)?)
1350                } else if active_loras.is_empty() {
1351                    let vb = crate::weight_loader::load_safetensors_with_progress(
1352                        std::slice::from_ref(&transformer_path),
1353                        loaded_dtype,
1354                        &loaded_device,
1355                        "SD3 transformer",
1356                        progress,
1357                    )?;
1358                    let vb = vb.pp("model.diffusion_model");
1359                    SD3Transformer::BF16(MMDiT::new(&mmdit_config, false, vb)?)
1360                } else {
1361                    // LoRA path: the `LoraBackend` already absorbs the
1362                    // on-disk `model.diffusion_model.` prefix, so the
1363                    // builder is positioned at the prefix root.
1364                    let vb = sd3_lora_var_builder(
1365                        &transformer_path,
1366                        &active_loras,
1367                        loaded_dtype,
1368                        &loaded_device,
1369                        progress,
1370                        Some(lora_delta_cache.clone()),
1371                    )?;
1372                    SD3Transformer::BF16(MMDiT::new(&mmdit_config, false, vb)?)
1373                };
1374                loaded.transformer = Some(transformer);
1375                progress.stage_done(reload_label, reload_start.elapsed());
1376            }
1377
1378            let slg_config = if loaded.is_medium {
1379                Some(SkipLayerGuidanceConfig {
1380                    scale: 2.5,
1381                    start: 0.01,
1382                    end: 0.2,
1383                    layers: vec![7, 8, 9],
1384                })
1385            } else {
1386                None
1387            };
1388
1389            let actual_steps = sigmas.len().saturating_sub(1);
1390            let denoise_label = format!("Denoising ({actual_steps} steps)");
1391            progress.stage_start(&denoise_label);
1392            let denoise_start = Instant::now();
1393
1394            let transformer = loaded
1395                .transformer
1396                .as_ref()
1397                .ok_or_else(|| anyhow::anyhow!("SD3 transformer not loaded"))?;
1398            let x = sampling::euler_sample(
1399                transformer,
1400                &y,
1401                &context,
1402                num_steps,
1403                req.guidance,
1404                resolve_cfg_plus(req),
1405                time_shift,
1406                height,
1407                width,
1408                slg_config.as_ref(),
1409                loaded._is_quantized,
1410                seed,
1411                progress,
1412                initial_latents.as_ref(),
1413                Some(sigmas),
1414                inpaint_ctx.as_ref(),
1415            )?;
1416
1417            progress.stage_done(&denoise_label, denoise_start.elapsed());
1418            drop(context);
1419            drop(y);
1420            drop(inpaint_ctx);
1421            let _ = early_vae;
1422
1423            // Drop transformer before VAE decode to free VRAM.
1424            loaded.transformer = None;
1425            loaded.device.synchronize()?;
1426            tracing::info!("SD3 transformer dropped to free VRAM for VAE decode");
1427
1428            progress.stage_start("VAE decode");
1429            let vae_decode_start = Instant::now();
1430
1431            let vae_dtype = crate::device::resolve_vae_dtype(loaded.dtype);
1432            let vae_vb = Self::load_vae_var_builder_from_pool(
1433                shared_pool.as_ref(),
1434                &loaded.vae_vb_path,
1435                vae_dtype,
1436                &loaded.device,
1437                "VAE",
1438                progress,
1439            )?;
1440            let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
1441            let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
1442
1443            let x = ((x / 1.5305)? + 0.0609)?.to_dtype(vae_dtype)?;
1444            let device_for_sync = loaded.device.clone();
1445            let img = crate::vae_tiling::decode_with_oom_fallback(
1446                &x,
1447                |t| autoencoder.decode(t).map_err(Into::into),
1448                || {
1449                    if let Err(e) = device_for_sync.synchronize() {
1450                        tracing::warn!(
1451                            "SD3 (parallel) device.synchronize() after VAE OOM failed: {e}"
1452                        );
1453                    }
1454                },
1455            )?;
1456
1457            let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1458            let img = img.i(0)?;
1459
1460            progress.stage_done("VAE decode", vae_decode_start.elapsed());
1461
1462            let output_metadata = build_output_metadata(req, seed, None);
1463            let image_bytes = encode_image(
1464                &img,
1465                req.resolved_output_format(),
1466                req.width,
1467                req.height,
1468                output_metadata.as_ref(),
1469            )?;
1470
1471            let generation_time_ms = start.elapsed().as_millis() as u64;
1472            tracing::info!(generation_time_ms, seed, "SD3 generation complete");
1473
1474            Ok(GenerateResponse {
1475                images: vec![ImageData {
1476                    data: image_bytes,
1477                    format: req.resolved_output_format(),
1478                    width: req.width,
1479                    height: req.height,
1480                    index: 0,
1481                }],
1482                generation_time_ms,
1483                model: req.model.clone(),
1484                seed_used: seed,
1485                video: None,
1486                gpu: None,
1487            })
1488        })()
1489    }
1490}
1491
1492impl InferenceEngine for SD3Engine {
1493    fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1494        self.pending_placement = req.placement.clone();
1495        let result = self.generate_inner(req);
1496        self.pending_placement = None;
1497        result
1498    }
1499
1500    fn model_name(&self) -> &str {
1501        self.base.model_name()
1502    }
1503
1504    fn is_loaded(&self) -> bool {
1505        self.base.is_loaded()
1506    }
1507
1508    fn load(&mut self) -> Result<()> {
1509        SD3Engine::load(self)
1510    }
1511
1512    fn unload(&mut self) {
1513        self.base.unload();
1514        clear_cache(&self.prompt_cache);
1515    }
1516
1517    fn set_on_progress(&mut self, callback: ProgressCallback) {
1518        self.base.set_on_progress(callback);
1519    }
1520
1521    fn clear_on_progress(&mut self) {
1522        self.base.clear_on_progress();
1523    }
1524
1525    fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1526        Some(&self.base.paths)
1527    }
1528}
1529
1530#[cfg(test)]
1531mod tests {
1532    use super::*;
1533    use crate::engine::LoadStrategy;
1534    use crate::shared_pool::SharedPool;
1535    use mold_core::ModelPaths;
1536    use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1537    use std::collections::HashMap;
1538    use std::fs;
1539    use std::path::{Path, PathBuf};
1540    use std::sync::{Arc, Mutex};
1541    use std::time::{SystemTime, UNIX_EPOCH};
1542    use tokenizers::models::bpe::BPE;
1543
1544    fn temp_test_dir(prefix: &str) -> PathBuf {
1545        let suffix = SystemTime::now()
1546            .duration_since(UNIX_EPOCH)
1547            .unwrap()
1548            .as_nanos();
1549        let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1550        fs::create_dir_all(&dir).unwrap();
1551        dir
1552    }
1553
1554    fn touch(dir: &Path, name: &str) -> PathBuf {
1555        let path = dir.join(name);
1556        fs::write(&path, b"test").unwrap();
1557        path
1558    }
1559
1560    #[allow(clippy::too_many_arguments)]
1561    fn sd3_model_paths(
1562        transformer: PathBuf,
1563        vae: PathBuf,
1564        clip_l_path: Option<PathBuf>,
1565        clip_l_tokenizer: Option<PathBuf>,
1566        clip_g_path: Option<PathBuf>,
1567        clip_g_tokenizer: Option<PathBuf>,
1568        t5_encoder: Option<PathBuf>,
1569        t5_tokenizer: Option<PathBuf>,
1570    ) -> ModelPaths {
1571        ModelPaths {
1572            transformer,
1573            transformer_shards: vec![],
1574            vae,
1575            spatial_upscaler: None,
1576            temporal_upscaler: None,
1577            distilled_lora: None,
1578            t5_encoder,
1579            clip_encoder: clip_l_path,
1580            t5_tokenizer,
1581            clip_tokenizer: clip_l_tokenizer,
1582            clip_encoder_2: clip_g_path,
1583            clip_tokenizer_2: clip_g_tokenizer,
1584            text_encoder_files: vec![],
1585            text_tokenizer: None,
1586            decoder: None,
1587        }
1588    }
1589
1590    #[test]
1591    fn sd3_img2img_uses_minus_one_to_one_source_normalization() {
1592        assert_eq!(
1593            SD3Engine::img2img_source_normalize_range(),
1594            img_utils::NormalizeRange::MinusOneToOne
1595        );
1596    }
1597
1598    #[test]
1599    fn sd3_mmdit_config_tracks_large_vs_medium_variants() {
1600        let base_dir = temp_test_dir("mold-sd3-config");
1601        let large = SD3Engine::new(
1602            "sd3.5-large:bf16".to_string(),
1603            sd3_model_paths(
1604                base_dir.join("transformer.safetensors"),
1605                base_dir.join("vae.safetensors"),
1606                None,
1607                None,
1608                None,
1609                None,
1610                None,
1611                None,
1612            ),
1613            false,
1614            false,
1615            None,
1616            LoadStrategy::Sequential,
1617            0,
1618            false,
1619            None,
1620        );
1621        let medium = SD3Engine::new(
1622            "sd3.5-medium:bf16".to_string(),
1623            sd3_model_paths(
1624                base_dir.join("transformer.safetensors"),
1625                base_dir.join("vae.safetensors"),
1626                None,
1627                None,
1628                None,
1629                None,
1630                None,
1631                None,
1632            ),
1633            false,
1634            true,
1635            None,
1636            LoadStrategy::Sequential,
1637            0,
1638            false,
1639            None,
1640        );
1641
1642        let large_cfg = large.mmdit_config();
1643        let medium_cfg = medium.mmdit_config();
1644
1645        assert_eq!(large_cfg.depth, 38);
1646        assert_eq!(large_cfg.pos_embed_max_size, 192);
1647        assert_eq!(medium_cfg.depth, 24);
1648        assert_eq!(medium_cfg.pos_embed_max_size, 384);
1649        assert!(large.slg_config().is_none());
1650        let slg = medium.slg_config().unwrap();
1651        assert_eq!(slg.scale, 2.5);
1652        assert_eq!(slg.layers, vec![7, 8, 9]);
1653
1654        fs::remove_dir_all(base_dir).ok();
1655    }
1656
1657    #[test]
1658    fn sd3_validate_paths_accepts_existing_files() {
1659        let dir = temp_test_dir("mold-sd3-validate-ok");
1660        let transformer = touch(&dir, "transformer.gguf");
1661        let vae = touch(&dir, "vae.safetensors");
1662        let clip_l = touch(&dir, "clip-l.safetensors");
1663        let clip_l_tok = touch(&dir, "clip-l-tokenizer.json");
1664        let clip_g = touch(&dir, "clip-g.safetensors");
1665        let clip_g_tok = touch(&dir, "clip-g-tokenizer.json");
1666        let t5 = touch(&dir, "t5.safetensors");
1667        let t5_tok = touch(&dir, "t5-tokenizer.json");
1668
1669        let engine = SD3Engine::new(
1670            "sd3.5-large-turbo:q8".to_string(),
1671            sd3_model_paths(
1672                transformer,
1673                vae,
1674                Some(clip_l),
1675                Some(clip_l_tok),
1676                Some(clip_g),
1677                Some(clip_g_tok),
1678                Some(t5),
1679                Some(t5_tok.clone()),
1680            ),
1681            true,
1682            false,
1683            None,
1684            LoadStrategy::Sequential,
1685            0,
1686            false,
1687            None,
1688        );
1689
1690        let (_, _, _, _, _, resolved_t5_tok) = engine.validate_paths().unwrap();
1691        assert_eq!(resolved_t5_tok, t5_tok);
1692        assert!(engine.detect_is_quantized());
1693
1694        fs::remove_dir_all(dir).ok();
1695    }
1696
1697    #[test]
1698    fn sd3_forced_offload_uses_sequential_generation_path() {
1699        let dir = temp_test_dir("mold-sd3-offload-sequential");
1700        let engine = SD3Engine::new(
1701            "sd3.5-large:bf16".to_string(),
1702            sd3_model_paths(
1703                dir.join("transformer.safetensors"),
1704                dir.join("vae.safetensors"),
1705                None,
1706                None,
1707                None,
1708                None,
1709                None,
1710                None,
1711            ),
1712            false,
1713            false,
1714            None,
1715            LoadStrategy::Eager,
1716            0,
1717            true,
1718            None,
1719        );
1720
1721        assert!(
1722            engine.uses_sequential_generate_path(),
1723            "SD3 --offload requests must reach the engine and select the \
1724             staged generation path instead of being silently ignored"
1725        );
1726
1727        fs::remove_dir_all(dir).ok();
1728    }
1729
1730    #[test]
1731    fn sd3_offload_decision_gates_current_unsupported_cases() {
1732        assert_eq!(
1733            sd3_offload_decision(false, false, false),
1734            SD3OffloadDecision::Disabled
1735        );
1736        assert_eq!(
1737            sd3_offload_decision(true, false, false),
1738            SD3OffloadDecision::Selected
1739        );
1740        assert!(matches!(
1741            sd3_offload_decision(true, true, false),
1742            SD3OffloadDecision::Unsupported(reason)
1743                if reason.contains("GGUF variants")
1744        ));
1745        assert!(matches!(
1746            sd3_offload_decision(true, false, true),
1747            SD3OffloadDecision::Unsupported(reason)
1748                if reason.contains("LoRA")
1749        ));
1750    }
1751
1752    #[test]
1753    fn sd3_selected_bf16_offload_reaches_runtime_loader() {
1754        use crate::cache::store_cached_tensor_pair;
1755
1756        let dir = temp_test_dir("mold-sd3-offload-loader");
1757        let transformer = touch(&dir, "transformer.safetensors");
1758        let vae = touch(&dir, "vae.safetensors");
1759        let clip_l = touch(&dir, "clip-l.safetensors");
1760        let clip_l_tok = touch(&dir, "clip-l-tokenizer.json");
1761        let clip_g = touch(&dir, "clip-g.safetensors");
1762        let clip_g_tok = touch(&dir, "clip-g-tokenizer.json");
1763        let t5 = touch(&dir, "t5.safetensors");
1764        let t5_tok = touch(&dir, "t5-tokenizer.json");
1765        let mut engine = SD3Engine::new(
1766            "sd3.5-large:bf16".to_string(),
1767            sd3_model_paths(
1768                transformer,
1769                vae,
1770                Some(clip_l),
1771                Some(clip_l_tok),
1772                Some(clip_g),
1773                Some(clip_g_tok),
1774                Some(t5),
1775                Some(t5_tok),
1776            ),
1777            false,
1778            false,
1779            None,
1780            LoadStrategy::Sequential,
1781            0,
1782            true,
1783            None,
1784        );
1785        let context = Tensor::zeros((1, 1, 4096), DType::F32, &Device::Cpu).unwrap();
1786        let y = Tensor::zeros((1, 2048), DType::F32, &Device::Cpu).unwrap();
1787        let key = cfg_prompt_cache_key("a cat", "", 1.0);
1788        store_cached_tensor_pair(&engine.prompt_cache, key, &context, &y).unwrap();
1789        let req = GenerateRequest {
1790            prompt: "a cat".to_string(),
1791            negative_prompt: None,
1792            model: "sd3.5-large:bf16".to_string(),
1793            width: 64,
1794            height: 64,
1795            steps: 1,
1796            guidance: 1.0,
1797            seed: Some(1),
1798            batch_size: 1,
1799            output_format: None,
1800            embed_metadata: None,
1801            scheduler: None,
1802            cfg_plus: None,
1803            source_image: None,
1804            edit_images: None,
1805            strength: 1.0,
1806            mask_image: None,
1807            control_image: None,
1808            control_model: None,
1809            control_scale: 1.0,
1810            expand: None,
1811            original_prompt: None,
1812            lora: None,
1813            frames: None,
1814            fps: None,
1815            upscale_model: None,
1816            gif_preview: false,
1817            enable_audio: None,
1818            audio_file: None,
1819            audio_file_path: None,
1820            source_video: None,
1821            source_video_path: None,
1822            keyframes: None,
1823            pipeline: None,
1824            loras: None,
1825            retake_range: None,
1826            spatial_upscale: None,
1827            temporal_upscale: None,
1828            placement: None,
1829        };
1830
1831        let err = engine.generate_sequential(&req).unwrap_err().to_string();
1832
1833        assert!(
1834            !err.contains("streaming is not implemented yet"),
1835            "selected BF16 offload must reach the runtime loader, got: {err}"
1836        );
1837        fs::remove_dir_all(dir).ok();
1838    }
1839
1840    /// Regression test for the SD3 prompt-cache key bug: when the cache is
1841    /// keyed only on the positive prompt, a follow-up request that changes
1842    /// just the negative prompt returns the previous `(cond, uncond_old)`
1843    /// pair — silently producing wrong output.
1844    ///
1845    /// Drives the cache through `restore_cached_tensor_pair` directly because
1846    /// the surrounding `encode_conditioning` requires loaded T5/CLIP weights;
1847    /// the contract under test is the keying, not the encoder forward pass.
1848    #[test]
1849    fn sd3_prompt_cache_distinguishes_negative_prompt_changes() {
1850        use crate::cache::{cfg_prompt_cache_key, store_cached_tensor_pair};
1851
1852        let cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensorPair>> =
1853            Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY));
1854        let device = Device::Cpu;
1855        let dtype = DType::F32;
1856        let context = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
1857        let y = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
1858
1859        let key_a = cfg_prompt_cache_key("a cat", "blurry", 7.0);
1860        store_cached_tensor_pair(&cache, key_a.clone(), &context, &y).unwrap();
1861
1862        // Same positive + same guidance, different negative → MUST miss.
1863        let key_b = cfg_prompt_cache_key("a cat", "low quality", 7.0);
1864        let restored = restore_cached_tensor_pair(&cache, &key_b, &device, dtype).unwrap();
1865        assert!(
1866            restored.is_none(),
1867            "different negative prompt must miss the cache (was the silent-wrong-output bug)"
1868        );
1869
1870        // Same key as the insert → MUST hit.
1871        let restored = restore_cached_tensor_pair(&cache, &key_a, &device, dtype).unwrap();
1872        assert!(
1873            restored.is_some(),
1874            "identical (pos, neg, guidance) must still hit"
1875        );
1876    }
1877
1878    #[test]
1879    fn sd3_validate_paths_requires_t5_encoder() {
1880        let dir = temp_test_dir("mold-sd3-validate-missing");
1881        let engine = SD3Engine::new(
1882            "sd3.5-large:bf16".to_string(),
1883            sd3_model_paths(
1884                dir.join("transformer.safetensors"),
1885                dir.join("vae.safetensors"),
1886                Some(dir.join("clip-l.safetensors")),
1887                Some(dir.join("clip-l-tokenizer.json")),
1888                Some(dir.join("clip-g.safetensors")),
1889                Some(dir.join("clip-g-tokenizer.json")),
1890                None,
1891                Some(dir.join("t5-tokenizer.json")),
1892            ),
1893            false,
1894            false,
1895            None,
1896            LoadStrategy::Sequential,
1897            0,
1898            false,
1899            None,
1900        );
1901
1902        let err = engine.validate_paths().unwrap_err();
1903        assert!(err.to_string().contains("T5 encoder path required"));
1904        assert!(!engine.detect_is_quantized());
1905
1906        fs::remove_dir_all(dir).ok();
1907    }
1908
1909    #[test]
1910    fn sd3_loads_text_tokenizers_through_shared_pool() {
1911        let dir = temp_test_dir("mold-sd3-tokenizer-pool");
1912        let clip_l_tok = dir.join("clip-l-tokenizer.json");
1913        let clip_g_tok = dir.join("clip-g-tokenizer.json");
1914        let t5_tok = dir.join("t5-tokenizer.json");
1915        for path in [&clip_l_tok, &clip_g_tok, &t5_tok] {
1916            tokenizers::Tokenizer::new(BPE::default())
1917                .save(path, false)
1918                .unwrap();
1919        }
1920
1921        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1922        let pooled_clip_l = shared_pool
1923            .lock()
1924            .unwrap()
1925            .load_tokenizer(&clip_l_tok)
1926            .unwrap();
1927        let pooled_clip_g = shared_pool
1928            .lock()
1929            .unwrap()
1930            .load_tokenizer(&clip_g_tok)
1931            .unwrap();
1932        let pooled_t5 = shared_pool.lock().unwrap().load_tokenizer(&t5_tok).unwrap();
1933
1934        let engine = SD3Engine::new(
1935            "sd3.5-large:bf16".to_string(),
1936            sd3_model_paths(
1937                dir.join("transformer.safetensors"),
1938                dir.join("vae.safetensors"),
1939                Some(dir.join("clip-l.safetensors")),
1940                Some(clip_l_tok.clone()),
1941                Some(dir.join("clip-g.safetensors")),
1942                Some(clip_g_tok.clone()),
1943                Some(dir.join("t5.safetensors")),
1944                Some(t5_tok.clone()),
1945            ),
1946            false,
1947            false,
1948            None,
1949            LoadStrategy::Sequential,
1950            0,
1951            false,
1952            Some(shared_pool),
1953        );
1954
1955        let (loaded_clip_l, loaded_clip_g, loaded_t5) = engine
1956            .load_text_tokenizers(&clip_l_tok, &clip_g_tok, &t5_tok)
1957            .unwrap();
1958
1959        assert!(Arc::ptr_eq(&pooled_clip_l, &loaded_clip_l));
1960        assert!(Arc::ptr_eq(&pooled_clip_g, &loaded_clip_g));
1961        assert!(Arc::ptr_eq(&pooled_t5, &loaded_t5));
1962        fs::remove_dir_all(dir).ok();
1963    }
1964
1965    #[test]
1966    fn sd3_loads_vae_tensors_through_shared_pool() {
1967        let dir = temp_test_dir("mold-sd3-vae-pool");
1968        let vae_path = dir.join("vae.safetensors");
1969        let weight = 1.0f32.to_le_bytes();
1970        let mut tensors = HashMap::new();
1971        tensors.insert(
1972            "first_stage_model.encoder.conv_in.weight".to_string(),
1973            TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1974        );
1975        serialize_to_file(&tensors, &None, &vae_path).unwrap();
1976
1977        let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1978        let pooled = shared_pool
1979            .lock()
1980            .unwrap()
1981            .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1982            .unwrap()
1983            .unwrap();
1984
1985        let engine = SD3Engine::new(
1986            "sd3.5-large:bf16".to_string(),
1987            sd3_model_paths(
1988                dir.join("transformer.safetensors"),
1989                vae_path.clone(),
1990                Some(dir.join("clip-l.safetensors")),
1991                Some(dir.join("clip-l-tokenizer.json")),
1992                Some(dir.join("clip-g.safetensors")),
1993                Some(dir.join("clip-g-tokenizer.json")),
1994                Some(dir.join("t5.safetensors")),
1995                Some(dir.join("t5-tokenizer.json")),
1996            ),
1997            false,
1998            false,
1999            None,
2000            LoadStrategy::Sequential,
2001            0,
2002            false,
2003            Some(shared_pool),
2004        );
2005
2006        let loaded = engine.load_vae_cpu_tensors(&vae_path).unwrap().unwrap();
2007
2008        assert!(Arc::ptr_eq(&pooled, &loaded));
2009        fs::remove_dir_all(dir).ok();
2010    }
2011
2012    // -----------------------------------------------------------------------
2013    // resolve_cfg_plus precedence: explicit request field beats env var,
2014    // env var beats default-off. MOLD_CFG_PLUS is process-global so these
2015    // tests serialize via a static mutex.
2016    // -----------------------------------------------------------------------
2017
2018    fn cfg_env_lock() -> std::sync::MutexGuard<'static, ()> {
2019        use std::sync::{Mutex, OnceLock};
2020        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
2021        LOCK.get_or_init(|| Mutex::new(()))
2022            .lock()
2023            .unwrap_or_else(|p| p.into_inner())
2024    }
2025
2026    fn req_with_cfg_plus(cfg_plus: Option<bool>) -> GenerateRequest {
2027        // Build a minimal SD3-shaped request via JSON to avoid maintaining a
2028        // by-hand list of every GenerateRequest field across schema changes.
2029        let mut req: GenerateRequest = serde_json::from_str(
2030            r#"{
2031                "prompt":"x",
2032                "model":"sd3.5-large:fp16",
2033                "width":1024,
2034                "height":1024,
2035                "steps":28,
2036                "guidance":4.5
2037            }"#,
2038        )
2039        .unwrap();
2040        req.cfg_plus = cfg_plus;
2041        req
2042    }
2043
2044    #[test]
2045    fn resolve_cfg_plus_defaults_off() {
2046        let _guard = cfg_env_lock();
2047        // SAFETY: serialized via cfg_env_lock to avoid racing parallel tests.
2048        unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2049        assert!(!resolve_cfg_plus(&req_with_cfg_plus(None)));
2050    }
2051
2052    #[test]
2053    fn resolve_cfg_plus_env_enables() {
2054        let _guard = cfg_env_lock();
2055        unsafe { std::env::set_var("MOLD_CFG_PLUS", "1") };
2056        let on = resolve_cfg_plus(&req_with_cfg_plus(None));
2057        unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2058        assert!(on, "MOLD_CFG_PLUS=1 must enable cfg++");
2059    }
2060
2061    #[test]
2062    fn resolve_cfg_plus_request_field_wins_over_env() {
2063        let _guard = cfg_env_lock();
2064        // Env says on, request explicitly says off → request wins. Without
2065        // this precedence a server with a global env default could not be
2066        // overridden per-request, which is the whole point of having a
2067        // request field.
2068        unsafe { std::env::set_var("MOLD_CFG_PLUS", "1") };
2069        let off = resolve_cfg_plus(&req_with_cfg_plus(Some(false)));
2070        unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2071        assert!(!off, "explicit Some(false) must override env=on");
2072    }
2073
2074    #[test]
2075    fn resolve_cfg_plus_request_true_without_env() {
2076        let _guard = cfg_env_lock();
2077        unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2078        assert!(resolve_cfg_plus(&req_with_cfg_plus(Some(true))));
2079    }
2080}