use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct RelationCardinality {
pub tph: f32,
pub hpt: f32,
}
impl RelationCardinality {
#[inline]
pub fn head_corrupt_prob(&self) -> f32 {
let denom = self.tph + self.hpt;
if denom == 0.0 {
0.5
} else {
self.tph / denom
}
}
}
pub fn compute_relation_cardinalities(
triples: &[(usize, usize, usize)],
) -> HashMap<usize, RelationCardinality> {
let mut stats: HashMap<usize, (HashSet<usize>, HashSet<usize>, usize)> = HashMap::new();
for &(h, r, t) in triples {
let entry = stats
.entry(r)
.or_insert_with(|| (HashSet::new(), HashSet::new(), 0));
entry.0.insert(h);
entry.1.insert(t);
entry.2 += 1;
}
stats
.into_iter()
.map(|(r, (heads, tails, count))| {
let tph = count as f32 / heads.len().max(1) as f32;
let hpt = count as f32 / tails.len().max(1) as f32;
(r, RelationCardinality { tph, hpt })
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_relation_cardinalities_empty() {
let triples: Vec<(usize, usize, usize)> = vec![];
let cards = compute_relation_cardinalities(&triples);
assert!(cards.is_empty());
}
#[test]
fn compute_relation_cardinalities_one_to_many() {
let triples = vec![(0, 0, 1), (0, 0, 2), (0, 0, 3)];
let cards = compute_relation_cardinalities(&triples);
let c = cards.get(&0).expect("relation 0 should be present");
assert!(
(c.tph - 3.0).abs() < 1e-6,
"tph should be 3.0, got {}",
c.tph
);
assert!(
(c.hpt - 1.0).abs() < 1e-6,
"hpt should be 1.0, got {}",
c.hpt
);
assert!(
(c.head_corrupt_prob() - 0.75).abs() < 1e-6,
"P(corrupt_head) should be 0.75, got {}",
c.head_corrupt_prob()
);
}
#[test]
fn compute_relation_cardinalities_many_to_one() {
let triples = vec![(0, 0, 3), (1, 0, 3), (2, 0, 3)];
let cards = compute_relation_cardinalities(&triples);
let c = cards.get(&0).unwrap();
assert!((c.tph - 1.0).abs() < 1e-6);
assert!((c.hpt - 3.0).abs() < 1e-6);
assert!(
(c.head_corrupt_prob() - 0.25).abs() < 1e-6,
"P(corrupt_head) should be 0.25, got {}",
c.head_corrupt_prob()
);
}
#[test]
fn compute_relation_cardinalities_symmetric() {
let triples = vec![(0, 0, 1), (1, 0, 0)];
let cards = compute_relation_cardinalities(&triples);
let c = cards.get(&0).unwrap();
assert!((c.head_corrupt_prob() - 0.5).abs() < 1e-6);
}
#[test]
fn compute_relation_cardinalities_multiple_relations() {
let triples = vec![
(0, 0, 1),
(0, 0, 2),
(0, 0, 3), (1, 1, 0),
(2, 1, 0),
(3, 1, 0), ];
let cards = compute_relation_cardinalities(&triples);
assert_eq!(cards.len(), 2);
let c0 = cards.get(&0).unwrap();
let c1 = cards.get(&1).unwrap();
assert!((c0.tph - 3.0).abs() < 1e-6);
assert!((c1.hpt - 3.0).abs() < 1e-6);
}
#[test]
fn relation_cardinality_head_corrupt_prob_zero_denom() {
let c = RelationCardinality { tph: 0.0, hpt: 0.0 };
assert!((c.head_corrupt_prob() - 0.5).abs() < 1e-6);
}
}