zensim 0.2.4

Fast psychovisual image similarity metric
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
//! Core zensim metric computation.
//!
//! Multi-scale SSIM + edge + high-frequency features in XYB color space,
//! with trained weights per feature.
//!
//! # Feature extraction pipeline
//!
//! Both images are converted to the XYB perceptual color space (cube-root LMS,
//! same as ssimulacra2 and butteraugli), then processed at multiple scales.
//! Each scale halves resolution via 2× box downscale. At each scale, 19 features
//! are extracted per XYB channel (X, Y, B): 13 basic + 6 peak/diagnostic,
//! giving **228 features total** (4 scales × 3 channels × 19 features).
//!
//! ## SSIM features (3 per channel per scale)
//!
//! Uses the ssimulacra2 variant of SSIM, which differs from standard SSIM:
//!
//! ```text
//! mu1 = blur(src), mu2 = blur(dst)
//! sigma12 = blur(src * dst)
//! sum_sq  = blur(src² + dst²)     // one blur instead of two
//!
//! num_m   = 1 - (mu1 - mu2)²      // luminance (no C1, no denominator)
//! num_s   = 2·sigma12 - 2·mu1·mu2 + C2   // structure × contrast
//! denom_s = sum_sq - mu1² - mu2² + C2     // = sigma1² + sigma2² + C2
//!
//! d = max(0, 1 - num_m · num_s / denom_s) // per-pixel SSIM error
//! ```
//!
//! The luminance component drops the standard SSIM denominator
//! `(mu1² + mu2² + C1)` — ssimulacra2's reasoning is that the denominator
//! over-weights dark-region errors, which is wrong for perceptually uniform
//! values (XYB is already gamma-like). There is no C1 constant; C2 = 0.0009.
//!
//! The `sum_sq` optimization computes `blur(src² + dst²)` with one blur
//! instead of separate `blur(src²)` and `blur(dst²)`, because the SSIM
//! formula only needs `sigma1² + sigma2²`, not each individually.
//!
//! Three pooling norms capture different aspects of the error distribution:
//! - **ssim_mean** = `mean(d)` — average error
//! - **ssim_4th**  = `(mean(d⁴))^(1/4)` — L4 norm, emphasizes worst-case errors
//! - **ssim_2nd**  = `(mean(d²))^(1/2)` — L2 norm, intermediate sensitivity
//!
//! ## Edge features (6 per channel per scale)
//!
//! Edge detection compares local detail (pixel minus local mean) between
//! source and distorted:
//!
//! ```text
//! diff_src = |src - mu1|    // source edge magnitude
//! diff_dst = |dst - mu2|    // distorted edge magnitude
//!
//! d = (1 + diff_dst) / (1 + diff_src) - 1   // per-pixel edge ratio
//!
//! artifact    = max(0,  d)   // distorted has MORE edge than source
//! detail_lost = max(0, -d)   // distorted has LESS edge than source
//! ```
//!
//! The `1 +` offsets prevent division by zero and dampen sensitivity in flat
//! regions. The ratio formulation is scale-invariant. Splitting into artifact
//! (ringing, banding, blockiness) vs detail_lost (blur, smoothing) lets the
//! model weight them independently.
//!
//! Each is pooled with three norms (mean, L4, L2) = 6 features.
//!
//! ## MSE (1 per channel per scale)
//!
//! Plain mean squared error in XYB space: `mean((src - dst)²)`.
//! No blur dependency, computed directly from pixels.
//!
//! ## High-frequency features (3 per channel per scale)
//!
//! These measure changes in local detail energy by comparing `pixel - blur(pixel)`
//! (the high-frequency residual) between source and distorted. Despite their
//! former names ("variance_loss", "texture_loss", "contrast_increase"), they
//! do NOT measure image variance — they measure the ratio of high-pass energy.
//!
//! ```text
//! hf_src_L2 = Σ(src - mu1)²    // source HF energy (L2)
//! hf_dst_L2 = Σ(dst - mu2)²    // distorted HF energy (L2)
//! hf_src_L1 = Σ|src - mu1|     // source HF magnitude (L1)
//! hf_dst_L1 = Σ|dst - mu2|     // distorted HF magnitude (L1)
//! ```
//!
//! - **hf_energy_loss** = `max(0, 1 - hf_dst_L2 / hf_src_L2)` — detail smoothed away
//! - **hf_mag_loss**    = `max(0, 1 - hf_dst_L1 / hf_src_L1)` — same, L1 (robust to outliers)
//! - **hf_energy_gain** = `max(0, hf_dst_L2 / hf_src_L2 - 1)` — detail added (ringing/sharpening)
//!
//! `hf_energy_loss` and `hf_energy_gain` are the positive and negative halves
//! of the same signal, split by ReLU — this gives the linear model separate
//! knobs for blur vs ringing without needing signed weights.
//!
//! ## Peak features (6 per channel per scale)
//!
//! Computed during the fused V-blur kernel at no extra cost:
//! - **ssim_max**, **art_max**, **det_max** — per-pixel maximum of each error type
//! - **ssim_l8**, **art_l8**, **det_l8** — L8-pooled (near-worst-case) values
//!
//! These capture outlier sensitivity that mean/L2/L4 pooling may miss.
//!
//! ## Scoring
//!
//! All 228 features are multiplied by trained weights, summed, normalized by
//! scale count, then mapped to a 0–100 score via:
//! `score = 100 - a · distance^b` (default a=18.0, b=0.7).

use crate::error::ZensimError;

/// Configuration for zensim computation.
///
/// All computation uses the streaming path, which processes scale 0 in
/// horizontal strips with fused blur+feature extraction for minimal memory
/// traffic. When `blur_passes == 1` (the default), fused H-blur + V-blur+reduce
/// SIMD kernels are used for peak performance.
///
/// Blur kernel shape for local-mean computation.
///
/// Controls how `blur(src)` and `blur(dst)` are computed at each scale.
/// The default `Box` kernel uses iterated box blur, which is O(1) per pixel
/// regardless of radius and has full SIMD optimization.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BlurKernel {
    /// Iterated box blur. `passes` controls the kernel shape:
    /// - 1 = rectangular (fastest, enables fused streaming kernels)
    /// - 2 = triangular (~1.5× slower at scale 0)
    /// - 3 = piecewise-quadratic ≈ Gaussian (~2× slower)
    Box {
        /// Number of passes (1 = rectangular, 2 = triangular, 3 ≈ Gaussian).
        passes: u8,
    },
}

impl Default for BlurKernel {
    fn default() -> Self {
        Self::Box { passes: 1 }
    }
}

/// Downscale filter for pyramid construction.
///
/// Controls how each pyramid level is produced from the previous one.
/// The default `Box2x2` averages 2×2 pixel blocks, halving resolution.
/// Enable the `zenresize` feature for `Mitchell` and `Lanczos` variants.
#[non_exhaustive]
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub enum DownscaleFilter {
    /// 2×2 box averaging (fastest, current default).
    #[default]
    Box2x2,
    /// Mitchell-Netravali bicubic (B=1/3, C=1/3). Good balance of sharpness
    /// and ringing. Requires the `zenresize` feature.
    #[cfg(feature = "zenresize")]
    #[allow(dead_code)]
    Mitchell,
    /// Lanczos-3 windowed sinc. Sharper than Mitchell but may ring on edges.
    /// Requires the `zenresize` feature.
    #[cfg(feature = "zenresize")]
    #[allow(dead_code)]
    Lanczos,
    /// Mitchell-Netravali bicubic followed by a Gaussian blur with the given
    /// sigma. This anti-aliases the pyramid more aggressively than plain
    /// Mitchell, which may help metrics that are sensitive to high-frequency
    /// ringing. Requires the `zenresize` feature.
    #[cfg(feature = "zenresize")]
    #[allow(dead_code)]
    MitchellBlur(f32),
}

