use crate::error::Result;
use crate::topic_modeling::Topic;
use scirs2_core::ndarray::Array2;
use std::collections::{HashMap, HashSet};
pub struct TopicCoherence {
window_size: usize,
_min_count: usize,
epsilon: f64,
}
impl Default for TopicCoherence {
fn default() -> Self {
Self {
window_size: 10,
_min_count: 5, epsilon: 1e-12,
}
}
}
type DocFreqMap = HashMap<String, usize>;
type CoDocFreqMap = HashMap<(String, String), usize>;
impl TopicCoherence {
pub fn new() -> Self {
Self::default()
}
pub fn with_window_size(mut self, windowsize: usize) -> Self {
self.window_size = windowsize;
self
}
pub fn cv_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
let top_words_per_topic: Vec<Vec<String>> = topics
.iter()
.map(|topic| {
topic
.top_words
.iter()
.map(|(word_, _)| word_.clone())
.collect()
})
.collect();
let (doc_freq, co_doc_freq) =
self.calculate_document_frequencies(&top_words_per_topic, documents)?;
let mut coherence_scores = Vec::new();
for topic_word_s in &top_words_per_topic {
let topic_coherence = self.calculate_topic_coherence_cv(
topic_word_s,
&doc_freq,
&co_doc_freq,
documents.len(),
)?;
coherence_scores.push(topic_coherence);
}
let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
Ok(avg_coherence)
}
pub fn umass_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
let doc_sets: Vec<HashSet<String>> = documents
.iter()
.map(|doc| doc.iter().cloned().collect())
.collect();
let mut coherence_scores = Vec::new();
for topic in topics {
let top_words: Vec<&String> = topic.top_words.iter().map(|(word_, _)| word_).collect();
let topic_coherence = self.calculate_topic_coherence_umass(&top_words, &doc_sets)?;
coherence_scores.push(topic_coherence);
}
let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
Ok(avg_coherence)
}
pub fn uci_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
let (word_freq, co_occurrence) = self.build_co_occurrence_matrix(documents)?;
let mut coherence_scores = Vec::new();
for topic in topics {
let top_words: Vec<&String> = topic.top_words.iter().map(|(word_, _)| word_).collect();
let topic_coherence =
self.calculate_topic_coherence_uci(&top_words, &word_freq, &co_occurrence)?;
coherence_scores.push(topic_coherence);
}
let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
Ok(avg_coherence)
}
fn calculate_document_frequencies(
&self,
topics: &[Vec<String>],
documents: &[Vec<String>],
) -> Result<(DocFreqMap, CoDocFreqMap)> {
let mut doc_freq: HashMap<String, usize> = HashMap::new();
let mut co_doc_freq: HashMap<(String, String), usize> = HashMap::new();
let mut all_word_s: HashSet<String> = HashSet::new();
for topic in topics {
for word_ in topic {
all_word_s.insert(word_.clone());
}
}
for doc in documents {
let doc_set: HashSet<String> = doc.iter().cloned().collect();
for word_ in &all_word_s {
if doc_set.contains(word_) {
*doc_freq.entry(word_.clone()).or_insert(0) += 1;
}
}
let word_s_vec: Vec<&String> = all_word_s.iter().collect();
for i in 0..word_s_vec.len() {
for j in (i + 1)..word_s_vec.len() {
let word_1 = word_s_vec[i];
let word_2 = word_s_vec[j];
if doc_set.contains(word_1) && doc_set.contains(word_2) {
let key = if word_1 < word_2 {
(word_1.clone(), word_2.clone())
} else {
(word_2.clone(), word_1.clone())
};
*co_doc_freq.entry(key).or_insert(0) += 1;
}
}
}
}
Ok((doc_freq, co_doc_freq))
}
fn calculate_topic_coherence_cv(
&self,
topic_word_s: &[String],
doc_freq: &HashMap<String, usize>,
co_doc_freq: &HashMap<(String, String), usize>,
n_docs: usize,
) -> Result<f64> {
let mut scores = Vec::new();
for i in 0..topic_word_s.len() {
for j in (i + 1)..topic_word_s.len() {
let word_1 = &topic_word_s[i];
let word_2 = &topic_word_s[j];
let freq1 = doc_freq.get(word_1).copied().unwrap_or(0) as f64;
let freq2 = doc_freq.get(word_2).copied().unwrap_or(0) as f64;
let co_freq = co_doc_freq
.get(&if word_1 < word_2 {
(word_1.clone(), word_2.clone())
} else {
(word_2.clone(), word_1.clone())
})
.copied()
.unwrap_or(0) as f64;
let npmi = self.calculate_npmi(freq1, freq2, co_freq, n_docs as f64);
scores.push(npmi);
}
}
if scores.is_empty() {
Ok(0.0)
} else {
Ok(scores.iter().sum::<f64>() / scores.len() as f64)
}
}
fn calculate_topic_coherence_umass(
&self,
topic_word_s: &[&String],
doc_sets: &[HashSet<String>],
) -> Result<f64> {
let mut scores = Vec::new();
for i in 1..topic_word_s.len() {
for j in 0..i {
let word_i = topic_word_s[i];
let word_j = topic_word_s[j];
let mut count_j = 0;
let mut count_both = 0;
for doc_set in doc_sets {
let has_i = doc_set.contains(word_i);
let has_j = doc_set.contains(word_j);
if has_j {
count_j += 1;
}
if has_i && has_j {
count_both += 1;
}
}
let score = if count_both > 0 {
((count_both as f64 + self.epsilon) / count_j as f64).ln()
} else {
(self.epsilon / count_j.max(1) as f64).ln()
};
scores.push(score);
}
}
if scores.is_empty() {
Ok(0.0)
} else {
Ok(scores.iter().sum::<f64>() / scores.len() as f64)
}
}
fn calculate_topic_coherence_uci(
&self,
topic_word_s: &[&String],
word_freq: &HashMap<String, usize>,
co_occurrence: &HashMap<(String, String), usize>,
) -> Result<f64> {
let mut scores = Vec::new();
for i in 0..topic_word_s.len() {
for j in (i + 1)..topic_word_s.len() {
let word_1 = topic_word_s[i];
let word_2 = topic_word_s[j];
let freq1 = word_freq.get(word_1).copied().unwrap_or(0) as f64;
let freq2 = word_freq.get(word_2).copied().unwrap_or(0) as f64;
let co_freq = co_occurrence
.get(&if word_1 < word_2 {
(word_1.clone(), word_2.clone())
} else {
(word_2.clone(), word_1.clone())
})
.copied()
.unwrap_or(0) as f64;
if freq1 > 0.0 && freq2 > 0.0 && co_freq > 0.0 {
let total = word_freq.values().sum::<usize>() as f64;
let pmi = (co_freq * total / (freq1 * freq2)).ln();
scores.push(pmi);
}
}
}
if scores.is_empty() {
Ok(0.0)
} else {
Ok(scores.iter().sum::<f64>() / scores.len() as f64)
}
}
fn build_co_occurrence_matrix(
&self,
documents: &[Vec<String>],
) -> Result<(DocFreqMap, CoDocFreqMap)> {
let mut word_freq: HashMap<String, usize> = HashMap::new();
let mut co_occurrence: HashMap<(String, String), usize> = HashMap::new();
for doc in documents {
for word_ in doc {
*word_freq.entry(word_.clone()).or_insert(0) += 1;
}
for i in 0..doc.len() {
let window_end = (i + self.window_size).min(doc.len());
for j in (i + 1)..window_end {
let word_1 = &doc[i];
let word_2 = &doc[j];
if word_1 != word_2 {
let key = if word_1 < word_2 {
(word_1.clone(), word_2.clone())
} else {
(word_2.clone(), word_1.clone())
};
*co_occurrence.entry(key).or_insert(0) += 1;
}
}
}
}
Ok((word_freq, co_occurrence))
}
fn calculate_npmi(&self, freq1: f64, freq2: f64, co_freq: f64, ntotal: f64) -> f64 {
if freq1 == 0.0 || freq2 == 0.0 || co_freq == 0.0 {
return -1.0;
}
let p1 = freq1 / ntotal;
let p2 = freq2 / ntotal;
let p12 = co_freq / ntotal;
let pmi = (p12 / (p1 * p2)).ln();
let npmi = pmi / -(p12.ln());
npmi.clamp(-1.0, 1.0)
}
}
pub struct TopicDiversity;
impl TopicDiversity {
pub fn calculate(topics: &[Topic]) -> f64 {
let mut all_word_s = Vec::new();
let mut unique_word_s = HashSet::new();
for topic in topics {
for (word_, _) in &topic.top_words {
all_word_s.push(word_.clone());
unique_word_s.insert(word_.clone());
}
}
if all_word_s.is_empty() {
return 0.0;
}
unique_word_s.len() as f64 / all_word_s.len() as f64
}
pub fn pairwise_distances(topics: &[Topic]) -> Array2<f64> {
let ntopics = topics.len();
let mut distances = Array2::zeros((ntopics, ntopics));
for i in 0..ntopics {
for j in 0..ntopics {
if i == j {
distances[[i, j]] = 0.0;
} else {
let word_s_i: HashSet<String> = topics[i]
.top_words
.iter()
.map(|(word, _)| word.clone())
.collect();
let word_s_j: HashSet<String> = topics[j]
.top_words
.iter()
.map(|(word, _)| word.clone())
.collect();
let intersection = word_s_i.intersection(&word_s_j).count();
let union = word_s_i.union(&word_s_j).count();
distances[[i, j]] = 1.0 - (intersection as f64 / union as f64);
}
}
}
distances
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_testtopics() -> Vec<Topic> {
vec![
Topic {
id: 0,
top_words: vec![
("machine".to_string(), 0.1),
("learning".to_string(), 0.09),
("algorithm".to_string(), 0.08),
],
coherence: None,
},
Topic {
id: 1,
top_words: vec![
("neural".to_string(), 0.12),
("network".to_string(), 0.11),
("deep".to_string(), 0.10),
],
coherence: None,
},
]
}
fn create_test_documents() -> Vec<Vec<String>> {
vec![
vec!["machine", "learning", "algorithm", "data"]
.into_iter()
.map(String::from)
.collect(),
vec!["neural", "network", "deep", "learning"]
.into_iter()
.map(String::from)
.collect(),
vec!["machine", "algorithm", "neural", "network"]
.into_iter()
.map(String::from)
.collect(),
vec!["deep", "learning", "machine", "data"]
.into_iter()
.map(String::from)
.collect(),
]
}
#[test]
fn test_cv_coherence() {
let coherence = TopicCoherence::new();
let topics = create_testtopics();
let documents = create_test_documents();
let score = coherence
.cv_coherence(&topics, &documents)
.expect("Operation failed");
assert!((-1.0..=1.0).contains(&score));
}
#[test]
fn test_umass_coherence() {
let coherence = TopicCoherence::new();
let topics = create_testtopics();
let documents = create_test_documents();
let score = coherence
.umass_coherence(&topics, &documents)
.expect("Operation failed");
assert!(score.is_finite());
}
#[test]
fn test_uci_coherence() {
let coherence = TopicCoherence::new();
let topics = create_testtopics();
let documents = create_test_documents();
let score = coherence
.uci_coherence(&topics, &documents)
.expect("Operation failed");
assert!(score.is_finite());
}
#[test]
fn test_topic_diversity() {
let topics = create_testtopics();
let diversity = TopicDiversity::calculate(&topics);
assert!((0.0..=1.0).contains(&diversity));
assert_eq!(diversity, 1.0);
}
#[test]
fn test_pairwise_distances() {
let topics = create_testtopics();
let distances = TopicDiversity::pairwise_distances(&topics);
assert_eq!(distances[[0, 0]], 0.0);
assert_eq!(distances[[1, 1]], 0.0);
assert_eq!(distances[[0, 1]], 1.0);
assert_eq!(distances[[1, 0]], 1.0);
}
#[test]
fn test_emptytopics() {
let coherence = TopicCoherence::new();
let topics: Vec<Topic> = vec![];
let documents = create_test_documents();
let cv_score = coherence
.cv_coherence(&topics, &documents)
.expect("Operation failed");
assert!(cv_score.is_nan() || cv_score == 0.0);
let diversity = TopicDiversity::calculate(&topics);
assert_eq!(diversity, 0.0);
}
}