tranz 0.5.1

Point-embedding knowledge graph models: TransE, RotatE, ComplEx, DistMult. GPU training via candle.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
//! Training loop for KGE models via candle.
//!
//! Implements:
//! - Negative sampling with configurable corruption strategy
//! - Self-adversarial negative sampling (SANS) weighting
//! - Log-sigmoid loss with margin
//! - N3 regularization (nuclear 3-norm, for ComplEx/DistMult)
//! - AdamW optimizer
//!
//! ## Training protocol
//!
//! For each batch of positive triples `(h, r, t)`:
//! 1. Sample `k` negative triples by corrupting head or tail.
//! 2. Score positives and negatives.
//! 3. Weight negatives by SANS: `p_i = softmax(alpha * score_i)` (detached).
//! 4. Loss = `-log(sigma(gamma - score_pos)) - sum_i p_i * log(sigma(score_neg_i - gamma))`.
//! 5. Optionally add N3 regularization.
//! 6. Backward + optimizer step.

use candle_core::{DType, Device, IndexOp, Result, Tensor, Var, D};
use candle_nn::optim::{AdamW, Optimizer, ParamsAdamW};

/// Optimizer to use for training.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptimizerType {
    /// AdamW (default). Good general-purpose optimizer.
    AdamW,
    /// Adagrad. Better for sparse embedding updates (proven for KGE by Lacroix 2018).
    /// Use with higher learning rate (0.1) and small init (1e-3).
    Adagrad,
}

/// Simple Adagrad optimizer for candle Vars.
struct Adagrad {
    vars: Vec<Var>,
    sum_sq: Vec<Var>,
    lr: f64,
    eps: f64,
}

impl Adagrad {
    fn new(vars: Vec<Var>, lr: f64) -> Result<Self> {
        let sum_sq: Vec<Var> = vars
            .iter()
            .map(|v| Var::zeros(v.shape(), v.dtype(), v.device()))
            .collect::<Result<_>>()?;
        Ok(Self {
            vars,
            sum_sq,
            lr,
            eps: 1e-10,
        })
    }

    fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
        let grads = loss.backward()?;
        for (var, ss) in self.vars.iter().zip(self.sum_sq.iter()) {
            if let Some(grad) = grads.get(var) {
                let new_ss = (ss.as_tensor() + grad.sqr()?)?;
                let adjusted = (grad / (new_ss.sqrt()? + self.eps)?)?;
                var.set(&(var.as_tensor() - (adjusted * self.lr)?)?)?;
                ss.set(&new_ss)?;
            }
        }
        Ok(())
    }

    fn set_learning_rate(&mut self, lr: f64) {
        self.lr = lr;
    }
}

/// Which model architecture to train.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelType {
    /// TransE: `||h + r - t||`.
    TransE,
    /// RotatE: `||h * r - t||` in complex space.
    RotatE,
    /// ComplEx: `Re(h * r * conj(t))`.
    ComplEx,
    /// DistMult: `sum(h * r * t)`.
    DistMult,
}

/// Training configuration.
#[derive(Debug, Clone)]
pub struct TrainConfig {
    /// Model type.
    pub model_type: ModelType,
    /// Optimizer type.
    pub optimizer: OptimizerType,
    /// Embedding dimension (complex dim for RotatE/ComplEx).
    pub dim: usize,
    /// Initialization scale. Embeddings drawn from N(0, init_scale).
    /// Default: 1e-3 (Lacroix 2018). Xavier uses ~0.17 for dim=200.
    pub init_scale: f64,
    /// Number of negative samples per positive (ignored in 1-N mode).
    pub num_negatives: usize,
    /// Use 1-N scoring: score all entities per (h,r) query with BCE loss.
    /// Much faster convergence (5-10x fewer epochs) than negative sampling.
    /// Requires more memory per batch: `batch_size * num_entities * 4` bytes.
    pub one_to_n: bool,
    /// Label smoothing epsilon for 1-N mode. 0 = no smoothing.
    /// Recommended: 0.1. Replaces hard 0/1 targets with (eps, 1-eps).
    pub label_smoothing: f32,
    /// Use multi-hot targets in 1-N mode (KvsAll). If false, uses single-target
    /// CE (1vsAll, as in Lacroix et al. 2018). Default: false (1vsAll).
    pub multi_hot: bool,
    /// Margin gamma for the loss function (used in negative sampling mode).
    pub gamma: f32,
    /// Norm for distance-based models (TransE, RotatE). 1 = L1, 2 = L2.
    /// The RotatE reference implementation uses L1 for TransE.
    pub distance_norm: u32,
    /// Apply subsampling weights based on entity frequency.
    /// Downweights triples involving high-frequency entities.
    /// Helps at convergence but can hurt during early training.
    pub subsampling: bool,
    /// SANS adversarial temperature. 0 = uniform weighting.
    pub adversarial_temperature: f32,
    /// Learning rate.
    pub lr: f64,
    /// Dropout rate on entity/relation embeddings. 0 = no dropout.
    /// Recommended: 0.1-0.2 (Ruffinelli et al. 2020).
    pub embedding_dropout: f32,
    /// N3 regularization coefficient. 0 = disabled.
    pub n3_reg: f32,
    /// L2 regularization coefficient on embeddings. 0 = disabled.
    /// Simpler and more universally effective than N3. Recommended: 1e-5 to 1e-3.
    pub l2_reg: f32,
    /// Batch size.
    pub batch_size: usize,
    /// Number of training epochs.
    pub epochs: usize,
    /// Normalize entity embeddings to unit L2 norm after each step.
    /// Standard for TransE (Bordes et al., 2013). Disabled by default.
    pub normalize_entities: bool,
    /// Linear warmup epochs. LR ramps from 0 to `lr` over this many epochs.
    /// 0 = no warmup.
    pub warmup_epochs: usize,
    /// Cosine annealing LR schedule. If > 0, divides training into this
    /// many cycles with cosine decay from `lr` to `lr * cosine_min_lr_frac`
    /// per cycle. Snapshots are saved at each cycle trough for ensembling
    /// (SnapE). 0 = no cosine annealing (use warmup + constant LR).
    pub cosine_cycles: usize,
    /// Minimum LR as fraction of base LR for cosine annealing.
    /// Default: 0.1 (10% of base LR). Values below 0.05 can cause
    /// late-stage stalling in long training runs.
    pub cosine_min_lr_frac: f64,
    /// Print loss to stderr every N epochs. 0 = silent.
    pub log_interval: usize,
    /// Evaluate on validation set every N epochs. 0 = no validation.
    pub eval_interval: usize,
    /// Stop if validation MRR doesn't improve for this many eval cycles.
    /// Only used when `eval_interval > 0`.
    pub patience: usize,
    /// Save embeddings to this directory every N epochs. None = no checkpoints.
    pub checkpoint_dir: Option<std::path::PathBuf>,
    /// Checkpoint interval in epochs. Only used when `checkpoint_dir` is Some.
    pub checkpoint_interval: usize,
    /// Stochastic Weight Averaging: start epoch. 0 = disabled.
    ///
    /// When > 0, maintains a running average of entity and relation
    /// embeddings starting at this epoch. The averaged model is returned
    /// in `TrainResult.swa_entity_vecs` / `swa_relation_vecs`.
    /// Typical: start at 75% of total epochs. Gives +1-3% MRR.
    pub swa_start_epoch: usize,
    /// Relation prediction auxiliary loss weight (1-N mode only). 0 = disabled.
    ///
    /// Adds a third loss term that predicts the relation given (head, tail).
    /// Weight is relative to the main loss. Recommended: 0.1.
    /// Chen et al. (2021) report +3-6% MRR on FB15k-237.
    pub relation_prediction_weight: f32,
}

impl Default for TrainConfig {
    fn default() -> Self {
        Self {
            model_type: ModelType::TransE,
            optimizer: OptimizerType::AdamW,
            dim: 200,
            init_scale: 1e-3,
            num_negatives: 256,
            one_to_n: false,
            label_smoothing: 0.0,
            multi_hot: false,
            gamma: 12.0,
            distance_norm: 1,
            subsampling: false,
            adversarial_temperature: 1.0,
            lr: 0.001,
            embedding_dropout: 0.0,
            n3_reg: 0.0,
            l2_reg: 0.0,
            batch_size: 512,
            epochs: 1000,
            normalize_entities: false,
            warmup_epochs: 0,
            cosine_cycles: 0,
            cosine_min_lr_frac: 0.1,
            log_interval: 0,
            eval_interval: 0,
            patience: 5,
            checkpoint_dir: None,
            checkpoint_interval: 0,
            swa_start_epoch: 0,
            relation_prediction_weight: 0.0,
        }
    }
}