/// Configuration for the zensim metric computation pipeline.
///
/// Controls blur kernel, pyramid construction, and feature extraction.
/// The defaults match the trained profile and give peak performance;
/// only change these for training or research.
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub struct ZensimConfig {
    /// Box blur radius at scale 0 (default: 5, giving an 11-pixel kernel).
    ///
    /// The blur kernel width is `2 * blur_radius + 1`. Larger radii capture
    /// coarser structure but increase computation proportionally.
    /// Both streaming and full-image paths are SIMD-optimized for any radius.
    pub blur_radius: usize,

    /// Number of box blur passes (1, 2, or 3; default: 1).
    ///
    /// Controls the blur kernel shape:
    /// - **1 pass** — rectangular kernel. Enables fused blur+feature SIMD kernels
    ///   in the streaming path (fastest).
    /// - **2 passes** — triangular kernel. Falls back to separate blur+reduce in
    ///   the streaming path (~1.5× slower at scale 0).
    /// - **3 passes** — piecewise-quadratic ≈ Gaussian. Same fallback (~2× slower).
    ///
    /// All three variants have full SIMD optimization (AVX-512 + AVX2 dispatch).
    /// The performance difference comes from whether the fused streaming kernels
    /// can be used, not from the blur itself.
    pub blur_passes: u8,

    /// Blur kernel shape (default: `BlurKernel::Box { passes: 1 }`).
    ///
    /// Overrides `blur_passes` when set. The `blur_radius` field still controls
    /// the kernel width. Currently only the `Box` variant is implemented.
    #[allow(dead_code)] // planned: not yet wired into blur dispatch
    pub blur_kernel: BlurKernel,

    /// Downscale filter for pyramid construction (default: `DownscaleFilter::Box2x2`).
    ///
    /// Controls how each pyramid level is produced. Enable the `zenresize`
    /// feature for `Mitchell` and `Lanczos` variants.
    #[allow(dead_code)] // planned: not yet wired into pyramid construction
    pub downscale_filter: DownscaleFilter,

    /// Compute all 156 features even when their weights are zero (default: false).
    ///
    /// When false, channels/features with zero weight are skipped entirely.
    /// Enable for weight training to avoid circular dependency (need all features
    /// to determine which weights should be nonzero).
    pub compute_all_features: bool,

    /// Compute extended features (25 per channel instead of 13; default: false).
    ///
    /// When true, adds 12 extra features per channel per scale:
    /// - 6 masked features (SSIM/edge/MSE weighted by source flatness)
    /// - 6 percentile/max features (worst-case SSIM/edge errors)
    ///
    /// The masking strength for extended features is controlled by
    /// `extended_masking_strength`.
    pub extended_features: bool,

    /// Masking strength for extended masked features (default: 4.0).
    ///
    /// Only used when `extended_features` is true. Controls the flatness mask:
    /// `mask[i] = 1 / (1 + k * blur(|src - mu|))`.
    ///
    /// Higher values = more aggressive masking of textured regions.
    /// Typical range: 2.0–8.0.
    pub extended_masking_strength: f32,

    /// Maximum number of downscale levels (default: 4).
    ///
    /// Each level halves resolution. 4 scales covers 1×, 2×, 4×, 8× — sufficient
    /// for most perceptual effects. The feature vector length scales linearly:
    /// `num_scales × 3 channels × 13 features`.
    ///
    /// Both paths are SIMD-optimized for any scale count.
    pub num_scales: usize,

    /// Score mapping scale factor (default: 18.0).
    ///
    /// Used in the final score formula: `score = 100 - a × d^b`, where `d` is
    /// the raw weighted distance. Larger values spread scores more aggressively.
    pub score_mapping_a: f64,

    /// Score mapping gamma exponent (default: 0.7).
    ///
    /// Used in the final score formula: `score = 100 - a × d^b`. Sub-linear
    /// gamma (< 1.0) compresses high distances, giving more resolution in the
    /// high-quality range.
    pub score_mapping_b: f64,

    /// Enable multi-threaded computation via rayon (default: true).
    pub allow_multithreading: bool,
}

impl Default for ZensimConfig {
    fn default() -> Self {
        Self {
            blur_radius: 5,
            blur_passes: 1,
            blur_kernel: BlurKernel::default(),
            downscale_filter: DownscaleFilter::default(),
            compute_all_features: false,
            extended_features: false,
            extended_masking_strength: 4.0,
            num_scales: crate::NUM_SCALES,
            score_mapping_a: 18.0,
            score_mapping_b: 0.7,
            allow_multithreading: true,
        }
    }
}

/// Map a raw weighted distance to the 0–100 quality score.
///
/// Uses the default power-law mapping: `score = 100 - 18 * d^0.7`, clamped to \[0, 100\].
/// Identical images (d = 0) score 100.
///
/// For profile-specific mapping, use [`Zensim::compute`] which applies the profile's
/// `score_mapping_a` and `score_mapping_b` automatically.
pub(crate) fn distance_to_score(raw_distance: f64) -> f64 {
    distance_to_score_mapped(raw_distance, 18.0, 0.7)
}

/// Map a raw weighted distance to the quality score with custom parameters.
///
/// `score = 100 - a * d^b`. Nominally 0–100 but can go negative for
/// extreme distortions (the magnitude below zero is informative —
/// it distinguishes "slightly wrong" from "completely wrong").
fn distance_to_score_mapped(raw_distance: f64, a: f64, b: f64) -> f64 {
    if raw_distance <= 0.0 {
        100.0
    } else {
        100.0 - a * raw_distance.powf(b)
    }
}

/// Compute score from raw features using custom weights.
/// `features`: raw features from ZensimResult.features
/// `weights`: one weight per feature (len must equal features.len())
/// Returns (score, raw_distance)
#[cfg_attr(not(feature = "training"), allow(dead_code))]
pub fn score_from_features(features: &[f64], weights: &[f64]) -> (f64, f64) {
    assert_eq!(
        features.len(),
        weights.len(),
        "features and weights must have same length"
    );
    let raw_distance: f64 = features
        .iter()
        .zip(weights.iter())
        .map(|(&f, &w)| w * f)
        .sum();
    // Normalize by number of scales.
    // Layout: [scored × N_scales] [peaks × N_scales] [masked × N_scales]
    // 156 = 39×4, 228 = 57×4, 300 = 75×4 — all divide by 4 scales.
    let per_scale_candidates = [
        FEATURES_PER_CHANNEL_EXTENDED * 3,   // 75
        FEATURES_PER_CHANNEL_WITH_PEAKS * 3, // 57
        FEATURES_PER_CHANNEL_BASIC * 3,      // 39
    ];
    let features_per_scale = per_scale_candidates
        .iter()
        .copied()
        .find(|&ps| ps > 0 && features.len().is_multiple_of(ps))
        .unwrap_or(FEATURES_PER_CHANNEL_BASIC * 3);
    let n_scales = features.len() / features_per_scale;
    let raw_distance = raw_distance / n_scales.max(1) as f64;
    (distance_to_score(raw_distance), raw_distance)
}

/// Pre-compute reference with a custom number of pyramid scales.
///
/// Use this when calling [`compute_zensim_with_ref_and_config`] with a non-default
/// `num_scales`. The precomputed data must have at least as many scales as the config
/// requests.
#[cfg_attr(not(feature = "training"), allow(dead_code))]
pub fn precompute_reference_with_scales(
    source: &[[u8; 3]],
    width: usize,
    height: usize,
    num_scales: usize,
) -> Result<crate::streaming::PrecomputedReference, ZensimError> {
    if width < 8 || height < 8 {
        return Err(ZensimError::ImageTooSmall);
    }
    if source.len() != width * height {
        return Err(ZensimError::InvalidDataLength);
    }
    let src_img = crate::source::RgbSlice::new(source, width, height);
    Ok(crate::streaming::PrecomputedReference::new(
        &src_img, num_scales, true,
    ))
}

/// Compute zensim with a precomputed reference and custom configuration.
///
/// Training/research variant. The `config.num_scales`
/// must not exceed the number of scales in `precomputed`.
#[cfg(feature = "training")]
pub fn compute_zensim_with_ref_and_config(
    precomputed: &crate::streaming::PrecomputedReference,
    distorted: &[[u8; 3]],
    width: usize,
    height: usize,
    config: ZensimConfig,
) -> Result<ZensimResult, ZensimError> {
    if width < 8 || height < 8 {
        return Err(ZensimError::ImageTooSmall);
    }
    if distorted.len() != width * height {
        return Err(ZensimError::InvalidDataLength);
    }
    let dst_img = crate::source::RgbSlice::new(distorted, width, height);
    let result = crate::streaming::compute_zensim_streaming_with_ref(
        precomputed,
        &dst_img,
        &config,
        WEIGHTS,
    );
    Ok(result)
}

/// Per-scale statistics collected during computation.
#[derive(Default)]
pub(crate) struct ScaleStats {
    /// SSIM statistics: [mean_d, root4_d] per channel = 6 values
    pub(crate) ssim: [f64; 6],
    /// Edge features: [art_mean, art_4th, det_mean, det_4th] per channel = 12 values
    pub(crate) edge: [f64; 12],
    /// Per-channel MSE: mean((src - dst)²) for X, Y, B
    pub(crate) mse: [f64; 3],
    /// High-frequency energy loss (L2): max(0, 1 - Σ(dst-mu_dst)²/Σ(src-mu_src)²) per channel.
    /// Measures loss of local detail energy relative to source. Sensitive to blur/smoothing.
    pub(crate) hf_energy_loss: [f64; 3],
    /// High-frequency magnitude loss (L1): max(0, 1 - Σ|dst-mu_dst|/Σ|src-mu_src|) per channel.
    /// Like hf_energy_loss but with L1 norm — more robust to outliers.
    pub(crate) hf_mag_loss: [f64; 3],
    /// 2nd-power pooled SSIM: [root2_d] per channel = 3 values
    pub(crate) ssim_2nd: [f64; 3],
    /// Edge 2nd power: [art_2nd, det_2nd] per channel = 6 values
    pub(crate) edge_2nd: [f64; 6],
    /// High-frequency energy gain (L2): max(0, Σ(dst-mu_dst)²/Σ(src-mu_src)² - 1) per channel.
    /// Measures added local detail energy (ringing, sharpening artifacts).
    pub(crate) hf_energy_gain: [f64; 3],
    // --- Extended features (only populated when extended_features=true) ---
    /// Masked SSIM: [mean, 4th, 2nd] per channel = 9 values
    pub(crate) masked_ssim: [f64; 9],
    /// Masked edge artifact L4 per channel = 3 values
    pub(crate) masked_art_4th: [f64; 3],
    /// Masked edge detail_lost L4 per channel = 3 values
    pub(crate) masked_det_4th: [f64; 3],
    /// Masked MSE per channel = 3 values
    pub(crate) masked_mse: [f64; 3],
    /// Max SSIM error per channel = 3 values
    pub(crate) ssim_max: [f64; 3],
    /// Max edge artifact per channel = 3 values
    pub(crate) art_max: [f64; 3],
    /// Max edge detail_lost per channel = 3 values
    pub(crate) det_max: [f64; 3],
    /// L8 power pool SSIM error per channel = 3 values: (Σd⁸/N)^(1/8)
    pub(crate) ssim_p95: [f64; 3],
    /// L8 power pool edge artifact per channel = 3 values: (Σd⁸/N)^(1/8)
    pub(crate) art_p95: [f64; 3],
    /// L8 power pool edge detail_lost per channel = 3 values: (Σd⁸/N)^(1/8)
    pub(crate) det_p95: [f64; 3],
}

