mold-ai-inference 0.13.1

Candle-based inference engine for mold — FLUX, SDXL, SD3.5, Z-Image diffusion models
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
//! Quantized (GGUF) MMDiT for SD3.5
//!
//! Mirrors `candle_transformers::models::mmdit::model::MMDiT` but uses quantized layer
//! types from `candle_transformers::quantized_nn`. The GGUF tensor naming from city96
//! quantizations preserves the BF16 naming convention, so tensor paths match directly.
//!
//! Supports both sd3_5_large (depth=38) and sd3_5_medium (depth=24) configs.

use anyhow::Result;
use candle_core::{DType, Module, Tensor, D};
use candle_nn::RmsNorm as CandleRmsNorm;
use candle_transformers::models::mmdit::model::Config as MMDiTConfig;
use candle_transformers::quantized_nn::{self, Linear};
use candle_transformers::quantized_var_builder::VarBuilder;

/// Apply a quantized linear layer and replace any NaN values with 0.0.
/// Candle's CUDA QMatMul produces hidden NaN in some output elements when
/// processing large tensors. This wrapper prevents NaN propagation.
fn linear_nan_safe(linear: &Linear, x: &Tensor) -> candle_core::Result<Tensor> {
    let out = linear.forward(x)?;
    let nan_mask = out.ne(&out)?; // NaN != NaN → true
    let zero = Tensor::zeros_like(&out)?;
    nan_mask.where_cond(&zero, &out)
}

// ==================== LayerNormNoAffine ====================

struct LayerNormNoAffine {
    eps: f64,
}

impl LayerNormNoAffine {
    fn new(eps: f64) -> Self {
        Self { eps }
    }

    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        // Manual LayerNorm without affine parameters — matches Z-Image's LayerNormNoParams.
        // Previous implementation used `Tensor::ones_like(x)` which created a weight tensor
        // with shape [1, seq_len, hidden_size] instead of the required [hidden_size], causing
        // incorrect normalization that cascaded into NaN/inf during attention.
        let hidden_size = x.dim(D::Minus1)?;
        let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
        let x_centered = x.broadcast_sub(&mean_x)?;
        let norm_x = (x_centered.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
        x_centered.broadcast_div(&(norm_x + self.eps)?.sqrt()?)
    }
}

// ==================== PatchEmbedder ====================

struct PatchEmbedder {
    proj_weight: Tensor,
    proj_bias: Tensor,
    patch_size: usize,
}

impl PatchEmbedder {
    fn new(
        patch_size: usize,
        in_channels: usize,
        embed_dim: usize,
        vb: VarBuilder,
    ) -> Result<Self> {
        let proj_vb = vb.pp("proj");
        let proj_weight = proj_vb
            .get((embed_dim, in_channels, patch_size, patch_size), "weight")?
            .dequantize(vb.device())?;
        let proj_bias = proj_vb.get(embed_dim, "bias")?.dequantize(vb.device())?;
        Ok(Self {
            proj_weight,
            proj_bias,
            patch_size,
        })
    }

    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        let x = x.conv2d(&self.proj_weight, 0, self.patch_size, 1, 1)?;
        let x = x.broadcast_add(&self.proj_bias.reshape((1, (), 1, 1))?)?;
        let (b, c, h, w) = x.dims4()?;
        x.reshape((b, c, h * w))?.transpose(1, 2)
    }
}

// ==================== PositionEmbedder ====================

struct PositionEmbedder {
    pos_embed: Tensor,
    patch_size: usize,
    pos_embed_max_size: usize,
}

impl PositionEmbedder {
    fn new(
        hidden_size: usize,
        patch_size: usize,
        pos_embed_max_size: usize,
        vb: VarBuilder,
    ) -> Result<Self> {
        let pos_embed = vb
            .get(
                (1, pos_embed_max_size * pos_embed_max_size, hidden_size),
                "pos_embed",
            )?
            .dequantize(vb.device())?;
        Ok(Self {
            pos_embed,
            patch_size,
            pos_embed_max_size,
        })
    }

    fn get_cropped_pos_embed(&self, h: usize, w: usize) -> candle_core::Result<Tensor> {
        let h = (h + 1) / self.patch_size;
        let w = (w + 1) / self.patch_size;

        if h > self.pos_embed_max_size || w > self.pos_embed_max_size {
            candle_core::bail!("Input size is too large for the position embedding");
        }

        let top = (self.pos_embed_max_size - h) / 2;
        let left = (self.pos_embed_max_size - w) / 2;

        let pos_embed =
            self.pos_embed
                .reshape((1, self.pos_embed_max_size, self.pos_embed_max_size, ()))?;
        let pos_embed = pos_embed.narrow(1, top, h)?.narrow(2, left, w)?;
        pos_embed.reshape((1, h * w, ()))
    }
}

// ==================== TimestepEmbedder ====================

struct TimestepEmbedder {
    mlp_0: Linear,
    mlp_2: Linear,
    frequency_embedding_size: usize,
}

impl TimestepEmbedder {
    fn new(hidden_size: usize, frequency_embedding_size: usize, vb: VarBuilder) -> Result<Self> {
        let mlp_0 = quantized_nn::linear(frequency_embedding_size, hidden_size, vb.pp("mlp.0"))?;
        let mlp_2 = quantized_nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?;
        Ok(Self {
            mlp_0,
            mlp_2,
            frequency_embedding_size,
        })
    }