/// Trained model with candle tensors.
///
/// Holds the embedding `Var`s for gradient-based training, and can
/// extract to CPU `Vec<Vec<f32>>` for evaluation via the `Scorer` trait.
pub struct TrainableModel {
    entity_embeddings: Var,
    relation_embeddings: Var,
    model_type: ModelType,
    dim: usize,
    gamma: f32,
    distance_norm: u32,
    embedding_dropout: f32,
    device: Device,
}

impl TrainableModel {
    /// Initialize a new trainable model.
    pub fn new(
        num_entities: usize,
        num_relations: usize,
        config: &TrainConfig,
        device: &Device,
    ) -> Result<Self> {
        let dim = config.dim;
        let gamma = config.gamma;

        let s = config.init_scale;

        let (entity_embeddings, relation_embeddings) = match config.model_type {
            ModelType::TransE => {
                let ent = Var::randn_f64(0.0, s, (num_entities, dim), DType::F32, device)?;
                let rel = Var::randn_f64(0.0, s, (num_relations, dim), DType::F32, device)?;
                (ent, rel)
            }
            ModelType::RotatE => {
                let ent = Var::randn_f64(0.0, s, (num_entities, dim * 2), DType::F32, device)?;
                // Relations: angles in [-pi, pi].
                let rel = Var::rand_f64(
                    -std::f64::consts::PI,
                    std::f64::consts::PI,
                    (num_relations, dim),
                    DType::F32,
                    device,
                )?;
                (ent, rel)
            }
            ModelType::ComplEx | ModelType::DistMult => {
                let ent_cols = if config.model_type == ModelType::ComplEx {
                    dim * 2
                } else {
                    dim
                };
                let rel_cols = ent_cols;
                let ent = Var::randn_f64(0.0, s, (num_entities, ent_cols), DType::F32, device)?;
                let rel = Var::randn_f64(0.0, s, (num_relations, rel_cols), DType::F32, device)?;
                (ent, rel)
            }
        };

        Ok(Self {
            entity_embeddings,
            relation_embeddings,
            model_type: config.model_type,
            dim,
            gamma,
            distance_norm: config.distance_norm,
            embedding_dropout: config.embedding_dropout,
            device: device.clone(),
        })
    }

    /// Score a batch of triples. Returns tensor of shape `[batch]`.
    ///
    /// For distance-based models (TransE, RotatE): returns distances (lower = more likely).
    /// For similarity-based models (ComplEx, DistMult): returns negative similarities.
    pub fn score_batch(
        &self,
        heads: &Tensor,
        relations: &Tensor,
        tails: &Tensor,
    ) -> Result<Tensor> {
        let mut h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
        let mut r = self
            .relation_embeddings
            .as_tensor()
            .index_select(relations, 0)?;
        let mut t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;

        if self.embedding_dropout > 0.0 {
            h = candle_nn::ops::dropout(&h, self.embedding_dropout)?;
            r = candle_nn::ops::dropout(&r, self.embedding_dropout)?;
            t = candle_nn::ops::dropout(&t, self.embedding_dropout)?;
        }

        match self.model_type {
            ModelType::TransE => {
                let diff = ((h + r)? - t)?;
                match self.distance_norm {
                    1 => diff.abs()?.sum(D::Minus1),
                    _ => diff.sqr()?.sum(D::Minus1)?.sqrt(),
                }
            }
            ModelType::RotatE => {
                // Split into re/im pairs.
                let dim = self.dim;
                let h_re = h.i((.., ..dim))?;
                let h_im = h.i((.., dim..))?;
                let t_re = t.i((.., ..dim))?;
                let t_im = t.i((.., dim..))?;
                // r is angles: compute cos/sin.
                let r_cos = r.cos()?;
                let r_sin = r.sin()?;
                // h * r (complex multiply)
                let hr_re = ((&h_re * &r_cos)? - (&h_im * &r_sin)?)?;
                let hr_im = ((&h_re * &r_sin)? + (&h_im * &r_cos)?)?;
                let d_re = (hr_re - t_re)?;
                let d_im = (hr_im - t_im)?;
                match self.distance_norm {
                    1 => {
                        let dist = (d_re.abs()? + d_im.abs()?)?;
                        dist.sum(D::Minus1)
                    }
                    _ => {
                        let dist_sq = (d_re.sqr()? + d_im.sqr()?)?;
                        dist_sq.sum(D::Minus1)?.sqrt()
                    }
                }
            }
            ModelType::ComplEx => {
                let dim = self.dim;
                let h_re = h.i((.., ..dim))?;
                let h_im = h.i((.., dim..))?;
                let r_re = r.i((.., ..dim))?;
                let r_im = r.i((.., dim..))?;
                let t_re = t.i((.., ..dim))?;
                let t_im = t.i((.., dim..))?;
                // h * r (complex)
                let hr_re = ((&h_re * &r_re)? - (&h_im * &r_im)?)?;
                let hr_im = ((&h_re * &r_im)? + (&h_im * &r_re)?)?;
                // Re(hr * conj(t)) = hr_re * t_re + hr_im * t_im
                let score = ((&hr_re * &t_re)? + (&hr_im * &t_im)?)?;
                // Negate so lower = more likely (distance convention).
                score.sum(D::Minus1)?.neg()
            }
            ModelType::DistMult => {
                let score = ((&h * &r)? * &t)?;
                score.sum(D::Minus1)?.neg()
            }
        }
    }

    /// Score all entities as tails for a batch of (h, r) queries.
    ///
    /// Returns tensor of shape `[batch, num_entities]`.
    /// For dot-product models (DistMult, ComplEx), uses matmul.
    /// For distance models (TransE), uses the squared-distance-via-GEMM trick.
    pub fn score_1n(&self, heads: &Tensor, relations: &Tensor) -> Result<Tensor> {
        let h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
        let r = self
            .relation_embeddings
            .as_tensor()
            .index_select(relations, 0)?;
        let ent_matrix = self.entity_embeddings.as_tensor(); // [E, dim]

        match self.model_type {
            ModelType::TransE => {
                // -||h+r-t||^2 = -(||h+r||^2 - 2*(h+r)@E^T + ||E||^2)
                // We want lower = more likely, so return positive distance.
                // But for BCE, we need higher = more likely. Return negative distance.
                let hr = (h + r)?; // [B, dim]
                let hr_sq = hr.sqr()?.sum(D::Minus1)?; // [B]
                let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?; // [E]
                let cross = hr.matmul(&ent_matrix.t()?)?; // [B, E]
                                                          // dist^2 = hr_sq - 2*cross + ent_sq
                let dist_sq = (hr_sq
                    .unsqueeze(D::Minus1)?
                    .broadcast_add(&ent_sq.unsqueeze(0)?)?
                    - (cross * 2.0)?)?;
                // Return negative distance (higher = more likely for BCE)
                dist_sq.neg()
            }
            ModelType::DistMult => {
                // score = sum(h * r * t) = (h*r) @ E^T
                let hr = (h * r)?; // [B, dim]
                hr.matmul(&ent_matrix.t()?) // [B, E], higher = more likely
            }
            ModelType::ComplEx => {
                let dim = self.dim;
                let h_re = h.i((.., ..dim))?;
                let h_im = h.i((.., dim..))?;
                let r_re = r.i((.., ..dim))?;
                let r_im = r.i((.., dim..))?;
                let hr_re = ((&h_re * &r_re)? - (&h_im * &r_im)?)?;
                let hr_im = ((&h_re * &r_im)? + (&h_im * &r_re)?)?;
                let e_re = ent_matrix.i((.., ..dim))?.contiguous()?;
                let e_im = ent_matrix.i((.., dim..))?.contiguous()?;
                // Re(hr * conj(e)) = hr_re @ e_re^T + hr_im @ e_im^T
                let score = (hr_re.matmul(&e_re.t()?)? + hr_im.matmul(&e_im.t()?)?)?;
                Ok(score) // higher = more likely
            }
            ModelType::RotatE => {
                // RotatE isn't a dot product, so 1-N via GEMM isn't straightforward.
                // Fall back to per-entity scoring.
                // TODO: implement distance-via-GEMM for complex rotation
                let dim = self.dim;
                let h_re = h.i((.., ..dim))?;
                let h_im = h.i((.., dim..))?;
                let r_cos = r.cos()?;
                let r_sin = r.sin()?;
                let hr_re = ((&h_re * &r_cos)? - (&h_im * &r_sin)?)?;
                let hr_im = ((&h_re * &r_sin)? + (&h_im * &r_cos)?)?;
                // Concatenate [hr_re, hr_im] -> [B, 2*dim]
                let hr = Tensor::cat(&[&hr_re, &hr_im], D::Minus1)?;
                // Same GEMM trick as TransE but on concatenated complex vectors
                let hr_sq = hr.sqr()?.sum(D::Minus1)?;
                let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?;
                let cross = hr.matmul(&ent_matrix.t()?)?;
                let dist_sq = (hr_sq
                    .unsqueeze(D::Minus1)?
                    .broadcast_add(&ent_sq.unsqueeze(0)?)?
                    - (cross * 2.0)?)?;
                dist_sq.neg()
            }
        }
    }