/// Result from a zensim comparison.
///
/// Contains the final score, the raw distance used to derive it, and the
/// full per-scale feature vector (useful for diagnostics or weight training).
#[derive(Debug, Clone)]
pub struct ZensimResult {
    score: f64,
    raw_distance: f64,
    features: Vec<f64>,
    profile: crate::profile::ZensimProfile,
    mean_offset: [f64; 3],
}

impl ZensimResult {
    /// Create a result from computed values. Internal use only.
    pub(crate) fn new(
        score: f64,
        raw_distance: f64,
        features: Vec<f64>,
        profile: crate::profile::ZensimProfile,
        mean_offset: [f64; 3],
    ) -> Self {
        Self {
            score,
            raw_distance,
            features,
            profile,
            mean_offset,
        }
    }

    /// Set the profile on this result (builder pattern). Internal use only.
    pub(crate) fn with_profile(mut self, profile: crate::profile::ZensimProfile) -> Self {
        self.profile = profile;
        self
    }

    /// Create a NaN sentinel result (for error/placeholder paths).
    pub fn nan() -> Self {
        Self {
            score: f64::NAN,
            raw_distance: f64::NAN,
            features: vec![],
            profile: crate::profile::ZensimProfile::PreviewV0_1,
            mean_offset: [f64::NAN; 3],
        }
    }

    /// Quality score on a 0–100 scale. 100 = identical, 0 = maximally different.
    /// Derived from `raw_distance` via a power-law mapping.
    pub fn score(&self) -> f64 {
        self.score
    }

    /// Raw weighted feature distance before nonlinear mapping. Lower = more similar.
    /// Not bounded to a fixed range; depends on image content and weights.
    pub fn raw_distance(&self) -> f64 {
        self.raw_distance
    }

    /// Per-scale raw features as a slice.
    ///
    /// Layout: 4 scales × 3 channels (X, Y, B) × 19 features per channel = 228.
    /// See [`FeatureView`] for named access.
    pub fn features(&self) -> &[f64] {
        &self.features
    }

    /// Consume the result and return the owned feature vector.
    pub fn into_features(self) -> Vec<f64> {
        self.features
    }

    /// Which profile produced this score.
    pub fn profile(&self) -> crate::profile::ZensimProfile {
        self.profile
    }

    /// Per-channel XYB mean offset: `mean(src_xyb[c]) - mean(dst_xyb[c])`.
    ///
    /// Captures global color/luminance shifts (CMS errors, white balance changes).
    /// Channels: `[X, Y, B]`, signed. Positive = distorted is darker/less saturated.
    pub fn mean_offset(&self) -> [f64; 3] {
        self.mean_offset
    }

    /// Convert the score to a dissimilarity value.
    ///
    /// Dissimilarity is `(100 - score) / 100`: 0 = identical, higher = worse.
    /// This is the inverse of the 0–100 score scale, normalized to 0–1.
    ///
    /// See also [`score_to_dissimilarity`] for the standalone conversion.
    pub fn dissimilarity(&self) -> f64 {
        score_to_dissimilarity(self.score)
    }

    /// Approximate SSIMULACRA2 score from the raw distance.
    ///
    /// Direct power-law fit: `100 - 19.04 × d^0.598`, calibrated on 344k
    /// synthetic pairs. MAE: 4.4 SSIM2 points, Pearson r = 0.974.
    ///
    /// More accurate than `mapping::zensim_to_ssim2(score)` (MAE 4.9, r = 0.932)
    /// because it skips the intermediate score mapping.
    pub fn approx_ssim2(&self) -> f64 {
        if self.raw_distance <= 0.0 {
            return 100.0;
        }
        (100.0 - 19.0379 * self.raw_distance.powf(0.5979)).max(-100.0)
    }

    /// Approximate DSSIM value from the raw distance.
    ///
    /// Direct power-law fit: `0.000922 × d^1.224`, calibrated on 344k
    /// synthetic pairs. MAE: 0.00129, Pearson r = 0.952.
    ///
    /// Significantly more accurate than `mapping::zensim_to_dssim(score)`
    /// (MAE 0.00213, r = 0.719) because DSSIM's natural exponent (1.22)
    /// differs from the score mapping exponent (0.70).
    pub fn approx_dssim(&self) -> f64 {
        if self.raw_distance <= 0.0 {
            return 0.0;
        }
        0.000922 * self.raw_distance.powf(1.2244)
    }

    /// Approximate butteraugli distance from the raw distance.
    ///
    /// Direct power-law fit: `2.365 × d^0.613`, calibrated on 344k
    /// synthetic pairs. MAE: 1.65 distance units, Pearson r = 0.713.
    ///
    /// Butteraugli's weak correlation with our features (r = 0.71) limits
    /// approximation accuracy regardless of mapping choice.
    pub fn approx_butteraugli(&self) -> f64 {
        if self.raw_distance <= 0.0 {
            return 0.0;
        }
        2.365353 * self.raw_distance.powf(0.6130)
    }
}

/// Convert a zensim score (0–100, 100 = identical) to a dissimilarity value
/// (0 = identical, higher = worse).
///
/// Linear conversion: `(100 - score) / 100`.
///
/// | score | dissimilarity |
/// |-------|---------------|
/// | 100.0 | 0.0           |
/// | 99.5  | 0.005         |
/// | 95.0  | 0.05          |
/// | 50.0  | 0.5           |
/// | 0.0   | 1.0           |
pub fn score_to_dissimilarity(score: f64) -> f64 {
    (100.0 - score) / 100.0
}

/// Convert a dissimilarity value (0 = identical, higher = worse) back to a
/// zensim score (0–100, 100 = identical).
///
/// Inverse of [`score_to_dissimilarity`]: `score = 100 * (1 - dissimilarity)`.
pub fn dissimilarity_to_score(dissimilarity: f64) -> f64 {
    (100.0 * (1.0 - dissimilarity)).clamp(0.0, 100.0)
}

/// What kind of perceptual difference dominates between source and distorted.
///
/// Only categories with provably defensible statistical signatures are offered.
/// If no category can be identified with high confidence, `Unclassified` is returned.
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorCategory {
    /// Images are perceptually identical (score ≈ 100).
    Identical,
    /// Max delta ≤ N/255 — integer rounding, LUT precision, truncation.
    RoundingError,
    /// One channel zero-delta, others large — RGB↔BGR swap.
    ChannelSwap,
    /// Alpha compositing error (e.g. straight/premul confusion, wrong background).
    AlphaCompositing,
    /// Images differ but no category reached sufficient confidence.
    Unclassified,
}

/// Decomposed error classification for a source/distorted pair.
///
/// `dominant` is the category with the highest confidence (or `Identical`
/// if the overall score is ≈ 100).
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ErrorClassification {
    /// The dominant error category.
    pub dominant: ErrorCategory,
    /// Overall confidence in the classification (0.0–1.0).
    pub confidence: f64,
    /// Rounding bias analysis (only populated when `dominant == RoundingError`).
    ///
    /// Measures how balanced the rounding errors are across positive and negative
    /// directions. `None` when not a rounding error or insufficient data.
    pub rounding_bias: Option<RoundingBias>,
}

/// Analysis of whether rounding errors are balanced (+/-) or systematic.
///
/// A balanced distribution (roughly equal +1 and -1 counts) indicates normal
/// rounding mode differences — nothing to worry about. A heavily skewed
/// distribution (mostly one direction) suggests systematic truncation or
/// a floor/ceil bias that may indicate a pipeline bug.
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RoundingBias {
    /// Per-channel ratio of positive-to-total differing pixels.
    ///
    /// 0.5 = perfectly balanced, 0.0 = all negative, 1.0 = all positive.
    /// Channels: `[R, G, B]`.
    pub positive_fraction: [f64; 3],
    /// Whether the rounding appears balanced (within statistical norms).
    ///
    /// `true` means the +/- distribution is consistent with unbiased rounding
    /// and is likely nothing to worry about. `false` means systematic bias
    /// was detected (e.g., all errors in one direction = truncation).
    pub balanced: bool,
}