    fn timestep_embedding(t: &Tensor, dim: usize) -> candle_core::Result<Tensor> {
        let half = dim / 2;
        let max_period: f64 = 10000.0;
        let freqs = Tensor::arange(0f32, half as f32, t.device())?
            .to_dtype(DType::F32)?
            .affine(-max_period.ln() / half as f64, 0.0)?
            .exp()?;
        let args = t
            .unsqueeze(1)?
            .to_dtype(DType::F32)?
            .matmul(&freqs.unsqueeze(0)?)?;
        let embedding = Tensor::cat(&[args.cos()?, args.sin()?], 1)?;
        // Keep F32 for quantized path (QMatMul dequantizes weights to F32)
        Ok(embedding)
    }

    fn forward(&self, t: &Tensor) -> candle_core::Result<Tensor> {
        let t_freq = Self::timestep_embedding(t, self.frequency_embedding_size)?;
        let x = linear_nan_safe(&self.mlp_0, &t_freq)?.silu()?;
        linear_nan_safe(&self.mlp_2, &x)
    }
}

// ==================== VectorEmbedder ====================

struct VectorEmbedder {
    mlp_0: Linear,
    mlp_2: Linear,
}

impl VectorEmbedder {
    fn new(input_dim: usize, hidden_size: usize, vb: VarBuilder) -> Result<Self> {
        let mlp_0 = quantized_nn::linear(input_dim, hidden_size, vb.pp("mlp.0"))?;
        let mlp_2 = quantized_nn::linear(hidden_size, hidden_size, vb.pp("mlp.2"))?;
        Ok(Self { mlp_0, mlp_2 })
    }

    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        let x = linear_nan_safe(&self.mlp_0, x)?.silu()?;
        linear_nan_safe(&self.mlp_2, &x)
    }
}

// ==================== Mlp ====================

struct Mlp {
    fc1: Linear,
    fc2: Linear,
}

impl Mlp {
    fn new(in_features: usize, hidden_features: usize, vb: VarBuilder) -> Result<Self> {
        let fc1 = quantized_nn::linear(in_features, hidden_features, vb.pp("fc1"))?;
        let fc2 = quantized_nn::linear(hidden_features, in_features, vb.pp("fc2"))?;
        Ok(Self { fc1, fc2 })
    }

    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        // GeluPytorchTanh activation
        // Ensure contiguous for quantized matmul
        let x = linear_nan_safe(&self.fc1, &x.contiguous()?)?;
        let x = x
            .apply(&candle_nn::Activation::GeluPytorchTanh)?
            .contiguous()?;
        linear_nan_safe(&self.fc2, &x)
    }
}

// ==================== AttnProjections ====================

struct AttnProjections {
    head_dim: usize,
    qkv: Linear,
    ln_k: Option<CandleRmsNorm>,
    ln_q: Option<CandleRmsNorm>,
    proj: Linear,
}

impl AttnProjections {
    fn new(dim: usize, num_heads: usize, has_qk_norm: bool, vb: VarBuilder) -> Result<Self> {
        let head_dim = dim / num_heads;
        let qkv = quantized_nn::linear(dim, dim * 3, vb.pp("qkv"))?;
        let proj = quantized_nn::linear(dim, dim, vb.pp("proj"))?;
        let (ln_k, ln_q) = if has_qk_norm {
            let ln_k_w = vb
                .pp("ln_k")
                .get(head_dim, "weight")?
                .dequantize(vb.device())?;
            let ln_q_w = vb
                .pp("ln_q")
                .get(head_dim, "weight")?
                .dequantize(vb.device())?;
            (
                Some(CandleRmsNorm::new(ln_k_w, 1e-6)),
                Some(CandleRmsNorm::new(ln_q_w, 1e-6)),
            )
        } else {
            (None, None)
        };
        Ok(Self {
            head_dim,
            qkv,
            ln_k,
            ln_q,
            proj,
        })
    }

    fn pre_attention(&self, x: &Tensor) -> candle_core::Result<Qkv> {
        let qkv = linear_nan_safe(&self.qkv, x)?;
        let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?;
        let q = match self.ln_q.as_ref() {
            None => q,
            Some(l) => {
                let (b, t, h) = q.dims3()?;
                l.forward(&q.reshape((b, t, (), self.head_dim))?)?
                    .reshape((b, t, h))?
            }
        };
        let k = match self.ln_k.as_ref() {
            None => k,
            Some(l) => {
                let (b, t, h) = k.dims3()?;
                l.forward(&k.reshape((b, t, (), self.head_dim))?)?
                    .reshape((b, t, h))?
            }
        };
        Ok(Qkv { q, k, v })
    }

    fn post_attention(&self, x: &Tensor) -> candle_core::Result<Tensor> {
        // Trace attention output before proj
        static PROJ_DIAG: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
        if !PROJ_DIAG.swap(true, std::sync::atomic::Ordering::Relaxed)
            && std::env::var_os("MOLD_SD3_DEBUG").is_some()
        {
            let xf = x.to_dtype(DType::F32)?;
            let mn = xf.min_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
            let mx = xf.max_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
            eprintln!("[sd3-proj] attention output (before proj): [{mn:.4},{mx:.4}] shape={:?} dtype={:?}", x.shape(), x.dtype());
        }
        // Quantized matmul requires contiguous tensors
        linear_nan_safe(&self.proj, &x.contiguous()?)
    }
}

// ==================== QkvOnlyAttnProjections ====================

struct QkvOnlyAttnProjections {
    qkv: Linear,
    head_dim: usize,
}

impl QkvOnlyAttnProjections {
    fn new(dim: usize, num_heads: usize, vb: VarBuilder) -> Result<Self> {
        let head_dim = dim / num_heads;
        let qkv = quantized_nn::linear(dim, dim * 3, vb.pp("qkv"))?;
        Ok(Self { qkv, head_dim })
    }

    fn pre_attention(&self, x: &Tensor) -> candle_core::Result<Qkv> {
        let qkv = linear_nan_safe(&self.qkv, x)?;
        split_qkv(&qkv, self.head_dim)
    }
}