    /// Score all entities as heads for a batch of (r, t) queries.
    ///
    /// Returns tensor of shape `[batch, num_entities]`.
    /// Higher = more likely for all models (similarity convention).
    pub fn score_1n_heads(&self, relations: &Tensor, tails: &Tensor) -> Result<Tensor> {
        let r = self
            .relation_embeddings
            .as_tensor()
            .index_select(relations, 0)?;
        let t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;
        let ent_matrix = self.entity_embeddings.as_tensor();

        match self.model_type {
            ModelType::TransE => {
                // score(h, r, t) ~ -||h - (t - r)||^2 for head prediction
                let tr = (t - r)?; // target for h
                let tr_sq = tr.sqr()?.sum(D::Minus1)?;
                let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?;
                let cross = tr.matmul(&ent_matrix.t()?)?;
                let dist_sq = (ent_sq
                    .unsqueeze(0)?
                    .broadcast_add(&tr_sq.unsqueeze(D::Minus1)?)?
                    - (cross * 2.0)?)?;
                dist_sq.neg()
            }
            ModelType::DistMult => {
                // score = sum(h * r * t) = (r*t) @ E^T (same as tail since symmetric)
                let rt = (r * t)?;
                rt.matmul(&ent_matrix.t()?)
            }
            ModelType::ComplEx => {
                let dim = self.dim;
                // For head prediction: Re(h * r * conj(t)) = Re(h * (r * conj(t)))
                // conj(t) = (t_re, -t_im)
                // r * conj(t) = (r_re*t_re + r_im*t_im, r_im*t_re - r_re*t_im)
                let r_re = r.i((.., ..dim))?;
                let r_im = r.i((.., dim..))?;
                let t_re = t.i((.., ..dim))?;
                let t_im = t.i((.., dim..))?;
                let rc_re = ((&r_re * &t_re)? + (&r_im * &t_im)?)?;
                let rc_im = ((&r_im * &t_re)? - (&r_re * &t_im)?)?;
                let e_re = ent_matrix.i((.., ..dim))?.contiguous()?;
                let e_im = ent_matrix.i((.., dim..))?.contiguous()?;
                // Re(h * rc) = h_re @ rc_re^T ... wait, we need h @ rc not rc @ h
                // Actually: Re(h * rc) = h_re*rc_re - h_im*rc_im
                // As matmul: e_re @ rc_re^T + e_im @ (-rc_im)^T
                // But we want [B, E] where E iterates over heads (entities).
                // score[b, e] = Re(e * rc[b]) = e_re @ rc_re[b]^T + e_im @ rc_im[b]^T
                // = rc_re[b] @ e_re^T + rc_im[b] @ e_im^T  (same as tail prediction with rc)
                let score = (rc_re.matmul(&e_re.t()?)? + rc_im.matmul(&e_im.t()?)?)?;
                Ok(score)
            }
            ModelType::RotatE => {
                // For head prediction with rotation: h = t * conj(r)
                // Similar GEMM trick. Use t*conj(r) as the query.
                let dim = self.dim;
                let t_re = t.i((.., ..dim))?;
                let t_im = t.i((.., dim..))?;
                let r_cos = r.cos()?;
                let r_sin = r.sin()?;
                // t * conj(r) = t * (cos, -sin)
                let tr_re = ((&t_re * &r_cos)? + (&t_im * &r_sin)?)?;
                let tr_im = ((&t_im * &r_cos)? - (&t_re * &r_sin)?)?;
                let tr = Tensor::cat(&[&tr_re, &tr_im], D::Minus1)?;
                let tr_sq = tr.sqr()?.sum(D::Minus1)?;
                let ent_sq = ent_matrix.sqr()?.sum(D::Minus1)?;
                let cross = tr.matmul(&ent_matrix.t()?)?;
                let dist_sq = (ent_sq
                    .unsqueeze(0)?
                    .broadcast_add(&tr_sq.unsqueeze(D::Minus1)?)?
                    - (cross * 2.0)?)?;
                dist_sq.neg()
            }
        }
    }

    /// Score all relations for a batch of (h, t) queries.
    ///
    /// Returns tensor of shape `[batch, num_relations]`.
    /// Higher = more likely for all models (similarity convention).
    pub fn score_1n_relations(
        &self,
        heads: &Tensor,
        tails: &Tensor,
        num_relations: usize,
    ) -> Result<Tensor> {
        let h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
        let t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;
        let rel_matrix = self.relation_embeddings.as_tensor(); // [R, dim]

        match self.model_type {
            ModelType::DistMult => {
                // score = sum(h * r * t) = (h*t) @ R^T
                let ht = (h * t)?;
                ht.matmul(&rel_matrix.t()?)
            }
            ModelType::ComplEx => {
                let dim = self.dim;
                let h_re = h.i((.., ..dim))?;
                let h_im = h.i((.., dim..))?;
                let t_re = t.i((.., ..dim))?;
                let t_im = t.i((.., dim..))?;
                // Re(h * conj(t)) components for matching against r
                let ht_re = ((&h_re * &t_re)? + (&h_im * &t_im)?)?;
                let ht_im = ((&h_im * &t_re)? - (&h_re * &t_im)?)?;
                let r_re = rel_matrix.i((.., ..dim))?.contiguous()?;
                let r_im = rel_matrix.i((.., dim..))?.contiguous()?;
                let score = (ht_re.matmul(&r_re.t()?)? + ht_im.matmul(&r_im.t()?)?)?;
                Ok(score)
            }
            ModelType::TransE | ModelType::RotatE => {
                // For distance models, score each relation individually.
                // Less efficient but relation prediction is auxiliary and rare.
                let batch_size = h.dim(0)?;
                let mut scores = Vec::with_capacity(num_relations);
                for r_idx in 0..num_relations {
                    let r_ids = Tensor::full(r_idx as u32, batch_size, &self.device)?;
                    let s = self.score_batch(heads, &r_ids, tails)?;
                    scores.push(s.neg()?); // negate: higher = more likely
                }
                Tensor::stack(&scores, 1)
            }
        }
    }

    /// Compute N3 regularization: `||h||_3^3 + ||r||_3^3 + ||t||_3^3`.
    fn n3_penalty(&self, heads: &Tensor, relations: &Tensor, tails: &Tensor) -> Result<Tensor> {
        let h = self.entity_embeddings.as_tensor().index_select(heads, 0)?;
        let r = self
            .relation_embeddings
            .as_tensor()
            .index_select(relations, 0)?;
        let t = self.entity_embeddings.as_tensor().index_select(tails, 0)?;

        // For ComplEx: compute moduli sqrt(re^2 + im^2) per dim, then cube.
        // For real models: just |x|^3.
        let cube_norm = |x: &Tensor, is_complex: bool, dim: usize| -> Result<Tensor> {
            if is_complex {
                let re = x.i((.., ..dim))?;
                let im = x.i((.., dim..))?;
                let moduli = (re.sqr()? + im.sqr()?)?.sqrt()?;
                moduli
                    .powf(3.0)?
                    .sum_all()?
                    .affine(1.0 / x.dim(0)? as f64, 0.0)
            } else {
                x.abs()?.powf(3.0)?.mean_all()
            }
        };
        let is_cx = self.model_type == ModelType::ComplEx;
        let dim = self.dim;
        let penalty =
            (cube_norm(&h, is_cx, dim)? + cube_norm(&r, is_cx, dim)? + cube_norm(&t, is_cx, dim)?)?;
        Ok(penalty)
    }

    /// Model type.
    pub fn model_type(&self) -> ModelType {
        self.model_type
    }

    /// Embedding dimension.
    pub fn dim(&self) -> usize {
        self.dim
    }

    /// Access the raw entity embedding tensor.
    pub fn entity_embeddings(&self) -> &Tensor {
        self.entity_embeddings.as_tensor()
    }

    /// Access the raw relation embedding tensor.
    pub fn relation_embeddings(&self) -> &Tensor {
        self.relation_embeddings.as_tensor()
    }

