Skip to main content

scirs2_text/
topic_coherence.rs

1//! Topic coherence metrics for evaluating topic models
2//!
3//! This module provides various coherence measures to evaluate the quality
4//! of topics generated by topic modeling algorithms.
5
6use crate::error::Result;
7use crate::topic_modeling::Topic;
8use scirs2_core::ndarray::Array2;
9use std::collections::{HashMap, HashSet};
10
11/// Topic coherence calculator
12pub struct TopicCoherence {
13    /// Window size for co-occurrence counting
14    window_size: usize,
15    /// Minimum word_ frequency (kept for API compatibility)
16    _min_count: usize,
17    /// Epsilon for smoothing
18    epsilon: f64,
19}
20
21impl Default for TopicCoherence {
22    fn default() -> Self {
23        Self {
24            window_size: 10,
25            _min_count: 5, // Kept for API compatibility
26            epsilon: 1e-12,
27        }
28    }
29}
30
31/// Type alias for document frequency map
32type DocFreqMap = HashMap<String, usize>;
33/// Type alias for co-document frequency map
34type CoDocFreqMap = HashMap<(String, String), usize>;
35
36impl TopicCoherence {
37    /// Create a new coherence calculator
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Set window size for co-occurrence
43    pub fn with_window_size(mut self, windowsize: usize) -> Self {
44        self.window_size = windowsize;
45        self
46    }
47
48    /// Calculate C_v coherence (Röder et al., 2015)
49    pub fn cv_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
50        // Get top word_s for each topic
51        let top_words_per_topic: Vec<Vec<String>> = topics
52            .iter()
53            .map(|topic| {
54                topic
55                    .top_words
56                    .iter()
57                    .map(|(word_, _)| word_.clone())
58                    .collect()
59            })
60            .collect();
61
62        // Calculate segmented document frequency
63        let (doc_freq, co_doc_freq) =
64            self.calculate_document_frequencies(&top_words_per_topic, documents)?;
65
66        // Calculate NPMI scores
67        let mut coherence_scores = Vec::new();
68
69        for topic_word_s in &top_words_per_topic {
70            let topic_coherence = self.calculate_topic_coherence_cv(
71                topic_word_s,
72                &doc_freq,
73                &co_doc_freq,
74                documents.len(),
75            )?;
76            coherence_scores.push(topic_coherence);
77        }
78
79        // Average across all topics
80        let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
81        Ok(avg_coherence)
82    }
83
84    /// Calculate UMass coherence
85    pub fn umass_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
86        // Convert documents to sets for efficient lookup
87        let doc_sets: Vec<HashSet<String>> = documents
88            .iter()
89            .map(|doc| doc.iter().cloned().collect())
90            .collect();
91
92        let mut coherence_scores = Vec::new();
93
94        for topic in topics {
95            let top_words: Vec<&String> = topic.top_words.iter().map(|(word_, _)| word_).collect();
96
97            let topic_coherence = self.calculate_topic_coherence_umass(&top_words, &doc_sets)?;
98            coherence_scores.push(topic_coherence);
99        }
100
101        let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
102        Ok(avg_coherence)
103    }
104
105    /// Calculate UCI coherence
106    pub fn uci_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
107        // Build sliding window co-occurrence counts
108        let (word_freq, co_occurrence) = self.build_co_occurrence_matrix(documents)?;
109
110        let mut coherence_scores = Vec::new();
111
112        for topic in topics {
113            let top_words: Vec<&String> = topic.top_words.iter().map(|(word_, _)| word_).collect();
114
115            let topic_coherence =
116                self.calculate_topic_coherence_uci(&top_words, &word_freq, &co_occurrence)?;
117            coherence_scores.push(topic_coherence);
118        }
119
120        let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
121        Ok(avg_coherence)
122    }
123
124    /// Calculate document frequencies for word_s
125    fn calculate_document_frequencies(
126        &self,
127        topics: &[Vec<String>],
128        documents: &[Vec<String>],
129    ) -> Result<(DocFreqMap, CoDocFreqMap)> {
130        let mut doc_freq: HashMap<String, usize> = HashMap::new();
131        let mut co_doc_freq: HashMap<(String, String), usize> = HashMap::new();
132
133        // Get all unique word_s from topics
134        let mut all_word_s: HashSet<String> = HashSet::new();
135        for topic in topics {
136            for word_ in topic {
137                all_word_s.insert(word_.clone());
138            }
139        }
140
141        // Count document frequencies
142        for doc in documents {
143            let doc_set: HashSet<String> = doc.iter().cloned().collect();
144
145            // Single word_ frequencies
146            for word_ in &all_word_s {
147                if doc_set.contains(word_) {
148                    *doc_freq.entry(word_.clone()).or_insert(0) += 1;
149                }
150            }
151
152            // Co-document frequencies
153            let word_s_vec: Vec<&String> = all_word_s.iter().collect();
154            for i in 0..word_s_vec.len() {
155                for j in (i + 1)..word_s_vec.len() {
156                    let word_1 = word_s_vec[i];
157                    let word_2 = word_s_vec[j];
158
159                    if doc_set.contains(word_1) && doc_set.contains(word_2) {
160                        let key = if word_1 < word_2 {
161                            (word_1.clone(), word_2.clone())
162                        } else {
163                            (word_2.clone(), word_1.clone())
164                        };
165                        *co_doc_freq.entry(key).or_insert(0) += 1;
166                    }
167                }
168            }
169        }
170
171        Ok((doc_freq, co_doc_freq))
172    }
173
174    /// Calculate C_v coherence for a single topic
175    fn calculate_topic_coherence_cv(
176        &self,
177        topic_word_s: &[String],
178        doc_freq: &HashMap<String, usize>,
179        co_doc_freq: &HashMap<(String, String), usize>,
180        n_docs: usize,
181    ) -> Result<f64> {
182        let mut scores = Vec::new();
183
184        for i in 0..topic_word_s.len() {
185            for j in (i + 1)..topic_word_s.len() {
186                let word_1 = &topic_word_s[i];
187                let word_2 = &topic_word_s[j];
188
189                let freq1 = doc_freq.get(word_1).copied().unwrap_or(0) as f64;
190                let freq2 = doc_freq.get(word_2).copied().unwrap_or(0) as f64;
191
192                let co_freq = co_doc_freq
193                    .get(&if word_1 < word_2 {
194                        (word_1.clone(), word_2.clone())
195                    } else {
196                        (word_2.clone(), word_1.clone())
197                    })
198                    .copied()
199                    .unwrap_or(0) as f64;
200
201                // Calculate NPMI
202                let npmi = self.calculate_npmi(freq1, freq2, co_freq, n_docs as f64);
203                scores.push(npmi);
204            }
205        }
206
207        if scores.is_empty() {
208            Ok(0.0)
209        } else {
210            Ok(scores.iter().sum::<f64>() / scores.len() as f64)
211        }
212    }
213
214    /// Calculate UMass coherence for a single topic
215    fn calculate_topic_coherence_umass(
216        &self,
217        topic_word_s: &[&String],
218        doc_sets: &[HashSet<String>],
219    ) -> Result<f64> {
220        let mut scores = Vec::new();
221
222        for i in 1..topic_word_s.len() {
223            for j in 0..i {
224                let word_i = topic_word_s[i];
225                let word_j = topic_word_s[j];
226
227                let mut count_j = 0;
228                let mut count_both = 0;
229
230                for doc_set in doc_sets {
231                    let has_i = doc_set.contains(word_i);
232                    let has_j = doc_set.contains(word_j);
233
234                    if has_j {
235                        count_j += 1;
236                    }
237                    if has_i && has_j {
238                        count_both += 1;
239                    }
240                }
241
242                // Calculate PMI
243                let score = if count_both > 0 {
244                    ((count_both as f64 + self.epsilon) / count_j as f64).ln()
245                } else {
246                    (self.epsilon / count_j.max(1) as f64).ln()
247                };
248
249                scores.push(score);
250            }
251        }
252
253        if scores.is_empty() {
254            Ok(0.0)
255        } else {
256            Ok(scores.iter().sum::<f64>() / scores.len() as f64)
257        }
258    }
259
260    /// Calculate UCI coherence for a single topic
261    fn calculate_topic_coherence_uci(
262        &self,
263        topic_word_s: &[&String],
264        word_freq: &HashMap<String, usize>,
265        co_occurrence: &HashMap<(String, String), usize>,
266    ) -> Result<f64> {
267        let mut scores = Vec::new();
268
269        for i in 0..topic_word_s.len() {
270            for j in (i + 1)..topic_word_s.len() {
271                let word_1 = topic_word_s[i];
272                let word_2 = topic_word_s[j];
273
274                let freq1 = word_freq.get(word_1).copied().unwrap_or(0) as f64;
275                let freq2 = word_freq.get(word_2).copied().unwrap_or(0) as f64;
276
277                let co_freq = co_occurrence
278                    .get(&if word_1 < word_2 {
279                        (word_1.clone(), word_2.clone())
280                    } else {
281                        (word_2.clone(), word_1.clone())
282                    })
283                    .copied()
284                    .unwrap_or(0) as f64;
285
286                // Calculate PMI
287                if freq1 > 0.0 && freq2 > 0.0 && co_freq > 0.0 {
288                    let total = word_freq.values().sum::<usize>() as f64;
289                    let pmi = (co_freq * total / (freq1 * freq2)).ln();
290                    scores.push(pmi);
291                }
292            }
293        }
294
295        if scores.is_empty() {
296            Ok(0.0)
297        } else {
298            Ok(scores.iter().sum::<f64>() / scores.len() as f64)
299        }
300    }
301
302    /// Build co-occurrence matrix using sliding windows
303    fn build_co_occurrence_matrix(
304        &self,
305        documents: &[Vec<String>],
306    ) -> Result<(DocFreqMap, CoDocFreqMap)> {
307        let mut word_freq: HashMap<String, usize> = HashMap::new();
308        let mut co_occurrence: HashMap<(String, String), usize> = HashMap::new();
309
310        for doc in documents {
311            // Count word_ frequencies
312            for word_ in doc {
313                *word_freq.entry(word_.clone()).or_insert(0) += 1;
314            }
315
316            // Count co-occurrences within windows
317            for i in 0..doc.len() {
318                let window_end = (i + self.window_size).min(doc.len());
319
320                for j in (i + 1)..window_end {
321                    let word_1 = &doc[i];
322                    let word_2 = &doc[j];
323
324                    if word_1 != word_2 {
325                        let key = if word_1 < word_2 {
326                            (word_1.clone(), word_2.clone())
327                        } else {
328                            (word_2.clone(), word_1.clone())
329                        };
330                        *co_occurrence.entry(key).or_insert(0) += 1;
331                    }
332                }
333            }
334        }
335
336        Ok((word_freq, co_occurrence))
337    }
338
339    /// Calculate Normalized Pointwise Mutual Information
340    fn calculate_npmi(&self, freq1: f64, freq2: f64, co_freq: f64, ntotal: f64) -> f64 {
341        if freq1 == 0.0 || freq2 == 0.0 || co_freq == 0.0 {
342            return -1.0;
343        }
344
345        let p1 = freq1 / ntotal;
346        let p2 = freq2 / ntotal;
347        let p12 = co_freq / ntotal;
348
349        let pmi = (p12 / (p1 * p2)).ln();
350        let npmi = pmi / -(p12.ln());
351
352        npmi.clamp(-1.0, 1.0)
353    }
354}
355
356/// Topic diversity calculator
357pub struct TopicDiversity;
358
359impl TopicDiversity {
360    /// Calculate topic diversity (percentage of unique word_s across topics)
361    pub fn calculate(topics: &[Topic]) -> f64 {
362        let mut all_word_s = Vec::new();
363        let mut unique_word_s = HashSet::new();
364
365        for topic in topics {
366            for (word_, _) in &topic.top_words {
367                all_word_s.push(word_.clone());
368                unique_word_s.insert(word_.clone());
369            }
370        }
371
372        if all_word_s.is_empty() {
373            return 0.0;
374        }
375
376        unique_word_s.len() as f64 / all_word_s.len() as f64
377    }
378
379    /// Calculate pairwise Jaccard distance between topics
380    pub fn pairwise_distances(topics: &[Topic]) -> Array2<f64> {
381        let ntopics = topics.len();
382        let mut distances = Array2::zeros((ntopics, ntopics));
383
384        for i in 0..ntopics {
385            for j in 0..ntopics {
386                if i == j {
387                    distances[[i, j]] = 0.0;
388                } else {
389                    let word_s_i: HashSet<String> = topics[i]
390                        .top_words
391                        .iter()
392                        .map(|(word, _)| word.clone())
393                        .collect();
394                    let word_s_j: HashSet<String> = topics[j]
395                        .top_words
396                        .iter()
397                        .map(|(word, _)| word.clone())
398                        .collect();
399
400                    let intersection = word_s_i.intersection(&word_s_j).count();
401                    let union = word_s_i.union(&word_s_j).count();
402
403                    distances[[i, j]] = 1.0 - (intersection as f64 / union as f64);
404                }
405            }
406        }
407
408        distances
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    fn create_testtopics() -> Vec<Topic> {
417        vec![
418            Topic {
419                id: 0,
420                top_words: vec![
421                    ("machine".to_string(), 0.1),
422                    ("learning".to_string(), 0.09),
423                    ("algorithm".to_string(), 0.08),
424                ],
425                coherence: None,
426            },
427            Topic {
428                id: 1,
429                top_words: vec![
430                    ("neural".to_string(), 0.12),
431                    ("network".to_string(), 0.11),
432                    ("deep".to_string(), 0.10),
433                ],
434                coherence: None,
435            },
436        ]
437    }
438
439    fn create_test_documents() -> Vec<Vec<String>> {
440        vec![
441            vec!["machine", "learning", "algorithm", "data"]
442                .into_iter()
443                .map(String::from)
444                .collect(),
445            vec!["neural", "network", "deep", "learning"]
446                .into_iter()
447                .map(String::from)
448                .collect(),
449            vec!["machine", "algorithm", "neural", "network"]
450                .into_iter()
451                .map(String::from)
452                .collect(),
453            vec!["deep", "learning", "machine", "data"]
454                .into_iter()
455                .map(String::from)
456                .collect(),
457        ]
458    }
459
460    #[test]
461    fn test_cv_coherence() {
462        let coherence = TopicCoherence::new();
463        let topics = create_testtopics();
464        let documents = create_test_documents();
465
466        let score = coherence
467            .cv_coherence(&topics, &documents)
468            .expect("Operation failed");
469        assert!((-1.0..=1.0).contains(&score));
470    }
471
472    #[test]
473    fn test_umass_coherence() {
474        let coherence = TopicCoherence::new();
475        let topics = create_testtopics();
476        let documents = create_test_documents();
477
478        let score = coherence
479            .umass_coherence(&topics, &documents)
480            .expect("Operation failed");
481        assert!(score.is_finite());
482    }
483
484    #[test]
485    fn test_uci_coherence() {
486        let coherence = TopicCoherence::new();
487        let topics = create_testtopics();
488        let documents = create_test_documents();
489
490        let score = coherence
491            .uci_coherence(&topics, &documents)
492            .expect("Operation failed");
493        assert!(score.is_finite());
494    }
495
496    #[test]
497    fn test_topic_diversity() {
498        let topics = create_testtopics();
499        let diversity = TopicDiversity::calculate(&topics);
500
501        assert!((0.0..=1.0).contains(&diversity));
502        // All word_s are unique in our test topics
503        assert_eq!(diversity, 1.0);
504    }
505
506    #[test]
507    fn test_pairwise_distances() {
508        let topics = create_testtopics();
509        let distances = TopicDiversity::pairwise_distances(&topics);
510
511        // Diagonal should be zero
512        assert_eq!(distances[[0, 0]], 0.0);
513        assert_eq!(distances[[1, 1]], 0.0);
514
515        // Topics have no overlap in our test case
516        assert_eq!(distances[[0, 1]], 1.0);
517        assert_eq!(distances[[1, 0]], 1.0);
518    }
519
520    #[test]
521    fn test_emptytopics() {
522        let coherence = TopicCoherence::new();
523        let topics: Vec<Topic> = vec![];
524        let documents = create_test_documents();
525
526        let cv_score = coherence
527            .cv_coherence(&topics, &documents)
528            .expect("Operation failed");
529        assert!(cv_score.is_nan() || cv_score == 0.0);
530
531        let diversity = TopicDiversity::calculate(&topics);
532        assert_eq!(diversity, 0.0);
533    }
534}