// ==================== Qkv + helpers ====================

struct Qkv {
    q: Tensor,
    k: Tensor,
    v: Tensor,
}

fn split_qkv(qkv: &Tensor, head_dim: usize) -> candle_core::Result<Qkv> {
    let (batch_size, seq_len, _) = qkv.dims3()?;
    let qkv = qkv.reshape((batch_size, seq_len, 3, (), head_dim))?;
    // Treat Q, K, V symmetrically — all get reshaped back to (B, seq, dim).
    // Previously V was left as 4D which caused shape inconsistencies in attention.
    let q = qkv.get_on_dim(2, 0)?.reshape((batch_size, seq_len, ()))?;
    let k = qkv.get_on_dim(2, 1)?.reshape((batch_size, seq_len, ()))?;
    let v = qkv.get_on_dim(2, 2)?.reshape((batch_size, seq_len, ()))?;
    Ok(Qkv { q, k, v })
}

fn modulate(x: &Tensor, shift: &Tensor, scale: &Tensor) -> candle_core::Result<Tensor> {
    let shift = shift.unsqueeze(1)?;
    let scale = scale.unsqueeze(1)?;
    // Use scalar addition instead of Tensor::ones_like to avoid unnecessary allocation
    // and potential precision issues from broadcasting a full-shape tensor.
    let scale_plus_one = (scale + 1.0)?;
    // Ensure contiguous for downstream quantized matmul operations
    shift
        .broadcast_add(&x.broadcast_mul(&scale_plus_one)?)?
        .contiguous()
}

/// Attention using standard (batch, heads, seq, head_dim) layout.
///
/// Previous approach used flatten_to(1) to merge batch+heads, then manual
/// chunked matmul — this produced hidden NaN in the Q@K^T matmul on CUDA.
/// This version keeps the 4D layout (batch, heads, seq, head_dim) and uses
/// standard matmul + softmax_last_dim without flattening.
fn attention(q: &Tensor, k: &Tensor, v: &Tensor, num_heads: usize) -> candle_core::Result<Tensor> {
    let batch_size = q.dim(0)?;
    let seqlen = q.dim(1)?;
    let q = q.reshape((batch_size, seqlen, num_heads, ()))?;
    let k = k.reshape((batch_size, seqlen, num_heads, ()))?;
    let v = v.reshape((batch_size, seqlen, num_heads, ()))?;
    let headdim = q.dim(D::Minus1)?;
    let softmax_scale = 1.0 / (headdim as f64).sqrt();

    // (batch, heads, seq, head_dim)
    let q = q.transpose(1, 2)?.contiguous()?;
    let k = k.transpose(1, 2)?.contiguous()?;
    let v = v.transpose(1, 2)?.contiguous()?;

    // Q @ K^T → (batch, heads, seq_q, seq_k)
    // Use chunked attention to avoid OOM on the [B, heads, full_seq, full_seq] matrix.
    let chunk_size = 256;
    let k_t = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
    let mut attn_chunks = Vec::new();
    let mut offset = 0;
    while offset < seqlen {
        let len = chunk_size.min(seqlen - offset);
        let q_chunk = q.narrow(2, offset, len)?.contiguous()?;
        let weights_chunk = (q_chunk.matmul(&k_t)? * softmax_scale)?;
        // Replace hidden NaN from CUDA matmul with 0.0 (NaN != NaN trick)
        let nan_mask = weights_chunk.ne(&weights_chunk)?;
        let zero = Tensor::zeros_like(&weights_chunk)?;
        let weights_chunk = nan_mask.where_cond(&zero, &weights_chunk)?;

        // Diagnostic: check each step for NaN/inf on first chunk of first call
        if offset == 0 && std::env::var_os("MOLD_SD3_DEBUG").is_some() {
            static STEP_DIAG: std::sync::atomic::AtomicBool =
                std::sync::atomic::AtomicBool::new(false);
            if !STEP_DIAG.swap(true, std::sync::atomic::Ordering::Relaxed) {
                let nan_count = nan_mask
                    .to_dtype(DType::F32)?
                    .sum_all()?
                    .to_scalar::<f32>()
                    .unwrap_or(-1.0);
                let wf = weights_chunk.to_dtype(DType::F32)?;
                let wmn = wf.min_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
                let wmx = wf.max_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
                eprintln!("[sd3-step] after nan_to_zero: nan_count={nan_count} weights=[{wmn:.4},{wmx:.4}]");
            }
        }

        let sm_chunk = candle_nn::ops::softmax_last_dim(&weights_chunk)?;

        if offset == 0 && std::env::var_os("MOLD_SD3_DEBUG").is_some() {
            static SM_STEP: std::sync::atomic::AtomicBool =
                std::sync::atomic::AtomicBool::new(false);
            if !SM_STEP.swap(true, std::sync::atomic::Ordering::Relaxed) {
                let sf = sm_chunk.to_dtype(DType::F32)?;
                let smn = sf.min_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
                let smx = sf.max_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
                let ssum = sf.sum_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
                eprintln!("[sd3-step] softmax: [{smn:.6},{smx:.6}] sum={ssum:.2}");
            }
        }

        let scores_chunk = sm_chunk.matmul(&v)?;
        // Also NaN-guard the sm@V matmul output
        let nan_mask2 = scores_chunk.ne(&scores_chunk)?;
        let zero2 = Tensor::zeros_like(&scores_chunk)?;
        let scores_chunk = nan_mask2.where_cond(&zero2, &scores_chunk)?;

        if offset == 0 && std::env::var_os("MOLD_SD3_DEBUG").is_some() {
            static SC_STEP: std::sync::atomic::AtomicBool =
                std::sync::atomic::AtomicBool::new(false);
            if !SC_STEP.swap(true, std::sync::atomic::Ordering::Relaxed) {
                let scf = scores_chunk.to_dtype(DType::F32)?;
                let scmn = scf.min_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
                let scmx = scf.max_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
                eprintln!("[sd3-step] scores (sm@V): [{scmn:.4},{scmx:.4}]");
            }
        }
        attn_chunks.push(scores_chunk);
        offset += len;
    }
    let attn = Tensor::cat(&attn_chunks, 2)?;

    // Trace attention output on first call
    if std::env::var_os("MOLD_SD3_DEBUG").is_some() {
        static ATTN_DIAG: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
        if !ATTN_DIAG.swap(true, std::sync::atomic::Ordering::Relaxed) {
            let af = attn.to_dtype(DType::F32)?;
            let mn = af.min_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
            let mx = af.max_all()?.to_scalar::<f32>().unwrap_or(f32::NAN);
            eprintln!(
                "[sd3-attn] attention output: [{mn:.4},{mx:.4}] shape={:?}",
                attn.shape()
            );
        }
    }

    // Back to (batch, seq, heads * head_dim) for downstream linear layers
    attn.transpose(1, 2)?
        .contiguous()?
        .reshape((batch_size, seqlen, num_heads * headdim))?
        .contiguous()
}