    /// Extract entity embeddings as `Vec<Vec<f32>>`.
    pub fn entity_vecs(&self) -> Result<Vec<Vec<f32>>> {
        tensor_to_vecs(self.entity_embeddings.as_tensor())
    }

    /// Extract relation embeddings as `Vec<Vec<f32>>`.
    pub fn relation_vecs(&self) -> Result<Vec<Vec<f32>>> {
        tensor_to_vecs(self.relation_embeddings.as_tensor())
    }

    /// Convert to a CPU-based TransE model for evaluation.
    pub fn to_transe(&self) -> Result<crate::TransE> {
        Ok(crate::TransE::from_vecs_with_norm(
            self.entity_vecs()?,
            self.relation_vecs()?,
            self.dim,
            self.distance_norm,
        ))
    }

    /// Convert to a CPU-based RotatE model for evaluation.
    pub fn to_rotate(&self) -> Result<crate::RotatE> {
        Ok(crate::RotatE::from_vecs(
            self.entity_vecs()?,
            self.relation_vecs()?,
            self.dim,
            self.gamma,
        ))
    }

    /// Convert to a CPU-based ComplEx model for evaluation.
    pub fn to_complex(&self) -> Result<crate::ComplEx> {
        Ok(crate::ComplEx::from_vecs(
            self.entity_vecs()?,
            self.relation_vecs()?,
            self.dim,
        ))
    }

    /// Convert to a CPU-based DistMult model for evaluation.
    pub fn to_distmult(&self) -> Result<crate::DistMult> {
        Ok(crate::DistMult::from_vecs(
            self.entity_vecs()?,
            self.relation_vecs()?,
            self.dim,
        ))
    }
}

/// A snapshot of entity and relation embeddings (for ensembling).
pub struct Snapshot {
    /// Entity embeddings.
    pub entity_vecs: Vec<Vec<f32>>,
    /// Relation embeddings.
    pub relation_vecs: Vec<Vec<f32>>,
    /// Epoch at which this snapshot was taken.
    pub epoch: usize,
}

/// Training outcome with model, loss history, timing, and optional snapshots.
pub struct TrainResult {
    /// The trained model.
    pub model: TrainableModel,
    /// Loss per epoch (averaged over batches).
    pub losses: Vec<f32>,
    /// Seconds per epoch.
    pub epoch_times: Vec<f32>,
    /// Snapshots taken at cosine annealing cycle troughs (for SnapE ensembling).
    pub snapshots: Vec<Snapshot>,
    /// SWA-averaged entity embeddings (if `swa_start_epoch > 0`).
    pub swa_entity_vecs: Option<Vec<Vec<f32>>>,
    /// SWA-averaged relation embeddings (if `swa_start_epoch > 0`).
    pub swa_relation_vecs: Option<Vec<Vec<f32>>>,
}

/// Validation data for early stopping.
pub struct ValidationData<'a> {
    /// Validation triples to evaluate.
    pub valid_triples: &'a [crate::dataset::TripleIds],
    /// Pre-built filter index for filtered evaluation.
    pub filter: &'a crate::dataset::FilterIndex,
}

/// Compute the learning rate for a given epoch based on the schedule.
///
/// Implements: linear warmup, then either constant LR or cyclic cosine
/// annealing (SnapE). Returns `config.lr` when no schedule is active.
pub fn learning_rate(epoch: usize, config: &TrainConfig) -> f64 {
    let base_lr = config.lr;
    if config.warmup_epochs > 0 && epoch < config.warmup_epochs {
        base_lr * (epoch + 1) as f64 / config.warmup_epochs as f64
    } else if config.cosine_cycles > 0 {
        let effective_epoch = epoch.saturating_sub(config.warmup_epochs);
        let total_effective = config.epochs.saturating_sub(config.warmup_epochs);
        let epochs_per_cycle = total_effective / config.cosine_cycles;
        if epochs_per_cycle > 0 {
            let cycle_pos = effective_epoch % epochs_per_cycle;
            let t = cycle_pos as f64 / epochs_per_cycle as f64;
            let min_lr = base_lr * config.cosine_min_lr_frac;
            min_lr + 0.5 * (base_lr - min_lr) * (1.0 + (t * std::f64::consts::PI).cos())
        } else {
            base_lr
        }
    } else {
        base_lr
    }
}

/// Train a KGE model on the given triples.
///
/// `train_triples` is a slice of `(head, relation, tail)` ID triples.
/// `num_entities` and `num_relations` define the vocabulary size.
///
/// If `validation` is provided and `config.eval_interval > 0`, evaluates
/// on the validation set periodically and stops early if MRR doesn't
/// improve for `config.patience` evaluation cycles.
///
/// Returns the trained model and per-epoch loss history.
pub fn train(
    train_triples: &[crate::dataset::TripleIds],
    num_entities: usize,
    num_relations: usize,
    config: &TrainConfig,
    device: &Device,
) -> Result<TrainResult> {
    train_with_validation(
        train_triples,
        num_entities,
        num_relations,
        config,
        device,
        None,
    )
}

