1use anyhow::{anyhow, Result};
19use serde::{Deserialize, Serialize};
20
21use crate::{NamedNode, Triple};
22
23struct Lcg {
31 state: u64,
32}
33
34impl Lcg {
35 fn new(seed: u64) -> Self {
36 Self {
37 state: seed ^ 0x6c62_272e_07bb_0142,
38 }
39 }
40
41 fn next_usize(&mut self, modulus: usize) -> usize {
43 self.state = self
45 .state
46 .wrapping_mul(6_364_136_223_846_793_005)
47 .wrapping_add(1_442_695_040_888_963_407);
48 ((self.state >> 33) as usize) % modulus
49 }
50
51 fn next_f64(&mut self) -> f64 {
53 self.state = self
54 .state
55 .wrapping_mul(6_364_136_223_846_793_005)
56 .wrapping_add(1_442_695_040_888_963_407);
57 (self.state >> 11) as f64 / (1u64 << 53) as f64
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
67pub enum NegativeSampler {
68 Uniform,
70 TypeConstrained,
73 SelfAdversarial {
76 temperature: f64,
79 },
80}
81
82#[derive(Debug, Clone)]
92pub struct KgCompletionTask {
93 known_entities: Vec<String>,
95 head_entities: Vec<String>,
97 tail_entities: Vec<String>,
99}
100
101impl KgCompletionTask {
102 pub fn new(known_entities: Vec<String>) -> Self {
104 let head_entities = known_entities.clone();
105 let tail_entities = known_entities.clone();
106 Self {
107 known_entities,
108 head_entities,
109 tail_entities,
110 }
111 }
112
113 pub fn with_type_constraints(
115 known_entities: Vec<String>,
116 head_entities: Vec<String>,
117 tail_entities: Vec<String>,
118 ) -> Self {
119 Self {
120 known_entities,
121 head_entities,
122 tail_entities,
123 }
124 }
125
126 pub fn from_triples(triples: &[Triple]) -> Self {
130 let mut all: std::collections::HashSet<String> = std::collections::HashSet::new();
131 let mut heads: std::collections::HashSet<String> = std::collections::HashSet::new();
132 let mut tails: std::collections::HashSet<String> = std::collections::HashSet::new();
133
134 for t in triples {
135 all.insert(t.subject.iri.clone());
136 all.insert(t.predicate.iri.clone());
137 all.insert(t.object.iri.clone());
138 heads.insert(t.subject.iri.clone());
139 tails.insert(t.object.iri.clone());
140 }
141
142 let mut known: Vec<String> = all.into_iter().collect();
143 let mut head_vec: Vec<String> = heads.into_iter().collect();
144 let mut tail_vec: Vec<String> = tails.into_iter().collect();
145 known.sort_unstable();
146 head_vec.sort_unstable();
147 tail_vec.sort_unstable();
148
149 Self {
150 known_entities: known,
151 head_entities: head_vec,
152 tail_entities: tail_vec,
153 }
154 }
155
156 pub fn sample_negatives(
164 &self,
165 triple: &Triple,
166 _entity_count: usize,
167 n: usize,
168 strategy: &NegativeSampler,
169 ) -> Vec<Triple> {
170 if self.known_entities.is_empty() || n == 0 {
171 return Vec::new();
172 }
173
174 let seed: u64 = triple
177 .subject
178 .iri
179 .bytes()
180 .chain(triple.predicate.iri.bytes())
181 .chain(triple.object.iri.bytes())
182 .enumerate()
183 .fold(0u64, |acc, (i, b)| {
184 acc.wrapping_add((b as u64).wrapping_mul(i as u64 + 1))
185 });
186 let mut rng = Lcg::new(seed);
187
188 match strategy {
189 NegativeSampler::Uniform => self.sample_uniform(triple, n, &mut rng),
190 NegativeSampler::TypeConstrained => self.sample_type_constrained(triple, n, &mut rng),
191 NegativeSampler::SelfAdversarial { temperature } => {
192 self.sample_self_adversarial(triple, n, *temperature, &mut rng)
193 }
194 }
195 }
196
197 fn sample_uniform(&self, triple: &Triple, n: usize, rng: &mut Lcg) -> Vec<Triple> {
200 let pool = &self.known_entities;
201 let mut result = Vec::with_capacity(n);
202 let mut attempts = 0usize;
203 while result.len() < n && attempts < n * 10 {
204 attempts += 1;
205 let idx = rng.next_usize(pool.len());
206 let replacement = &pool[idx];
207 let neg = if rng.next_usize(2) == 0 {
209 make_triple(replacement, &triple.predicate.iri, &triple.object.iri)
210 } else {
211 make_triple(&triple.subject.iri, &triple.predicate.iri, replacement)
212 };
213 if is_different(&neg, triple) {
215 result.push(neg);
216 }
217 }
218 result
219 }
220
221 fn sample_type_constrained(&self, triple: &Triple, n: usize, rng: &mut Lcg) -> Vec<Triple> {
222 let heads = if self.head_entities.is_empty() {
223 &self.known_entities
224 } else {
225 &self.head_entities
226 };
227 let tails = if self.tail_entities.is_empty() {
228 &self.known_entities
229 } else {
230 &self.tail_entities
231 };
232
233 let mut result = Vec::with_capacity(n);
234 let mut attempts = 0usize;
235 while result.len() < n && attempts < n * 10 {
236 attempts += 1;
237 let neg = if rng.next_usize(2) == 0 {
238 let idx = rng.next_usize(heads.len());
240 make_triple(&heads[idx], &triple.predicate.iri, &triple.object.iri)
241 } else {
242 let idx = rng.next_usize(tails.len());
244 make_triple(&triple.subject.iri, &triple.predicate.iri, &tails[idx])
245 };
246 if is_different(&neg, triple) {
247 result.push(neg);
248 }
249 }
250 result
251 }
252
253 fn sample_self_adversarial(
254 &self,
255 triple: &Triple,
256 n: usize,
257 temperature: f64,
258 rng: &mut Lcg,
259 ) -> Vec<Triple> {
260 let pool = &self.known_entities;
261 if pool.is_empty() {
262 return Vec::new();
263 }
264
265 let temp = temperature.max(1e-6);
268 let raw_scores: Vec<f64> = pool
269 .iter()
270 .enumerate()
271 .map(|(i, _)| {
272 1.0 / (i as f64 + 1.0)
274 })
275 .collect();
276
277 let max_score = raw_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
279 let exp_scores: Vec<f64> = raw_scores
280 .iter()
281 .map(|s| ((s - max_score) / temp).exp())
282 .collect();
283 let sum_exp: f64 = exp_scores.iter().sum();
284
285 let mut cdf: Vec<f64> = Vec::with_capacity(pool.len());
287 let mut cumsum = 0.0_f64;
288 for s in &exp_scores {
289 cumsum += s / sum_exp;
290 cdf.push(cumsum);
291 }
292
293 let mut result = Vec::with_capacity(n);
294 let mut attempts = 0usize;
295 while result.len() < n && attempts < n * 10 {
296 attempts += 1;
297 let u = rng.next_f64();
298 let idx = cdf.iter().position(|&c| u <= c).unwrap_or(pool.len() - 1);
299 let replacement = &pool[idx];
300
301 let neg = if rng.next_usize(2) == 0 {
302 make_triple(replacement, &triple.predicate.iri, &triple.object.iri)
303 } else {
304 make_triple(&triple.subject.iri, &triple.predicate.iri, replacement)
305 };
306 if is_different(&neg, triple) {
307 result.push(neg);
308 }
309 }
310 result
311 }
312}
313
314#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct TrainingBatch {
321 pub positive_triples: Vec<Triple>,
323 pub negative_triples: Vec<Triple>,
325}
326
327impl TrainingBatch {
328 pub fn positive_count(&self) -> usize {
330 self.positive_triples.len()
331 }
332
333 pub fn negative_count(&self) -> usize {
335 self.negative_triples.len()
336 }
337}
338
339#[derive(Debug, Clone, Default)]
344pub struct BatchedTrainingLoop;
345
346impl BatchedTrainingLoop {
347 pub fn new() -> Self {
349 Self
350 }
351
352 pub fn prepare_batch(
359 &self,
360 task: &KgCompletionTask,
361 positives: &[Triple],
362 neg_ratio: u32,
363 sampler: &NegativeSampler,
364 ) -> Result<TrainingBatch> {
365 if positives.is_empty() {
366 return Err(anyhow!("positives must not be empty"));
367 }
368 let mut negatives = Vec::with_capacity(positives.len() * neg_ratio as usize);
369 for triple in positives {
370 let mut neg_samples = task.sample_negatives(
371 triple,
372 task.known_entities.len(),
373 neg_ratio as usize,
374 sampler,
375 );
376 negatives.append(&mut neg_samples);
377 }
378
379 Ok(TrainingBatch {
380 positive_triples: positives.to_vec(),
381 negative_triples: negatives,
382 })
383 }
384
385 pub fn compute_margin_loss(
395 &self,
396 pos_scores: &[f64],
397 neg_scores: &[f64],
398 margin: f64,
399 ) -> Result<f64> {
400 if pos_scores.is_empty() {
401 return Err(anyhow!("pos_scores must not be empty"));
402 }
403 if neg_scores.is_empty() {
404 return Err(anyhow!("neg_scores must not be empty"));
405 }
406
407 let n_neg = neg_scores.len();
409 let loss: f64 = pos_scores
410 .iter()
411 .enumerate()
412 .flat_map(|(i, &pos)| {
413 neg_scores.iter().enumerate().map(move |(j, &neg)| {
415 let _ = (i, j); (margin - pos + neg).max(0.0)
417 })
418 })
419 .sum();
420
421 Ok(loss / (pos_scores.len() * n_neg) as f64)
422 }
423
424 pub fn compute_binary_cross_entropy(
432 &self,
433 pos_scores: &[f64],
434 neg_scores: &[f64],
435 ) -> Result<f64> {
436 if pos_scores.is_empty() {
437 return Err(anyhow!("pos_scores must not be empty"));
438 }
439 if neg_scores.is_empty() {
440 return Err(anyhow!("neg_scores must not be empty"));
441 }
442
443 let sigmoid = |x: f64| 1.0 / (1.0 + (-x).exp());
444 let eps = 1e-12_f64;
445
446 let pos_loss: f64 = pos_scores
447 .iter()
448 .map(|&s| -(sigmoid(s).max(eps).ln()))
449 .sum();
450 let neg_loss: f64 = neg_scores
451 .iter()
452 .map(|&s| -((1.0 - sigmoid(s)).max(eps).ln()))
453 .sum();
454
455 let n = (pos_scores.len() + neg_scores.len()) as f64;
456 Ok((pos_loss + neg_loss) / n)
457 }
458}
459
460fn make_triple(subject: &str, predicate: &str, object: &str) -> Triple {
465 Triple::new(
466 NamedNode {
467 iri: subject.to_string(),
468 },
469 NamedNode {
470 iri: predicate.to_string(),
471 },
472 NamedNode {
473 iri: object.to_string(),
474 },
475 )
476}
477
478fn is_different(a: &Triple, b: &Triple) -> bool {
479 a.subject.iri != b.subject.iri
480 || a.predicate.iri != b.predicate.iri
481 || a.object.iri != b.object.iri
482}
483
484#[cfg(test)]
489mod tests {
490 use super::*;
491
492 fn sample_entities() -> Vec<String> {
493 (0..10).map(|i| format!("entity_{i}")).collect()
494 }
495
496 fn sample_triple() -> Triple {
497 make_triple("entity_0", "relation_A", "entity_1")
498 }
499
500 #[test]
503 fn test_uniform_sampling_returns_correct_count() {
504 let task = KgCompletionTask::new(sample_entities());
505 let positive = sample_triple();
506 let negatives = task.sample_negatives(&positive, 10, 5, &NegativeSampler::Uniform);
507 assert_eq!(negatives.len(), 5);
508 }
509
510 #[test]
511 fn test_uniform_negatives_differ_from_positive() {
512 let task = KgCompletionTask::new(sample_entities());
513 let positive = sample_triple();
514 let negatives = task.sample_negatives(&positive, 10, 8, &NegativeSampler::Uniform);
515 for neg in &negatives {
516 assert!(is_different(neg, &positive), "negative == positive");
517 }
518 }
519
520 #[test]
521 fn test_type_constrained_sampling() {
522 let entities = sample_entities();
523 let heads = vec!["entity_0".into(), "entity_2".into(), "entity_4".into()];
524 let tails = vec!["entity_1".into(), "entity_3".into(), "entity_5".into()];
525 let task = KgCompletionTask::with_type_constraints(entities, heads.clone(), tails.clone());
526 let positive = sample_triple();
527 let negatives = task.sample_negatives(&positive, 10, 6, &NegativeSampler::TypeConstrained);
528 assert!(!negatives.is_empty());
529 for neg in &negatives {
530 let head_ok = heads.contains(&neg.subject.iri);
532 let tail_ok = tails.contains(&neg.object.iri);
533 assert!(
534 head_ok || tail_ok,
535 "corrupted entity not in allowed pool: {neg:?}"
536 );
537 }
538 }
539
540 #[test]
541 fn test_self_adversarial_sampling() {
542 let task = KgCompletionTask::new(sample_entities());
543 let positive = sample_triple();
544 let negatives = task.sample_negatives(
545 &positive,
546 10,
547 6,
548 &NegativeSampler::SelfAdversarial { temperature: 0.5 },
549 );
550 assert_eq!(negatives.len(), 6);
551 for neg in &negatives {
552 assert!(is_different(neg, &positive));
553 }
554 }
555
556 #[test]
557 fn test_sampling_empty_entity_pool() {
558 let task = KgCompletionTask::new(vec![]);
559 let positive = sample_triple();
560 let negatives = task.sample_negatives(&positive, 0, 5, &NegativeSampler::Uniform);
561 assert!(negatives.is_empty());
562 }
563
564 #[test]
565 fn test_sampling_n_zero() {
566 let task = KgCompletionTask::new(sample_entities());
567 let positive = sample_triple();
568 let negatives = task.sample_negatives(&positive, 10, 0, &NegativeSampler::Uniform);
569 assert!(negatives.is_empty());
570 }
571
572 #[test]
573 fn test_from_triples_builds_pools() {
574 let triples = vec![
575 make_triple("alice", "knows", "bob"),
576 make_triple("bob", "knows", "charlie"),
577 ];
578 let task = KgCompletionTask::from_triples(&triples);
579 assert!(task.known_entities.contains(&"alice".to_string()));
580 assert!(task.head_entities.contains(&"alice".to_string()));
581 assert!(task.tail_entities.contains(&"bob".to_string()));
582 }
583
584 #[test]
587 fn test_prepare_batch_basic() {
588 let task = KgCompletionTask::new(sample_entities());
589 let positives = vec![sample_triple()];
590 let batch_loop = BatchedTrainingLoop::new();
591 let batch = batch_loop
592 .prepare_batch(&task, &positives, 3, &NegativeSampler::Uniform)
593 .expect("batch");
594 assert_eq!(batch.positive_count(), 1);
595 assert!(!batch.negative_triples.is_empty());
598 }
599
600 #[test]
601 fn test_prepare_batch_empty_positives_error() {
602 let task = KgCompletionTask::new(sample_entities());
603 let batch_loop = BatchedTrainingLoop::new();
604 let result = batch_loop.prepare_batch(&task, &[], 3, &NegativeSampler::Uniform);
605 assert!(result.is_err());
606 }
607
608 #[test]
609 fn test_training_batch_counts() {
610 let batch = TrainingBatch {
611 positive_triples: vec![sample_triple(), sample_triple()],
612 negative_triples: vec![sample_triple(); 6],
613 };
614 assert_eq!(batch.positive_count(), 2);
615 assert_eq!(batch.negative_count(), 6);
616 }
617
618 #[test]
621 fn test_margin_loss_zero_when_pos_larger() {
622 let bl = BatchedTrainingLoop::new();
623 let loss = bl.compute_margin_loss(&[10.0], &[1.0], 1.0).expect("loss");
625 assert!((loss).abs() < 1e-9, "expected 0 loss, got {loss}");
626 }
627
628 #[test]
629 fn test_margin_loss_positive_when_neg_larger() {
630 let bl = BatchedTrainingLoop::new();
631 let loss = bl.compute_margin_loss(&[1.0], &[10.0], 1.0).expect("loss");
633 assert!(loss > 0.0, "expected positive loss, got {loss}");
634 }
635
636 #[test]
637 fn test_margin_loss_multiple_pairs() {
638 let bl = BatchedTrainingLoop::new();
639 let pos = vec![5.0, 5.0];
640 let neg = vec![4.0, 3.0];
641 let loss = bl.compute_margin_loss(&pos, &neg, 1.0).expect("loss");
643 assert!((loss).abs() < 1e-9);
644 }
645
646 #[test]
647 fn test_margin_loss_empty_pos_error() {
648 let bl = BatchedTrainingLoop::new();
649 assert!(bl.compute_margin_loss(&[], &[1.0], 1.0).is_err());
650 }
651
652 #[test]
653 fn test_margin_loss_empty_neg_error() {
654 let bl = BatchedTrainingLoop::new();
655 assert!(bl.compute_margin_loss(&[1.0], &[], 1.0).is_err());
656 }
657
658 #[test]
661 fn test_bce_positive_loss() {
662 let bl = BatchedTrainingLoop::new();
663 let loss = bl
665 .compute_binary_cross_entropy(&[10.0], &[-10.0])
666 .expect("bce");
667 assert!(loss < 0.01, "expected near-zero loss, got {loss}");
668 }
669
670 #[test]
671 fn test_bce_high_loss_when_wrong() {
672 let bl = BatchedTrainingLoop::new();
673 let loss = bl
675 .compute_binary_cross_entropy(&[-10.0], &[10.0])
676 .expect("bce");
677 assert!(loss > 5.0, "expected high loss, got {loss}");
678 }
679
680 #[test]
681 fn test_bce_empty_pos_error() {
682 let bl = BatchedTrainingLoop::new();
683 assert!(bl.compute_binary_cross_entropy(&[], &[1.0]).is_err());
684 }
685
686 #[test]
687 fn test_bce_empty_neg_error() {
688 let bl = BatchedTrainingLoop::new();
689 assert!(bl.compute_binary_cross_entropy(&[1.0], &[]).is_err());
690 }
691
692 #[test]
693 fn test_bce_symmetric_scores_moderate_loss() {
694 let bl = BatchedTrainingLoop::new();
695 let loss = bl
697 .compute_binary_cross_entropy(&[0.0], &[0.0])
698 .expect("bce");
699 assert!(
700 (loss - std::f64::consts::LN_2).abs() < 0.001,
701 "expected ln(2) ≈ 0.693, got {loss}"
702 );
703 }
704
705 #[test]
708 fn test_negative_sampler_serialization() {
709 let s = NegativeSampler::SelfAdversarial { temperature: 0.5 };
710 let json = serde_json::to_string(&s).expect("serialize");
711 let s2: NegativeSampler = serde_json::from_str(&json).expect("deserialize");
712 assert_eq!(s, s2);
713 }
714
715 #[test]
716 fn test_training_batch_serialization() {
717 let batch = TrainingBatch {
718 positive_triples: vec![sample_triple()],
719 negative_triples: vec![],
720 };
721 let json = serde_json::to_string(&batch).expect("serialize");
722 let batch2: TrainingBatch = serde_json::from_str(&json).expect("deserialize");
723 assert_eq!(batch2.positive_count(), 1);
724 }
725}