fn joint_attn(
    context_qkv: &Qkv,
    x_qkv: &Qkv,
    num_heads: usize,
) -> candle_core::Result<(Tensor, Tensor)> {
    let q = Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?;
    let k = Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?;
    let v = Tensor::cat(&[&context_qkv.v, &x_qkv.v], 1)?;

    let seqlen = q.dim(1)?;
    let attn = attention(&q, &k, &v, num_heads)?;
    let context_seqlen = context_qkv.q.dim(1)?;
    let context_attn = attn.narrow(1, 0, context_seqlen)?;
    let x_attn = attn.narrow(1, context_seqlen, seqlen - context_seqlen)?;
    Ok((context_attn, x_attn))
}

// ==================== DiTBlock (MMDiT standard) ====================

struct DiTBlock {
    norm1: LayerNormNoAffine,
    attn: AttnProjections,
    norm2: LayerNormNoAffine,
    mlp: Mlp,
    ada_ln_modulation_1: Linear,
}

impl DiTBlock {
    fn new(
        hidden_size: usize,
        num_heads: usize,
        has_qk_norm: bool,
        vb: VarBuilder,
    ) -> Result<Self> {
        let norm1 = LayerNormNoAffine::new(1e-6);
        let attn = AttnProjections::new(hidden_size, num_heads, has_qk_norm, vb.pp("attn"))?;
        let norm2 = LayerNormNoAffine::new(1e-6);
        let mlp_ratio = 4;
        let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
        let n_mods = 6;
        let ada_ln_modulation_1 = quantized_nn::linear(
            hidden_size,
            n_mods * hidden_size,
            vb.pp("adaLN_modulation.1"),
        )?;
        Ok(Self {
            norm1,
            attn,
            norm2,
            mlp,
            ada_ln_modulation_1,
        })
    }

    fn pre_attention(
        &self,
        x: &Tensor,
        c: &Tensor,
    ) -> candle_core::Result<(Qkv, ModulateIntermediates)> {
        let modulation = linear_nan_safe(&self.ada_ln_modulation_1, &c.silu()?)?;
        let chunks = modulation.chunk(6, D::Minus1)?;
        let (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = (
            chunks[0].clone(),
            chunks[1].clone(),
            chunks[2].clone(),
            chunks[3].clone(),
            chunks[4].clone(),
            chunks[5].clone(),
        );

        let norm_x = self.norm1.forward(x)?;
        let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
        let qkv = self.attn.pre_attention(&modulated_x)?;

        Ok((
            qkv,
            ModulateIntermediates {
                gate_msa,
                shift_mlp,
                scale_mlp,
                gate_mlp,
            },
        ))
    }

    fn post_attention(
        &self,
        attn: &Tensor,
        x: &Tensor,
        mod_interm: &ModulateIntermediates,
    ) -> candle_core::Result<Tensor> {
        let attn_out = self.attn.post_attention(attn)?;

        // Trace post_attention on first call
        static PA_DIAG: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
        let do_diag = !PA_DIAG.swap(true, std::sync::atomic::Ordering::Relaxed)
            && std::env::var_os("MOLD_SD3_DEBUG").is_some();
        macro_rules! trace {
            ($name:expr, $t:expr) => {
                if do_diag {
                    let mn = $t
                        .min_all()
                        .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                        .unwrap_or(f32::NAN);
                    let mx = $t
                        .max_all()
                        .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                        .unwrap_or(f32::NAN);
                    eprintln!("[sd3-post] {}: [{mn:.4},{mx:.4}]", $name);
                }
            };
        }

        trace!("attn_proj", &attn_out);
        let gated_attn = attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?;
        trace!("gated_attn", &gated_attn);
        let x = x.broadcast_add(&gated_attn)?;
        trace!("x+gated", &x);

        let norm_x = self.norm2.forward(&x)?;
        trace!("norm_x", &norm_x);
        let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
        trace!("modulated", &modulated_x);
        let mlp_out = self.mlp.forward(&modulated_x)?;
        trace!("mlp_out", &mlp_out);
        let gated_mlp = mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?;
        trace!("gated_mlp", &gated_mlp);
        let result = x.broadcast_add(&gated_mlp)?;
        trace!("final", &result);
        Ok(result)
    }
}

struct ModulateIntermediates {
    gate_msa: Tensor,
    shift_mlp: Tensor,
    scale_mlp: Tensor,
    gate_mlp: Tensor,
}

// ==================== SelfAttnDiTBlock (MMDiT-X) ====================

struct SelfAttnDiTBlock {
    norm1: LayerNormNoAffine,
    attn: AttnProjections,
    attn2: AttnProjections,
    norm2: LayerNormNoAffine,
    mlp: Mlp,
    ada_ln_modulation_1: Linear,
}

struct SelfAttnModulateIntermediates {
    gate_msa: Tensor,
    shift_mlp: Tensor,
    scale_mlp: Tensor,
    gate_mlp: Tensor,
    gate_msa2: Tensor,
}

impl SelfAttnDiTBlock {
    fn new(
        hidden_size: usize,
        num_heads: usize,
        has_qk_norm: bool,
        vb: VarBuilder,
    ) -> Result<Self> {
        let norm1 = LayerNormNoAffine::new(1e-6);
        let attn = AttnProjections::new(hidden_size, num_heads, has_qk_norm, vb.pp("attn"))?;
        let attn2 = AttnProjections::new(hidden_size, num_heads, has_qk_norm, vb.pp("attn2"))?;
        let norm2 = LayerNormNoAffine::new(1e-6);
        let mlp_ratio = 4;
        let mlp = Mlp::new(hidden_size, hidden_size * mlp_ratio, vb.pp("mlp"))?;
        let n_mods = 9;
        let ada_ln_modulation_1 = quantized_nn::linear(
            hidden_size,
            n_mods * hidden_size,
            vb.pp("adaLN_modulation.1"),
        )?;
        Ok(Self {
            norm1,
            attn,
            attn2,
            norm2,
            mlp,
            ada_ln_modulation_1,
        })
    }