/// Train with optional validation-based early stopping.
pub fn train_with_validation(
    train_triples: &[crate::dataset::TripleIds],
    num_entities: usize,
    num_relations: usize,
    config: &TrainConfig,
    device: &Device,
    validation: Option<ValidationData<'_>>,
) -> Result<TrainResult> {
    let model = TrainableModel::new(num_entities, num_relations, config, device)?;
    let vars = vec![
        model.entity_embeddings.clone(),
        model.relation_embeddings.clone(),
    ];

    enum Opt {
        Adam(AdamW),
        Adagrad(self::Adagrad),
    }
    impl Opt {
        fn backward_step(&mut self, loss: &Tensor) -> Result<()> {
            match self {
                Opt::Adam(o) => o.backward_step(loss),
                Opt::Adagrad(o) => o.backward_step(loss),
            }
        }
        fn set_learning_rate(&mut self, lr: f64) {
            match self {
                Opt::Adam(o) => o.set_learning_rate(lr),
                Opt::Adagrad(o) => o.set_learning_rate(lr),
            }
        }
    }

    let mut optimizer = match config.optimizer {
        OptimizerType::AdamW => Opt::Adam(AdamW::new(
            vars,
            ParamsAdamW {
                lr: config.lr,
                weight_decay: 0.0,
                ..ParamsAdamW::default()
            },
        )?),
        OptimizerType::Adagrad => Opt::Adagrad(self::Adagrad::new(vars, config.lr)?),
    };

    let n_triples = train_triples.len();
    let batch_size = config.batch_size.min(n_triples);
    let gamma = config.gamma;
    let alpha = config.adversarial_temperature;
    let n3_coeff = config.n3_reg;

    // Precompute known tails/heads for multi-hot 1-N labels.
    // known_tails[(h, r)] = list of known tail entity IDs.
    // known_heads[(r, t)] = list of known head entity IDs.
    let (known_tails, known_heads) = if config.one_to_n {
        let mut kt: std::collections::HashMap<(usize, usize), Vec<usize>> =
            std::collections::HashMap::new();
        let mut kh: std::collections::HashMap<(usize, usize), Vec<usize>> =
            std::collections::HashMap::new();
        for triple in train_triples {
            kt.entry((triple.head, triple.relation))
                .or_default()
                .push(triple.tail);
            kh.entry((triple.relation, triple.tail))
                .or_default()
                .push(triple.head);
        }
        (Some(kt), Some(kh))
    } else {
        (None, None)
    };

    // Preallocate target buffers for multi-hot 1-N mode (outside epoch loop).
    let mut tail_target_buf = if config.one_to_n && config.multi_hot {
        vec![0.0_f32; batch_size * num_entities]
    } else {
        Vec::new()
    };
    let mut head_target_buf = if config.one_to_n && config.multi_hot {
        vec![0.0_f32; batch_size * num_entities]
    } else {
        Vec::new()
    };

    // Precompute entity frequency for optional subsampling weights.
    let entity_freq = if config.subsampling {
        let mut freq = vec![0u32; num_entities];
        for triple in train_triples {
            freq[triple.head] += 1;
            freq[triple.tail] += 1;
        }
        Some(freq)
    } else {
        None
    };

    let mut losses = Vec::with_capacity(config.epochs);
    let mut epoch_times = Vec::with_capacity(config.epochs);
    let mut snapshots = Vec::new();
    let mut shuffled: Vec<crate::dataset::TripleIds> = train_triples.to_vec();
    let mut best_mrr = f32::NEG_INFINITY;
    let mut patience_counter = 0_usize;
    let mut best_entity_vecs: Option<Vec<Vec<f32>>> = None;
    let mut best_relation_vecs: Option<Vec<Vec<f32>>> = None;

    // SWA running average (flat f32 buffers for efficiency).
    let swa_active = config.swa_start_epoch > 0;
    let mut swa_ent: Option<Vec<f32>> = None;
    let mut swa_rel: Option<Vec<f32>> = None;
    let mut swa_count = 0u64;

    for _epoch in 0..config.epochs {
        let lr = learning_rate(_epoch, config);
        optimizer.set_learning_rate(lr);

        let mut epoch_loss = 0.0_f64;
        let mut n_batches = 0u32;
        let epoch_start = std::time::Instant::now();

        // Shuffle triples each epoch.
        {
            use rand::seq::SliceRandom;
            shuffled.shuffle(&mut rand::rng());
        }

        let mut offset = 0;
        while offset < n_triples {
            let end = (offset + batch_size).min(n_triples);
            let batch = &shuffled[offset..end];
            let actual_bs = batch.len();
            offset = end;

            let heads_data: Vec<u32> = batch.iter().map(|t| t.head as u32).collect();
            let rels_data: Vec<u32> = batch.iter().map(|t| t.relation as u32).collect();
            let tails_data: Vec<u32> = batch.iter().map(|t| t.tail as u32).collect();

            let heads = Tensor::from_vec(heads_data, actual_bs, &model.device)?;
            let rels = Tensor::from_vec(rels_data, actual_bs, &model.device)?;
            let tails = Tensor::from_vec(tails_data, actual_bs, &model.device)?;

            let mut loss = if config.one_to_n {
                let eps = config.label_smoothing as f64;

                // Tail prediction: score all entities as tails for (h, r, ?).
                let tail_scores = model.score_1n(&heads, &rels)?;
                let tail_log_probs = candle_nn::ops::log_softmax(&tail_scores, D::Minus1)?;

                // Head prediction: score all entities as heads for (?, r, t).
                let head_scores = model.score_1n_heads(&rels, &tails)?;
                let head_log_probs = candle_nn::ops::log_softmax(&head_scores, D::Minus1)?;

                let (tail_nll, head_nll) = if config.multi_hot {
                    // KvsAll: multi-hot targets (all known tails/heads).
                    let kt = known_tails.as_ref().unwrap();
                    let tgt = &mut tail_target_buf[..actual_bs * num_entities];
                    tgt.fill(0.0);
                    for (i, triple) in batch.iter().enumerate() {
                        let tails = kt.get(&(triple.head, triple.relation)).unwrap();
                        let w = 1.0 / tails.len() as f32;
                        for &t in tails {
                            tgt[i * num_entities + t] = w;
                        }
                    }
                    let tail_t = Tensor::from_slice(tgt, (actual_bs, num_entities), &model.device)?;
                    let t_nll = (&tail_t * &tail_log_probs)?
                        .sum_all()?
                        .neg()?
                        .affine(1.0 / actual_bs as f64, 0.0)?;

                    let kh = known_heads.as_ref().unwrap();
                    let htgt = &mut head_target_buf[..actual_bs * num_entities];
                    htgt.fill(0.0);
                    for (i, triple) in batch.iter().enumerate() {
                        let heads = kh.get(&(triple.relation, triple.tail)).unwrap();
                        let w = 1.0 / heads.len() as f32;
                        for &h in heads {
                            htgt[i * num_entities + h] = w;
                        }
                    }
                    let head_t =
                        Tensor::from_slice(htgt, (actual_bs, num_entities), &model.device)?;
                    let h_nll = (&head_t * &head_log_probs)?
                        .sum_all()?
                        .neg()?
                        .affine(1.0 / actual_bs as f64, 0.0)?;

                    (t_nll, h_nll)
                } else {
                    // 1vsAll: single-target CE via gather (Lacroix 2018).
                    let tail_ids = Tensor::from_vec(
                        batch.iter().map(|t| t.tail as u32).collect::<Vec<_>>(),
                        actual_bs,
                        &model.device,
                    )?;
                    let t_nll = tail_log_probs
                        .gather(&tail_ids.unsqueeze(1)?, 1)?
                        .squeeze(1)?
                        .neg()?
                        .mean_all()?;

                    let head_ids = Tensor::from_vec(
                        batch.iter().map(|t| t.head as u32).collect::<Vec<_>>(),
                        actual_bs,
                        &model.device,
                    )?;
                    let h_nll = head_log_probs
                        .gather(&head_ids.unsqueeze(1)?, 1)?
                        .squeeze(1)?
                        .neg()?
                        .mean_all()?;

                    (t_nll, h_nll)
                };

                // Average head and tail losses.
                let nll = ((tail_nll + head_nll)? * 0.5)?;

                let main_loss = if eps > 0.0 {
                    let tail_uniform = tail_log_probs.mean_all()?.neg()?;
                    let head_uniform = head_log_probs.mean_all()?.neg()?;
                    let uniform = ((tail_uniform + head_uniform)? * 0.5)?;
                    ((nll * (1.0 - eps))? + (uniform * eps)?)?
                } else {
                    nll
                };

                // Relation prediction auxiliary loss (Chen et al. 2021).
                if config.relation_prediction_weight > 0.0 {
                    let rel_scores = model.score_1n_relations(&heads, &tails, num_relations)?;
                    let rel_log_probs = candle_nn::ops::log_softmax(&rel_scores, D::Minus1)?;
                    let rel_nll = rel_log_probs
                        .gather(&rels.unsqueeze(1)?, 1)?
                        .squeeze(1)?
                        .neg()?
                        .mean_all()?;
                    (main_loss + (rel_nll * config.relation_prediction_weight as f64)?)?
                } else {
                    main_loss
                }
            } else {
                // Negative sampling with SANS.
                let pos_scores = model.score_batch(&heads, &rels, &tails)?;

                let neg_entities = Tensor::rand(
                    0.0_f32,
                    num_entities as f32,
                    (actual_bs, config.num_negatives),
                    &model.device,
                )?
                .to_dtype(DType::U32)?;

                let corrupt_mask = Tensor::rand(
                    0.0_f32,
                    1.0_f32,
                    (actual_bs, config.num_negatives),
                    &model.device,
                )?;
                let half = Tensor::full(0.5_f32, (actual_bs, config.num_negatives), &model.device)?;
                let corrupt_head = corrupt_mask.lt(&half)?;

                let heads_exp = heads
                    .unsqueeze(1)?
                    .expand((actual_bs, config.num_negatives))?;
                let rels_exp = rels
                    .unsqueeze(1)?
                    .expand((actual_bs, config.num_negatives))?;
                let tails_exp = tails
                    .unsqueeze(1)?
                    .expand((actual_bs, config.num_negatives))?;

                let neg_heads = corrupt_head.where_cond(&neg_entities, &heads_exp)?;
                let neg_tails = corrupt_head.where_cond(&tails_exp, &neg_entities)?;

                let neg_scores = model
                    .score_batch(
                        &neg_heads.flatten_all()?,
                        &rels_exp.flatten_all()?,
                        &neg_tails.flatten_all()?,
                    )?
                    .reshape((actual_bs, config.num_negatives))?;

                let neg_weights = if alpha > 0.0 {
                    let scaled = (neg_scores.detach() * (-(alpha as f64)))?;
                    candle_nn::ops::softmax(&scaled, D::Minus1)?
                } else {
                    Tensor::ones((actual_bs, config.num_negatives), DType::F32, &model.device)?
                        .affine(1.0 / config.num_negatives as f64, 0.0)?
                };

                let pos_loss = log_sigmoid(&(pos_scores.neg()? + gamma as f64)?)?.neg()?;
                let neg_loss_per = log_sigmoid(&(neg_scores - gamma as f64)?)?;
                let weighted_neg_loss = (&neg_weights * &neg_loss_per)?.sum(D::Minus1)?.neg()?;

                let per_triple_loss = (pos_loss + weighted_neg_loss)?;
                if let Some(ref freq) = entity_freq {
                    let subsample_w: Vec<f32> = batch
                        .iter()
                        .map(|triple| 1.0 / ((freq[triple.head] + freq[triple.tail]) as f32).sqrt())
                        .collect();
                    let subsample_t = Tensor::from_vec(subsample_w, actual_bs, &model.device)?;
                    (&per_triple_loss * &subsample_t)?.mean_all()?
                } else {
                    per_triple_loss.mean_all()?
                }
            };

            // N3 regularization.
            if n3_coeff > 0.0 {
                let n3 = model.n3_penalty(&heads, &rels, &tails)?;
                loss = (loss + (n3 * n3_coeff as f64)?)?;
            }

            // L2 regularization on used embeddings.
            if config.l2_reg > 0.0 {
                let h = model
                    .entity_embeddings
                    .as_tensor()
                    .index_select(&heads, 0)?;
                let r = model
                    .relation_embeddings
                    .as_tensor()
                    .index_select(&rels, 0)?;
                let t = model
                    .entity_embeddings
                    .as_tensor()
                    .index_select(&tails, 0)?;
                let l2 = ((h.sqr()?.mean_all()? + r.sqr()?.mean_all()?)? + t.sqr()?.mean_all()?)?;
                loss = (loss + (l2 * config.l2_reg as f64)?)?;
            }

            optimizer.backward_step(&loss)?;

            // Entity normalization (L2 unit norm per row).
            if config.normalize_entities {
                let ent = model.entity_embeddings.as_tensor();
                let norms = ent.sqr()?.sum(D::Minus1)?.sqrt()?.unsqueeze(D::Minus1)?;
                let normalized = ent.broadcast_div(&norms.clamp(1e-8, f64::MAX)?)?;
                model.entity_embeddings.set(&normalized)?;
            }

            // Accumulate loss. to_scalar forces GPU sync -- acceptable per batch
            // since we need the value for epoch averaging.
            epoch_loss += loss.to_scalar::<f32>()? as f64;
            n_batches += 1;
        }

        let avg_loss = (epoch_loss / n_batches as f64) as f32;
        losses.push(avg_loss);
        epoch_times.push(epoch_start.elapsed().as_secs_f32());

        // Snapshot at cosine annealing cycle troughs.
        if config.cosine_cycles > 0 {
            let effective_epoch = _epoch.saturating_sub(config.warmup_epochs);
            let total_effective = config.epochs.saturating_sub(config.warmup_epochs);
            let epochs_per_cycle = total_effective / config.cosine_cycles;
            if epochs_per_cycle > 0
                && effective_epoch > 0
                && (effective_epoch + 1) % epochs_per_cycle == 0
            {
                if let (Ok(ev), Ok(rv)) = (model.entity_vecs(), model.relation_vecs()) {
                    eprintln!(
                        "Snapshot {} saved at epoch {}",
                        snapshots.len() + 1,
                        _epoch + 1
                    );
                    snapshots.push(Snapshot {
                        entity_vecs: ev,
                        relation_vecs: rv,
                        epoch: _epoch + 1,
                    });
                }
            }
        }

        if config.log_interval > 0 && (_epoch + 1) % config.log_interval == 0 {
            let epoch_secs = epoch_start.elapsed().as_secs_f32();
            let ent_norm = model
                .entity_embeddings
                .as_tensor()
                .sqr()
                .and_then(|t| t.mean_all())
                .and_then(|t| t.to_scalar::<f32>())
                .map(|v| v.sqrt())
                .unwrap_or(0.0);
            eprintln!(
                "epoch {:>4} | loss {:.4} | {:.1}s | emb_rms {:.4}",
                _epoch + 1,
                avg_loss,
                epoch_secs,
                ent_norm,
            );
        }

        // Checkpoint save.
        if let Some(ref dir) = config.checkpoint_dir {
            if config.checkpoint_interval > 0 && (_epoch + 1) % config.checkpoint_interval == 0 {
                if let (Ok(ent), Ok(rel)) = (model.entity_vecs(), model.relation_vecs()) {
                    let ent_names: Vec<String> = (0..num_entities).map(|i| i.to_string()).collect();
                    let rel_names: Vec<String> =
                        (0..num_relations).map(|i| i.to_string()).collect();
                    let _ = crate::io::export_embeddings(dir, &ent_names, &ent, &rel_names, &rel);
                    eprintln!("Checkpoint saved to {}", dir.display());
                }
            }
        }

        // Stochastic Weight Averaging.
        if swa_active && _epoch + 1 >= config.swa_start_epoch {
            if let (Ok(ent_flat), Ok(rel_flat)) = (
                model
                    .entity_embeddings
                    .as_tensor()
                    .flatten_all()?
                    .to_vec1::<f32>(),
                model
                    .relation_embeddings
                    .as_tensor()
                    .flatten_all()?
                    .to_vec1::<f32>(),
            ) {
                swa_count += 1;
                let update = |avg: &mut Option<Vec<f32>>, current: &[f32]| match avg {
                    None => *avg = Some(current.to_vec()),
                    Some(ref mut buf) => {
                        for (a, &c) in buf.iter_mut().zip(current.iter()) {
                            *a += (c - *a) / swa_count as f32;
                        }
                    }
                };
                update(&mut swa_ent, &ent_flat);
                update(&mut swa_rel, &rel_flat);
            }
        }

        // Validation-based early stopping.
        if let Some(ref val) = validation {
            if config.eval_interval > 0 && (_epoch + 1) % config.eval_interval == 0 {
                let scorer: Box<dyn crate::Scorer + Sync> = match model.model_type {
                    ModelType::TransE => Box::new(model.to_transe()?),
                    ModelType::RotatE => Box::new(model.to_rotate()?),
                    ModelType::ComplEx => Box::new(model.to_complex()?),
                    ModelType::DistMult => Box::new(model.to_distmult()?),
                };
                let metrics = crate::eval::evaluate_link_prediction(
                    scorer.as_ref(),
                    val.valid_triples,
                    val.filter,
                    num_entities,
                );
                if metrics.mrr > best_mrr {
                    best_mrr = metrics.mrr;
                    patience_counter = 0;
                    // Snapshot best model (copy through CPU to avoid Var storage sharing).
                    best_entity_vecs = model.entity_vecs().ok();
                    best_relation_vecs = model.relation_vecs().ok();
                } else {
                    patience_counter += 1;
                    if patience_counter >= config.patience {
                        eprintln!(
                            "Early stopping at epoch {} (best MRR: {:.4})",
                            _epoch + 1,
                            best_mrr,
                        );
                        break;
                    }
                }
            }
        }
    }

    // Restore best-validation model if early stopping saved a snapshot.
    if let (Some(ent_vecs), Some(rel_vecs)) = (best_entity_vecs, best_relation_vecs) {
        let ent_flat: Vec<f32> = ent_vecs.iter().flat_map(|v| v.iter().copied()).collect();
        let rel_flat: Vec<f32> = rel_vecs.iter().flat_map(|v| v.iter().copied()).collect();
        let ent_shape = model.entity_embeddings.shape().clone();
        let rel_shape = model.relation_embeddings.shape().clone();
        let ent_t = Tensor::from_vec(ent_flat, ent_shape, &model.device)?;
        let rel_t = Tensor::from_vec(rel_flat, rel_shape, &model.device)?;
        model.entity_embeddings.set(&ent_t)?;
        model.relation_embeddings.set(&rel_t)?;
    }

    // Convert SWA flat buffers to Vec<Vec<f32>>.
    let ent_cols = model.entity_embeddings.as_tensor().dim(1)?;
    let rel_cols = model.relation_embeddings.as_tensor().dim(1)?;
    let swa_entity_vecs =
        swa_ent.map(|flat| flat.chunks_exact(ent_cols).map(|c| c.to_vec()).collect());
    let swa_relation_vecs =
        swa_rel.map(|flat| flat.chunks_exact(rel_cols).map(|c| c.to_vec()).collect());

    Ok(TrainResult {
        model,
        losses,
        epoch_times,
        snapshots,
        swa_entity_vecs,
        swa_relation_vecs,
    })
}

