1use anyhow::{bail, Result};
2use candle_core::{DType, Device, IndexOp, Tensor};
3use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT};
4use candle_transformers::quantized_var_builder;
5use mold_core::{GenerateRequest, GenerateResponse, ImageData, LoraWeight, ModelPaths};
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9use std::time::Instant;
10use tokenizers::Tokenizer;
11
12use crate::cache::{
13 cfg_prompt_cache_key, clear_cache, get_or_insert_cached_tensor_pair,
14 restore_cached_tensor_pair, CachedTensorPair, CfgPromptCacheKey, LruCache,
15 DEFAULT_PROMPT_CACHE_CAPACITY,
16};
17use crate::device::{
18 check_memory_budget, fmt_gb, free_vram_bytes, memory_status_string, preflight_memory_check,
19 usable_free_vram_bytes,
20};
21use crate::encoders;
22use crate::engine::{
23 rand_seed, resolve_cfg_plus, InferenceEngine, LoadStrategy, OptionRestoreGuard,
24};
25use crate::engine_base::EngineBase;
26use crate::image::{build_output_metadata, encode_image};
27use crate::img_utils;
28use crate::progress::{ProgressCallback, ProgressReporter};
29
30use super::lora as sd3_lora;
31use super::quantized_mmdit::QuantizedMMDiT;
32use super::sampling::{self, SkipLayerGuidanceConfig};
33use super::transformer::SD3Transformer;
34use super::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
35
36const ZERO_SCALE_EPS: f64 = 1e-8;
41
42pub(crate) fn effective_loras(req: &GenerateRequest) -> Vec<LoraWeight> {
46 let raw: Vec<LoraWeight> = if let Some(plural) = &req.loras {
47 if !plural.is_empty() {
48 plural.clone()
49 } else {
50 req.lora.iter().cloned().collect()
51 }
52 } else {
53 req.lora.iter().cloned().collect()
54 };
55 raw.into_iter()
56 .filter(|w| {
57 let keep = w.scale.abs() > ZERO_SCALE_EPS;
58 if !keep {
59 tracing::debug!(
60 path = w.path.as_str(),
61 scale = w.scale,
62 "dropping zero-scale LoRA from SD3 effective stack"
63 );
64 }
65 keep
66 })
67 .collect()
68}
69
70#[derive(Debug, PartialEq, Eq)]
71enum SD3OffloadDecision {
72 Disabled,
73 Selected,
74 Unsupported(&'static str),
75}
76
77fn sd3_offload_decision(
78 forced_offload: bool,
79 is_quantized: bool,
80 has_lora: bool,
81) -> SD3OffloadDecision {
82 if !forced_offload {
83 return SD3OffloadDecision::Disabled;
84 }
85 if is_quantized {
86 return SD3OffloadDecision::Unsupported(
87 "SD3 block-level offload is only planned for BF16/FP transformers; \
88 GGUF variants already use quantized transformer paths",
89 );
90 }
91 if has_lora {
92 return SD3OffloadDecision::Unsupported(
93 "SD3 block-level offload with LoRA is not wired yet; \
94 LoRA merge/cache semantics need a dedicated offload design",
95 );
96 }
97 SD3OffloadDecision::Selected
98}
99
100fn sd3_lora_var_builder<'a>(
105 transformer_path: &Path,
106 loras: &[LoraWeight],
107 dtype: DType,
108 device: &Device,
109 progress: &ProgressReporter,
110 delta_cache: Option<Arc<Mutex<sd3_lora::LoraDeltaCache>>>,
111) -> Result<candle_nn::VarBuilder<'a>> {
112 let adapters: Vec<Arc<sd3_lora::LoraAdapter>> = loras
113 .iter()
114 .map(|w| {
115 progress.info("Loading SD3 LoRA adapter");
116 let adapter = sd3_lora::get_or_load_adapter(Path::new(&w.path))?;
117 progress.info(&format!(
118 "SD3 LoRA: {} layers, rank {}, scale {:.2}",
119 adapter.layers.len(),
120 adapter.rank,
121 w.scale,
122 ));
123 anyhow::Ok(adapter)
124 })
125 .collect::<Result<_>>()?;
126
127 let specs: Vec<sd3_lora::LoraSpec<'_>> = adapters
128 .iter()
129 .zip(loras.iter())
130 .map(|(adapter, w)| sd3_lora::LoraSpec {
131 adapter: adapter.as_ref(),
132 scale: w.scale,
133 path_hash: sd3_lora::lora_path_hash(&w.path),
134 })
135 .collect();
136
137 sd3_lora::lora_var_builder(
138 transformer_path,
139 &specs,
140 dtype,
141 device,
142 progress,
143 delta_cache,
144 )
145}
146
147fn sd3_gguf_lora_var_builder(
151 transformer_path: &Path,
152 loras: &[LoraWeight],
153 device: &Device,
154 progress: &ProgressReporter,
155 delta_cache: Option<Arc<Mutex<sd3_lora::LoraDeltaCache>>>,
156) -> Result<quantized_var_builder::VarBuilder> {
157 let adapters: Vec<Arc<sd3_lora::LoraAdapter>> = loras
158 .iter()
159 .map(|w| {
160 progress.info("Loading SD3 LoRA adapter");
161 let adapter = sd3_lora::get_or_load_adapter(Path::new(&w.path))?;
162 progress.info(&format!(
163 "SD3 LoRA: {} layers, rank {}, scale {:.2}",
164 adapter.layers.len(),
165 adapter.rank,
166 w.scale,
167 ));
168 anyhow::Ok(adapter)
169 })
170 .collect::<Result<_>>()?;
171
172 let specs: Vec<sd3_lora::LoraSpec<'_>> = adapters
173 .iter()
174 .zip(loras.iter())
175 .map(|(adapter, w)| sd3_lora::LoraSpec {
176 adapter: adapter.as_ref(),
177 scale: w.scale,
178 path_hash: sd3_lora::lora_path_hash(&w.path),
179 })
180 .collect();
181
182 sd3_lora::gguf_lora_var_builder(transformer_path, &specs, device, progress, delta_cache)
183}
184
185struct LoadedSD3 {
187 transformer: Option<SD3Transformer>,
189 triple_encoder: encoders::sd3_clip::SD3TripleEncoder,
190 vae_vb_path: std::path::PathBuf,
191 device: Device,
192 dtype: DType,
193 _is_quantized: bool,
194 is_turbo: bool,
195 is_medium: bool,
196}
197
198pub struct SD3Engine {
204 base: EngineBase<LoadedSD3>,
205 is_turbo: bool,
206 is_medium: bool,
207 t5_variant: Option<String>,
208 offload: bool,
209 prompt_cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensorPair>>,
210 pending_placement: Option<mold_core::types::DevicePlacement>,
211 lora_delta_cache: Arc<Mutex<sd3_lora::LoraDeltaCache>>,
215 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
216}
217
218impl SD3Engine {
219 #[allow(clippy::too_many_arguments)]
221 pub fn new(
222 model_name: String,
223 paths: ModelPaths,
224 is_turbo: bool,
225 is_medium: bool,
226 t5_variant: Option<String>,
227 load_strategy: LoadStrategy,
228 gpu_ordinal: usize,
229 offload: bool,
230 shared_pool: Option<Arc<Mutex<crate::shared_pool::SharedPool>>>,
231 ) -> Self {
232 Self {
233 base: EngineBase::new(model_name, paths, load_strategy, gpu_ordinal),
234 is_turbo,
235 is_medium,
236 t5_variant,
237 offload,
238 prompt_cache: Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY)),
239 pending_placement: None,
240 lora_delta_cache: Arc::new(Mutex::new(sd3_lora::LoraDeltaCache::new())),
241 shared_pool,
242 }
243 }
244
245 fn load_text_tokenizers(
246 &self,
247 clip_l_tokenizer: &Path,
248 clip_g_tokenizer: &Path,
249 t5_tokenizer: &Path,
250 ) -> Result<(Arc<Tokenizer>, Arc<Tokenizer>, Arc<Tokenizer>)> {
251 if let Some(shared_pool) = &self.shared_pool {
252 let mut pool = shared_pool.lock().unwrap();
253 return Ok((
254 pool.load_tokenizer(clip_l_tokenizer)?,
255 pool.load_tokenizer(clip_g_tokenizer)?,
256 pool.load_tokenizer(t5_tokenizer)?,
257 ));
258 }
259
260 let load = |path: &Path, label: &str| {
261 Tokenizer::from_file(path)
262 .map(Arc::new)
263 .map_err(|e| anyhow::anyhow!("failed to load {label} tokenizer: {e}"))
264 };
265 Ok((
266 load(clip_l_tokenizer, "CLIP-L")?,
267 load(clip_g_tokenizer, "CLIP-G")?,
268 load(t5_tokenizer, "T5")?,
269 ))
270 }
271
272 #[cfg(test)]
273 fn load_vae_cpu_tensors(
274 &self,
275 vae_path: &Path,
276 ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
277 Self::load_vae_cpu_tensors_from_pool(self.shared_pool.as_ref(), vae_path)
278 }
279
280 fn load_vae_cpu_tensors_from_pool(
281 shared_pool: Option<&Arc<Mutex<crate::shared_pool::SharedPool>>>,
282 vae_path: &Path,
283 ) -> Result<Option<Arc<HashMap<String, Tensor>>>> {
284 let Some(shared_pool) = shared_pool else {
285 return Ok(None);
286 };
287 shared_pool
288 .lock()
289 .unwrap()
290 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
291 }
292
293 fn load_transformer_cpu_tensors(&self) -> Result<Arc<HashMap<String, Tensor>>> {
294 if let Some(shared_pool) = &self.shared_pool {
295 if let Some(tensors) = shared_pool
296 .lock()
297 .unwrap()
298 .load_safetensors_cpu_tensors(std::slice::from_ref(&self.base.paths.transformer))?
299 {
300 return Ok(tensors);
301 }
302 }
303 Ok(Arc::new(crate::encoders::park::load_tensors_to_cpu(
304 std::slice::from_ref(&self.base.paths.transformer),
305 )?))
306 }
307
308 fn load_vae_var_builder<'a>(
309 &self,
310 vae_path: &Path,
311 dtype: DType,
312 device: &Device,
313 component: &str,
314 progress: &ProgressReporter,
315 ) -> Result<candle_nn::VarBuilder<'a>> {
316 Self::load_vae_var_builder_from_pool(
317 self.shared_pool.as_ref(),
318 vae_path,
319 dtype,
320 device,
321 component,
322 progress,
323 )
324 }
325
326 fn load_vae_var_builder_from_pool<'a>(
327 shared_pool: Option<&Arc<Mutex<crate::shared_pool::SharedPool>>>,
328 vae_path: &Path,
329 dtype: DType,
330 device: &Device,
331 component: &str,
332 progress: &ProgressReporter,
333 ) -> Result<candle_nn::VarBuilder<'a>> {
334 if let Some(tensors) = Self::load_vae_cpu_tensors_from_pool(shared_pool, vae_path)? {
335 return Ok(crate::encoders::park::varbuilder_from_parked(
336 tensors.as_ref(),
337 dtype,
338 device,
339 ));
340 }
341
342 crate::weight_loader::load_safetensors_with_progress(
343 std::slice::from_ref(&vae_path),
344 dtype,
345 device,
346 component,
347 progress,
348 )
349 }
350
351 #[allow(clippy::too_many_arguments)]
352 fn encode_conditioning(
353 progress: &ProgressReporter,
354 prompt_cache: &Mutex<LruCache<CfgPromptCacheKey, CachedTensorPair>>,
355 triple_encoder: &mut encoders::sd3_clip::SD3TripleEncoder,
356 prompt: &str,
357 negative_prompt: &str,
358 guidance: f64,
359 device: &Device,
360 dtype: DType,
361 is_quantized: bool,
362 ) -> Result<(candle_core::Tensor, candle_core::Tensor)> {
363 let cache_key = cfg_prompt_cache_key(prompt, negative_prompt, guidance);
368 let ((context, y), cache_hit) = get_or_insert_cached_tensor_pair(
369 prompt_cache,
370 cache_key,
371 device,
372 if is_quantized { DType::F32 } else { dtype },
373 || {
374 progress.stage_start("Encoding prompt (SD3 triple)");
375 let encode_start = Instant::now();
376 let (context_cond, y_cond) = triple_encoder.encode(prompt, device, dtype)?;
377 let (context_uncond, y_uncond) =
378 triple_encoder.encode(negative_prompt, device, dtype)?;
379 progress.stage_done("Encoding prompt (SD3 triple)", encode_start.elapsed());
380
381 let pair = if is_quantized {
382 (
383 candle_core::Tensor::cat(&[&context_cond, &context_uncond], 0)?
384 .to_dtype(DType::F32)?,
385 candle_core::Tensor::cat(&[&y_cond, &y_uncond], 0)?.to_dtype(DType::F32)?,
386 )
387 } else {
388 (
389 candle_core::Tensor::cat(&[&context_cond, &context_uncond], 0)?,
390 candle_core::Tensor::cat(&[&y_cond, &y_uncond], 0)?,
391 )
392 };
393 Ok(pair)
394 },
395 )?;
396 if cache_hit {
397 progress.cache_hit("prompt conditioning");
398 return Ok((context, y));
399 }
400 Ok((context, y))
401 }
402
403 fn img2img_source_normalize_range() -> img_utils::NormalizeRange {
404 img_utils::NormalizeRange::MinusOneToOne
405 }
406
407 fn uses_sequential_generate_path(&self) -> bool {
408 self.base.load_strategy == LoadStrategy::Sequential || self.offload
409 }
410
411 fn detect_is_quantized(&self) -> bool {
413 self.base
414 .paths
415 .transformer
416 .extension()
417 .and_then(|e| e.to_str())
418 .map(|e| e.eq_ignore_ascii_case("gguf"))
419 .unwrap_or(false)
420 }
421
422 fn mmdit_config(&self) -> MMDiTConfig {
424 if self.is_medium {
425 MMDiTConfig::sd3_5_medium()
426 } else {
427 MMDiTConfig::sd3_5_large()
428 }
429 }
430
431 fn validate_paths(
433 &self,
434 ) -> Result<(
435 std::path::PathBuf, std::path::PathBuf, std::path::PathBuf, std::path::PathBuf, std::path::PathBuf, std::path::PathBuf, )> {
442 let clip_l_path = self
443 .base
444 .paths
445 .clip_encoder
446 .as_ref()
447 .ok_or_else(|| anyhow::anyhow!("CLIP-L encoder path required for SD3 models"))?
448 .clone();
449 let clip_l_tokenizer = self
450 .base
451 .paths
452 .clip_tokenizer
453 .as_ref()
454 .ok_or_else(|| anyhow::anyhow!("CLIP-L tokenizer path required for SD3 models"))?
455 .clone();
456 let clip_g_path = self
457 .base
458 .paths
459 .clip_encoder_2
460 .as_ref()
461 .ok_or_else(|| anyhow::anyhow!("CLIP-G encoder path required for SD3 models"))?
462 .clone();
463 let clip_g_tokenizer = self
464 .base
465 .paths
466 .clip_tokenizer_2
467 .as_ref()
468 .ok_or_else(|| anyhow::anyhow!("CLIP-G tokenizer path required for SD3 models"))?
469 .clone();
470 let t5_encoder_path = self
471 .base
472 .paths
473 .t5_encoder
474 .as_ref()
475 .ok_or_else(|| anyhow::anyhow!("T5 encoder path required for SD3 models"))?
476 .clone();
477 let t5_tokenizer_path = self
478 .base
479 .paths
480 .t5_tokenizer
481 .as_ref()
482 .ok_or_else(|| anyhow::anyhow!("T5 tokenizer path required for SD3 models"))?
483 .clone();
484
485 for (label, path) in [
486 ("transformer", &self.base.paths.transformer),
487 ("vae", &self.base.paths.vae),
488 ("clip_encoder (CLIP-L)", &clip_l_path),
489 ("clip_tokenizer (CLIP-L)", &clip_l_tokenizer),
490 ("clip_encoder_2 (CLIP-G)", &clip_g_path),
491 ("clip_tokenizer_2 (CLIP-G)", &clip_g_tokenizer),
492 ("t5_encoder", &t5_encoder_path),
493 ("t5_tokenizer", &t5_tokenizer_path),
494 ] {
495 if !path.exists() {
496 bail!("{label} file not found: {}", path.display());
497 }
498 }
499
500 Ok((
501 clip_l_path,
502 clip_l_tokenizer,
503 clip_g_path,
504 clip_g_tokenizer,
505 t5_encoder_path,
506 t5_tokenizer_path,
507 ))
508 }
509
510 pub fn load(&mut self) -> Result<()> {
516 if self.base.loaded.is_some() {
517 return Ok(());
518 }
519
520 if self.base.load_strategy == LoadStrategy::Sequential {
522 return Ok(());
523 }
524
525 tracing::info!(model = %self.base.model_name, "loading SD3 model components...");
526
527 let (
528 clip_l_path,
529 clip_l_tokenizer,
530 clip_g_path,
531 clip_g_tokenizer,
532 t5_encoder_path,
533 t5_tokenizer_path,
534 ) = self.validate_paths()?;
535
536 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
537 let gpu_dtype = if crate::device::is_gpu(&device) {
538 DType::F16
539 } else {
540 DType::F32
541 };
542
543 let is_quantized = self.detect_is_quantized();
544 let mmdit_config = self.mmdit_config();
545
546 let xformer_label = if is_quantized {
548 "Loading SD3 MMDiT transformer (GPU, quantized)"
549 } else {
550 "Loading SD3 MMDiT transformer (GPU, FP16)"
551 };
552 self.base.progress.stage_start(xformer_label);
553 let xformer_stage = Instant::now();
554
555 let transformer = if is_quantized {
556 let vb = quantized_var_builder::VarBuilder::from_gguf(
558 &self.base.paths.transformer,
559 &device,
560 )?;
561 SD3Transformer::Quantized(QuantizedMMDiT::new(&mmdit_config, vb)?)
562 } else {
563 let vb = crate::weight_loader::load_safetensors_with_progress(
565 std::slice::from_ref(&self.base.paths.transformer),
566 gpu_dtype,
567 &device,
568 "SD3 transformer",
569 &self.base.progress,
570 )?;
571 SD3Transformer::BF16(MMDiT::new(
572 &mmdit_config,
573 false,
574 vb.pp("model.diffusion_model"),
575 )?)
576 };
577 self.base
578 .progress
579 .stage_done(xformer_label, xformer_stage.elapsed());
580
581 let free_raw = free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
585 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
586 if free_raw > 0 {
587 self.base.progress.info(&format!(
588 "Free VRAM after transformer: {}",
589 fmt_gb(free_raw)
590 ));
591 }
592
593 self.base.progress.stage_start("Selecting T5 encoder");
596 let t5_resolve_start = Instant::now();
597 let t5_preference = self.t5_variant.as_deref();
598 let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
599 crate::encoders::variant_resolution::resolve_t5_variant(
600 &self.base.progress,
601 t5_preference,
602 &device,
603 free,
604 &t5_encoder_path,
605 )?;
606 self.base
607 .progress
608 .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
609
610 let tier1 = self
612 .pending_placement
613 .as_ref()
614 .map(|p| p.text_encoders)
615 .unwrap_or_default();
616 let auto_encoder_device = if t5_on_gpu {
617 device.clone()
618 } else {
619 Device::Cpu
620 };
621 let encoder_device_owned =
622 crate::device::resolve_device(Some(tier1), || Ok(auto_encoder_device.clone()))?;
623 let encoder_device = &encoder_device_owned;
624 let t5_on_gpu = !encoder_device.is_cpu();
625 let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
626 let encoder_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
627
628 let encoder_label = format!("Loading SD3 triple encoder ({t5_device_label})");
629 self.base.progress.stage_start(&encoder_label);
630 let encoder_stage = Instant::now();
631 let (clip_l_tokenizer_handle, clip_g_tokenizer_handle, t5_tokenizer_handle) =
632 self.load_text_tokenizers(&clip_l_tokenizer, &clip_g_tokenizer, &t5_tokenizer_path)?;
633
634 let triple_encoder = encoders::sd3_clip::SD3TripleEncoder::load_with_tokenizers(
635 &clip_l_path,
636 &clip_l_tokenizer,
637 Some(clip_l_tokenizer_handle),
638 &clip_g_path,
639 &clip_g_tokenizer,
640 Some(clip_g_tokenizer_handle),
641 &resolved_t5_path,
642 &t5_tokenizer_path,
643 Some(t5_tokenizer_handle),
644 encoder_device,
645 encoder_dtype,
646 &self.base.progress,
647 )?;
648
649 self.base
650 .progress
651 .stage_done(&encoder_label, encoder_stage.elapsed());
652
653 self.base.loaded = Some(LoadedSD3 {
654 transformer: Some(transformer),
655 triple_encoder,
656 vae_vb_path: self.base.paths.vae.clone(),
657 device,
658 dtype: gpu_dtype,
659 _is_quantized: is_quantized,
660 is_turbo: self.is_turbo,
661 is_medium: self.is_medium,
662 });
663
664 tracing::info!(model = %self.base.model_name, "all SD3 model components loaded successfully");
665 Ok(())
666 }
667
668 fn slg_config(&self) -> Option<SkipLayerGuidanceConfig> {
670 if self.is_medium {
671 Some(SkipLayerGuidanceConfig {
672 scale: 2.5,
673 start: 0.01,
674 end: 0.2,
675 layers: vec![7, 8, 9],
676 })
677 } else {
678 None
679 }
680 }
681
682 fn generate_sequential(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
684 let is_quantized = self.detect_is_quantized();
685 let active_loras = effective_loras(req);
686 match sd3_offload_decision(self.offload, is_quantized, !active_loras.is_empty()) {
687 SD3OffloadDecision::Disabled => {}
688 SD3OffloadDecision::Unsupported(reason) => bail!("{reason}"),
689 SD3OffloadDecision::Selected => {}
690 }
691
692 let (
693 clip_l_path,
694 clip_l_tokenizer,
695 clip_g_path,
696 clip_g_tokenizer,
697 t5_encoder_path,
698 t5_tokenizer_path,
699 ) = self.validate_paths()?;
700
701 if let Some(warning) = check_memory_budget(&self.base.paths, LoadStrategy::Sequential) {
702 self.base.progress.info(&warning);
703 }
704
705 let device = crate::device::create_device(self.base.gpu_ordinal, &self.base.progress)?;
706 let gpu_dtype = if crate::device::is_gpu(&device) {
707 DType::F16
708 } else {
709 DType::F32
710 };
711
712 let start = Instant::now();
713 let seed = req.seed.unwrap_or_else(rand_seed);
714
715 let width = req.width as usize;
716 let height = req.height as usize;
717
718 tracing::info!(
719 prompt = %req.prompt,
720 seed, width, height,
721 steps = req.steps,
722 guidance = req.guidance,
723 "starting sequential SD3 generation"
724 );
725
726 self.base
727 .progress
728 .info("Using sequential loading (load-use-drop) to minimize peak memory");
729
730 let neg = req.negative_prompt.as_deref().unwrap_or("");
732 let cache_key = cfg_prompt_cache_key(&req.prompt, neg, req.guidance);
733 let (context, y) = if let Some((context, y)) =
734 restore_cached_tensor_pair(&self.prompt_cache, &cache_key, &device, gpu_dtype)?
735 {
736 self.base.progress.cache_hit("prompt conditioning");
737 (context, y)
738 } else {
739 let free = usable_free_vram_bytes(self.base.gpu_ordinal).unwrap_or(0);
741 self.base.progress.stage_start("Selecting T5 encoder");
742 let t5_resolve_start = Instant::now();
743 let t5_preference = self.t5_variant.as_deref();
744 let (resolved_t5_path, t5_on_gpu, _t5_auto_device_label) =
745 crate::encoders::variant_resolution::resolve_t5_variant(
746 &self.base.progress,
747 t5_preference,
748 &device,
749 free,
750 &t5_encoder_path,
751 )?;
752 self.base
753 .progress
754 .stage_done("Selecting T5 encoder", t5_resolve_start.elapsed());
755
756 let tier1 = self
757 .pending_placement
758 .as_ref()
759 .map(|p| p.text_encoders)
760 .unwrap_or_default();
761 let auto_encoder_device = if t5_on_gpu {
762 device.clone()
763 } else {
764 Device::Cpu
765 };
766 let encoder_device_owned =
767 crate::device::resolve_device(Some(tier1), || Ok(auto_encoder_device.clone()))?;
768 let encoder_device = &encoder_device_owned;
769 let t5_on_gpu = !encoder_device.is_cpu();
770 let t5_device_label = if t5_on_gpu { "GPU" } else { "CPU" };
771 let encoder_dtype = if t5_on_gpu { gpu_dtype } else { DType::F32 };
772
773 let t5_size = std::fs::metadata(&resolved_t5_path)
774 .map(|m| m.len())
775 .unwrap_or(0);
776 let te_activation_budget = crate::device::activation_bytes(
777 req.width,
778 req.height,
779 1,
780 crate::device::dtype_bytes(encoder_dtype),
781 crate::device::ActivationFamily::SmallTransformer,
782 );
783 preflight_memory_check("SD3 triple encoder", t5_size, te_activation_budget)?;
784 if let Some(status) = memory_status_string() {
785 self.base.progress.info(&status);
786 }
787
788 let encoder_label = format!("Loading SD3 triple encoder ({t5_device_label})");
789 self.base.progress.stage_start(&encoder_label);
790 let encoder_stage = Instant::now();
791 let (clip_l_tokenizer_handle, clip_g_tokenizer_handle, t5_tokenizer_handle) = self
792 .load_text_tokenizers(&clip_l_tokenizer, &clip_g_tokenizer, &t5_tokenizer_path)?;
793 let mut triple_encoder = encoders::sd3_clip::SD3TripleEncoder::load_with_tokenizers(
794 &clip_l_path,
795 &clip_l_tokenizer,
796 Some(clip_l_tokenizer_handle),
797 &clip_g_path,
798 &clip_g_tokenizer,
799 Some(clip_g_tokenizer_handle),
800 &resolved_t5_path,
801 &t5_tokenizer_path,
802 Some(t5_tokenizer_handle),
803 encoder_device,
804 encoder_dtype,
805 &self.base.progress,
806 )?;
807 self.base
808 .progress
809 .stage_done(&encoder_label, encoder_stage.elapsed());
810
811 let (context, y) = Self::encode_conditioning(
812 &self.base.progress,
813 &self.prompt_cache,
814 &mut triple_encoder,
815 &req.prompt,
816 neg,
817 req.guidance,
818 &device,
819 gpu_dtype,
820 is_quantized,
821 )?;
822
823 drop(triple_encoder);
824 self.base.progress.info("Freed SD3 triple encoder");
825
826 (context, y)
827 };
828
829 let noise_dtype = if is_quantized { DType::F32 } else { gpu_dtype };
831 let latent_h = height / 16 * 2;
832 let latent_w = width / 16 * 2;
833 let time_shift = 3.0;
834
835 let num_steps = req.steps as usize;
837 let mut sigmas: Vec<f64> = (0..=num_steps)
838 .map(|s| s as f64 / num_steps as f64)
839 .rev()
840 .map(|t| sampling::time_snr_shift(time_shift, t))
841 .collect();
842
843 if req.source_image.is_some() {
844 let (trimmed, start_index) =
845 crate::img2img::trim_schedule_tail(&sigmas, req.steps as usize, req.strength);
846 sigmas = trimmed;
847 tracing::info!(
848 strength = req.strength,
849 start_index,
850 start_sigma = sigmas[0],
851 schedule = ?sigmas,
852 remaining_steps = sigmas.len().saturating_sub(1),
853 "img2img: truncated schedule from strength"
854 );
855 }
856
857 let (initial_latents, inpaint_ctx) = if let Some(ref source_bytes) = req.source_image {
858 let start_t = sigmas[0];
859
860 self.base.progress.stage_start("Loading VAE for encoding");
862 let vae_stage = Instant::now();
863 let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
864 let vae_vb = self.load_vae_var_builder(
865 &self.base.paths.vae,
866 vae_dtype,
867 &device,
868 "VAE",
869 &self.base.progress,
870 )?;
871 let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
872 let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
873 self.base
874 .progress
875 .stage_done("Loading VAE for encoding", vae_stage.elapsed());
876
877 self.base
878 .progress
879 .stage_start("Encoding source image (VAE)");
880 let encode_start = Instant::now();
881 let source_tensor = img_utils::decode_source_image(
882 source_bytes,
883 req.width,
884 req.height,
885 Self::img2img_source_normalize_range(),
886 &device,
887 vae_dtype,
888 )?;
889 let dist = autoencoder.encode(&source_tensor)?;
890 let encoded = ((dist.mode()? - 0.0609)? * 1.5305)?;
893 self.base
894 .progress
895 .stage_done("Encoding source image (VAE)", encode_start.elapsed());
896
897 drop(autoencoder);
899 device.synchronize()?;
900 self.base
901 .progress
902 .info("Freed VAE encoder to make room for transformer");
903
904 let encoded = encoded.to_dtype(noise_dtype)?;
905 let prepared = crate::img2img::prepare_flow_match_img2img(
906 &encoded,
907 seed,
908 &[1, 16, latent_h, latent_w],
909 start_t,
910 req.mask_image.as_deref(),
911 latent_h,
912 latent_w,
913 &device,
914 noise_dtype,
915 )?;
916 (Some(prepared.initial_latents), prepared.inpaint_ctx)
917 } else {
918 (None, None)
919 };
920
921 let mmdit_config = self.mmdit_config();
923
924 let xformer_size = if self.offload && !is_quantized && active_loras.is_empty() {
925 0
926 } else {
927 std::fs::metadata(&self.base.paths.transformer)
928 .map(|m| m.len())
929 .unwrap_or(0)
930 };
931 let xformer_batch = if req.guidance > 1.0 { 2 } else { 1 };
933 let xformer_activation_budget = crate::device::activation_bytes(
934 req.width,
935 req.height,
936 xformer_batch,
937 crate::device::dtype_bytes(gpu_dtype),
938 crate::device::ActivationFamily::Sd3Mmdit,
939 );
940 preflight_memory_check(
941 "SD3 MMDiT transformer",
942 xformer_size,
943 xformer_activation_budget,
944 )?;
945 if let Some(status) = memory_status_string() {
946 self.base.progress.info(&status);
947 }
948
949 let active_loras = effective_loras(req);
950 let lora_delta_cache = self.lora_delta_cache.clone();
951 let xformer_label = match (is_quantized, active_loras.is_empty(), self.offload) {
952 (true, true, _) => "Loading SD3 MMDiT transformer (GPU, quantized)",
953 (true, false, _) => "Loading SD3 MMDiT transformer (GPU, quantized, with LoRA)",
954 (false, true, true) => "Loading SD3 MMDiT transformer (offload, FP16)",
955 (false, true, false) => "Loading SD3 MMDiT transformer (GPU, FP16)",
956 (false, false, _) => "Loading SD3 MMDiT transformer (GPU, FP16, with LoRA)",
957 };
958 self.base.progress.stage_start(xformer_label);
959 let xformer_stage = Instant::now();
960
961 let transformer = if is_quantized {
962 let vb = if active_loras.is_empty() {
963 quantized_var_builder::VarBuilder::from_gguf(&self.base.paths.transformer, &device)?
964 } else {
965 sd3_gguf_lora_var_builder(
966 &self.base.paths.transformer,
967 &active_loras,
968 &device,
969 &self.base.progress,
970 Some(lora_delta_cache.clone()),
971 )?
972 };
973 SD3Transformer::Quantized(QuantizedMMDiT::new(&mmdit_config, vb)?)
974 } else if active_loras.is_empty() && self.offload {
975 let tensors = self.load_transformer_cpu_tensors()?;
976 SD3Transformer::Offloaded(Box::new(super::offload::OffloadedMMDiT::new(
977 &mmdit_config,
978 tensors,
979 gpu_dtype,
980 &device,
981 )?))
982 } else if active_loras.is_empty() {
983 let vb = crate::weight_loader::load_safetensors_with_progress(
985 std::slice::from_ref(&self.base.paths.transformer),
986 gpu_dtype,
987 &device,
988 "SD3 transformer",
989 &self.base.progress,
990 )?;
991 SD3Transformer::BF16(MMDiT::new(
992 &mmdit_config,
993 false,
994 vb.pp("model.diffusion_model"),
995 )?)
996 } else {
997 let vb = sd3_lora_var_builder(
1001 &self.base.paths.transformer,
1002 &active_loras,
1003 gpu_dtype,
1004 &device,
1005 &self.base.progress,
1006 Some(lora_delta_cache.clone()),
1007 )?;
1008 SD3Transformer::BF16(MMDiT::new(&mmdit_config, false, vb)?)
1009 };
1010 self.base
1011 .progress
1012 .stage_done(xformer_label, xformer_stage.elapsed());
1013
1014 let slg_config = self.slg_config();
1016 let actual_steps = sigmas.len().saturating_sub(1);
1017 let denoise_label = format!("Denoising ({actual_steps} steps)");
1018 self.base.progress.stage_start(&denoise_label);
1019 let denoise_start = Instant::now();
1020
1021 let x = sampling::euler_sample(
1022 &transformer,
1023 &y,
1024 &context,
1025 num_steps,
1026 req.guidance,
1027 resolve_cfg_plus(req),
1028 time_shift,
1029 height,
1030 width,
1031 slg_config.as_ref(),
1032 is_quantized,
1033 seed,
1034 &self.base.progress,
1035 initial_latents.as_ref(),
1036 Some(sigmas),
1037 inpaint_ctx.as_ref(),
1038 )?;
1039
1040 self.base
1041 .progress
1042 .stage_done(&denoise_label, denoise_start.elapsed());
1043
1044 drop(transformer);
1046 drop(context);
1047 drop(y);
1048 drop(inpaint_ctx);
1049 device.synchronize()?;
1050 self.base.progress.info("Freed SD3 MMDiT transformer");
1051
1052 self.base.progress.stage_start("Loading VAE (GPU)");
1054 let vae_stage = Instant::now();
1055 let vae_dtype = crate::device::resolve_vae_dtype(gpu_dtype);
1056 let vae_vb = self.load_vae_var_builder(
1057 &self.base.paths.vae,
1058 vae_dtype,
1059 &device,
1060 "VAE",
1061 &self.base.progress,
1062 )?;
1063 let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
1064 let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
1065 self.base
1066 .progress
1067 .stage_done("Loading VAE (GPU)", vae_stage.elapsed());
1068
1069 self.base.progress.stage_start("VAE decode");
1070 let vae_decode_start = Instant::now();
1071
1072 let x = ((x / 1.5305)? + 0.0609)?.to_dtype(vae_dtype)?;
1076 let device_for_sync = device.clone();
1077 let img = crate::vae_tiling::decode_with_oom_fallback(
1078 &x,
1079 |t| autoencoder.decode(t).map_err(Into::into),
1080 || {
1081 if let Err(e) = device_for_sync.synchronize() {
1082 tracing::warn!(
1083 "SD3 (sequential) device.synchronize() after VAE OOM failed: {e}"
1084 );
1085 }
1086 },
1087 )?;
1088
1089 let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1090 let img = img.i(0)?;
1091
1092 self.base
1093 .progress
1094 .stage_done("VAE decode", vae_decode_start.elapsed());
1095
1096 let output_metadata = build_output_metadata(req, seed, None);
1097 let image_bytes = encode_image(
1098 &img,
1099 req.resolved_output_format(),
1100 req.width,
1101 req.height,
1102 output_metadata.as_ref(),
1103 )?;
1104
1105 let generation_time_ms = start.elapsed().as_millis() as u64;
1106 tracing::info!(
1107 generation_time_ms,
1108 seed,
1109 "sequential SD3 generation complete"
1110 );
1111
1112 Ok(GenerateResponse {
1113 images: vec![ImageData {
1114 data: image_bytes,
1115 format: req.resolved_output_format(),
1116 width: req.width,
1117 height: req.height,
1118 index: 0,
1119 }],
1120 generation_time_ms,
1121 model: req.model.clone(),
1122 seed_used: seed,
1123 video: None,
1124 gpu: None,
1125 })
1126 }
1127}
1128
1129impl SD3Engine {
1130 fn generate_inner(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1131 if req.scheduler.is_some() {
1132 tracing::warn!("scheduler selection not supported for SD3 (flow-matching), ignoring");
1133 }
1134
1135 if self.uses_sequential_generate_path() {
1137 return self.generate_sequential(req);
1138 }
1139
1140 let progress = &self.base.progress;
1142 let prompt_cache = &self.prompt_cache;
1143 let mmdit_config = self.mmdit_config();
1144 let transformer_path = self.base.paths.transformer.clone();
1145 let active_loras = effective_loras(req);
1146 let lora_delta_cache = self.lora_delta_cache.clone();
1147 let shared_pool = self.shared_pool.clone();
1148
1149 let mut loaded = OptionRestoreGuard::take(&mut self.base.loaded)
1150 .ok_or_else(|| anyhow::anyhow!("model not loaded -- call load() first"))?;
1151 let loaded_dtype = loaded.dtype;
1152 let loaded_device = loaded.device.clone();
1153 let is_quantized = loaded._is_quantized;
1154
1155 if !active_loras.is_empty() && loaded.transformer.is_some() {
1162 loaded.transformer = None;
1163 loaded_device.synchronize()?;
1164 progress.info("SD3 LoRA: dropping base transformer for LoRA merge");
1165 }
1166
1167 let start = Instant::now();
1168 let seed = req.seed.unwrap_or_else(rand_seed);
1169
1170 let width = req.width as usize;
1171 let height = req.height as usize;
1172
1173 tracing::info!(
1174 prompt = %req.prompt,
1175 seed, width, height,
1176 steps = req.steps,
1177 guidance = req.guidance,
1178 turbo = loaded.is_turbo,
1179 medium = loaded.is_medium,
1180 "starting SD3 generation"
1181 );
1182
1183 (|| -> Result<GenerateResponse> {
1184 if !loaded.triple_encoder.is_loaded() {
1185 let label = if loaded.triple_encoder.is_parked() {
1186 "Unparking SD3 triple encoder (CPU→GPU)"
1187 } else {
1188 "Reloading SD3 triple encoder"
1189 };
1190 progress.stage_start(label);
1191 let reload_start = Instant::now();
1192 if loaded.triple_encoder.is_parked() {
1193 loaded
1194 .triple_encoder
1195 .unpark_to_gpu(loaded_dtype, progress)?;
1196 } else {
1197 loaded.triple_encoder.reload(loaded_dtype, progress)?;
1198 }
1199 progress.stage_done(label, reload_start.elapsed());
1200 }
1201
1202 let neg = req.negative_prompt.as_deref().unwrap_or("");
1203 let (context, y) = Self::encode_conditioning(
1204 progress,
1205 prompt_cache,
1206 &mut loaded.triple_encoder,
1207 &req.prompt,
1208 neg,
1209 req.guidance,
1210 &loaded_device,
1211 loaded_dtype,
1212 is_quantized,
1213 )?;
1214
1215 if loaded.triple_encoder.on_gpu {
1216 let park_mode = crate::device::keep_te_in_ram() && !loaded_device.is_metal();
1219 if park_mode {
1220 loaded.triple_encoder.park_to_cpu()?;
1221 tracing::info!("SD3 triple encoder parked to CPU host RAM");
1222 } else {
1223 loaded.triple_encoder.drop_weights();
1224 tracing::info!(
1225 "SD3 triple encoder dropped from GPU to free VRAM for denoising"
1226 );
1227 }
1228 }
1229
1230 let noise_dtype = if is_quantized {
1232 DType::F32
1233 } else {
1234 loaded_dtype
1235 };
1236 let latent_h = height / 16 * 2;
1237 let latent_w = width / 16 * 2;
1238 let time_shift = 3.0;
1239 let num_steps = req.steps as usize;
1240
1241 let mut sigmas: Vec<f64> = (0..=num_steps)
1242 .map(|s| s as f64 / num_steps as f64)
1243 .rev()
1244 .map(|t| sampling::time_snr_shift(time_shift, t))
1245 .collect();
1246
1247 if req.source_image.is_some() {
1248 let (trimmed, start_index) =
1249 crate::img2img::trim_schedule_tail(&sigmas, req.steps as usize, req.strength);
1250 sigmas = trimmed;
1251 tracing::info!(
1252 strength = req.strength,
1253 start_index,
1254 start_sigma = sigmas[0],
1255 schedule = ?sigmas,
1256 remaining_steps = sigmas.len().saturating_sub(1),
1257 "img2img: truncated schedule from strength"
1258 );
1259 }
1260
1261 let (initial_latents, inpaint_ctx, early_vae) =
1262 if let Some(ref source_bytes) = req.source_image {
1263 let start_t = sigmas[0];
1264
1265 loaded.transformer = None;
1267 loaded.device.synchronize()?;
1268
1269 progress.stage_start("Loading VAE for encoding");
1270 let vae_stage = Instant::now();
1271 let vae_dtype = crate::device::resolve_vae_dtype(loaded_dtype);
1272 let vae_vb = Self::load_vae_var_builder_from_pool(
1273 shared_pool.as_ref(),
1274 &loaded.vae_vb_path,
1275 vae_dtype,
1276 &loaded.device,
1277 "VAE",
1278 progress,
1279 )?;
1280 let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
1281 let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
1282 progress.stage_done("Loading VAE for encoding", vae_stage.elapsed());
1283
1284 progress.stage_start("Encoding source image (VAE)");
1285 let encode_start = Instant::now();
1286 let source_tensor = img_utils::decode_source_image(
1287 source_bytes,
1288 req.width,
1289 req.height,
1290 Self::img2img_source_normalize_range(),
1291 &loaded_device,
1292 vae_dtype,
1293 )?;
1294 let dist = autoencoder.encode(&source_tensor)?;
1295 let encoded = ((dist.mode()? - 0.0609)? * 1.5305)?;
1298 progress.stage_done("Encoding source image (VAE)", encode_start.elapsed());
1299
1300 drop(autoencoder);
1302 loaded.device.synchronize()?;
1303
1304 let encoded = encoded.to_dtype(noise_dtype)?;
1305 let prepared = crate::img2img::prepare_flow_match_img2img(
1306 &encoded,
1307 seed,
1308 &[1, 16, latent_h, latent_w],
1309 start_t,
1310 req.mask_image.as_deref(),
1311 latent_h,
1312 latent_w,
1313 &loaded_device,
1314 noise_dtype,
1315 )?;
1316 (
1317 Some(prepared.initial_latents),
1318 prepared.inpaint_ctx,
1319 None::<()>,
1320 )
1321 } else {
1322 (None, None, None)
1323 };
1324
1325 if loaded.transformer.is_none() {
1327 let reload_label = if active_loras.is_empty() {
1328 "Reloading SD3 transformer"
1329 } else {
1330 "Reloading SD3 transformer (with LoRA)"
1331 };
1332 progress.stage_start(reload_label);
1333 let reload_start = Instant::now();
1334 let transformer = if is_quantized {
1335 let vb = if active_loras.is_empty() {
1336 quantized_var_builder::VarBuilder::from_gguf(
1337 &transformer_path,
1338 &loaded_device,
1339 )?
1340 } else {
1341 sd3_gguf_lora_var_builder(
1342 &transformer_path,
1343 &active_loras,
1344 &loaded_device,
1345 progress,
1346 Some(lora_delta_cache.clone()),
1347 )?
1348 };
1349 SD3Transformer::Quantized(QuantizedMMDiT::new(&mmdit_config, vb)?)
1350 } else if active_loras.is_empty() {
1351 let vb = crate::weight_loader::load_safetensors_with_progress(
1352 std::slice::from_ref(&transformer_path),
1353 loaded_dtype,
1354 &loaded_device,
1355 "SD3 transformer",
1356 progress,
1357 )?;
1358 let vb = vb.pp("model.diffusion_model");
1359 SD3Transformer::BF16(MMDiT::new(&mmdit_config, false, vb)?)
1360 } else {
1361 let vb = sd3_lora_var_builder(
1365 &transformer_path,
1366 &active_loras,
1367 loaded_dtype,
1368 &loaded_device,
1369 progress,
1370 Some(lora_delta_cache.clone()),
1371 )?;
1372 SD3Transformer::BF16(MMDiT::new(&mmdit_config, false, vb)?)
1373 };
1374 loaded.transformer = Some(transformer);
1375 progress.stage_done(reload_label, reload_start.elapsed());
1376 }
1377
1378 let slg_config = if loaded.is_medium {
1379 Some(SkipLayerGuidanceConfig {
1380 scale: 2.5,
1381 start: 0.01,
1382 end: 0.2,
1383 layers: vec![7, 8, 9],
1384 })
1385 } else {
1386 None
1387 };
1388
1389 let actual_steps = sigmas.len().saturating_sub(1);
1390 let denoise_label = format!("Denoising ({actual_steps} steps)");
1391 progress.stage_start(&denoise_label);
1392 let denoise_start = Instant::now();
1393
1394 let transformer = loaded
1395 .transformer
1396 .as_ref()
1397 .ok_or_else(|| anyhow::anyhow!("SD3 transformer not loaded"))?;
1398 let x = sampling::euler_sample(
1399 transformer,
1400 &y,
1401 &context,
1402 num_steps,
1403 req.guidance,
1404 resolve_cfg_plus(req),
1405 time_shift,
1406 height,
1407 width,
1408 slg_config.as_ref(),
1409 loaded._is_quantized,
1410 seed,
1411 progress,
1412 initial_latents.as_ref(),
1413 Some(sigmas),
1414 inpaint_ctx.as_ref(),
1415 )?;
1416
1417 progress.stage_done(&denoise_label, denoise_start.elapsed());
1418 drop(context);
1419 drop(y);
1420 drop(inpaint_ctx);
1421 let _ = early_vae;
1422
1423 loaded.transformer = None;
1425 loaded.device.synchronize()?;
1426 tracing::info!("SD3 transformer dropped to free VRAM for VAE decode");
1427
1428 progress.stage_start("VAE decode");
1429 let vae_decode_start = Instant::now();
1430
1431 let vae_dtype = crate::device::resolve_vae_dtype(loaded.dtype);
1432 let vae_vb = Self::load_vae_var_builder_from_pool(
1433 shared_pool.as_ref(),
1434 &loaded.vae_vb_path,
1435 vae_dtype,
1436 &loaded.device,
1437 "VAE",
1438 progress,
1439 )?;
1440 let vae_vb = vae_vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
1441 let autoencoder = build_sd3_vae_autoencoder(vae_vb)?;
1442
1443 let x = ((x / 1.5305)? + 0.0609)?.to_dtype(vae_dtype)?;
1444 let device_for_sync = loaded.device.clone();
1445 let img = crate::vae_tiling::decode_with_oom_fallback(
1446 &x,
1447 |t| autoencoder.decode(t).map_err(Into::into),
1448 || {
1449 if let Err(e) = device_for_sync.synchronize() {
1450 tracing::warn!(
1451 "SD3 (parallel) device.synchronize() after VAE OOM failed: {e}"
1452 );
1453 }
1454 },
1455 )?;
1456
1457 let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(DType::U8)?;
1458 let img = img.i(0)?;
1459
1460 progress.stage_done("VAE decode", vae_decode_start.elapsed());
1461
1462 let output_metadata = build_output_metadata(req, seed, None);
1463 let image_bytes = encode_image(
1464 &img,
1465 req.resolved_output_format(),
1466 req.width,
1467 req.height,
1468 output_metadata.as_ref(),
1469 )?;
1470
1471 let generation_time_ms = start.elapsed().as_millis() as u64;
1472 tracing::info!(generation_time_ms, seed, "SD3 generation complete");
1473
1474 Ok(GenerateResponse {
1475 images: vec![ImageData {
1476 data: image_bytes,
1477 format: req.resolved_output_format(),
1478 width: req.width,
1479 height: req.height,
1480 index: 0,
1481 }],
1482 generation_time_ms,
1483 model: req.model.clone(),
1484 seed_used: seed,
1485 video: None,
1486 gpu: None,
1487 })
1488 })()
1489 }
1490}
1491
1492impl InferenceEngine for SD3Engine {
1493 fn generate(&mut self, req: &GenerateRequest) -> Result<GenerateResponse> {
1494 self.pending_placement = req.placement.clone();
1495 let result = self.generate_inner(req);
1496 self.pending_placement = None;
1497 result
1498 }
1499
1500 fn model_name(&self) -> &str {
1501 self.base.model_name()
1502 }
1503
1504 fn is_loaded(&self) -> bool {
1505 self.base.is_loaded()
1506 }
1507
1508 fn load(&mut self) -> Result<()> {
1509 SD3Engine::load(self)
1510 }
1511
1512 fn unload(&mut self) {
1513 self.base.unload();
1514 clear_cache(&self.prompt_cache);
1515 }
1516
1517 fn set_on_progress(&mut self, callback: ProgressCallback) {
1518 self.base.set_on_progress(callback);
1519 }
1520
1521 fn clear_on_progress(&mut self) {
1522 self.base.clear_on_progress();
1523 }
1524
1525 fn model_paths(&self) -> Option<&mold_core::ModelPaths> {
1526 Some(&self.base.paths)
1527 }
1528}
1529
1530#[cfg(test)]
1531mod tests {
1532 use super::*;
1533 use crate::engine::LoadStrategy;
1534 use crate::shared_pool::SharedPool;
1535 use mold_core::ModelPaths;
1536 use safetensors::tensor::{serialize_to_file, Dtype as SafeDtype, TensorView};
1537 use std::collections::HashMap;
1538 use std::fs;
1539 use std::path::{Path, PathBuf};
1540 use std::sync::{Arc, Mutex};
1541 use std::time::{SystemTime, UNIX_EPOCH};
1542 use tokenizers::models::bpe::BPE;
1543
1544 fn temp_test_dir(prefix: &str) -> PathBuf {
1545 let suffix = SystemTime::now()
1546 .duration_since(UNIX_EPOCH)
1547 .unwrap()
1548 .as_nanos();
1549 let dir = std::env::temp_dir().join(format!("{prefix}-{}-{suffix}", std::process::id()));
1550 fs::create_dir_all(&dir).unwrap();
1551 dir
1552 }
1553
1554 fn touch(dir: &Path, name: &str) -> PathBuf {
1555 let path = dir.join(name);
1556 fs::write(&path, b"test").unwrap();
1557 path
1558 }
1559
1560 #[allow(clippy::too_many_arguments)]
1561 fn sd3_model_paths(
1562 transformer: PathBuf,
1563 vae: PathBuf,
1564 clip_l_path: Option<PathBuf>,
1565 clip_l_tokenizer: Option<PathBuf>,
1566 clip_g_path: Option<PathBuf>,
1567 clip_g_tokenizer: Option<PathBuf>,
1568 t5_encoder: Option<PathBuf>,
1569 t5_tokenizer: Option<PathBuf>,
1570 ) -> ModelPaths {
1571 ModelPaths {
1572 transformer,
1573 transformer_shards: vec![],
1574 vae,
1575 spatial_upscaler: None,
1576 temporal_upscaler: None,
1577 distilled_lora: None,
1578 t5_encoder,
1579 clip_encoder: clip_l_path,
1580 t5_tokenizer,
1581 clip_tokenizer: clip_l_tokenizer,
1582 clip_encoder_2: clip_g_path,
1583 clip_tokenizer_2: clip_g_tokenizer,
1584 text_encoder_files: vec![],
1585 text_tokenizer: None,
1586 decoder: None,
1587 }
1588 }
1589
1590 #[test]
1591 fn sd3_img2img_uses_minus_one_to_one_source_normalization() {
1592 assert_eq!(
1593 SD3Engine::img2img_source_normalize_range(),
1594 img_utils::NormalizeRange::MinusOneToOne
1595 );
1596 }
1597
1598 #[test]
1599 fn sd3_mmdit_config_tracks_large_vs_medium_variants() {
1600 let base_dir = temp_test_dir("mold-sd3-config");
1601 let large = SD3Engine::new(
1602 "sd3.5-large:bf16".to_string(),
1603 sd3_model_paths(
1604 base_dir.join("transformer.safetensors"),
1605 base_dir.join("vae.safetensors"),
1606 None,
1607 None,
1608 None,
1609 None,
1610 None,
1611 None,
1612 ),
1613 false,
1614 false,
1615 None,
1616 LoadStrategy::Sequential,
1617 0,
1618 false,
1619 None,
1620 );
1621 let medium = SD3Engine::new(
1622 "sd3.5-medium:bf16".to_string(),
1623 sd3_model_paths(
1624 base_dir.join("transformer.safetensors"),
1625 base_dir.join("vae.safetensors"),
1626 None,
1627 None,
1628 None,
1629 None,
1630 None,
1631 None,
1632 ),
1633 false,
1634 true,
1635 None,
1636 LoadStrategy::Sequential,
1637 0,
1638 false,
1639 None,
1640 );
1641
1642 let large_cfg = large.mmdit_config();
1643 let medium_cfg = medium.mmdit_config();
1644
1645 assert_eq!(large_cfg.depth, 38);
1646 assert_eq!(large_cfg.pos_embed_max_size, 192);
1647 assert_eq!(medium_cfg.depth, 24);
1648 assert_eq!(medium_cfg.pos_embed_max_size, 384);
1649 assert!(large.slg_config().is_none());
1650 let slg = medium.slg_config().unwrap();
1651 assert_eq!(slg.scale, 2.5);
1652 assert_eq!(slg.layers, vec![7, 8, 9]);
1653
1654 fs::remove_dir_all(base_dir).ok();
1655 }
1656
1657 #[test]
1658 fn sd3_validate_paths_accepts_existing_files() {
1659 let dir = temp_test_dir("mold-sd3-validate-ok");
1660 let transformer = touch(&dir, "transformer.gguf");
1661 let vae = touch(&dir, "vae.safetensors");
1662 let clip_l = touch(&dir, "clip-l.safetensors");
1663 let clip_l_tok = touch(&dir, "clip-l-tokenizer.json");
1664 let clip_g = touch(&dir, "clip-g.safetensors");
1665 let clip_g_tok = touch(&dir, "clip-g-tokenizer.json");
1666 let t5 = touch(&dir, "t5.safetensors");
1667 let t5_tok = touch(&dir, "t5-tokenizer.json");
1668
1669 let engine = SD3Engine::new(
1670 "sd3.5-large-turbo:q8".to_string(),
1671 sd3_model_paths(
1672 transformer,
1673 vae,
1674 Some(clip_l),
1675 Some(clip_l_tok),
1676 Some(clip_g),
1677 Some(clip_g_tok),
1678 Some(t5),
1679 Some(t5_tok.clone()),
1680 ),
1681 true,
1682 false,
1683 None,
1684 LoadStrategy::Sequential,
1685 0,
1686 false,
1687 None,
1688 );
1689
1690 let (_, _, _, _, _, resolved_t5_tok) = engine.validate_paths().unwrap();
1691 assert_eq!(resolved_t5_tok, t5_tok);
1692 assert!(engine.detect_is_quantized());
1693
1694 fs::remove_dir_all(dir).ok();
1695 }
1696
1697 #[test]
1698 fn sd3_forced_offload_uses_sequential_generation_path() {
1699 let dir = temp_test_dir("mold-sd3-offload-sequential");
1700 let engine = SD3Engine::new(
1701 "sd3.5-large:bf16".to_string(),
1702 sd3_model_paths(
1703 dir.join("transformer.safetensors"),
1704 dir.join("vae.safetensors"),
1705 None,
1706 None,
1707 None,
1708 None,
1709 None,
1710 None,
1711 ),
1712 false,
1713 false,
1714 None,
1715 LoadStrategy::Eager,
1716 0,
1717 true,
1718 None,
1719 );
1720
1721 assert!(
1722 engine.uses_sequential_generate_path(),
1723 "SD3 --offload requests must reach the engine and select the \
1724 staged generation path instead of being silently ignored"
1725 );
1726
1727 fs::remove_dir_all(dir).ok();
1728 }
1729
1730 #[test]
1731 fn sd3_offload_decision_gates_current_unsupported_cases() {
1732 assert_eq!(
1733 sd3_offload_decision(false, false, false),
1734 SD3OffloadDecision::Disabled
1735 );
1736 assert_eq!(
1737 sd3_offload_decision(true, false, false),
1738 SD3OffloadDecision::Selected
1739 );
1740 assert!(matches!(
1741 sd3_offload_decision(true, true, false),
1742 SD3OffloadDecision::Unsupported(reason)
1743 if reason.contains("GGUF variants")
1744 ));
1745 assert!(matches!(
1746 sd3_offload_decision(true, false, true),
1747 SD3OffloadDecision::Unsupported(reason)
1748 if reason.contains("LoRA")
1749 ));
1750 }
1751
1752 #[test]
1753 fn sd3_selected_bf16_offload_reaches_runtime_loader() {
1754 use crate::cache::store_cached_tensor_pair;
1755
1756 let dir = temp_test_dir("mold-sd3-offload-loader");
1757 let transformer = touch(&dir, "transformer.safetensors");
1758 let vae = touch(&dir, "vae.safetensors");
1759 let clip_l = touch(&dir, "clip-l.safetensors");
1760 let clip_l_tok = touch(&dir, "clip-l-tokenizer.json");
1761 let clip_g = touch(&dir, "clip-g.safetensors");
1762 let clip_g_tok = touch(&dir, "clip-g-tokenizer.json");
1763 let t5 = touch(&dir, "t5.safetensors");
1764 let t5_tok = touch(&dir, "t5-tokenizer.json");
1765 let mut engine = SD3Engine::new(
1766 "sd3.5-large:bf16".to_string(),
1767 sd3_model_paths(
1768 transformer,
1769 vae,
1770 Some(clip_l),
1771 Some(clip_l_tok),
1772 Some(clip_g),
1773 Some(clip_g_tok),
1774 Some(t5),
1775 Some(t5_tok),
1776 ),
1777 false,
1778 false,
1779 None,
1780 LoadStrategy::Sequential,
1781 0,
1782 true,
1783 None,
1784 );
1785 let context = Tensor::zeros((1, 1, 4096), DType::F32, &Device::Cpu).unwrap();
1786 let y = Tensor::zeros((1, 2048), DType::F32, &Device::Cpu).unwrap();
1787 let key = cfg_prompt_cache_key("a cat", "", 1.0);
1788 store_cached_tensor_pair(&engine.prompt_cache, key, &context, &y).unwrap();
1789 let req = GenerateRequest {
1790 prompt: "a cat".to_string(),
1791 negative_prompt: None,
1792 model: "sd3.5-large:bf16".to_string(),
1793 width: 64,
1794 height: 64,
1795 steps: 1,
1796 guidance: 1.0,
1797 seed: Some(1),
1798 batch_size: 1,
1799 output_format: None,
1800 embed_metadata: None,
1801 scheduler: None,
1802 cfg_plus: None,
1803 source_image: None,
1804 edit_images: None,
1805 strength: 1.0,
1806 mask_image: None,
1807 control_image: None,
1808 control_model: None,
1809 control_scale: 1.0,
1810 expand: None,
1811 original_prompt: None,
1812 lora: None,
1813 frames: None,
1814 fps: None,
1815 upscale_model: None,
1816 gif_preview: false,
1817 enable_audio: None,
1818 audio_file: None,
1819 audio_file_path: None,
1820 source_video: None,
1821 source_video_path: None,
1822 keyframes: None,
1823 pipeline: None,
1824 loras: None,
1825 retake_range: None,
1826 spatial_upscale: None,
1827 temporal_upscale: None,
1828 placement: None,
1829 };
1830
1831 let err = engine.generate_sequential(&req).unwrap_err().to_string();
1832
1833 assert!(
1834 !err.contains("streaming is not implemented yet"),
1835 "selected BF16 offload must reach the runtime loader, got: {err}"
1836 );
1837 fs::remove_dir_all(dir).ok();
1838 }
1839
1840 #[test]
1849 fn sd3_prompt_cache_distinguishes_negative_prompt_changes() {
1850 use crate::cache::{cfg_prompt_cache_key, store_cached_tensor_pair};
1851
1852 let cache: Mutex<LruCache<CfgPromptCacheKey, CachedTensorPair>> =
1853 Mutex::new(LruCache::new(DEFAULT_PROMPT_CACHE_CAPACITY));
1854 let device = Device::Cpu;
1855 let dtype = DType::F32;
1856 let context = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
1857 let y = candle_core::Tensor::zeros((1, 4), dtype, &device).unwrap();
1858
1859 let key_a = cfg_prompt_cache_key("a cat", "blurry", 7.0);
1860 store_cached_tensor_pair(&cache, key_a.clone(), &context, &y).unwrap();
1861
1862 let key_b = cfg_prompt_cache_key("a cat", "low quality", 7.0);
1864 let restored = restore_cached_tensor_pair(&cache, &key_b, &device, dtype).unwrap();
1865 assert!(
1866 restored.is_none(),
1867 "different negative prompt must miss the cache (was the silent-wrong-output bug)"
1868 );
1869
1870 let restored = restore_cached_tensor_pair(&cache, &key_a, &device, dtype).unwrap();
1872 assert!(
1873 restored.is_some(),
1874 "identical (pos, neg, guidance) must still hit"
1875 );
1876 }
1877
1878 #[test]
1879 fn sd3_validate_paths_requires_t5_encoder() {
1880 let dir = temp_test_dir("mold-sd3-validate-missing");
1881 let engine = SD3Engine::new(
1882 "sd3.5-large:bf16".to_string(),
1883 sd3_model_paths(
1884 dir.join("transformer.safetensors"),
1885 dir.join("vae.safetensors"),
1886 Some(dir.join("clip-l.safetensors")),
1887 Some(dir.join("clip-l-tokenizer.json")),
1888 Some(dir.join("clip-g.safetensors")),
1889 Some(dir.join("clip-g-tokenizer.json")),
1890 None,
1891 Some(dir.join("t5-tokenizer.json")),
1892 ),
1893 false,
1894 false,
1895 None,
1896 LoadStrategy::Sequential,
1897 0,
1898 false,
1899 None,
1900 );
1901
1902 let err = engine.validate_paths().unwrap_err();
1903 assert!(err.to_string().contains("T5 encoder path required"));
1904 assert!(!engine.detect_is_quantized());
1905
1906 fs::remove_dir_all(dir).ok();
1907 }
1908
1909 #[test]
1910 fn sd3_loads_text_tokenizers_through_shared_pool() {
1911 let dir = temp_test_dir("mold-sd3-tokenizer-pool");
1912 let clip_l_tok = dir.join("clip-l-tokenizer.json");
1913 let clip_g_tok = dir.join("clip-g-tokenizer.json");
1914 let t5_tok = dir.join("t5-tokenizer.json");
1915 for path in [&clip_l_tok, &clip_g_tok, &t5_tok] {
1916 tokenizers::Tokenizer::new(BPE::default())
1917 .save(path, false)
1918 .unwrap();
1919 }
1920
1921 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1922 let pooled_clip_l = shared_pool
1923 .lock()
1924 .unwrap()
1925 .load_tokenizer(&clip_l_tok)
1926 .unwrap();
1927 let pooled_clip_g = shared_pool
1928 .lock()
1929 .unwrap()
1930 .load_tokenizer(&clip_g_tok)
1931 .unwrap();
1932 let pooled_t5 = shared_pool.lock().unwrap().load_tokenizer(&t5_tok).unwrap();
1933
1934 let engine = SD3Engine::new(
1935 "sd3.5-large:bf16".to_string(),
1936 sd3_model_paths(
1937 dir.join("transformer.safetensors"),
1938 dir.join("vae.safetensors"),
1939 Some(dir.join("clip-l.safetensors")),
1940 Some(clip_l_tok.clone()),
1941 Some(dir.join("clip-g.safetensors")),
1942 Some(clip_g_tok.clone()),
1943 Some(dir.join("t5.safetensors")),
1944 Some(t5_tok.clone()),
1945 ),
1946 false,
1947 false,
1948 None,
1949 LoadStrategy::Sequential,
1950 0,
1951 false,
1952 Some(shared_pool),
1953 );
1954
1955 let (loaded_clip_l, loaded_clip_g, loaded_t5) = engine
1956 .load_text_tokenizers(&clip_l_tok, &clip_g_tok, &t5_tok)
1957 .unwrap();
1958
1959 assert!(Arc::ptr_eq(&pooled_clip_l, &loaded_clip_l));
1960 assert!(Arc::ptr_eq(&pooled_clip_g, &loaded_clip_g));
1961 assert!(Arc::ptr_eq(&pooled_t5, &loaded_t5));
1962 fs::remove_dir_all(dir).ok();
1963 }
1964
1965 #[test]
1966 fn sd3_loads_vae_tensors_through_shared_pool() {
1967 let dir = temp_test_dir("mold-sd3-vae-pool");
1968 let vae_path = dir.join("vae.safetensors");
1969 let weight = 1.0f32.to_le_bytes();
1970 let mut tensors = HashMap::new();
1971 tensors.insert(
1972 "first_stage_model.encoder.conv_in.weight".to_string(),
1973 TensorView::new(SafeDtype::F32, vec![1], &weight).unwrap(),
1974 );
1975 serialize_to_file(&tensors, &None, &vae_path).unwrap();
1976
1977 let shared_pool = Arc::new(Mutex::new(SharedPool::new()));
1978 let pooled = shared_pool
1979 .lock()
1980 .unwrap()
1981 .load_safetensors_cpu_tensors(std::slice::from_ref(&vae_path))
1982 .unwrap()
1983 .unwrap();
1984
1985 let engine = SD3Engine::new(
1986 "sd3.5-large:bf16".to_string(),
1987 sd3_model_paths(
1988 dir.join("transformer.safetensors"),
1989 vae_path.clone(),
1990 Some(dir.join("clip-l.safetensors")),
1991 Some(dir.join("clip-l-tokenizer.json")),
1992 Some(dir.join("clip-g.safetensors")),
1993 Some(dir.join("clip-g-tokenizer.json")),
1994 Some(dir.join("t5.safetensors")),
1995 Some(dir.join("t5-tokenizer.json")),
1996 ),
1997 false,
1998 false,
1999 None,
2000 LoadStrategy::Sequential,
2001 0,
2002 false,
2003 Some(shared_pool),
2004 );
2005
2006 let loaded = engine.load_vae_cpu_tensors(&vae_path).unwrap().unwrap();
2007
2008 assert!(Arc::ptr_eq(&pooled, &loaded));
2009 fs::remove_dir_all(dir).ok();
2010 }
2011
2012 fn cfg_env_lock() -> std::sync::MutexGuard<'static, ()> {
2019 use std::sync::{Mutex, OnceLock};
2020 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
2021 LOCK.get_or_init(|| Mutex::new(()))
2022 .lock()
2023 .unwrap_or_else(|p| p.into_inner())
2024 }
2025
2026 fn req_with_cfg_plus(cfg_plus: Option<bool>) -> GenerateRequest {
2027 let mut req: GenerateRequest = serde_json::from_str(
2030 r#"{
2031 "prompt":"x",
2032 "model":"sd3.5-large:fp16",
2033 "width":1024,
2034 "height":1024,
2035 "steps":28,
2036 "guidance":4.5
2037 }"#,
2038 )
2039 .unwrap();
2040 req.cfg_plus = cfg_plus;
2041 req
2042 }
2043
2044 #[test]
2045 fn resolve_cfg_plus_defaults_off() {
2046 let _guard = cfg_env_lock();
2047 unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2049 assert!(!resolve_cfg_plus(&req_with_cfg_plus(None)));
2050 }
2051
2052 #[test]
2053 fn resolve_cfg_plus_env_enables() {
2054 let _guard = cfg_env_lock();
2055 unsafe { std::env::set_var("MOLD_CFG_PLUS", "1") };
2056 let on = resolve_cfg_plus(&req_with_cfg_plus(None));
2057 unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2058 assert!(on, "MOLD_CFG_PLUS=1 must enable cfg++");
2059 }
2060
2061 #[test]
2062 fn resolve_cfg_plus_request_field_wins_over_env() {
2063 let _guard = cfg_env_lock();
2064 unsafe { std::env::set_var("MOLD_CFG_PLUS", "1") };
2069 let off = resolve_cfg_plus(&req_with_cfg_plus(Some(false)));
2070 unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2071 assert!(!off, "explicit Some(false) must override env=on");
2072 }
2073
2074 #[test]
2075 fn resolve_cfg_plus_request_true_without_env() {
2076 let _guard = cfg_env_lock();
2077 unsafe { std::env::remove_var("MOLD_CFG_PLUS") };
2078 assert!(resolve_cfg_plus(&req_with_cfg_plus(Some(true))));
2079 }
2080}