use unicode_segmentation::UnicodeSegmentation;
use crate::traits::{Entity, ExtractionSource, Relation};
pub const DEFAULT_PMI_THRESHOLD: f32 = 1.0;
#[derive(Debug, Clone)]
pub struct CoOccurrenceMiner {
pub threshold: f32,
pub source: ExtractionSource,
}
impl Default for CoOccurrenceMiner {
fn default() -> Self {
Self {
threshold: DEFAULT_PMI_THRESHOLD,
source: ExtractionSource::Statistical,
}
}
}
impl CoOccurrenceMiner {
#[must_use]
pub fn mine(&self, text: &str, entities: &[Entity]) -> Vec<Relation> {
mine_relations(text, entities, self.threshold, self.source.clone())
}
}
#[must_use]
pub fn mine_relations(
text: &str,
entities: &[Entity],
threshold: f32,
source: ExtractionSource,
) -> Vec<Relation> {
if entities.len() < 2 || text.is_empty() {
return Vec::new();
}
let sentences: Vec<(usize, usize)> = text
.split_sentence_bound_indices()
.map(|(start, frag)| (start, start + frag.len()))
.filter(|(s, e)| e > s)
.collect();
if sentences.is_empty() {
return Vec::new();
}
let n = entities.len();
let mut per_entity: Vec<u32> = vec![0; n];
let mut per_pair: std::collections::BTreeMap<(usize, usize), u32> =
std::collections::BTreeMap::new();
let sent_count = sentences.len() as u32;
for (s_start, s_end) in &sentences {
let mut present: Vec<usize> = Vec::new();
for (idx, e) in entities.iter().enumerate() {
let (e_start, e_end) = e.span;
if e_start < *s_end && e_end > *s_start {
present.push(idx);
}
}
if present.is_empty() {
continue;
}
present.sort_unstable();
present.dedup();
for &i in &present {
per_entity[i] += 1;
}
for i in 0..present.len() {
for j in (i + 1)..present.len() {
let a = present[i];
let b = present[j];
let key = if a < b { (a, b) } else { (b, a) };
*per_pair.entry(key).or_insert(0) += 1;
}
}
}
let total = f64::from(sent_count);
let mut out: Vec<Relation> = Vec::with_capacity(per_pair.len());
for ((i, j), c_ij) in per_pair {
let c_i = per_entity[i];
let c_j = per_entity[j];
if c_ij == 0 || c_i == 0 || c_j == 0 {
continue;
}
let p_ij = f64::from(c_ij) / total;
let p_i = f64::from(c_i) / total;
let p_j = f64::from(c_j) / total;
let pmi = (p_ij / (p_i * p_j)).ln();
#[allow(clippy::cast_possible_truncation)]
let pmi_f32 = pmi as f32;
if !pmi_f32.is_finite() || pmi_f32 <= threshold {
continue;
}
let m_i = &entities[i].mention;
let m_j = &entities[j].mention;
let (src, dst) = if lc(m_i) <= lc(m_j) {
(m_i.clone(), m_j.clone())
} else {
(m_j.clone(), m_i.clone())
};
out.push(Relation {
src,
dst,
weight: pmi_f32,
source: source.clone(),
});
}
out.sort_by(|a, b| {
a.src
.cmp(&b.src)
.then_with(|| a.dst.cmp(&b.dst))
.then_with(|| a.weight.to_bits().cmp(&b.weight.to_bits()))
});
out
}
fn lc(s: &str) -> String {
s.chars().flat_map(char::to_lowercase).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn ent(mention: &str, span: (usize, usize)) -> Entity {
Entity {
mention: mention.to_string(),
score: 0.5,
span,
}
}
#[test]
fn empty_inputs_return_empty() {
let out = mine_relations("", &[], 0.0, ExtractionSource::Statistical);
assert!(out.is_empty());
}
#[test]
fn single_entity_returns_empty() {
let text = "The dog ran fast.";
let entities = vec![ent("dog", (4, 7))];
let out = mine_relations(text, &entities, 0.0, ExtractionSource::Statistical);
assert!(out.is_empty());
}
#[test]
fn cooccurring_pair_emits_positive_pmi() {
let text = "Alice met Bob. They shook hands.";
let entities = vec![ent("Alice", (0, 5)), ent("Bob", (10, 13))];
let out = mine_relations(text, &entities, 0.0, ExtractionSource::Statistical);
assert_eq!(out.len(), 1);
assert!(out[0].weight > 0.0);
}
}