/// Pixel-level delta analysis for error classification.
///
/// All deltas are `src - dst` (positive = distorted is darker/lower).
/// Values normalized to [0.0, 1.0] regardless of input bit depth.
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct DeltaStats {
    // --- Per-channel [R, G, B] summary stats ---
    /// Mean delta (signed). Positive = dst darker.
    pub mean_delta: [f64; 3],
    /// Standard deviation of delta.
    pub stddev_delta: [f64; 3],
    /// Maximum |delta|.
    pub max_abs_delta: [f64; 3],

    // --- Signed small-delta histogram ---
    /// Per-channel pixel counts for signed deltas -3 to +3 (in 1/native_max units).
    ///
    /// Index mapping: `[0]`=−3, `[1]`=−2, `[2]`=−1, `[3]`=0, `[4]`=+1, `[5]`=+2, `[6]`=+3.
    /// Delta convention: `src - dst`, so +1 means dst is 1 LSB lower than src.
    /// Only counts pixels whose per-channel delta falls in \[−3, +3\]; pixels
    /// outside this range are not tracked here.
    pub signed_small_histogram: [[u64; 7]; 3],

    /// Maximum representable value for the native pixel format.
    ///
    /// 255.0 for u8 formats, 65535.0 for u16, 1.0 for f32/f16.
    /// Used to interpret delta magnitudes at native precision.
    pub native_max: f64,

    // --- Pixel counts ---
    /// Total pixels compared.
    pub pixel_count: u64,
    /// Pixels where any channel differs.
    pub pixels_differing: u64,
    /// Pixels where any channel |delta| > 1/255.
    pub pixels_differing_by_more_than_1: u64,

    // --- Alpha channel ---
    /// Whether the input format has an alpha channel.
    pub has_alpha: bool,
    /// Max |src_alpha - dst_alpha| in 0-255 units. 0 for RGB-only formats.
    pub alpha_max_delta: u8,
    /// Pixels where alpha differs at all. 0 for RGB-only formats.
    pub alpha_pixels_differing: u64,

    // --- Per-channel value histograms (256 bins, quantized to 8-bit) ---
    /// Source image histogram. `[channel][value]`. R=0, G=1, B=2, A=3.
    pub src_histogram: [[u64; 256]; 4],
    /// Distorted image histogram. `[channel][value]`. R=0, G=1, B=2, A=3.
    pub dst_histogram: [[u64; 256]; 4],

    // --- Alpha-stratified stats (only for RGBA/BGRA inputs) ---
    /// Delta stats for fully opaque pixels (A = max).
    pub opaque_stats: Option<AlphaStratifiedStats>,
    /// Delta stats for semitransparent pixels (0 < A < max).
    pub semitransparent_stats: Option<AlphaStratifiedStats>,
    /// Pearson correlation between |delta| and (1 - alpha).
    /// High (> 0.8) = compositing/premul error. None if no alpha.
    pub alpha_error_correlation: Option<f64>,
}

/// Stats for a subset of pixels grouped by alpha.
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct AlphaStratifiedStats {
    /// Number of pixels in this stratum.
    pub pixel_count: u64,
    /// Mean |delta| per channel in this alpha stratum.
    pub mean_abs_delta: [f64; 3],
    /// Max |delta| per channel.
    pub max_abs_delta: [f64; 3],
}

/// Result from `classify()`: the zensim score plus delta analysis and error classification.
#[cfg(feature = "classification")]
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ClassifiedResult {
    /// The standard zensim result (score, features, etc.).
    pub result: ZensimResult,
    /// Error classification with per-category confidence scores.
    pub classification: ErrorClassification,
    /// Pixel-level delta statistics.
    pub delta_stats: DeltaStats,
}

// --- Zensim config struct (primary API) ---

use crate::profile::{ProfileParams, ZensimProfile};
use crate::source::ImageSource;

/// Metric configuration. Methods on this struct are the primary API.
///
/// ```no_run
/// use zensim::{Zensim, ZensimProfile, RgbSlice};
/// # let (src, dst) = (vec![[0u8; 3]; 64], vec![[0u8; 3]; 64]);
/// let z = Zensim::new(ZensimProfile::latest());
/// let source = RgbSlice::new(&src, 8, 8);
/// let distorted = RgbSlice::new(&dst, 8, 8);
/// let result = z.compute(&source, &distorted).unwrap();
/// println!("{}: {:.2}", result.profile(), result.score());
/// ```
#[derive(Clone, Debug)]
pub struct Zensim {
    profile: ZensimProfile,
    parallel: bool,
}

impl Zensim {
    /// Create a new `Zensim` with the given profile. Parallel by default.
    pub fn new(profile: ZensimProfile) -> Self {
        Self {
            profile,
            parallel: true,
        }
    }

    /// Enable or disable multi-threaded computation (rayon).
    /// Default: `true`.
    pub fn with_parallel(mut self, parallel: bool) -> Self {
        self.parallel = parallel;
        self
    }

    /// Current profile.
    pub fn profile(&self) -> ZensimProfile {
        self.profile
    }

    /// Whether multi-threaded computation is enabled.
    pub fn parallel(&self) -> bool {
        self.parallel
    }

    /// Compare source and distorted images.
    ///
    /// # Errors
    ///
    /// Returns [`ZensimError`] if dimensions are mismatched or too small.
    pub fn compute(
        &self,
        source: &impl ImageSource,
        distorted: &impl ImageSource,
    ) -> Result<ZensimResult, ZensimError> {
        let params = self.profile.params();
        validate_pair(source, distorted)?;
        let config = config_from_params(params, self.parallel);
        let result = compute_with_config_inner(source, distorted, &config, params.weights);
        Ok(result.with_profile(self.profile))
    }

    /// Pre-compute reference image data for batch comparison.
    ///
    /// # Errors
    ///
    /// Returns [`ZensimError::ImageTooSmall`] if dimensions < 8×8.
    pub fn precompute_reference(
        &self,
        source: &impl ImageSource,
    ) -> Result<crate::streaming::PrecomputedReference, ZensimError> {
        let params = self.profile.params();
        if source.width() < 8 || source.height() < 8 {
            return Err(ZensimError::ImageTooSmall);
        }
        Ok(crate::streaming::PrecomputedReference::new(
            source,
            params.num_scales,
            self.parallel,
        ))
    }

    /// Compare a distorted image against a precomputed reference.
    ///
    /// # Errors
    ///
    /// Returns [`ZensimError::ImageTooSmall`] if dimensions < 8×8.
    pub fn compute_with_ref(
        &self,
        precomputed: &crate::streaming::PrecomputedReference,
        distorted: &impl ImageSource,
    ) -> Result<ZensimResult, ZensimError> {
        let params = self.profile.params();
        if distorted.width() < 8 || distorted.height() < 8 {
            return Err(ZensimError::ImageTooSmall);
        }
        let config = config_from_params(params, self.parallel);
        let result = crate::streaming::compute_zensim_streaming_with_ref(
            precomputed,
            distorted,
            &config,
            params.weights,
        );
        Ok(result.with_profile(self.profile))
    }

    /// Precompute reference from planar linear RGB f32 data.
    ///
    /// `planes` are `[R, G, B]`, each with at least `stride * height` elements.
    /// `stride` is the number of f32 elements per row (≥ `width`; may be larger
    /// for padded buffers like the encoder's `padded_width`).
    ///
    /// This avoids the interleave-to-RGBA overhead when the caller already has
    /// separate channel buffers in linear light.
    ///
    /// # Errors
    ///
    /// Returns [`ZensimError::ImageTooSmall`] if dimensions < 8×8.
    pub fn precompute_reference_linear_planar(
        &self,
        planes: [&[f32]; 3],
        width: usize,
        height: usize,
        stride: usize,
    ) -> Result<crate::streaming::PrecomputedReference, ZensimError> {
        let params = self.profile.params();
        if width < 8 || height < 8 {
            return Err(ZensimError::ImageTooSmall);
        }
        Ok(crate::streaming::PrecomputedReference::from_linear_planar(
            planes,
            width,
            height,
            stride,
            params.num_scales,
            self.parallel,
        ))
    }

    /// Like `compute`, but always computes all features regardless of
    /// zero weights (forces every channel active). For training/research.
    #[cfg(feature = "training")]
    pub fn compute_all_features(
        &self,
        source: &impl ImageSource,
        distorted: &impl ImageSource,
    ) -> Result<ZensimResult, ZensimError> {
        let params = self.profile.params();
        validate_pair(source, distorted)?;
        let mut config = config_from_params(params, self.parallel);
        config.compute_all_features = true;
        let result = compute_with_config_inner(source, distorted, &config, params.weights);
        Ok(result.with_profile(self.profile))
    }
}

#[cfg(feature = "classification")]
impl Zensim {
    /// Compare source and distorted images with full error classification.
    ///
    /// Returns a [`ClassifiedResult`] containing the standard zensim score,
    /// pixel-level delta statistics, and error type classification.
    ///
    /// The `result.score()` is identical to what `compute()` returns — classification
    /// is a separate analysis pass that doesn't affect the score.
    ///
    /// # Errors
    ///
    /// Returns [`ZensimError`] if dimensions are mismatched or too small.
    pub fn classify(
        &self,
        source: &impl ImageSource,
        distorted: &impl ImageSource,
    ) -> Result<ClassifiedResult, ZensimError> {
        validate_pair(source, distorted)?;

        // Compute delta stats (pixel-level analysis in sRGB space)
        let delta_stats = crate::streaming::compute_delta_stats(source, distorted);

        // Compute the standard zensim score
        let result = self.compute(source, distorted)?;

        // Derive classification from delta stats and zensim features
        let classification = derive_classification(&delta_stats, &result);

        Ok(ClassifiedResult {
            result,
            classification,
            delta_stats,
        })
    }
}

#[cfg(feature = "training")]
impl Zensim {
    /// Compute with explicit custom params (for training).
    pub fn compute_with_params(
        params: &ProfileParams,
        source: &impl ImageSource,
        distorted: &impl ImageSource,
    ) -> Result<ZensimResult, ZensimError> {
        validate_pair(source, distorted)?;
        let config = config_from_params(params, true);
        let result = compute_with_config_inner(source, distorted, &config, params.weights);
        Ok(result)
    }
}