    fn pre_attention(
        &self,
        x: &Tensor,
        c: &Tensor,
    ) -> candle_core::Result<(Qkv, Qkv, SelfAttnModulateIntermediates)> {
        let modulation = linear_nan_safe(&self.ada_ln_modulation_1, &c.silu()?)?;
        let chunks = modulation.chunk(9, D::Minus1)?;
        let (
            shift_msa,
            scale_msa,
            gate_msa,
            shift_mlp,
            scale_mlp,
            gate_mlp,
            shift_msa2,
            scale_msa2,
            gate_msa2,
        ) = (
            chunks[0].clone(),
            chunks[1].clone(),
            chunks[2].clone(),
            chunks[3].clone(),
            chunks[4].clone(),
            chunks[5].clone(),
            chunks[6].clone(),
            chunks[7].clone(),
            chunks[8].clone(),
        );

        let norm_x = self.norm1.forward(x)?;
        let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
        let qkv = self.attn.pre_attention(&modulated_x)?;

        let modulated_x2 = modulate(&norm_x, &shift_msa2, &scale_msa2)?;
        let qkv2 = self.attn2.pre_attention(&modulated_x2)?;

        Ok((
            qkv,
            qkv2,
            SelfAttnModulateIntermediates {
                gate_msa,
                shift_mlp,
                scale_mlp,
                gate_mlp,
                gate_msa2,
            },
        ))
    }

    fn post_attention(
        &self,
        attn: &Tensor,
        attn2: &Tensor,
        x: &Tensor,
        mod_interm: &SelfAttnModulateIntermediates,
    ) -> candle_core::Result<Tensor> {
        let attn_out = self.attn.post_attention(attn)?;
        let x = x.broadcast_add(&attn_out.broadcast_mul(&mod_interm.gate_msa.unsqueeze(1)?)?)?;
        let attn_out2 = self.attn2.post_attention(attn2)?;
        let x = x.broadcast_add(&attn_out2.broadcast_mul(&mod_interm.gate_msa2.unsqueeze(1)?)?)?;

        let norm_x = self.norm2.forward(&x)?;
        let modulated_x = modulate(&norm_x, &mod_interm.shift_mlp, &mod_interm.scale_mlp)?;
        let mlp_out = self.mlp.forward(&modulated_x)?;
        x.broadcast_add(&mlp_out.broadcast_mul(&mod_interm.gate_mlp.unsqueeze(1)?)?)
    }
}

// ==================== QkvOnlyDiTBlock (final joint block context) ====================

struct QkvOnlyDiTBlock {
    norm1: LayerNormNoAffine,
    attn: QkvOnlyAttnProjections,
    ada_ln_modulation_1: Linear,
}

impl QkvOnlyDiTBlock {
    fn new(hidden_size: usize, num_heads: usize, vb: VarBuilder) -> Result<Self> {
        let norm1 = LayerNormNoAffine::new(1e-6);
        let attn = QkvOnlyAttnProjections::new(hidden_size, num_heads, vb.pp("attn"))?;
        let n_mods = 2;
        let ada_ln_modulation_1 = quantized_nn::linear(
            hidden_size,
            n_mods * hidden_size,
            vb.pp("adaLN_modulation.1"),
        )?;
        Ok(Self {
            norm1,
            attn,
            ada_ln_modulation_1,
        })
    }

    fn pre_attention(&self, x: &Tensor, c: &Tensor) -> candle_core::Result<Qkv> {
        let modulation = linear_nan_safe(&self.ada_ln_modulation_1, &c.silu()?)?;
        let chunks = modulation.chunk(2, D::Minus1)?;
        let (shift_msa, scale_msa) = (chunks[0].clone(), chunks[1].clone());
        let norm_x = self.norm1.forward(x)?;
        let modulated_x = modulate(&norm_x, &shift_msa, &scale_msa)?;
        self.attn.pre_attention(&modulated_x)
    }
}

// ==================== FinalLayer ====================

struct FinalLayer {
    norm_final: LayerNormNoAffine,
    linear: Linear,
    ada_ln_modulation_1: Linear,
}

impl FinalLayer {
    fn new(
        hidden_size: usize,
        patch_size: usize,
        out_channels: usize,
        vb: VarBuilder,
    ) -> Result<Self> {
        let norm_final = LayerNormNoAffine::new(1e-6);
        let linear = quantized_nn::linear(
            hidden_size,
            patch_size * patch_size * out_channels,
            vb.pp("linear"),
        )?;
        let ada_ln_modulation_1 =
            quantized_nn::linear(hidden_size, 2 * hidden_size, vb.pp("adaLN_modulation.1"))?;
        Ok(Self {
            norm_final,
            linear,
            ada_ln_modulation_1,
        })
    }

