use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct RelationCardinality {
pub tph: f32,
pub hpt: f32,
}
#[derive(Debug, Clone, Default)]
pub struct RelationEntityPools {
pub heads: Vec<usize>,
pub tails: Vec<usize>,
}
impl RelationCardinality {
#[inline]
pub fn head_corrupt_prob(&self) -> f32 {
let denom = self.tph + self.hpt;
if denom == 0.0 {
0.5
} else {
self.tph / denom
}
}
}
pub fn compute_relation_cardinalities(
triples: &[(usize, usize, usize)],
) -> HashMap<usize, RelationCardinality> {
let mut stats: HashMap<usize, (HashSet<usize>, HashSet<usize>, usize)> = HashMap::new();
for &(h, r, t) in triples {
let entry = stats
.entry(r)
.or_insert_with(|| (HashSet::new(), HashSet::new(), 0));
entry.0.insert(h);
entry.1.insert(t);
entry.2 += 1;
}
stats
.into_iter()
.map(|(r, (heads, tails, count))| {
let tph = count as f32 / heads.len().max(1) as f32;
let hpt = count as f32 / tails.len().max(1) as f32;
(r, RelationCardinality { tph, hpt })
})
.collect()
}
pub fn compute_relation_entity_pools(
triples: &[(usize, usize, usize)],
) -> HashMap<usize, RelationEntityPools> {
let mut stats: HashMap<usize, (HashSet<usize>, HashSet<usize>)> = HashMap::new();
for &(h, r, t) in triples {
let entry = stats
.entry(r)
.or_insert_with(|| (HashSet::new(), HashSet::new()));
entry.0.insert(h);
entry.1.insert(t);
}
stats
.into_iter()
.map(|(r, (heads, tails))| {
let mut heads: Vec<usize> = heads.into_iter().collect();
let mut tails: Vec<usize> = tails.into_iter().collect();
heads.sort_unstable();
tails.sort_unstable();
(r, RelationEntityPools { heads, tails })
})
.collect()
}
pub fn sample_excluding<F>(
candidates: &[usize],
exclude: usize,
mut sample_index: F,
) -> Option<usize>
where
F: FnMut(usize) -> usize,
{
if candidates.is_empty() {
return None;
}
for _ in 0..8 {
let idx = sample_index(candidates.len());
if let Some(&candidate) = candidates.get(idx) {
if candidate != exclude {
return Some(candidate);
}
}
}
candidates
.iter()
.copied()
.find(|&candidate| candidate != exclude)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_relation_cardinalities_empty() {
let triples: Vec<(usize, usize, usize)> = vec![];
let cards = compute_relation_cardinalities(&triples);
assert!(cards.is_empty());
}
#[test]
fn compute_relation_cardinalities_one_to_many() {
let triples = vec![(0, 0, 1), (0, 0, 2), (0, 0, 3)];
let cards = compute_relation_cardinalities(&triples);
let c = cards.get(&0).expect("relation 0 should be present");
assert!(
(c.tph - 3.0).abs() < 1e-6,
"tph should be 3.0, got {}",
c.tph
);
assert!(
(c.hpt - 1.0).abs() < 1e-6,
"hpt should be 1.0, got {}",
c.hpt
);
assert!(
(c.head_corrupt_prob() - 0.75).abs() < 1e-6,
"P(corrupt_head) should be 0.75, got {}",
c.head_corrupt_prob()
);
}
#[test]
fn compute_relation_cardinalities_many_to_one() {
let triples = vec![(0, 0, 3), (1, 0, 3), (2, 0, 3)];
let cards = compute_relation_cardinalities(&triples);
let c = cards.get(&0).unwrap();
assert!((c.tph - 1.0).abs() < 1e-6);
assert!((c.hpt - 3.0).abs() < 1e-6);
assert!(
(c.head_corrupt_prob() - 0.25).abs() < 1e-6,
"P(corrupt_head) should be 0.25, got {}",
c.head_corrupt_prob()
);
}
#[test]
fn compute_relation_cardinalities_symmetric() {
let triples = vec![(0, 0, 1), (1, 0, 0)];
let cards = compute_relation_cardinalities(&triples);
let c = cards.get(&0).unwrap();
assert!((c.head_corrupt_prob() - 0.5).abs() < 1e-6);
}
#[test]
fn compute_relation_cardinalities_multiple_relations() {
let triples = vec![
(0, 0, 1),
(0, 0, 2),
(0, 0, 3), (1, 1, 0),
(2, 1, 0),
(3, 1, 0), ];
let cards = compute_relation_cardinalities(&triples);
assert_eq!(cards.len(), 2);
let c0 = cards.get(&0).unwrap();
let c1 = cards.get(&1).unwrap();
assert!((c0.tph - 3.0).abs() < 1e-6);
assert!((c1.hpt - 3.0).abs() < 1e-6);
}
#[test]
fn compute_relation_entity_pools_deduplicates_and_sorts() {
let triples = vec![(2, 0, 3), (1, 0, 3), (2, 0, 4), (1, 0, 4)];
let pools = compute_relation_entity_pools(&triples);
let p = pools.get(&0).expect("relation 0 should exist");
assert_eq!(p.heads, vec![1, 2]);
assert_eq!(p.tails, vec![3, 4]);
}
#[test]
fn sample_excluding_skips_target_and_falls_back() {
let candidates = vec![3, 4, 5];
let picked = sample_excluding(&candidates, 4, |_| 1).expect("should pick valid candidate");
assert_ne!(picked, 4);
let singleton = vec![7];
assert_eq!(sample_excluding(&singleton, 7, |_| 0), None);
}
#[test]
fn relation_cardinality_head_corrupt_prob_zero_denom() {
let c = RelationCardinality { tph: 0.0, hpt: 0.0 };
assert!((c.head_corrupt_prob() - 0.5).abs() < 1e-6);
}
#[test]
fn sample_excluding_empty_returns_none() {
let empty: Vec<usize> = vec![];
assert_eq!(sample_excluding(&empty, 0, |_| 0), None);
}
#[test]
fn sample_excluding_all_valid_returns_some() {
let candidates = vec![1, 2, 3, 4, 5];
let mut call_count = 0usize;
let picked = sample_excluding(&candidates, 99, |n| {
call_count += 1;
0 % n });
assert_eq!(picked, Some(1));
assert_eq!(call_count, 1); }
#[test]
fn relation_entity_pools_multiple_relations() {
let triples = vec![
(0, 0, 1),
(0, 0, 2), (3, 1, 4),
(5, 1, 4), ];
let pools = compute_relation_entity_pools(&triples);
assert_eq!(pools.len(), 2);
let p0 = pools.get(&0).unwrap();
assert_eq!(p0.heads, vec![0]);
assert_eq!(p0.tails, vec![1, 2]);
let p1 = pools.get(&1).unwrap();
assert_eq!(p1.heads, vec![3, 5]);
assert_eq!(p1.tails, vec![4]);
}
#[test]
fn type_constrained_sampling_draws_from_pool() {
let pool = vec![10, 20, 30];
let mut rng_idx = 0usize;
let sequence = [1, 0, 2]; let picked = sample_excluding(&pool, 20, |_n| {
let idx = sequence[rng_idx % sequence.len()];
rng_idx += 1;
idx
});
assert_eq!(picked, Some(10));
}
}