/// Derive error classification from pixel-level delta statistics.
///
/// Uses only 3 provable detectors with mathematically defensible signatures:
/// 1. **RoundingError** — max delta ≤ 3/255, based on `pixels_differing_by_more_than_1`
/// 2. **ChannelSwap** — one zero-delta channel with large deltas in others
/// 3. **AlphaCompositing** — opaque unchanged, semitransparent changed (tightened)
///
/// No `Mixed` category — highest score wins, or `Unclassified`.
#[cfg(feature = "classification")]
fn derive_classification(delta_stats: &DeltaStats, _result: &ZensimResult) -> ErrorClassification {
    let mut rounding_bias: Option<RoundingBias> = None;

    // Track per-detector scores internally
    let mut score_rounding = 0.0f64;
    let mut score_swap = 0.0f64;
    let mut score_alpha = 0.0f64;

    // If images are identical, short circuit
    if delta_stats.pixels_differing == 0 {
        return ErrorClassification {
            dominant: ErrorCategory::Identical,
            confidence: 1.0,
            rounding_bias: None,
        };
    }

    let max_delta = delta_stats
        .max_abs_delta
        .iter()
        .copied()
        .fold(0.0f64, f64::max);

    // === 1. Rounding error: based on max_delta + pixels_differing_by_more_than_1 ===
    //
    // If no pixel in any channel exceeds 1/255 delta, this is provably off-by-1.
    // The only operations that produce max_delta ≤ 3/255 are: integer rounding
    // mode differences, sRGB LUT precision, float→int truncation.
    if delta_stats.pixels_differing_by_more_than_1 == 0 {
        score_rounding = 1.0;
    } else if max_delta <= 2.0 / 255.0 {
        score_rounding = 0.95;
    } else if max_delta <= 3.0 / 255.0 {
        score_rounding = 0.9;
    }

    // === 2. Channel swap: one zero-delta channel, others large ===
    //
    // The only way to get one channel with zero delta and others with large
    // deltas is a channel swap. No other operation produces this pattern.
    let mut zero_channels = 0u32;
    let mut hot_channels = 0u32;
    for ch in 0..3 {
        if delta_stats.max_abs_delta[ch] < 1.0 / 255.0 {
            zero_channels += 1;
        }
        if delta_stats.max_abs_delta[ch] > 0.1 {
            hot_channels += 1;
        }
    }
    if zero_channels == 1 && hot_channels >= 1 && max_delta > 0.05 {
        score_swap = 0.9;
    }

    // === 3. Alpha compositing: tightened thresholds ===
    //
    // Stratification: opaque pixels unchanged, semitransparent changed.
    // Tightened from 0.01→0.005 opaque threshold, 0.7→0.8 correlation threshold.
    if let Some(ref opaque) = delta_stats.opaque_stats
        && let Some(ref semi) = delta_stats.semitransparent_stats
    {
        let opaque_max = opaque.mean_abs_delta.iter().copied().fold(0.0f64, f64::max);
        let semi_mean = semi.mean_abs_delta.iter().copied().fold(0.0f64, f64::max);
        if opaque_max < 0.005 && semi_mean > 0.02 && semi.pixel_count > 100 {
            score_alpha = 0.9;
        }
    }
    if let Some(corr) = delta_stats.alpha_error_correlation
        && corr > 0.8
    {
        score_alpha = score_alpha.max(corr);
    }

    // === Determine dominant category ===
    // Highest score wins. No Mixed category.
    let scores = [
        (ErrorCategory::RoundingError, score_rounding),
        (ErrorCategory::ChannelSwap, score_swap),
        (ErrorCategory::AlphaCompositing, score_alpha),
    ];

    let (best_cat, best_score) = scores
        .iter()
        .copied()
        .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
        .unwrap();

    let (dominant, confidence) = if best_score > 0.0 {
        (best_cat, best_score)
    } else {
        (ErrorCategory::Unclassified, 0.0)
    };

    // === Rounding bias analysis ===
    // When RoundingError is detected, analyze the signed small-delta histogram
    // to determine if errors are balanced (+/-) or systematic (one direction).
    if dominant == ErrorCategory::RoundingError {
        rounding_bias = Some(compute_rounding_bias(delta_stats));
    }

    ErrorClassification {
        dominant,
        confidence,
        rounding_bias,
    }
}

/// Compute rounding bias from the signed small-delta histogram.
///
/// Examines the +1/-1, +2/-2, +3/-3 bins per channel to determine whether
/// errors are balanced (unbiased rounding) or systematic (truncation/floor).
#[cfg(feature = "classification")]
fn compute_rounding_bias(delta_stats: &DeltaStats) -> RoundingBias {
    let h = &delta_stats.signed_small_histogram;
    let mut positive_fraction = [0.5f64; 3];
    let mut all_balanced = true;

    for ch in 0..3 {
        // Count positive deltas (+1, +2, +3) and negative deltas (-1, -2, -3)
        let neg = h[ch][0] + h[ch][1] + h[ch][2]; // bins -3, -2, -1
        let pos = h[ch][4] + h[ch][5] + h[ch][6]; // bins +1, +2, +3
        let total_nonzero = neg + pos;

        if total_nonzero == 0 {
            // No differing pixels in this channel — perfectly balanced
            positive_fraction[ch] = 0.5;
            continue;
        }

        positive_fraction[ch] = pos as f64 / total_nonzero as f64;

        // Statistical test: for balanced rounding, we'd expect ~50% positive.
        // With N trials and p=0.5, the standard deviation is sqrt(N)/2.
        // Use a 3-sigma threshold: if |pos_frac - 0.5| > 3 * 0.5 / sqrt(N),
        // consider it unbalanced. But also require a minimum absolute skew
        // (at least 60/40 split) to avoid flagging trivially small deviations
        // in large samples.
        let n = total_nonzero as f64;
        let expected_std = 0.5 / n.sqrt();
        let deviation = (positive_fraction[ch] - 0.5).abs();
        if deviation > 3.0 * expected_std && deviation > 0.1 {
            all_balanced = false;
        }
    }

    RoundingBias {
        positive_fraction,
        balanced: all_balanced,
    }
}

pub(crate) fn validate_pair(
    source: &impl ImageSource,
    distorted: &impl ImageSource,
) -> Result<(), ZensimError> {
    if source.width() < 8 || source.height() < 8 {
        return Err(ZensimError::ImageTooSmall);
    }
    if source.width() != distorted.width() || source.height() != distorted.height() {
        return Err(ZensimError::DimensionMismatch);
    }
    Ok(())
}

/// Check if source and distorted images have byte-identical pixel data
/// and matching color interpretation (format + primaries).
fn images_byte_identical(source: &impl ImageSource, distorted: &impl ImageSource) -> bool {
    use crate::source::{AlphaMode, PixelFormat};

    let (w, h) = (source.width(), source.height());
    if w != distorted.width() || h != distorted.height() {
        return false;
    }
    if source.pixel_format() != distorted.pixel_format() {
        return false;
    }
    // Different primaries mean different perceptual colors even with identical bytes.
    if source.color_primaries() != distorted.color_primaries() {
        return false;
    }
    let fmt = source.pixel_format();
    let bpp = fmt.bytes_per_pixel();
    let row_len = w * bpp;

    // For RGBA formats with non-opaque alpha: pixels where both have A=0
    // composite to the same background, so they're visually identical
    // regardless of their RGB values.
    let alpha_aware = fmt.has_alpha()
        && !matches!(source.alpha_mode(), AlphaMode::Opaque)
        && !matches!(distorted.alpha_mode(), AlphaMode::Opaque);

    for y in 0..h {
        let sr = source.row_bytes(y);
        let dr = distorted.row_bytes(y);
        if sr[..row_len] == dr[..row_len] {
            continue; // fast path: row is byte-identical
        }
        if !alpha_aware {
            return false;
        }
        // Slow path: check pixel-by-pixel, skipping A=0 pairs
        match fmt {
            PixelFormat::Srgb8Rgba | PixelFormat::Srgb8Bgra => {
                for x in 0..w {
                    let o = x * 4;
                    if sr[o + 3] == 0 && dr[o + 3] == 0 {
                        continue;
                    }
                    if sr[o..o + 4] != dr[o..o + 4] {
                        return false;
                    }
                }
            }
            PixelFormat::Srgb16Rgba => {
                for x in 0..w {
                    let o = x * 8;
                    let sa = u16::from_ne_bytes([sr[o + 6], sr[o + 7]]);
                    let da = u16::from_ne_bytes([dr[o + 6], dr[o + 7]]);
                    if sa == 0 && da == 0 {
                        continue;
                    }
                    if sr[o..o + 8] != dr[o..o + 8] {
                        return false;
                    }
                }
            }
            PixelFormat::LinearF32Rgba => {
                for x in 0..w {
                    let o = x * 16;
                    let sa = f32::from_ne_bytes([sr[o + 12], sr[o + 13], sr[o + 14], sr[o + 15]]);
                    let da = f32::from_ne_bytes([dr[o + 12], dr[o + 13], dr[o + 14], dr[o + 15]]);
                    if sa <= 0.0 && da <= 0.0 {
                        continue;
                    }
                    if sr[o..o + 16] != dr[o..o + 16] {
                        return false;
                    }
                }
            }
            _ => return false,
        }
    }
    true
}