    fn forward(&self, x: &Tensor, c: &Tensor) -> candle_core::Result<Tensor> {
        let modulation = linear_nan_safe(&self.ada_ln_modulation_1, &c.silu()?)?;
        let chunks = modulation.chunk(2, D::Minus1)?;
        let (shift, scale) = (chunks[0].clone(), chunks[1].clone());
        let norm_x = self.norm_final.forward(x)?;
        let modulated_x = modulate(&norm_x, &shift, &scale)?;
        linear_nan_safe(&self.linear, &modulated_x)
    }
}

// ==================== Unpatchifier ====================

struct Unpatchifier {
    patch_size: usize,
    out_channels: usize,
}

impl Unpatchifier {
    fn new(patch_size: usize, out_channels: usize) -> Self {
        Self {
            patch_size,
            out_channels,
        }
    }

    fn unpatchify(&self, x: &Tensor, h: usize, w: usize) -> candle_core::Result<Tensor> {
        let h = (h + 1) / self.patch_size;
        let w = (w + 1) / self.patch_size;
        let x = x.reshape((
            x.dim(0)?,
            h,
            w,
            self.patch_size,
            self.patch_size,
            self.out_channels,
        ))?;
        let x = x.permute((0, 5, 1, 3, 2, 4))?; // "nhwpqc->nchpwq"
        x.reshape((
            x.dim(0)?,
            self.out_channels,
            self.patch_size * h,
            self.patch_size * w,
        ))
    }
}

// ==================== JointBlock trait + implementations ====================

trait JointBlock {
    fn forward(
        &self,
        context: &Tensor,
        x: &Tensor,
        c: &Tensor,
        num_heads: usize,
    ) -> candle_core::Result<(Tensor, Tensor)>;
}

/// Standard MMDiT joint block (SD3.5 Large uses these exclusively).
struct MMDiTJointBlock {
    x_block: DiTBlock,
    context_block: DiTBlock,
}

impl MMDiTJointBlock {
    fn new(
        hidden_size: usize,
        num_heads: usize,
        has_qk_norm: bool,
        vb: VarBuilder,
    ) -> Result<Self> {
        let x_block = DiTBlock::new(hidden_size, num_heads, has_qk_norm, vb.pp("x_block"))?;
        let context_block =
            DiTBlock::new(hidden_size, num_heads, has_qk_norm, vb.pp("context_block"))?;
        Ok(Self {
            x_block,
            context_block,
        })
    }
}

impl JointBlock for MMDiTJointBlock {
    fn forward(
        &self,
        context: &Tensor,
        x: &Tensor,
        c: &Tensor,
        num_heads: usize,
    ) -> candle_core::Result<(Tensor, Tensor)> {
        let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
        let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?;

        // Trace block internals on first call
        static BLK_DIAG: std::sync::atomic::AtomicBool = std::sync::atomic::AtomicBool::new(false);
        let do_diag = !BLK_DIAG.swap(true, std::sync::atomic::Ordering::Relaxed)
            && std::env::var_os("MOLD_SD3_DEBUG").is_some();
        if do_diag {
            let stats = |name: &str, t: &Tensor| -> String {
                let mn = t
                    .min_all()
                    .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                    .unwrap_or(f32::NAN);
                let mx = t
                    .max_all()
                    .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                    .unwrap_or(f32::NAN);
                format!("{name}=[{mn:.4},{mx:.4}]")
            };
            eprintln!(
                "[sd3-blk0] pre_attn: {} {} {} {} gate_msa={} gate_mlp={}",
                stats("x_q", &x_qkv.q),
                stats("x_k", &x_qkv.k),
                stats("ctx_q", &context_qkv.q),
                stats("ctx_k", &context_qkv.k),
                stats("x_gate", &x_interm.gate_msa),
                stats("x_gmlp", &x_interm.gate_mlp)
            );
        }

        let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, num_heads)?;

        if do_diag {
            let mn = |t: &Tensor| {
                t.min_all()
                    .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                    .unwrap_or(f32::NAN)
            };
            let mx = |t: &Tensor| {
                t.max_all()
                    .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                    .unwrap_or(f32::NAN)
            };
            eprintln!(
                "[sd3-blk0] attn out: x=[{:.4},{:.4}] ctx=[{:.4},{:.4}]",
                mn(&x_attn),
                mx(&x_attn),
                mn(&context_attn),
                mx(&context_attn)
            );
        }

        let context_out =
            self.context_block
                .post_attention(&context_attn, context, &context_interm)?;

        if do_diag {
            // Check x_attn BEFORE post_attention (is narrow'd view causing issues?)
            let xac = x_attn.contiguous()?;
            let mn = xac
                .min_all()
                .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                .unwrap_or(f32::NAN);
            let mx = xac
                .max_all()
                .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                .unwrap_or(f32::NAN);
            eprintln!(
                "[sd3-blk0] x_attn (contiguous): [{mn:.4},{mx:.4}] shape={:?}",
                xac.shape()
            );
            // Also check context_out
            let co = context_out
                .min_all()
                .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                .unwrap_or(f32::NAN);
            let cox = context_out
                .max_all()
                .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                .unwrap_or(f32::NAN);
            eprintln!("[sd3-blk0] context_out: [{co:.4},{cox:.4}]");
        }

