Skip to main content

oxirs_embed/kg_completion/
mod.rs

1//! Knowledge Graph Completion: negative sampling and batched training loops.
2//!
3//! This module provides the building blocks for training knowledge graph embedding
4//! models using the standard *positive / negative sample* approach:
5//!
6//! - [`NegativeSampler`] — three strategies for generating corrupted triples.
7//! - [`KgCompletionTask`] — produces negative samples given a positive triple.
8//! - [`TrainingBatch`] — a bundle of positive and negative triples ready for training.
9//! - [`BatchedTrainingLoop`] — prepares batches and computes training losses.
10//!
11//! ## Supported loss functions
12//!
13//! | Function | Formula |
14//! |---|---|
15//! | Margin loss (TransE-style) | `max(0, margin - score_pos + score_neg)` summed over negatives |
16//! | Binary cross-entropy | `−Σ [log σ(score_pos) + log(1 − σ(score_neg))]` |
17
18use anyhow::{anyhow, Result};
19use serde::{Deserialize, Serialize};
20
21use crate::{NamedNode, Triple};
22
23// ─────────────────────────────────────────────────────────────────────────────
24// Pseudo-random helper — uses simple linear-congruential generator so we stay
25// free of the `rand` crate (scirs2_core::random requires feature flags that
26// may not always be active in unit tests).
27// ─────────────────────────────────────────────────────────────────────────────
28
29/// Minimal LCG random number generator seeded deterministically.
30struct Lcg {
31    state: u64,
32}
33
34impl Lcg {
35    fn new(seed: u64) -> Self {
36        Self {
37            state: seed ^ 0x6c62_272e_07bb_0142,
38        }
39    }
40
41    /// Return the next value in `[0, modulus)`.
42    fn next_usize(&mut self, modulus: usize) -> usize {
43        // Knuth LCG constants.
44        self.state = self
45            .state
46            .wrapping_mul(6_364_136_223_846_793_005)
47            .wrapping_add(1_442_695_040_888_963_407);
48        ((self.state >> 33) as usize) % modulus
49    }
50
51    /// Return the next `f64` in `[0.0, 1.0)`.
52    fn next_f64(&mut self) -> f64 {
53        self.state = self
54            .state
55            .wrapping_mul(6_364_136_223_846_793_005)
56            .wrapping_add(1_442_695_040_888_963_407);
57        (self.state >> 11) as f64 / (1u64 << 53) as f64
58    }
59}
60
61// ─────────────────────────────────────────────────────────────────────────────
62// NegativeSampler
63// ─────────────────────────────────────────────────────────────────────────────
64
65/// Strategy used to generate corrupted (*negative*) triples from a positive one.
66#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub enum NegativeSampler {
68    /// Randomly replace head or tail with any entity drawn uniformly at random.
69    Uniform,
70    /// Replace only with entities that appear in the same position (head or tail)
71    /// across the observed triple set — provides type-constrained negatives.
72    TypeConstrained,
73    /// Self-adversarial sampling: weight candidates proportionally to their
74    /// current model score so hard negatives are sampled more often.
75    SelfAdversarial {
76        /// Temperature parameter controlling sharpness of the distribution.
77        /// Higher temperature → more uniform; lower temperature → harder negatives.
78        temperature: f64,
79    },
80}
81
82// ─────────────────────────────────────────────────────────────────────────────
83// KgCompletionTask
84// ─────────────────────────────────────────────────────────────────────────────
85
86/// Generates negative triples for knowledge graph completion training.
87///
88/// The task maintains an optional list of *known head entities* and *known tail
89/// entities* for type-constrained sampling; if not provided, uniform sampling
90/// is used as a fallback.
91#[derive(Debug, Clone)]
92pub struct KgCompletionTask {
93    /// All entity IRI strings observed in the training set.
94    known_entities: Vec<String>,
95    /// Entities that appear as heads in at least one observed triple.
96    head_entities: Vec<String>,
97    /// Entities that appear as tails in at least one observed triple.
98    tail_entities: Vec<String>,
99}
100
101impl KgCompletionTask {
102    /// Create a task with a flat list of entities (used for uniform sampling).
103    pub fn new(known_entities: Vec<String>) -> Self {
104        let head_entities = known_entities.clone();
105        let tail_entities = known_entities.clone();
106        Self {
107            known_entities,
108            head_entities,
109            tail_entities,
110        }
111    }
112
113    /// Create a task with separate head and tail entity pools for type-constrained sampling.
114    pub fn with_type_constraints(
115        known_entities: Vec<String>,
116        head_entities: Vec<String>,
117        tail_entities: Vec<String>,
118    ) -> Self {
119        Self {
120            known_entities,
121            head_entities,
122            tail_entities,
123        }
124    }
125
126    /// Build a `KgCompletionTask` by inspecting a set of observed triples.
127    ///
128    /// Extracts all unique entities, heads, and tails automatically.
129    pub fn from_triples(triples: &[Triple]) -> Self {
130        let mut all: std::collections::HashSet<String> = std::collections::HashSet::new();
131        let mut heads: std::collections::HashSet<String> = std::collections::HashSet::new();
132        let mut tails: std::collections::HashSet<String> = std::collections::HashSet::new();
133
134        for t in triples {
135            all.insert(t.subject.iri.clone());
136            all.insert(t.predicate.iri.clone());
137            all.insert(t.object.iri.clone());
138            heads.insert(t.subject.iri.clone());
139            tails.insert(t.object.iri.clone());
140        }
141
142        let mut known: Vec<String> = all.into_iter().collect();
143        let mut head_vec: Vec<String> = heads.into_iter().collect();
144        let mut tail_vec: Vec<String> = tails.into_iter().collect();
145        known.sort_unstable();
146        head_vec.sort_unstable();
147        tail_vec.sort_unstable();
148
149        Self {
150            known_entities: known,
151            head_entities: head_vec,
152            tail_entities: tail_vec,
153        }
154    }
155
156    /// Sample `n` negative triples for the given positive `triple`.
157    ///
158    /// * `entity_count` — not used internally (the task uses its own entity pool),
159    ///   but kept in the signature for API compatibility.
160    /// * `strategy` — sampling strategy to apply.
161    ///
162    /// Returns an empty `Vec` when the entity pool is empty.
163    pub fn sample_negatives(
164        &self,
165        triple: &Triple,
166        _entity_count: usize,
167        n: usize,
168        strategy: &NegativeSampler,
169    ) -> Vec<Triple> {
170        if self.known_entities.is_empty() || n == 0 {
171            return Vec::new();
172        }
173
174        // Seed the LCG with a deterministic hash of the positive triple so
175        // different triples produce different negatives.
176        let seed: u64 = triple
177            .subject
178            .iri
179            .bytes()
180            .chain(triple.predicate.iri.bytes())
181            .chain(triple.object.iri.bytes())
182            .enumerate()
183            .fold(0u64, |acc, (i, b)| {
184                acc.wrapping_add((b as u64).wrapping_mul(i as u64 + 1))
185            });
186        let mut rng = Lcg::new(seed);
187
188        match strategy {
189            NegativeSampler::Uniform => self.sample_uniform(triple, n, &mut rng),
190            NegativeSampler::TypeConstrained => self.sample_type_constrained(triple, n, &mut rng),
191            NegativeSampler::SelfAdversarial { temperature } => {
192                self.sample_self_adversarial(triple, n, *temperature, &mut rng)
193            }
194        }
195    }
196
197    // ── private sampling methods ──────────────────────────────────────────────
198
199    fn sample_uniform(&self, triple: &Triple, n: usize, rng: &mut Lcg) -> Vec<Triple> {
200        let pool = &self.known_entities;
201        let mut result = Vec::with_capacity(n);
202        let mut attempts = 0usize;
203        while result.len() < n && attempts < n * 10 {
204            attempts += 1;
205            let idx = rng.next_usize(pool.len());
206            let replacement = &pool[idx];
207            // Randomly choose to corrupt head (0) or tail (1).
208            let neg = if rng.next_usize(2) == 0 {
209                make_triple(replacement, &triple.predicate.iri, &triple.object.iri)
210            } else {
211                make_triple(&triple.subject.iri, &triple.predicate.iri, replacement)
212            };
213            // Only accept if it differs from the positive triple.
214            if is_different(&neg, triple) {
215                result.push(neg);
216            }
217        }
218        result
219    }
220
221    fn sample_type_constrained(&self, triple: &Triple, n: usize, rng: &mut Lcg) -> Vec<Triple> {
222        let heads = if self.head_entities.is_empty() {
223            &self.known_entities
224        } else {
225            &self.head_entities
226        };
227        let tails = if self.tail_entities.is_empty() {
228            &self.known_entities
229        } else {
230            &self.tail_entities
231        };
232
233        let mut result = Vec::with_capacity(n);
234        let mut attempts = 0usize;
235        while result.len() < n && attempts < n * 10 {
236            attempts += 1;
237            let neg = if rng.next_usize(2) == 0 {
238                // Replace head with a type-compatible head entity.
239                let idx = rng.next_usize(heads.len());
240                make_triple(&heads[idx], &triple.predicate.iri, &triple.object.iri)
241            } else {
242                // Replace tail with a type-compatible tail entity.
243                let idx = rng.next_usize(tails.len());
244                make_triple(&triple.subject.iri, &triple.predicate.iri, &tails[idx])
245            };
246            if is_different(&neg, triple) {
247                result.push(neg);
248            }
249        }
250        result
251    }
252
253    fn sample_self_adversarial(
254        &self,
255        triple: &Triple,
256        n: usize,
257        temperature: f64,
258        rng: &mut Lcg,
259    ) -> Vec<Triple> {
260        let pool = &self.known_entities;
261        if pool.is_empty() {
262            return Vec::new();
263        }
264
265        // Assign pseudo-scores as position-based values (simulating model scores
266        // without requiring actual model access).
267        let temp = temperature.max(1e-6);
268        let raw_scores: Vec<f64> = pool
269            .iter()
270            .enumerate()
271            .map(|(i, _)| {
272                // Deterministic mock score — decreases with position.
273                1.0 / (i as f64 + 1.0)
274            })
275            .collect();
276
277        // Softmax with temperature.
278        let max_score = raw_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
279        let exp_scores: Vec<f64> = raw_scores
280            .iter()
281            .map(|s| ((s - max_score) / temp).exp())
282            .collect();
283        let sum_exp: f64 = exp_scores.iter().sum();
284
285        // Build cumulative distribution.
286        let mut cdf: Vec<f64> = Vec::with_capacity(pool.len());
287        let mut cumsum = 0.0_f64;
288        for s in &exp_scores {
289            cumsum += s / sum_exp;
290            cdf.push(cumsum);
291        }
292
293        let mut result = Vec::with_capacity(n);
294        let mut attempts = 0usize;
295        while result.len() < n && attempts < n * 10 {
296            attempts += 1;
297            let u = rng.next_f64();
298            let idx = cdf.iter().position(|&c| u <= c).unwrap_or(pool.len() - 1);
299            let replacement = &pool[idx];
300
301            let neg = if rng.next_usize(2) == 0 {
302                make_triple(replacement, &triple.predicate.iri, &triple.object.iri)
303            } else {
304                make_triple(&triple.subject.iri, &triple.predicate.iri, replacement)
305            };
306            if is_different(&neg, triple) {
307                result.push(neg);
308            }
309        }
310        result
311    }
312}
313
314// ─────────────────────────────────────────────────────────────────────────────
315// BatchedTrainingLoop
316// ─────────────────────────────────────────────────────────────────────────────
317
318/// A prepared batch of positive and negative triples for one training step.
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct TrainingBatch {
321    /// Observed (positive) triples.
322    pub positive_triples: Vec<Triple>,
323    /// Corrupted (negative) triples — `neg_ratio × |positives|` in total.
324    pub negative_triples: Vec<Triple>,
325}
326
327impl TrainingBatch {
328    /// Total number of positive triples in this batch.
329    pub fn positive_count(&self) -> usize {
330        self.positive_triples.len()
331    }
332
333    /// Total number of negative triples in this batch.
334    pub fn negative_count(&self) -> usize {
335        self.negative_triples.len()
336    }
337}
338
339/// Efficient batched training for knowledge graph completion.
340///
341/// Combines negative-sample generation, batch preparation, and loss computation
342/// into a single cohesive API.
343#[derive(Debug, Clone, Default)]
344pub struct BatchedTrainingLoop;
345
346impl BatchedTrainingLoop {
347    /// Create a new `BatchedTrainingLoop`.
348    pub fn new() -> Self {
349        Self
350    }
351
352    /// Prepare a `TrainingBatch` from a slice of positive triples.
353    ///
354    /// For every positive triple, `neg_ratio` negative triples are sampled via
355    /// the given `sampler`.
356    ///
357    /// Returns an error when `positives` is empty.
358    pub fn prepare_batch(
359        &self,
360        task: &KgCompletionTask,
361        positives: &[Triple],
362        neg_ratio: u32,
363        sampler: &NegativeSampler,
364    ) -> Result<TrainingBatch> {
365        if positives.is_empty() {
366            return Err(anyhow!("positives must not be empty"));
367        }
368        let mut negatives = Vec::with_capacity(positives.len() * neg_ratio as usize);
369        for triple in positives {
370            let mut neg_samples = task.sample_negatives(
371                triple,
372                task.known_entities.len(),
373                neg_ratio as usize,
374                sampler,
375            );
376            negatives.append(&mut neg_samples);
377        }
378
379        Ok(TrainingBatch {
380            positive_triples: positives.to_vec(),
381            negative_triples: negatives,
382        })
383    }
384
385    /// Compute the TransE-style margin ranking loss.
386    ///
387    /// ```text
388    /// L = Σ_neg max(0, margin − score_pos + score_neg)
389    /// ```
390    ///
391    /// Higher scores are assumed to be *better* for positive triples.
392    ///
393    /// Returns an error when `pos_scores` or `neg_scores` is empty.
394    pub fn compute_margin_loss(
395        &self,
396        pos_scores: &[f64],
397        neg_scores: &[f64],
398        margin: f64,
399    ) -> Result<f64> {
400        if pos_scores.is_empty() {
401            return Err(anyhow!("pos_scores must not be empty"));
402        }
403        if neg_scores.is_empty() {
404            return Err(anyhow!("neg_scores must not be empty"));
405        }
406
407        // Pair each positive with all negatives (or in round-robin if counts differ).
408        let n_neg = neg_scores.len();
409        let loss: f64 = pos_scores
410            .iter()
411            .enumerate()
412            .flat_map(|(i, &pos)| {
413                // Assign all negatives to each positive when counts differ.
414                neg_scores.iter().enumerate().map(move |(j, &neg)| {
415                    let _ = (i, j); // suppress unused warning
416                    (margin - pos + neg).max(0.0)
417                })
418            })
419            .sum();
420
421        Ok(loss / (pos_scores.len() * n_neg) as f64)
422    }
423
424    /// Compute binary cross-entropy loss over positive and negative scores.
425    ///
426    /// ```text
427    /// L = −(1/N) Σ [log σ(s_pos) + log(1 − σ(s_neg))]
428    /// ```
429    ///
430    /// Returns an error when either input slice is empty.
431    pub fn compute_binary_cross_entropy(
432        &self,
433        pos_scores: &[f64],
434        neg_scores: &[f64],
435    ) -> Result<f64> {
436        if pos_scores.is_empty() {
437            return Err(anyhow!("pos_scores must not be empty"));
438        }
439        if neg_scores.is_empty() {
440            return Err(anyhow!("neg_scores must not be empty"));
441        }
442
443        let sigmoid = |x: f64| 1.0 / (1.0 + (-x).exp());
444        let eps = 1e-12_f64;
445
446        let pos_loss: f64 = pos_scores
447            .iter()
448            .map(|&s| -(sigmoid(s).max(eps).ln()))
449            .sum();
450        let neg_loss: f64 = neg_scores
451            .iter()
452            .map(|&s| -((1.0 - sigmoid(s)).max(eps).ln()))
453            .sum();
454
455        let n = (pos_scores.len() + neg_scores.len()) as f64;
456        Ok((pos_loss + neg_loss) / n)
457    }
458}
459
460// ─────────────────────────────────────────────────────────────────────────────
461// Utility functions
462// ─────────────────────────────────────────────────────────────────────────────
463
464fn make_triple(subject: &str, predicate: &str, object: &str) -> Triple {
465    Triple::new(
466        NamedNode {
467            iri: subject.to_string(),
468        },
469        NamedNode {
470            iri: predicate.to_string(),
471        },
472        NamedNode {
473            iri: object.to_string(),
474        },
475    )
476}
477
478fn is_different(a: &Triple, b: &Triple) -> bool {
479    a.subject.iri != b.subject.iri
480        || a.predicate.iri != b.predicate.iri
481        || a.object.iri != b.object.iri
482}
483
484// ─────────────────────────────────────────────────────────────────────────────
485// Tests
486// ─────────────────────────────────────────────────────────────────────────────
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    fn sample_entities() -> Vec<String> {
493        (0..10).map(|i| format!("entity_{i}")).collect()
494    }
495
496    fn sample_triple() -> Triple {
497        make_triple("entity_0", "relation_A", "entity_1")
498    }
499
500    // ── NegativeSampler / KgCompletionTask ────────────────────────────────────
501
502    #[test]
503    fn test_uniform_sampling_returns_correct_count() {
504        let task = KgCompletionTask::new(sample_entities());
505        let positive = sample_triple();
506        let negatives = task.sample_negatives(&positive, 10, 5, &NegativeSampler::Uniform);
507        assert_eq!(negatives.len(), 5);
508    }
509
510    #[test]
511    fn test_uniform_negatives_differ_from_positive() {
512        let task = KgCompletionTask::new(sample_entities());
513        let positive = sample_triple();
514        let negatives = task.sample_negatives(&positive, 10, 8, &NegativeSampler::Uniform);
515        for neg in &negatives {
516            assert!(is_different(neg, &positive), "negative == positive");
517        }
518    }
519
520    #[test]
521    fn test_type_constrained_sampling() {
522        let entities = sample_entities();
523        let heads = vec!["entity_0".into(), "entity_2".into(), "entity_4".into()];
524        let tails = vec!["entity_1".into(), "entity_3".into(), "entity_5".into()];
525        let task = KgCompletionTask::with_type_constraints(entities, heads.clone(), tails.clone());
526        let positive = sample_triple();
527        let negatives = task.sample_negatives(&positive, 10, 6, &NegativeSampler::TypeConstrained);
528        assert!(!negatives.is_empty());
529        for neg in &negatives {
530            // Every negative must have either a head from the head pool or a tail from the tail pool.
531            let head_ok = heads.contains(&neg.subject.iri);
532            let tail_ok = tails.contains(&neg.object.iri);
533            assert!(
534                head_ok || tail_ok,
535                "corrupted entity not in allowed pool: {neg:?}"
536            );
537        }
538    }
539
540    #[test]
541    fn test_self_adversarial_sampling() {
542        let task = KgCompletionTask::new(sample_entities());
543        let positive = sample_triple();
544        let negatives = task.sample_negatives(
545            &positive,
546            10,
547            6,
548            &NegativeSampler::SelfAdversarial { temperature: 0.5 },
549        );
550        assert_eq!(negatives.len(), 6);
551        for neg in &negatives {
552            assert!(is_different(neg, &positive));
553        }
554    }
555
556    #[test]
557    fn test_sampling_empty_entity_pool() {
558        let task = KgCompletionTask::new(vec![]);
559        let positive = sample_triple();
560        let negatives = task.sample_negatives(&positive, 0, 5, &NegativeSampler::Uniform);
561        assert!(negatives.is_empty());
562    }
563
564    #[test]
565    fn test_sampling_n_zero() {
566        let task = KgCompletionTask::new(sample_entities());
567        let positive = sample_triple();
568        let negatives = task.sample_negatives(&positive, 10, 0, &NegativeSampler::Uniform);
569        assert!(negatives.is_empty());
570    }
571
572    #[test]
573    fn test_from_triples_builds_pools() {
574        let triples = vec![
575            make_triple("alice", "knows", "bob"),
576            make_triple("bob", "knows", "charlie"),
577        ];
578        let task = KgCompletionTask::from_triples(&triples);
579        assert!(task.known_entities.contains(&"alice".to_string()));
580        assert!(task.head_entities.contains(&"alice".to_string()));
581        assert!(task.tail_entities.contains(&"bob".to_string()));
582    }
583
584    // ── BatchedTrainingLoop / prepare_batch ───────────────────────────────────
585
586    #[test]
587    fn test_prepare_batch_basic() {
588        let task = KgCompletionTask::new(sample_entities());
589        let positives = vec![sample_triple()];
590        let batch_loop = BatchedTrainingLoop::new();
591        let batch = batch_loop
592            .prepare_batch(&task, &positives, 3, &NegativeSampler::Uniform)
593            .expect("batch");
594        assert_eq!(batch.positive_count(), 1);
595        // Should have up to 3 negatives (may be fewer if uniqueness is hard to satisfy,
596        // but should have at least 1 given 10 available entities).
597        assert!(!batch.negative_triples.is_empty());
598    }
599
600    #[test]
601    fn test_prepare_batch_empty_positives_error() {
602        let task = KgCompletionTask::new(sample_entities());
603        let batch_loop = BatchedTrainingLoop::new();
604        let result = batch_loop.prepare_batch(&task, &[], 3, &NegativeSampler::Uniform);
605        assert!(result.is_err());
606    }
607
608    #[test]
609    fn test_training_batch_counts() {
610        let batch = TrainingBatch {
611            positive_triples: vec![sample_triple(), sample_triple()],
612            negative_triples: vec![sample_triple(); 6],
613        };
614        assert_eq!(batch.positive_count(), 2);
615        assert_eq!(batch.negative_count(), 6);
616    }
617
618    // ── BatchedTrainingLoop / compute_margin_loss ─────────────────────────────
619
620    #[test]
621    fn test_margin_loss_zero_when_pos_larger() {
622        let bl = BatchedTrainingLoop::new();
623        // pos=10 >> neg=1, margin=1 → loss = max(0, 1-10+1) = 0
624        let loss = bl.compute_margin_loss(&[10.0], &[1.0], 1.0).expect("loss");
625        assert!((loss).abs() < 1e-9, "expected 0 loss, got {loss}");
626    }
627
628    #[test]
629    fn test_margin_loss_positive_when_neg_larger() {
630        let bl = BatchedTrainingLoop::new();
631        // pos=1, neg=10, margin=1 → loss = max(0, 1-1+10) = 10
632        let loss = bl.compute_margin_loss(&[1.0], &[10.0], 1.0).expect("loss");
633        assert!(loss > 0.0, "expected positive loss, got {loss}");
634    }
635
636    #[test]
637    fn test_margin_loss_multiple_pairs() {
638        let bl = BatchedTrainingLoop::new();
639        let pos = vec![5.0, 5.0];
640        let neg = vec![4.0, 3.0];
641        // All negatives lower than positive → zero loss
642        let loss = bl.compute_margin_loss(&pos, &neg, 1.0).expect("loss");
643        assert!((loss).abs() < 1e-9);
644    }
645
646    #[test]
647    fn test_margin_loss_empty_pos_error() {
648        let bl = BatchedTrainingLoop::new();
649        assert!(bl.compute_margin_loss(&[], &[1.0], 1.0).is_err());
650    }
651
652    #[test]
653    fn test_margin_loss_empty_neg_error() {
654        let bl = BatchedTrainingLoop::new();
655        assert!(bl.compute_margin_loss(&[1.0], &[], 1.0).is_err());
656    }
657
658    // ── BatchedTrainingLoop / compute_binary_cross_entropy ────────────────────
659
660    #[test]
661    fn test_bce_positive_loss() {
662        let bl = BatchedTrainingLoop::new();
663        // High positive scores and very negative scores → low loss.
664        let loss = bl
665            .compute_binary_cross_entropy(&[10.0], &[-10.0])
666            .expect("bce");
667        assert!(loss < 0.01, "expected near-zero loss, got {loss}");
668    }
669
670    #[test]
671    fn test_bce_high_loss_when_wrong() {
672        let bl = BatchedTrainingLoop::new();
673        // Negative score high, positive score low → high loss.
674        let loss = bl
675            .compute_binary_cross_entropy(&[-10.0], &[10.0])
676            .expect("bce");
677        assert!(loss > 5.0, "expected high loss, got {loss}");
678    }
679
680    #[test]
681    fn test_bce_empty_pos_error() {
682        let bl = BatchedTrainingLoop::new();
683        assert!(bl.compute_binary_cross_entropy(&[], &[1.0]).is_err());
684    }
685
686    #[test]
687    fn test_bce_empty_neg_error() {
688        let bl = BatchedTrainingLoop::new();
689        assert!(bl.compute_binary_cross_entropy(&[1.0], &[]).is_err());
690    }
691
692    #[test]
693    fn test_bce_symmetric_scores_moderate_loss() {
694        let bl = BatchedTrainingLoop::new();
695        // Zero scores → sigmoid = 0.5, loss = -log(0.5) ≈ 0.693 each.
696        let loss = bl
697            .compute_binary_cross_entropy(&[0.0], &[0.0])
698            .expect("bce");
699        assert!(
700            (loss - std::f64::consts::LN_2).abs() < 0.001,
701            "expected ln(2) ≈ 0.693, got {loss}"
702        );
703    }
704
705    // ── Serialization ─────────────────────────────────────────────────────────
706
707    #[test]
708    fn test_negative_sampler_serialization() {
709        let s = NegativeSampler::SelfAdversarial { temperature: 0.5 };
710        let json = serde_json::to_string(&s).expect("serialize");
711        let s2: NegativeSampler = serde_json::from_str(&json).expect("deserialize");
712        assert_eq!(s, s2);
713    }
714
715    #[test]
716    fn test_training_batch_serialization() {
717        let batch = TrainingBatch {
718            positive_triples: vec![sample_triple()],
719            negative_triples: vec![],
720        };
721        let json = serde_json::to_string(&batch).expect("serialize");
722        let batch2: TrainingBatch = serde_json::from_str(&json).expect("deserialize");
723        assert_eq!(batch2.positive_count(), 1);
724    }
725}