use crate::dataset::Triple;
#[cfg(feature = "rand")]
use rand::Rng;
use std::collections::{HashMap, HashSet};
use super::NegativeSamplingStrategy;
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
#[cfg(feature = "rand")]
pub fn generate_negative_samples(
triple: &Triple,
entities: &HashSet<String>,
strategy: &NegativeSamplingStrategy,
n: usize,
) -> Vec<Triple> {
let mut rng = rand::rng();
generate_negative_samples_with_rng(triple, entities, strategy, n, &mut rng)
}
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
#[cfg(feature = "rand")]
pub fn generate_negative_samples_with_rng<R: Rng>(
triple: &Triple,
entities: &HashSet<String>,
strategy: &NegativeSamplingStrategy,
n: usize,
rng: &mut R,
) -> Vec<Triple> {
let pool = SortedEntityPool::new(entities);
generate_negative_samples_from_sorted_pool_with_rng(triple, &pool, strategy, n, rng)
}
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
#[cfg(feature = "rand")]
#[derive(Debug, Clone)]
pub struct SortedEntityPool<'a> {
pub(crate) entities: Vec<&'a str>,
}
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
#[cfg(feature = "rand")]
impl<'a> SortedEntityPool<'a> {
pub fn new(entities: &'a HashSet<String>) -> Self {
let mut pool: Vec<&'a str> = entities.iter().map(|s| s.as_str()).collect();
pool.sort();
Self { entities: pool }
}
#[inline]
pub(crate) fn pick<R: Rng>(&self, rng: &mut R) -> &'a str {
let idx = rng.random_range(0..self.entities.len());
self.entities[idx]
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
#[cfg(feature = "rand")]
pub fn generate_negative_samples_from_sorted_pool_with_rng<R: Rng>(
triple: &Triple,
entity_pool: &SortedEntityPool<'_>,
strategy: &NegativeSamplingStrategy,
n: usize,
rng: &mut R,
) -> Vec<Triple> {
let mut negatives = Vec::with_capacity(n);
if entity_pool.entities.is_empty() {
return negatives;
}
for _ in 0..n {
let negative = match strategy {
NegativeSamplingStrategy::Uniform => {
if rng.random::<bool>() {
Triple {
head: entity_pool.pick(rng).to_string(),
relation: triple.relation.clone(),
tail: triple.tail.clone(),
}
} else {
Triple {
head: triple.head.clone(),
relation: triple.relation.clone(),
tail: entity_pool.pick(rng).to_string(),
}
}
}
NegativeSamplingStrategy::CorruptHead => Triple {
head: entity_pool.pick(rng).to_string(),
relation: triple.relation.clone(),
tail: triple.tail.clone(),
},
NegativeSamplingStrategy::CorruptTail => Triple {
head: triple.head.clone(),
relation: triple.relation.clone(),
tail: entity_pool.pick(rng).to_string(),
},
NegativeSamplingStrategy::CorruptBoth => Triple {
head: entity_pool.pick(rng).to_string(),
relation: triple.relation.clone(),
tail: entity_pool.pick(rng).to_string(),
},
};
if negative != *triple {
negatives.push(negative);
}
}
negatives
}
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
#[cfg(feature = "rand")]
pub fn generate_degree_weighted_negatives(
positive_triples: &[(usize, usize, usize)],
entity_degrees: &HashMap<usize, usize>,
num_negatives: usize,
rng: &mut impl rand::Rng,
) -> Vec<(usize, usize, usize)> {
use rand::distr::weighted::WeightedIndex;
use rand::distr::Distribution;
if entity_degrees.is_empty() || positive_triples.is_empty() || num_negatives == 0 {
return Vec::new();
}
let mut entities: Vec<usize> = entity_degrees.keys().copied().collect();
entities.sort_unstable();
let weights: Vec<f64> = entities
.iter()
.map(|id| {
let deg = *entity_degrees.get(id).unwrap_or(&1) as f64;
deg.powf(DEGREE_SMOOTHING_EXPONENT)
})
.collect();
let dist = match WeightedIndex::new(&weights) {
Ok(d) => d,
Err(_) => return Vec::new(), };
let mut negatives = Vec::with_capacity(positive_triples.len() * num_negatives);
for &(h, r, t) in positive_triples {
for _ in 0..num_negatives {
let corrupt_head = rng.random::<bool>();
let sampled = entities[dist.sample(rng)];
let neg = if corrupt_head {
(sampled, r, t)
} else {
(h, r, sampled)
};
if neg != (h, r, t) {
negatives.push(neg);
}
}
}
negatives
}
#[cfg_attr(docsrs, doc(cfg(feature = "rand")))]
#[cfg(feature = "rand")]
pub fn generate_self_adversarial_negatives<F>(
positive_triples: &[(usize, usize, usize)],
scores_fn: F,
entity_ids: &[usize],
temperature: f32,
num_negatives: usize,
rng: &mut impl rand::Rng,
) -> Vec<(usize, usize, usize)>
where
F: Fn(usize, usize) -> f32,
{
use rand::distr::weighted::WeightedIndex;
use rand::distr::Distribution;
if entity_ids.is_empty() || positive_triples.is_empty() || num_negatives == 0 {
return Vec::new();
}
let mut negatives = Vec::with_capacity(positive_triples.len() * num_negatives);
for &(h, r, t) in positive_triples {
let weights: Vec<f64> = entity_ids
.iter()
.map(|&candidate| {
if candidate == t {
0.0 } else {
let score = scores_fn(h, candidate);
(score as f64 / temperature as f64).exp()
}
})
.collect();
let total: f64 = weights.iter().sum();
if total <= 0.0 {
for _ in 0..num_negatives {
let idx = rng.random_range(0..entity_ids.len());
let candidate = entity_ids[idx];
if candidate != t {
negatives.push((h, r, candidate));
}
}
continue;
}
let dist = match WeightedIndex::new(&weights) {
Ok(d) => d,
Err(_) => continue,
};
for _ in 0..num_negatives {
let candidate = entity_ids[dist.sample(rng)];
if candidate != t {
negatives.push((h, r, candidate));
}
}
}
negatives
}
#[cfg(feature = "rand")]
const DEGREE_SMOOTHING_EXPONENT: f64 = 0.75;
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "rand")]
use proptest::prelude::*;
#[cfg(feature = "rand")]
use proptest::proptest;
#[cfg(feature = "rand")]
use std::collections::HashSet;
#[test]
#[cfg(feature = "rand")]
fn test_generate_negative_samples() {
let triple = Triple {
head: "e1".to_string(),
relation: "r1".to_string(),
tail: "e2".to_string(),
};
let entities: HashSet<String> = ["e1", "e2", "e3", "e4"]
.iter()
.map(|s| s.to_string())
.collect();
let negatives = generate_negative_samples(
&triple,
&entities,
&NegativeSamplingStrategy::CorruptTail,
5,
);
assert!(
!negatives.is_empty(),
"Expected at least 1 negative, got {}",
negatives.len()
);
for neg in negatives {
assert_eq!(neg.head, "e1");
assert_eq!(neg.relation, "r1");
assert_ne!(neg.tail, "e2"); }
}
#[cfg(feature = "rand")]
proptest! {
#[test]
fn prop_generate_negative_samples_with_rng_is_deterministic(seed in any::<u64>()) {
use rand::SeedableRng;
use rand::rngs::StdRng;
let triple = Triple {
head: "e1".to_string(),
relation: "r1".to_string(),
tail: "e2".to_string(),
};
let entities: HashSet<String> = ["e1", "e2", "e3", "e4"]
.iter()
.map(|s| s.to_string())
.collect();
let mut rng1 = StdRng::seed_from_u64(seed);
let mut rng2 = StdRng::seed_from_u64(seed);
let a = generate_negative_samples_with_rng(
&triple,
&entities,
&NegativeSamplingStrategy::Uniform,
25,
&mut rng1,
);
let b = generate_negative_samples_with_rng(
&triple,
&entities,
&NegativeSamplingStrategy::Uniform,
25,
&mut rng2,
);
prop_assert_eq!(a, b);
}
}
#[test]
#[cfg(feature = "rand")]
fn test_generate_negative_samples_all_strategies() {
let triple = Triple {
head: "e1".to_string(),
relation: "r1".to_string(),
tail: "e2".to_string(),
};
let entities: HashSet<String> = ["e1", "e2", "e3", "e4", "e5"]
.iter()
.map(|s| s.to_string())
.collect();
for strategy in [
NegativeSamplingStrategy::Uniform,
NegativeSamplingStrategy::CorruptHead,
NegativeSamplingStrategy::CorruptTail,
NegativeSamplingStrategy::CorruptBoth,
] {
let negatives = generate_negative_samples(&triple, &entities, &strategy, 10);
assert!(
!negatives.is_empty(),
"Strategy {:?} should generate negatives",
strategy
);
for neg in &negatives {
assert_ne!(neg, &triple, "Negative should differ from positive");
}
}
}
#[test]
#[cfg(feature = "rand")]
fn sorted_entity_pool_is_sorted_and_stable() {
let entities: HashSet<String> = ["charlie", "alice", "bob", "delta"]
.iter()
.map(|s| s.to_string())
.collect();
let pool = SortedEntityPool::new(&entities);
let sorted: Vec<&str> = pool.entities.clone();
let mut expected = sorted.clone();
expected.sort();
assert_eq!(sorted, expected, "pool must be sorted lexicographically");
}
#[test]
#[cfg(feature = "rand")]
fn sorted_entity_pool_pick_returns_pool_member() {
use rand::rngs::StdRng;
use rand::SeedableRng;
let entities: HashSet<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
let pool = SortedEntityPool::new(&entities);
let mut rng = StdRng::seed_from_u64(42);
for _ in 0..20 {
let picked = pool.pick(&mut rng);
assert!(
entities.contains(picked),
"picked entity '{picked}' not in original set"
);
}
}
#[test]
#[cfg(feature = "rand")]
fn negative_samples_empty_pool_returns_empty() {
let entities: HashSet<String> = HashSet::new();
let triple = Triple {
head: "h".into(),
relation: "r".into(),
tail: "t".into(),
};
let negatives = generate_negative_samples(
&triple,
&entities,
&NegativeSamplingStrategy::CorruptTail,
5,
);
assert!(
negatives.is_empty(),
"empty entity set should yield no negatives"
);
}
#[test]
#[cfg(feature = "rand")]
fn negative_samples_single_entity_may_produce_none() {
let entities: HashSet<String> = ["t"].iter().map(|s| s.to_string()).collect();
let triple = Triple {
head: "h".into(),
relation: "r".into(),
tail: "t".into(),
};
let negatives = generate_negative_samples(
&triple,
&entities,
&NegativeSamplingStrategy::CorruptTail,
10,
);
assert!(
negatives.is_empty(),
"single-entity pool matching the positive tail should yield no negatives"
);
}
#[test]
#[cfg(feature = "rand")]
fn negative_samples_corrupt_head_preserves_tail_and_relation() {
use rand::rngs::StdRng;
use rand::SeedableRng;
let entities: HashSet<String> = ["e1", "e2", "e3", "e4"]
.iter()
.map(|s| s.to_string())
.collect();
let triple = Triple {
head: "e1".into(),
relation: "r".into(),
tail: "e2".into(),
};
let mut rng = StdRng::seed_from_u64(99);
let negatives = generate_negative_samples_with_rng(
&triple,
&entities,
&NegativeSamplingStrategy::CorruptHead,
20,
&mut rng,
);
for neg in &negatives {
assert_eq!(neg.tail, "e2", "CorruptHead must preserve tail");
assert_eq!(neg.relation, "r", "CorruptHead must preserve relation");
}
}
#[test]
#[cfg(feature = "rand")]
fn negative_samples_corrupt_both_may_change_head_and_tail() {
use rand::rngs::StdRng;
use rand::SeedableRng;
let entities: HashSet<String> = (0..20).map(|i| format!("e{i}")).collect();
let triple = Triple {
head: "e0".into(),
relation: "r".into(),
tail: "e1".into(),
};
let mut rng = StdRng::seed_from_u64(123);
let negatives = generate_negative_samples_with_rng(
&triple,
&entities,
&NegativeSamplingStrategy::CorruptBoth,
50,
&mut rng,
);
let any_head_changed = negatives.iter().any(|n| n.head != "e0");
let any_tail_changed = negatives.iter().any(|n| n.tail != "e1");
assert!(
any_head_changed,
"CorruptBoth should change head at least sometimes"
);
assert!(
any_tail_changed,
"CorruptBoth should change tail at least sometimes"
);
}
#[test]
#[cfg(feature = "rand")]
fn negative_sampling_uniformity() {
use rand::SeedableRng;
use std::collections::HashMap;
let triple = Triple {
head: "e0".to_string(),
relation: "r".to_string(),
tail: "e0".to_string(),
};
let entities: HashSet<String> = (0..5).map(|i| format!("e{i}")).collect();
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let negatives = generate_negative_samples_with_rng(
&triple,
&entities,
&NegativeSamplingStrategy::CorruptTail,
1000,
&mut rng,
);
let mut counts: HashMap<String, usize> = HashMap::new();
for neg in &negatives {
*counts.entry(neg.tail.clone()).or_insert(0) += 1;
}
assert!(
!counts.contains_key("e0"),
"positive tail should be filtered: counts={counts:?}"
);
for i in 1..5 {
let key = format!("e{i}");
let c = counts.get(&key).copied().unwrap_or(0);
assert!(
c >= 100,
"entity {key} appeared only {c} times out of {} negatives (expected >= 100)",
negatives.len()
);
}
}
}