        let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?;

        if do_diag {
            let mn = |t: &Tensor| {
                t.min_all()
                    .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                    .unwrap_or(f32::NAN)
            };
            let mx = |t: &Tensor| {
                t.max_all()
                    .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                    .unwrap_or(f32::NAN)
            };
            eprintln!(
                "[sd3-blk0] post_attn: x=[{:.4},{:.4}] ctx=[{:.4},{:.4}]",
                mn(&x_out),
                mx(&x_out),
                mn(&context_out),
                mx(&context_out)
            );
        }

        Ok((context_out, x_out))
    }
}

/// MMDiT-X joint block (SD3.5 Medium uses these with self-attention on x).
struct MMDiTXJointBlock {
    x_block: SelfAttnDiTBlock,
    context_block: DiTBlock,
}

impl MMDiTXJointBlock {
    fn new(
        hidden_size: usize,
        num_heads: usize,
        has_qk_norm: bool,
        vb: VarBuilder,
    ) -> Result<Self> {
        let x_block = SelfAttnDiTBlock::new(hidden_size, num_heads, has_qk_norm, vb.pp("x_block"))?;
        let context_block =
            DiTBlock::new(hidden_size, num_heads, has_qk_norm, vb.pp("context_block"))?;
        Ok(Self {
            x_block,
            context_block,
        })
    }
}

impl JointBlock for MMDiTXJointBlock {
    fn forward(
        &self,
        context: &Tensor,
        x: &Tensor,
        c: &Tensor,
        num_heads: usize,
    ) -> candle_core::Result<(Tensor, Tensor)> {
        let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?;
        let (x_qkv, x_qkv2, x_interm) = self.x_block.pre_attention(x, c)?;
        let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, num_heads)?;
        let x_attn2 = attention(&x_qkv2.q, &x_qkv2.k, &x_qkv2.v, num_heads)?;
        let context_out =
            self.context_block
                .post_attention(&context_attn, context, &context_interm)?;
        let x_out = self
            .x_block
            .post_attention(&x_attn, &x_attn2, x, &x_interm)?;
        Ok((context_out, x_out))
    }
}

// ==================== QuantizedMMDiT (main struct) ====================

/// Quantized MMDiT model for SD3.5 inference with GGUF weights.
pub(crate) struct QuantizedMMDiT {
    patch_embedder: PatchEmbedder,
    pos_embedder: PositionEmbedder,
    timestep_embedder: TimestepEmbedder,
    vector_embedder: VectorEmbedder,
    context_embedder: Linear,
    joint_blocks: Vec<Box<dyn JointBlock + Send + Sync>>,
    context_qkv_only_block: ContextQkvOnlyBlock,
    final_layer: FinalLayer,
    unpatchifier: Unpatchifier,
    num_heads: usize,
}

/// The last joint block where context only produces QKV (no MLP).
struct ContextQkvOnlyBlock {
    x_block: DiTBlock,
    context_block: QkvOnlyDiTBlock,
}

impl ContextQkvOnlyBlock {
    fn new(
        hidden_size: usize,
        num_heads: usize,
        has_qk_norm: bool,
        vb: VarBuilder,
    ) -> Result<Self> {
        let x_block = DiTBlock::new(hidden_size, num_heads, has_qk_norm, vb.pp("x_block"))?;
        let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?;
        Ok(Self {
            x_block,
            context_block,
        })
    }
}

impl QuantizedMMDiT {
    /// Load a quantized MMDiT from GGUF weights.
    ///
    /// GGUF files from city96 use unprefixed tensor names (e.g. `x_embedder.proj.weight`),
    /// so `vb` should NOT have a `model.diffusion_model` prefix.
    pub fn new(cfg: &MMDiTConfig, vb: VarBuilder) -> Result<Self> {
        let hidden_size = cfg.head_size * cfg.depth;
        let num_heads = cfg.depth;

        let patch_embedder = PatchEmbedder::new(
            cfg.patch_size,
            cfg.in_channels,
            hidden_size,
            vb.pp("x_embedder"),
        )?;
        let pos_embedder = PositionEmbedder::new(
            hidden_size,
            cfg.patch_size,
            cfg.pos_embed_max_size,
            vb.clone(),
        )?;
        let timestep_embedder = TimestepEmbedder::new(
            hidden_size,
            cfg.frequency_embedding_size,
            vb.pp("t_embedder"),
        )?;
        let vector_embedder =
            VectorEmbedder::new(cfg.adm_in_channels, hidden_size, vb.pp("y_embedder"))?;
        let context_embedder = quantized_nn::linear(
            cfg.context_embed_size,
            hidden_size,
            vb.pp("context_embedder"),
        )?;

        // Detect MMDiT vs MMDiT-X blocks by checking for attn2 weights
        let mut joint_blocks: Vec<Box<dyn JointBlock + Send + Sync>> =
            Vec::with_capacity(cfg.depth - 1);
        for i in 0..cfg.depth - 1 {
            let block_vb = vb.pp(format!("joint_blocks.{i}"));
            // Check if this block has attn2 (MMDiT-X) by probing for tensor presence only.
            // Using a shaped get() here can produce a false negative on real tensors if the
            // probe shape is wrong, which misclassifies SD3.5 Medium blocks as plain MMDiT.
            let has_attn2 = block_vb
                .pp("x_block")
                .pp("attn2")
                .pp("qkv")
                .get_no_shape("weight")
                .is_ok();
            // Check for QK norm by probing ln_k (head_dim shape)
            let head_dim = hidden_size / num_heads;
            let has_qk_norm = block_vb
                .pp("x_block")
                .pp("attn")
                .pp("ln_k")
                .get(head_dim, "weight")
                .is_ok();
            let block: Box<dyn JointBlock + Send + Sync> = if has_attn2 {
                Box::new(MMDiTXJointBlock::new(
                    hidden_size,
                    num_heads,
                    has_qk_norm,
                    block_vb,
                )?)
            } else {
                Box::new(MMDiTJointBlock::new(
                    hidden_size,
                    num_heads,
                    has_qk_norm,
                    block_vb,
                )?)
            };
            joint_blocks.push(block);
        }

        // Check for QK norm on the final block
        let final_block_vb = vb.pp(format!("joint_blocks.{}", cfg.depth - 1));
        let head_dim = hidden_size / num_heads;
        let final_has_qk_norm = final_block_vb
            .pp("x_block")
            .pp("attn")
            .pp("ln_k")
            .get(head_dim, "weight")
            .is_ok();

        let context_qkv_only_block =
            ContextQkvOnlyBlock::new(hidden_size, num_heads, final_has_qk_norm, final_block_vb)?;

        let final_layer = FinalLayer::new(
            hidden_size,
            cfg.patch_size,
            cfg.out_channels,
            vb.pp("final_layer"),
        )?;

        let unpatchifier = Unpatchifier::new(cfg.patch_size, cfg.out_channels);

        Ok(Self {
            patch_embedder,
            pos_embedder,
            timestep_embedder,
            vector_embedder,
            context_embedder,
            joint_blocks,
            context_qkv_only_block,
            final_layer,
            unpatchifier,
            num_heads,
        })
    }