fn compute_with_config_inner(
    source: &impl ImageSource,
    distorted: &impl ImageSource,
    config: &ZensimConfig,
    weights: &[f64],
) -> ZensimResult {
    // Identical images must score exactly 100.0 — short-circuit before
    // floating-point arithmetic introduces sub-ULP noise in SSIM/edge features.
    if images_byte_identical(source, distorted) {
        let fpc = if config.extended_features {
            FEATURES_PER_CHANNEL_EXTENDED
        } else {
            FEATURES_PER_CHANNEL_WITH_PEAKS
        };
        let num_features = config.num_scales * 3 * fpc;
        return ZensimResult::new(
            100.0,
            0.0,
            vec![0.0; num_features],
            ZensimProfile::latest(),
            [0.0; 3],
        );
    }

    crate::streaming::compute_zensim_streaming(source, distorted, config, weights)
}

pub(crate) fn config_from_params(params: &ProfileParams, parallel: bool) -> ZensimConfig {
    ZensimConfig {
        blur_radius: params.blur_radius,
        blur_passes: params.blur_passes,
        blur_kernel: BlurKernel::Box {
            passes: params.blur_passes,
        },
        downscale_filter: DownscaleFilter::default(),
        compute_all_features: false,
        extended_features: false,
        extended_masking_strength: 4.0,
        num_scales: params.num_scales,
        score_mapping_a: params.score_mapping_a,
        score_mapping_b: params.score_mapping_b,
        allow_multithreading: parallel,
    }
}

/// Features per channel per scale: 19 features always emitted.
///
/// ```text
///  Index  Name             Pooling  Source
///  ─────  ───────────────  ───────  ──────────────────
///   0     ssim_mean        mean     SSIM error map
///   1     ssim_4th         L4       SSIM error map
///   2     ssim_2nd         L2       SSIM error map
///   3     art_mean         mean     edge artifact (ringing)
///   4     art_4th          L4       edge artifact
///   5     art_2nd          L2       edge artifact
///   6     det_mean         mean     edge detail lost (blur)
///   7     det_4th          L4       edge detail lost
///   8     det_2nd          L2       edge detail lost
///   9     mse              mean     (src - dst)²
///  10     hf_energy_loss   ratio    1 - Σ(dst-mu)²/Σ(src-mu)²
///  11     hf_mag_loss      ratio    1 - Σ|dst-mu|/Σ|src-mu|
///  12     hf_energy_gain   ratio    Σ(dst-mu)²/Σ(src-mu)² - 1
///  13     ssim_max         max      per-pixel SSIM error
///  14     art_max          max      per-pixel edge artifact
///  15     det_max          max      per-pixel edge detail_lost
///  16     ssim_l8          L8       (Σd⁸/N)^(1/8) SSIM error
///  17     art_l8           L8       (Σd⁸/N)^(1/8) edge artifact
///  18     det_l8           L8       (Σd⁸/N)^(1/8) edge detail_lost
/// ```
///
/// Total features = `num_scales × 3 channels × 13` = 156 at 4 scales.
///
/// Note: 6 additional "peak" features (max/l8) are always computed
/// but only included when `compute_all_features` is true. This keeps
/// the default feature vector compatible with existing profiles.
pub const FEATURES_PER_CHANNEL_BASIC: usize = 13;

/// Features per channel when `compute_all_features` is true: 19 features
/// (13 basic + 6 peak/l8). Peak features are always computed (near-zero cost)
/// but excluded from the default feature vector for profile compatibility.
pub const FEATURES_PER_CHANNEL_WITH_PEAKS: usize = 19;

/// Extended features per channel per scale: 25 features (19 with peaks + 6 masked).
///
/// ```text
///  Index  Name               Pooling  Source
///  ─────  ─────────────────  ───────  ──────────────────
///  0–12   (same as basic 13)
///  13–18  (same as peak features: max/l8)
///  19     masked_ssim_mean   mean     SSIM × flatness mask
///  20     masked_ssim_4th    L4       SSIM × flatness mask
///  21     masked_ssim_2nd    L2       SSIM × flatness mask
///  22     masked_art_4th     L4       edge artifact × flatness mask
///  23     masked_det_4th     L4       edge detail_lost × flatness mask
///  24     masked_mse         mean     (src-dst)² × flatness mask
/// ```
///
/// Total features = `num_scales × 3 channels × 25` = 300 at 4 scales.
pub const FEATURES_PER_CHANNEL_EXTENDED: usize = 25;

/// Named view over a flat feature vector.
///
/// Provides ergonomic access to features by name, scale, and channel
/// without changing the underlying storage format.
///
/// ```ignore
/// let result = z.compute_all_features(&src, &dst)?;
/// let view = FeatureView::new(result.features(), 4)?;
/// let ssim_mean_s0_y = view.ssim_mean(0, 1);
/// let ssim_max_s2_x = view.ssim_max(0, 2).unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct FeatureView<'a> {
    features: &'a [f64],
    n_scales: usize,
    /// Number of features in the scored block
    scored_total: usize,
    /// Number of features in the peaks block (0 if not present)
    peaks_total: usize,
}

/// XYB channel index: X (red-green chrominance).
#[cfg(feature = "training")]
pub const CH_X: usize = 0;
/// XYB channel index: Y (luminance).
#[cfg(feature = "training")]
pub const CH_Y: usize = 1;
/// XYB channel index: B (blue-yellow chrominance).
#[cfg(feature = "training")]
pub const CH_B: usize = 2;

impl<'a> FeatureView<'a> {
    /// Create a view over a feature vector.
    ///
    /// Automatically detects the tier (peaks/extended) from length.
    /// Returns `None` if the length doesn't match any valid layout.
    /// Peaks are always present (basic-only 156-element vectors are no longer generated).
    pub fn new(features: &'a [f64], n_scales: usize) -> Option<Self> {
        let basic_total = n_scales * 3 * FEATURES_PER_CHANNEL_BASIC;
        let peaks_total = n_scales * 3 * 6;
        let masked_total = n_scales * 3 * 6;

        let (scored_total, peaks_total) = if features.len() == basic_total {
            // Legacy basic-only layout (backward compat)
            (basic_total, 0)
        } else if features.len() == basic_total + peaks_total
            || features.len() == basic_total + peaks_total + masked_total
        {
            (basic_total, peaks_total)
        } else {
            return None;
        };

        Some(Self {
            features,
            n_scales,
            scored_total,
            peaks_total,
        })
    }

    /// Number of scales in this feature vector.
    pub fn n_scales(&self) -> usize {
        self.n_scales
    }

    /// Whether peak features (max/L8) are present.
    pub fn has_peaks(&self) -> bool {
        self.peaks_total > 0
    }

    /// Whether masked features are present.
    pub fn has_masked(&self) -> bool {
        self.features.len() > self.scored_total + self.peaks_total
    }

    // --- Scored features (always present) ---

    fn scored_idx(&self, scale: usize, ch: usize, offset: usize) -> usize {
        scale * 3 * FEATURES_PER_CHANNEL_BASIC + ch * FEATURES_PER_CHANNEL_BASIC + offset
    }