/// Numerically stable `log(sigmoid(x))`.
fn log_sigmoid(x: &Tensor) -> Result<Tensor> {
    // log(sigmoid(x)) = x - softplus(x) = x - log(1 + exp(x))
    // For numerical stability: -softplus(-x) = -log(1 + exp(-x))
    // Use: log(sigmoid(x)) = -max(0, -x) - log(1 + exp(-|x|))
    let neg_x = x.neg()?;
    let abs_x = x.abs()?;
    let neg_abs = abs_x.neg()?;
    // relu(-x) = max(0, -x)
    let relu_neg = neg_x.relu()?;
    // log(1 + exp(-|x|)) -- no log1p in candle, use log(exp(-|x|) + 1)
    let softplus = (neg_abs.exp()? + 1.0)?.log()?;
    let result = (relu_neg.neg()? - softplus)?;
    Ok(result)
}

fn tensor_to_vecs(t: &Tensor) -> Result<Vec<Vec<f32>>> {
    let t = t.to_device(&Device::Cpu)?;
    let rows = t.dim(0)?;
    let cols = t.dim(1)?;
    let data = t.flatten_all()?.to_vec1::<f32>()?;
    Ok((0..rows)
        .map(|i| data[i * cols..(i + 1) * cols].to_vec())
        .collect())
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::dataset::TripleIds;
    use crate::Scorer;

    fn tid(h: usize, r: usize, t: usize) -> TripleIds {
        TripleIds::new(h, r, t)
    }

    #[test]
    fn log_sigmoid_basic() {
        let device = Device::Cpu;
        let x = Tensor::new(&[0.0_f32, 10.0, -10.0], &device).unwrap();
        let result = log_sigmoid(&x).unwrap().to_vec1::<f32>().unwrap();
        // log(sigmoid(0)) = log(0.5) ~ -0.693
        assert!((result[0] - (-0.693)).abs() < 0.01, "got {}", result[0]);
        // log(sigmoid(10)) ~ 0
        assert!(result[1] > -0.001, "got {}", result[1]);
        // log(sigmoid(-10)) ~ -10
        assert!((result[2] - (-10.0)).abs() < 0.01, "got {}", result[2]);
    }

    #[test]
    fn train_transe_smoke() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0), tid(0, 1, 2)];
        let config = TrainConfig {
            model_type: ModelType::TransE,
            dim: 8,
            num_negatives: 4,
            gamma: 6.0,
            adversarial_temperature: 0.5,
            lr: 0.01,
            n3_reg: 0.0,
            batch_size: 4,
            epochs: 5,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 2, &config, &device).unwrap();
        assert_eq!(result.losses.len(), 5);
        assert!(result.losses.iter().all(|l| l.is_finite()));
        let model = result.model.to_transe().unwrap();
        assert_eq!(model.num_entities(), 3);
    }

    #[test]
    fn train_rotate_smoke() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
        let config = TrainConfig {
            model_type: ModelType::RotatE,
            dim: 4,
            num_negatives: 2,
            gamma: 6.0,
            adversarial_temperature: 1.0,
            lr: 0.01,
            n3_reg: 0.0,
            batch_size: 2,
            epochs: 10,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 1, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
        let first = result.losses[0];
        let last = *result.losses.last().unwrap();
        assert!(
            last < first,
            "RotatE loss should decrease: {first} -> {last}"
        );
        let model = result.model.to_rotate().unwrap();
        assert_eq!(model.num_entities(), 3);
    }

    #[test]
    fn train_complex_with_n3() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
        let config = TrainConfig {
            model_type: ModelType::ComplEx,
            dim: 4,
            num_negatives: 2,
            gamma: 6.0,
            adversarial_temperature: 1.0,
            lr: 0.01,
            n3_reg: 0.001,
            batch_size: 2,
            epochs: 10,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 1, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
        let first = result.losses[0];
        let last = *result.losses.last().unwrap();
        assert!(
            last < first,
            "ComplEx loss should decrease: {first} -> {last}"
        );
        let model = result.model.to_complex().unwrap();
        assert_eq!(model.num_entities(), 3);
    }

    #[test]
    fn train_distmult_smoke() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 8,
            num_negatives: 2,
            gamma: 6.0,
            adversarial_temperature: 0.0,
            lr: 0.01,
            n3_reg: 0.0,
            batch_size: 2,
            epochs: 10,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 1, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
        let first = result.losses[0];
        let last = *result.losses.last().unwrap();
        assert!(
            last < first,
            "DistMult loss should decrease: {first} -> {last}"
        );
        let model = result.model.to_distmult().unwrap();
        assert_eq!(model.num_entities(), 3);
    }

    #[test]
    fn loss_decreases() {
        let device = Device::Cpu;
        // Enough data and epochs for loss to decrease.
        let triples: Vec<_> = (0..20).map(|i| tid(i % 10, i % 3, (i + 1) % 10)).collect();
        let config = TrainConfig {
            model_type: ModelType::TransE,
            dim: 16,
            num_negatives: 8,
            gamma: 6.0,
            adversarial_temperature: 0.5,
            lr: 0.01,
            n3_reg: 0.0,
            batch_size: 10,
            epochs: 50,
            ..TrainConfig::default()
        };
        let result = train(&triples, 10, 3, &config, &device).unwrap();
        let first = result.losses[0];
        let last = *result.losses.last().unwrap();
        assert!(
            last < first,
            "Loss should decrease: first={first}, last={last}"
        );
    }

    #[test]
    fn transe_achieves_nonzero_mrr_on_trivial_graph() {
        // 5 entities, 1 relation: 0->1, 1->2, 2->3, 3->4.
        // After training, score(0,0,1) should be the best among all tails.
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 0, 3), tid(3, 0, 4)];
        let config = TrainConfig {
            model_type: ModelType::TransE,
            dim: 32,
            num_negatives: 4,
            gamma: 6.0,
            adversarial_temperature: 0.0,
            lr: 0.01,
            n3_reg: 0.0,
            batch_size: 4,
            epochs: 500,
            ..TrainConfig::default()
        };
        let result = train(&triples, 5, 1, &config, &device).unwrap();
        let model = result.model.to_transe().unwrap();

        let ds = crate::dataset::Dataset::new(
            triples
                .iter()
                .map(|t| {
                    crate::dataset::Triple::new(
                        t.head.to_string(),
                        t.relation.to_string(),
                        t.tail.to_string(),
                    )
                })
                .collect(),
            Vec::new(),
            Vec::new(),
        )
        .into_interned();
        let filter = crate::dataset::FilterIndex::from_dataset(&ds);
        let metrics = crate::eval::evaluate_link_prediction(&model, &triples, &filter, 5);
        assert!(
            metrics.mrr > 0.3,
            "TransE should achieve MRR > 0.3 on trivial graph, got {:.4}",
            metrics.mrr
        );
    }

    #[test]
    fn one_to_n_distmult_smoke() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0), tid(0, 1, 2)];
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 8,
            one_to_n: true,
            label_smoothing: 0.1,
            lr: 0.01,
            batch_size: 4,
            epochs: 10,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 2, &config, &device).unwrap();
        assert_eq!(result.losses.len(), 10);
        assert!(result.losses.iter().all(|l| l.is_finite()));
        // Loss should decrease with 1-N.
        let first = result.losses[0];
        let last = *result.losses.last().unwrap();
        assert!(last < first, "1-N loss should decrease: {first} -> {last}");
    }

    #[test]
    fn one_to_n_transe_smoke() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 0, 3)];
        let config = TrainConfig {
            model_type: ModelType::TransE,
            dim: 8,
            one_to_n: true,
            label_smoothing: 0.1,
            lr: 0.001,
            batch_size: 3,
            epochs: 10,
            ..TrainConfig::default()
        };
        let result = train(&triples, 4, 1, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
    }

    #[test]
    fn adagrad_optimizer_smoke() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0), tid(0, 1, 2)];
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            optimizer: OptimizerType::Adagrad,
            dim: 8,
            init_scale: 1e-3,
            lr: 0.1,
            one_to_n: true,
            batch_size: 4,
            epochs: 10,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 2, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
        let first = result.losses[0];
        let last = *result.losses.last().unwrap();
        assert!(
            last < first,
            "Adagrad loss should decrease: {first} -> {last}"
        );
    }

    #[test]
    fn multi_hot_labels_with_duplicate_tails() {
        // Triple (0,0,1) and (0,0,2) share (h=0, r=0).
        // Multi-hot target should have weight 0.5 on both entities 1 and 2.
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(0, 0, 2), tid(1, 0, 0)];
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 8,
            one_to_n: true,
            batch_size: 3,
            epochs: 5,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 1, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
    }

    #[test]
    fn n3_regularization_complex_moduli() {
        // Verify N3 with ComplEx doesn't NaN.
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
        let config = TrainConfig {
            model_type: ModelType::ComplEx,
            dim: 4,
            n3_reg: 0.1,
            one_to_n: true,
            batch_size: 2,
            epochs: 5,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 1, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
    }

    #[test]
    fn l2_regularization_reduces_embedding_norm() {
        let device = Device::Cpu;
        let triples: Vec<_> = (0..20).map(|i| tid(i % 5, 0, (i + 1) % 5)).collect();
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 8,
            l2_reg: 0.1,
            one_to_n: true,
            batch_size: 10,
            epochs: 20,
            ..TrainConfig::default()
        };
        let result = train(&triples, 5, 1, &config, &device).unwrap();
        // With strong L2 reg, embedding norms should stay small.
        let ent_vecs = result.model.entity_vecs().unwrap();
        let max_norm: f32 = ent_vecs
            .iter()
            .map(|v| v.iter().map(|x| x * x).sum::<f32>().sqrt())
            .fold(0.0_f32, f32::max);
        assert!(
            max_norm < 10.0,
            "L2 reg should keep norms small, got max_norm={max_norm}"
        );
    }

    // -- LR schedule -----------------------------------------------------------

    #[test]
    fn lr_warmup_ramps_linearly() {
        let config = TrainConfig {
            lr: 0.01,
            warmup_epochs: 10,
            epochs: 100,
            ..TrainConfig::default()
        };
        let lr0 = learning_rate(0, &config);
        let lr5 = learning_rate(5, &config);
        let lr9 = learning_rate(9, &config);
        assert!((lr0 - 0.001).abs() < 1e-10, "epoch 0: {lr0}");
        assert!((lr5 - 0.006).abs() < 1e-10, "epoch 5: {lr5}");
        assert!((lr9 - 0.01).abs() < 1e-10, "epoch 9: {lr9}");
    }

    #[test]
    fn lr_constant_after_warmup_without_cosine() {
        let config = TrainConfig {
            lr: 0.01,
            warmup_epochs: 5,
            cosine_cycles: 0,
            epochs: 100,
            ..TrainConfig::default()
        };
        let lr = learning_rate(50, &config);
        assert!((lr - 0.01).abs() < 1e-10, "should be base LR: {lr}");
    }

    #[test]
    fn lr_cosine_starts_at_base_and_decays() {
        let config = TrainConfig {
            lr: 0.01,
            warmup_epochs: 0,
            cosine_cycles: 1,
            cosine_min_lr_frac: 0.1,
            epochs: 100,
            ..TrainConfig::default()
        };
        let lr_start = learning_rate(0, &config);
        let lr_mid = learning_rate(50, &config);
        let lr_end = learning_rate(99, &config);
        assert!(
            (lr_start - 0.01).abs() < 1e-6,
            "cosine should start at base LR: {lr_start}"
        );
        assert!(
            lr_mid < lr_start,
            "mid-cycle LR should be below start: {lr_mid}"
        );
        assert!(
            lr_end < lr_mid,
            "end-of-cycle LR should be below mid: {lr_end}"
        );
        // Should not go below min_frac * base_lr
        assert!(
            lr_end >= 0.001 - 1e-10,
            "LR should not drop below min: {lr_end}"
        );
    }

    #[test]
    fn lr_cosine_min_frac_respected() {
        let config = TrainConfig {
            lr: 0.1,
            warmup_epochs: 0,
            cosine_cycles: 1,
            cosine_min_lr_frac: 0.1,
            epochs: 100,
            ..TrainConfig::default()
        };
        for epoch in 0..100 {
            let lr = learning_rate(epoch, &config);
            assert!(lr >= 0.1 * 0.1 - 1e-10, "epoch {epoch}: LR {lr} below min");
            assert!(lr <= 0.1 + 1e-10, "epoch {epoch}: LR {lr} above base");
        }
    }

    #[test]
    fn lr_always_positive() {
        let config = TrainConfig {
            lr: 0.001,
            warmup_epochs: 10,
            cosine_cycles: 3,
            cosine_min_lr_frac: 0.1,
            epochs: 300,
            ..TrainConfig::default()
        };
        for epoch in 0..300 {
            let lr = learning_rate(epoch, &config);
            assert!(lr > 0.0, "epoch {epoch}: LR must be positive, got {lr}");
        }
    }

    // -- MRR integration tests for all models ---------------------------------

    fn make_trivial_graph() -> Vec<TripleIds> {
        // 5 entities, 1 relation: 0->1, 1->2, 2->3, 3->4.
        vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 0, 3), tid(3, 0, 4)]
    }

    fn eval_mrr(
        triples: &[TripleIds],
        model: &(dyn crate::Scorer + Sync),
        num_entities: usize,
    ) -> f32 {
        let ds = crate::dataset::Dataset::new(
            triples
                .iter()
                .map(|t| {
                    crate::dataset::Triple::new(
                        t.head.to_string(),
                        t.relation.to_string(),
                        t.tail.to_string(),
                    )
                })
                .collect(),
            Vec::new(),
            Vec::new(),
        )
        .into_interned();
        let filter = crate::dataset::FilterIndex::from_dataset(&ds);
        crate::eval::evaluate_link_prediction(model, triples, &filter, num_entities).mrr
    }

    #[test]
    fn rotate_achieves_nonzero_mrr_on_trivial_graph() {
        let device = Device::Cpu;
        let triples = make_trivial_graph();
        let config = TrainConfig {
            model_type: ModelType::RotatE,
            dim: 32,
            num_negatives: 4,
            gamma: 6.0,
            adversarial_temperature: 0.0,
            lr: 0.01,
            batch_size: 4,
            epochs: 500,
            ..TrainConfig::default()
        };
        let result = train(&triples, 5, 1, &config, &device).unwrap();
        let model = result.model.to_rotate().unwrap();
        let mrr = eval_mrr(&triples, &model, 5);
        assert!(
            mrr > 0.3,
            "RotatE should achieve MRR > 0.3 on trivial graph, got {mrr:.4}"
        );
    }

    #[test]
    fn complex_achieves_nonzero_mrr_on_trivial_graph() {
        let device = Device::Cpu;
        let triples = make_trivial_graph();
        let config = TrainConfig {
            model_type: ModelType::ComplEx,
            dim: 32,
            one_to_n: true,
            lr: 0.01,
            batch_size: 4,
            epochs: 200,
            ..TrainConfig::default()
        };
        let result = train(&triples, 5, 1, &config, &device).unwrap();
        let model = result.model.to_complex().unwrap();
        let mrr = eval_mrr(&triples, &model, 5);
        assert!(
            mrr > 0.3,
            "ComplEx should achieve MRR > 0.3 on trivial graph, got {mrr:.4}"
        );
    }

    #[test]
    fn distmult_achieves_nonzero_mrr_on_trivial_graph() {
        let device = Device::Cpu;
        let triples = make_trivial_graph();
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 32,
            one_to_n: true,
            lr: 0.01,
            batch_size: 4,
            epochs: 200,
            ..TrainConfig::default()
        };
        let result = train(&triples, 5, 1, &config, &device).unwrap();
        let model = result.model.to_distmult().unwrap();
        let mrr = eval_mrr(&triples, &model, 5);
        assert!(
            mrr > 0.3,
            "DistMult should achieve MRR > 0.3 on trivial graph, got {mrr:.4}"
        );
    }

    #[test]
    fn swa_produces_averaged_embeddings() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2), tid(2, 1, 0)];
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 8,
            one_to_n: true,
            batch_size: 3,
            epochs: 10,
            swa_start_epoch: 5,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 2, &config, &device).unwrap();
        assert!(
            result.swa_entity_vecs.is_some(),
            "SWA should produce entity vecs"
        );
        assert!(
            result.swa_relation_vecs.is_some(),
            "SWA should produce relation vecs"
        );
        let swa_ent = result.swa_entity_vecs.unwrap();
        assert_eq!(swa_ent.len(), 3);
        assert_eq!(swa_ent[0].len(), 8);
        // SWA vecs should differ from the final model vecs.
        let final_ent = result.model.entity_vecs().unwrap();
        let differs = swa_ent
            .iter()
            .zip(final_ent.iter())
            .any(|(a, b)| a.iter().zip(b.iter()).any(|(x, y)| (x - y).abs() > 1e-8));
        assert!(differs, "SWA average should differ from final model");
    }

    #[test]
    fn swa_disabled_returns_none() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 0, 2)];
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 4,
            one_to_n: true,
            batch_size: 2,
            epochs: 5,
            swa_start_epoch: 0, // disabled
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 1, &config, &device).unwrap();
        assert!(result.swa_entity_vecs.is_none());
        assert!(result.swa_relation_vecs.is_none());
    }

    #[test]
    fn relation_prediction_loss_smoke() {
        let device = Device::Cpu;
        let triples = vec![tid(0, 0, 1), tid(1, 1, 2), tid(2, 0, 0), tid(0, 1, 2)];
        let config = TrainConfig {
            model_type: ModelType::DistMult,
            dim: 8,
            one_to_n: true,
            relation_prediction_weight: 0.1,
            batch_size: 4,
            epochs: 10,
            ..TrainConfig::default()
        };
        let result = train(&triples, 3, 2, &config, &device).unwrap();
        assert!(result.losses.iter().all(|l| l.is_finite()));
        let first = result.losses[0];
        let last = *result.losses.last().unwrap();
        assert!(
            last < first,
            "Loss with relation prediction should decrease: {first} -> {last}"
        );
    }
}