    /// Forward pass through the quantized MMDiT.
    pub fn forward(
        &self,
        x: &Tensor,
        t: &Tensor,
        y: &Tensor,
        context: &Tensor,
        skip_layers: Option<&[usize]>,
    ) -> Result<Tensor> {
        // Quantized model operates in F32 (QMatMul dequantizes weights to F32)
        let x = &x.to_dtype(DType::F32)?;
        let t = &t.to_dtype(DType::F32)?;
        let y = &y.to_dtype(DType::F32)?;
        let context = &context.to_dtype(DType::F32)?;

        let h = x.dim(D::Minus2)?;
        let w = x.dim(D::Minus1)?;
        let cropped_pos_embed = self.pos_embedder.get_cropped_pos_embed(h, w)?;
        let x = self
            .patch_embedder
            .forward(x)?
            .broadcast_add(&cropped_pos_embed)?;
        let c = self.timestep_embedder.forward(t)?;
        let y = self.vector_embedder.forward(y)?;
        let c = (c + y)?;
        let context = linear_nan_safe(&self.context_embedder, context)?;

        // Diagnostic: check if embeddings are finite before any block runs
        if std::env::var_os("MOLD_SD3_DEBUG").is_some() {
            static EMB_DIAG: std::sync::atomic::AtomicBool =
                std::sync::atomic::AtomicBool::new(false);
            if !EMB_DIAG.swap(true, std::sync::atomic::Ordering::Relaxed) {
                let stats = |name: &str, t: &Tensor| {
                    let mn = t
                        .min_all()
                        .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                        .unwrap_or(f32::NAN);
                    let mx = t
                        .max_all()
                        .and_then(|v| v.to_dtype(DType::F32)?.to_scalar::<f32>())
                        .unwrap_or(f32::NAN);
                    eprintln!("[sd3-emb] {name}: [{mn:.4},{mx:.4}] shape={:?}", t.shape());
                };
                stats("x (patch+pos)", &x);
                stats("c (timestep+y)", &c);
                stats("context (encoded)", &context);
            }
        }

        // Joint blocks
        let (mut context, mut x) = (context, x);
        for (i, joint_block) in self.joint_blocks.iter().enumerate() {
            if let Some(skip) = &skip_layers {
                if skip.contains(&i) {
                    continue;
                }
            }
            let result = joint_block.forward(&context, &x, &c, self.num_heads)?;
            context = result.0;
            x = result.1;
            // One-shot diagnostic for first forward pass
            // Trace ALL blocks on first forward pass
            {
                static FWD_DIAG: std::sync::atomic::AtomicBool =
                    std::sync::atomic::AtomicBool::new(false);
                if !FWD_DIAG.load(std::sync::atomic::Ordering::Relaxed)
                    && std::env::var_os("MOLD_SD3_DEBUG").is_some()
                {
                    let xmin = x
                        .min_all()?
                        .to_dtype(DType::F32)?
                        .to_scalar::<f32>()
                        .unwrap_or(f32::NAN);
                    let xmax = x
                        .max_all()?
                        .to_dtype(DType::F32)?
                        .to_scalar::<f32>()
                        .unwrap_or(f32::NAN);
                    eprintln!("[sd3-debug] block {i} output: x=[{xmin:.4},{xmax:.4}]");
                    if !xmin.is_finite() || !xmax.is_finite() {
                        FWD_DIAG.store(true, std::sync::atomic::Ordering::Relaxed);
                        eprintln!("[sd3-debug] INF FIRST AT BLOCK {i}");
                    }
                }
            }
        }

        // Final context QKV only block
        let context_qkv = self
            .context_qkv_only_block
            .context_block
            .pre_attention(&context, &c)?;
        let (x_qkv, x_interm) = self.context_qkv_only_block.x_block.pre_attention(&x, &c)?;
        let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?;
        let x = self
            .context_qkv_only_block
            .x_block
            .post_attention(&x_attn, &x, &x_interm)?;

        let x = self.final_layer.forward(&x, &c)?;
        let x = self.unpatchifier.unpatchify(&x, h, w)?;
        Ok(x.narrow(2, 0, h)?.narrow(3, 0, w)?)
    }
}