    /// SSIM error, mean pooling.
    pub fn ssim_mean(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 0)]
    }
    /// SSIM error, L4 norm.
    pub fn ssim_4th(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 1)]
    }
    /// SSIM error, L2 norm.
    pub fn ssim_2nd(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 2)]
    }
    /// Edge artifact (ringing), mean pooling.
    pub fn art_mean(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 3)]
    }
    /// Edge artifact, L4 norm.
    pub fn art_4th(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 4)]
    }
    /// Edge artifact, L2 norm.
    pub fn art_2nd(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 5)]
    }
    /// Edge detail lost (blur), mean pooling.
    pub fn det_mean(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 6)]
    }
    /// Edge detail lost, L4 norm.
    pub fn det_4th(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 7)]
    }
    /// Edge detail lost, L2 norm.
    pub fn det_2nd(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 8)]
    }
    /// Mean squared error.
    pub fn mse(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 9)]
    }
    /// High-frequency energy loss ratio.
    pub fn hf_energy_loss(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 10)]
    }
    /// High-frequency magnitude loss ratio.
    pub fn hf_mag_loss(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 11)]
    }
    /// High-frequency energy gain ratio.
    pub fn hf_energy_gain(&self, scale: usize, ch: usize) -> f64 {
        self.features[self.scored_idx(scale, ch, 12)]
    }

    // --- Peak features (always present) ---

    fn peak_idx(&self, scale: usize, ch: usize, offset: usize) -> Option<usize> {
        if self.peaks_total == 0 {
            return None;
        }
        Some(self.scored_total + scale * 3 * 6 + ch * 6 + offset)
    }

    /// SSIM error, pixel-wise max.
    pub fn ssim_max(&self, scale: usize, ch: usize) -> Option<f64> {
        self.peak_idx(scale, ch, 0).map(|i| self.features[i])
    }
    /// Edge artifact, pixel-wise max.
    pub fn art_max(&self, scale: usize, ch: usize) -> Option<f64> {
        self.peak_idx(scale, ch, 1).map(|i| self.features[i])
    }
    /// Edge detail lost, pixel-wise max.
    pub fn det_max(&self, scale: usize, ch: usize) -> Option<f64> {
        self.peak_idx(scale, ch, 2).map(|i| self.features[i])
    }
    /// SSIM error, L8 norm `(Σd⁸/N)^(1/8)`.
    pub fn ssim_l8(&self, scale: usize, ch: usize) -> Option<f64> {
        self.peak_idx(scale, ch, 3).map(|i| self.features[i])
    }
    /// Edge artifact, L8 norm.
    pub fn art_l8(&self, scale: usize, ch: usize) -> Option<f64> {
        self.peak_idx(scale, ch, 4).map(|i| self.features[i])
    }
    /// Edge detail lost, L8 norm.
    pub fn det_l8(&self, scale: usize, ch: usize) -> Option<f64> {
        self.peak_idx(scale, ch, 5).map(|i| self.features[i])
    }

    // --- Masked features (require extended_features) ---

    fn masked_idx(&self, scale: usize, ch: usize, offset: usize) -> Option<usize> {
        if !self.has_masked() {
            return None;
        }
        Some(self.scored_total + self.peaks_total + scale * 3 * 6 + ch * 6 + offset)
    }

    /// Masked SSIM error, mean pooling.
    pub fn masked_ssim_mean(&self, scale: usize, ch: usize) -> Option<f64> {
        self.masked_idx(scale, ch, 0).map(|i| self.features[i])
    }
    /// Masked SSIM error, L4 norm.
    pub fn masked_ssim_4th(&self, scale: usize, ch: usize) -> Option<f64> {
        self.masked_idx(scale, ch, 1).map(|i| self.features[i])
    }
    /// Masked SSIM error, L2 norm.
    pub fn masked_ssim_2nd(&self, scale: usize, ch: usize) -> Option<f64> {
        self.masked_idx(scale, ch, 2).map(|i| self.features[i])
    }
    /// Masked edge artifact, L4 norm.
    pub fn masked_art_4th(&self, scale: usize, ch: usize) -> Option<f64> {
        self.masked_idx(scale, ch, 3).map(|i| self.features[i])
    }
    /// Masked edge detail lost, L4 norm.
    pub fn masked_det_4th(&self, scale: usize, ch: usize) -> Option<f64> {
        self.masked_idx(scale, ch, 4).map(|i| self.features[i])
    }
    /// Masked MSE.
    pub fn masked_mse(&self, scale: usize, ch: usize) -> Option<f64> {
        self.masked_idx(scale, ch, 5).map(|i| self.features[i])
    }

    /// Get the scored features slice (first N features, WEIGHTS-compatible).
    pub fn scored_features(&self) -> &[f64] {
        &self.features[..self.scored_total]
    }

    /// Get the peak features slice, if present.
    pub fn peak_features(&self) -> Option<&[f64]> {
        if self.peaks_total == 0 {
            None
        } else {
            Some(&self.features[self.scored_total..self.scored_total + self.peaks_total])
        }
    }

    /// Get the masked features slice, if present.
    pub fn masked_features(&self) -> Option<&[f64]> {
        if !self.has_masked() {
            None
        } else {
            Some(&self.features[self.scored_total + self.peaks_total..])
        }
    }
}

/// Compute zensim with custom configuration (training API).
///
/// Uses the v0.2 weights (latest general-purpose profile).
#[cfg(any(feature = "training", test))]
pub fn compute_zensim_with_config(
    source: &[[u8; 3]],
    distorted: &[[u8; 3]],
    width: usize,
    height: usize,
    config: ZensimConfig,
) -> Result<ZensimResult, ZensimError> {
    // Validation
    if width < 8 || height < 8 {
        return Err(ZensimError::ImageTooSmall);
    }
    if source.len() != width * height {
        return Err(ZensimError::InvalidDataLength);
    }
    if distorted.len() != width * height {
        return Err(ZensimError::InvalidDataLength);
    }
    if source.len() != distorted.len() {
        return Err(ZensimError::DimensionMismatch);
    }

    // Identical images must score exactly 100.0 — short-circuit before
    // floating-point arithmetic introduces sub-ULP noise in SSIM/edge features.
    if source == distorted {
        let fpc = if config.extended_features {
            FEATURES_PER_CHANNEL_EXTENDED
        } else {
            FEATURES_PER_CHANNEL_WITH_PEAKS
        };
        let num_features = config.num_scales * 3 * fpc;
        return Ok(ZensimResult::new(
            100.0,
            0.0,
            vec![0.0; num_features],
            ZensimProfile::latest(),
            [0.0; 3],
        ));
    }

    let src_img = crate::source::RgbSlice::new(source, width, height);
    let dst_img = crate::source::RgbSlice::new(distorted, width, height);

    let result = crate::streaming::compute_zensim_streaming(&src_img, &dst_img, &config, WEIGHTS);
    Ok(result)
}

/// Combine per-scale statistics into a final score.
///
/// Uses learned weights that balance:
/// - Per-channel sensitivity (Y > X > B, matching human vision)
/// - Per-scale importance (medium scales most important)
/// - SSIM vs edge features
/// - Mean vs 4th-power pooling
///
/// Weights are trained against synthetic quality scores (see `weights/` directory).
/// Features per scale for the default scoring profile (3 channels × 13 features = 39).
#[cfg_attr(not(feature = "training"), allow(dead_code))]
pub const FEATURES_PER_SCALE: usize = FEATURES_PER_CHANNEL_WITH_PEAKS * 3;

/// Default scoring weights — references the latest profile weights.
///
/// Layout: 4 scales × 3 channels (X,Y,B) × 13 basic features, then
///         4 scales × 3 channels × 6 peak features = 228 total.
#[cfg(any(feature = "training", test))]
pub const WEIGHTS: &[f64; 228] = &crate::profile::WEIGHTS_PREVIEW_V0_2;

pub(crate) fn combine_scores(
    scale_stats: &[ScaleStats],
    weights: &[f64],
    config: &ZensimConfig,
    mean_offset: [f64; 3],
) -> ZensimResult {
    let extended = config.extended_features;

    // Feature vector layout:
    //   [0..N_basic)        — 13/ch × 3ch × n_scales (basic features)
    //   [N_basic..N_peaks)  — 6/ch × 3ch × n_scales peak features (always included)
    //   [N_peaks..N_all)    — 6/ch × 3ch × n_scales masked features (if extended)
    //
    // Both basic and peak features are scored: features[0..WEIGHTS.len()]
    // produces the dot product used for the final score.
    let n_scales = scale_stats.len();
    let basic_per_ch = FEATURES_PER_CHANNEL_BASIC; // 13
    let basic_total = n_scales * basic_per_ch * 3;
    let peak_total = n_scales * 6 * 3;
    let masked_total = if extended { n_scales * 6 * 3 } else { 0 };
    let total = basic_total + peak_total + masked_total;

    let mut features = Vec::with_capacity(total);
    let mut raw_distance = 0.0f64;

    // Pass 1: scored features (13/ch, weight-compatible order)
    for ss in scale_stats.iter() {
        for c in 0..3 {
            features.push(ss.ssim[c * 2].abs());
            features.push(ss.ssim[c * 2 + 1].abs());
            features.push(ss.ssim_2nd[c].abs());
            features.push(ss.edge[c * 4].abs());
            features.push(ss.edge[c * 4 + 1].abs());
            features.push(ss.edge_2nd[c * 2].abs());
            features.push(ss.edge[c * 4 + 2].abs());
            features.push(ss.edge[c * 4 + 3].abs());
            features.push(ss.edge_2nd[c * 2 + 1].abs());
            features.push(ss.mse[c]);
            features.push(ss.hf_energy_loss[c]);
            features.push(ss.hf_mag_loss[c]);
            features.push(ss.hf_energy_gain[c]);
        }
    }

    // Pass 2: peak features (6/ch — max + L8, always computed at near-zero cost)
    for ss in scale_stats.iter() {
        for c in 0..3 {
            features.push(ss.ssim_max[c]);
            features.push(ss.art_max[c]);
            features.push(ss.det_max[c]);
            features.push(ss.ssim_p95[c]);
            features.push(ss.art_p95[c]);
            features.push(ss.det_p95[c]);
        }
    }

    // Pass 3: masked features (6/ch — expensive, training only)
    if extended {
        for ss in scale_stats.iter() {
            for c in 0..3 {
                features.push(ss.masked_ssim[c * 3].abs());
                features.push(ss.masked_ssim[c * 3 + 1].abs());
                features.push(ss.masked_ssim[c * 3 + 2].abs());
                features.push(ss.masked_art_4th[c].abs());
                features.push(ss.masked_det_4th[c].abs());
                features.push(ss.masked_mse[c]);
            }
        }
    }

    // Apply weights — basic + peak features are scored
    let scored_total = basic_total + peak_total;
    let n_score = scored_total.min(weights.len());
    for (i, &feat) in features[..n_score].iter().enumerate() {
        raw_distance += feat * weights[i];
    }

    // Normalize by number of scales
    raw_distance /= scale_stats.len().max(1) as f64;

    let score =
        distance_to_score_mapped(raw_distance, config.score_mapping_a, config.score_mapping_b);

    ZensimResult::new(
        score,
        raw_distance,
        features,
        ZensimProfile::PreviewV0_1,
        mean_offset,
    )
}

#[cfg(test)]
mod tests {
    use super::*;

    /// Verify compute_all_features produces same score as default (weight-skipped) path.
    /// This exercises the multi-SSIM channel code path where ssim_chs.len() > 1.
    #[test]
    fn compute_all_matches_default() {
        // Generate a simple test pattern: gradient source, slightly different distorted
        let w = 128;
        let h = 128;
        let n = w * h;
        let mut src = vec![[128u8, 128, 128]; n];
        let mut dst = vec![[128u8, 128, 128]; n];
        for y in 0..h {
            for x in 0..w {
                let r = ((x * 255) / w) as u8;
                let g = ((y * 255) / h) as u8;
                let b = 128;
                src[y * w + x] = [r, g, b];
                // Slight distortion
                dst[y * w + x] = [r.saturating_add(5), g, b.saturating_sub(3)];
            }
        }

        let default_result =
            compute_zensim_with_config(&src, &dst, w, h, ZensimConfig::default()).unwrap();
        let all_result = compute_zensim_with_config(
            &src,
            &dst,
            w,
            h,
            ZensimConfig {
                compute_all_features: true,
                ..Default::default()
            },
        )
        .unwrap();

        // Same score (default weights skip zero-weight channels; compute_all computes them
        // but zero weights still produce same weighted distance)
        assert!(
            (default_result.score - all_result.score).abs() < 0.01,
            "default {} vs all_features {}",
            default_result.score,
            all_result.score,
        );

        // Both default and compute_all now include peak features (228)
        assert_eq!(all_result.features.len(), 228);
        assert_eq!(default_result.features.len(), 228);
        // With compute_all, previously-skipped channels should now have nonzero features
        let all_nonzero = all_result
            .features
            .iter()
            .filter(|f| f.abs() > 1e-12)
            .count();
        let default_nonzero = default_result
            .features
            .iter()
            .filter(|f| f.abs() > 1e-12)
            .count();
        assert!(
            all_nonzero >= default_nonzero,
            "compute_all should have >= features: {} vs {}",
            all_nonzero,
            default_nonzero,
        );
    }

    /// Helper: create a gradient test image pair.
    fn make_gradient_pair(w: usize, h: usize) -> (Vec<[u8; 3]>, Vec<[u8; 3]>) {
        let n = w * h;
        let mut src = vec![[128u8, 128, 128]; n];
        let mut dst = vec![[128u8, 128, 128]; n];
        for y in 0..h {
            for x in 0..w {
                let r = ((x * 255) / w) as u8;
                let g = ((y * 255) / h) as u8;
                let b = 128;
                src[y * w + x] = [r, g, b];
                dst[y * w + x] = [
                    r.saturating_add(10),
                    g.saturating_sub(5),
                    b.saturating_add(3),
                ];
            }
        }
        (src, dst)
    }

    /// Extended features: default config produces same score as non-extended.
    #[test]
    fn extended_features_backward_compat() {
        let (w, h) = (64, 64);
        let (src, dst) = make_gradient_pair(w, h);

        let basic = compute_zensim_with_config(&src, &dst, w, h, ZensimConfig::default()).unwrap();

        let extended = compute_zensim_with_config(
            &src,
            &dst,
            w,
            h,
            ZensimConfig {
                extended_features: false,
                compute_all_features: true,
                ..Default::default()
            },
        )
        .unwrap();

        // Both produce 228 features now (peaks always included)
        assert_eq!(basic.features.len(), 228);
        assert_eq!(extended.features.len(), 228);
        // Score should be the same — compute_all forces all channels active but result is same
        assert!(
            (basic.score - extended.score).abs() < 0.01,
            "basic {} vs compute_all {}",
            basic.score,
            extended.score,
        );
    }

    /// Extended features produce 300 values and all are non-negative.
    #[test]
    fn extended_features_count_and_nonneg() {
        let (w, h) = (64, 64);
        let (src, dst) = make_gradient_pair(w, h);

        let result = compute_zensim_with_config(
            &src,
            &dst,
            w,
            h,
            ZensimConfig {
                extended_features: true,
                compute_all_features: true,
                ..Default::default()
            },
        )
        .unwrap();

        assert_eq!(
            result.features.len(),
            300,
            "Expected 25 × 3 × 4 = 300 features"
        );
        for (i, &f) in result.features.iter().enumerate() {
            assert!(f >= 0.0, "Feature {} is negative: {}", i, f);
        }
    }

    /// ssim_max >= ssim_4th >= ssim_mean ordering.
    #[test]
    fn extended_features_ordering() {
        let (w, h) = (64, 64);
        let (src, dst) = make_gradient_pair(w, h);

        let result = compute_zensim_with_config(
            &src,
            &dst,
            w,
            h,
            ZensimConfig {
                extended_features: true,
                compute_all_features: true,
                ..Default::default()
            },
        )
        .unwrap();

        // Feature layout (block-separated):
        //   [0..156)   scored: 13/ch × 3ch × 4 scales
        //   [156..228)  peaks: 6/ch × 3ch × 4 scales
        //   [228..300) masked: 6/ch × 3ch × 4 scales
        let scored_per_ch = FEATURES_PER_CHANNEL_BASIC; // 13
        let peaks_offset = 4 * scored_per_ch * 3; // 156
        let peaks_per_ch = 6;
        for scale in 0..4 {
            for ch in 0..3 {
                let scored_base = scale * scored_per_ch * 3 + ch * scored_per_ch;
                let peaks_base = peaks_offset + scale * peaks_per_ch * 3 + ch * peaks_per_ch;
                let ssim_mean = result.features[scored_base]; // scored[0]
                let ssim_4th = result.features[scored_base + 1]; // scored[1]
                let ssim_max = result.features[peaks_base]; // peaks[0]
                let ssim_p95 = result.features[peaks_base + 3]; // peaks[3]

                // max >= 4th >= mean (4th is L4 norm, always >= mean for non-negative values)
                assert!(
                    ssim_max >= ssim_4th - 1e-10,
                    "s{} c{}: max {:.6} < 4th {:.6}",
                    scale,
                    ch,
                    ssim_max,
                    ssim_4th,
                );
                assert!(
                    ssim_4th >= ssim_mean - 1e-10,
                    "s{} c{}: 4th {:.6} < mean {:.6}",
                    scale,
                    ch,
                    ssim_4th,
                    ssim_mean,
                );
                // p95 between 4th and max
                assert!(
                    ssim_p95 <= ssim_max + 1e-10,
                    "s{} c{}: p95 {:.6} > max {:.6}",
                    scale,
                    ch,
                    ssim_p95,
                    ssim_max,
                );
            }
        }
    }

    /// Identical images: all features zero.
    #[test]
    fn extended_features_identical_zero() {
        let (w, h) = (64, 64);
        let (src, _) = make_gradient_pair(w, h);

        let result = compute_zensim_with_config(
            &src,
            &src,
            w,
            h,
            ZensimConfig {
                extended_features: true,
                compute_all_features: true,
                ..Default::default()
            },
        )
        .unwrap();

        assert_eq!(result.score, 100.0);
        assert_eq!(result.features.len(), 300);
        for (i, &f) in result.features.iter().enumerate() {
            assert!(
                f.abs() < 1e-10,
                "Feature {} not zero for identical: {}",
                i,
                f
            );
        }
    }

    /// Masked features <= unmasked features (masking reduces).
    #[test]
    fn extended_masked_leq_unmasked() {
        let (w, h) = (64, 64);
        let (src, dst) = make_gradient_pair(w, h);

        let result = compute_zensim_with_config(
            &src,
            &dst,
            w,
            h,
            ZensimConfig {
                extended_features: true,
                compute_all_features: true,
                ..Default::default()
            },
        )
        .unwrap();

        // Feature layout (block-separated):
        //   [0..156)   scored: 13/ch × 3ch × 4 scales
        //   [156..228)  peaks: 6/ch × 3ch × 4 scales
        //   [228..300) masked: 6/ch × 3ch × 4 scales
        let scored_per_ch = FEATURES_PER_CHANNEL_BASIC; // 13
        let masked_offset = 4 * scored_per_ch * 3 + 4 * 6 * 3; // 156 + 72 = 228
        let masked_per_ch = 6;
        for scale in 0..4 {
            for ch in 0..3 {
                let scored_base = scale * scored_per_ch * 3 + ch * scored_per_ch;
                let masked_base = masked_offset + scale * masked_per_ch * 3 + ch * masked_per_ch;
                let ssim_mean = result.features[scored_base]; // scored[0]
                let ssim_4th = result.features[scored_base + 1]; // scored[1]
                let ssim_2nd = result.features[scored_base + 2]; // scored[2]
                let masked_ssim_mean = result.features[masked_base]; // masked[0]
                let masked_ssim_4th = result.features[masked_base + 1]; // masked[1]
                let masked_ssim_2nd = result.features[masked_base + 2]; // masked[2]

                // Masked values should be <= unmasked (mask weights ∈ [0,1])
                assert!(
                    masked_ssim_mean <= ssim_mean + 1e-10,
                    "s{} c{}: masked_mean {:.6} > mean {:.6}",
                    scale,
                    ch,
                    masked_ssim_mean,
                    ssim_mean,
                );
                assert!(
                    masked_ssim_4th <= ssim_4th + 1e-10,
                    "s{} c{}: masked_4th {:.6} > 4th {:.6}",
                    scale,
                    ch,
                    masked_ssim_4th,
                    ssim_4th,
                );
                assert!(
                    masked_ssim_2nd <= ssim_2nd + 1e-10,
                    "s{} c{}: masked_2nd {:.6} > 2nd {:.6}",
                    scale,
                    ch,
                    masked_ssim_2nd,
                    ssim_2nd,
                );
            }
        }
    }
}