jieba_rs/keywords/
textrank.rs1use std::cmp::Ordering;
2use std::collections::{BTreeSet, BinaryHeap};
3
4use ordered_float::OrderedFloat;
5
6use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder};
7use crate::FxHashMap as HashMap;
8use crate::Jieba;
9
10type Weight = f64;
11
12#[derive(Clone)]
13struct Edge {
14 dst: usize,
15 weight: Weight,
16}
17
18impl Edge {
19 fn new(dst: usize, weight: Weight) -> Edge {
20 Edge { dst, weight }
21 }
22}
23
24type Edges = Vec<Edge>;
25type Graph = Vec<Edges>;
26
27struct StateDiagram {
28 damping_factor: Weight,
29 g: Graph,
30}
31
32impl StateDiagram {
33 fn new(size: usize) -> Self {
34 StateDiagram {
35 damping_factor: 0.85,
36 g: vec![Vec::new(); size],
37 }
38 }
39
40 fn add_undirected_edge(&mut self, src: usize, dst: usize, weight: Weight) {
41 self.g[src].push(Edge::new(dst, weight));
42 self.g[dst].push(Edge::new(src, weight));
43 }
44
45 fn rank(&mut self) -> Vec<Weight> {
46 let n = self.g.len();
47 let default_weight = 1.0 / (n as f64);
48
49 let mut ranking_vector = vec![default_weight; n];
50
51 let mut outflow_weights = vec![0.0; n];
52 for (i, v) in self.g.iter().enumerate() {
53 outflow_weights[i] = v.iter().map(|e| e.weight).sum();
54 }
55
56 for _ in 0..20 {
57 for (i, v) in self.g.iter().enumerate() {
58 let s: f64 = v
59 .iter()
60 .map(|e| e.weight / outflow_weights[e.dst] * ranking_vector[e.dst])
61 .sum();
62
63 ranking_vector[i] = (1.0 - self.damping_factor) + self.damping_factor * s;
64 }
65 }
66
67 ranking_vector
68 }
69}
70
71#[derive(Debug)]
75pub struct TextRank {
76 span: usize,
77 config: KeywordExtractConfig,
78}
79
80impl TextRank {
81 pub fn new(span: usize, config: KeywordExtractConfig) -> Self {
98 TextRank { span, config }
99 }
100}
101
102impl Default for TextRank {
103 fn default() -> Self {
105 TextRank::new(5, KeywordExtractConfigBuilder::default().build().unwrap())
106 }
107}
108
109impl KeywordExtract for TextRank {
110 fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
145 let tags = jieba.tag(sentence, self.config.use_hmm());
146 let mut allowed_pos_set = BTreeSet::new();
147
148 for s in allowed_pos {
149 allowed_pos_set.insert(s);
150 }
151
152 let mut word2id: HashMap<String, usize> =
153 HashMap::with_capacity_and_hasher(tags.len() / 2, rustc_hash::FxBuildHasher);
154 let mut unique_words = Vec::with_capacity(tags.len() / 2);
155 for t in &tags {
156 if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) {
157 continue;
158 }
159
160 if !word2id.contains_key(t.word) {
161 unique_words.push(String::from(t.word));
162 word2id.insert(String::from(t.word), unique_words.len() - 1);
163 }
164 }
165
166 let mut cooccurence: HashMap<(usize, usize), usize> = HashMap::default();
167 for (i, t) in tags.iter().enumerate() {
168 if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) {
169 continue;
170 }
171
172 if !self.config.filter(t.word) {
173 continue;
174 }
175
176 for j in (i + 1)..(i + self.span) {
177 if j >= tags.len() {
178 break;
179 }
180
181 if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(tags[j].tag) {
182 continue;
183 }
184
185 if !self.config.filter(tags[j].word) {
186 continue;
187 }
188
189 let u = word2id.get(t.word).unwrap().to_owned();
190 let v = word2id.get(tags[j].word).unwrap().to_owned();
191 let entry = cooccurence.entry((u, v)).or_insert(0);
192 *entry += 1;
193 }
194 }
195
196 let mut diagram = StateDiagram::new(unique_words.len());
197 for (k, &v) in cooccurence.iter() {
198 diagram.add_undirected_edge(k.0, k.1, v as f64);
199 }
200
201 let ranking_vector = diagram.rank();
202
203 let mut heap = BinaryHeap::new();
204 for (k, v) in ranking_vector.iter().enumerate() {
205 heap.push(HeapNode {
206 rank: OrderedFloat(v * 1e10),
207 word_id: k,
208 });
209
210 if k >= top_k {
211 heap.pop();
212 }
213 }
214
215 let mut res = Vec::with_capacity(top_k);
216 for _ in 0..top_k {
217 if let Some(w) = heap.pop() {
218 res.push(Keyword {
219 keyword: unique_words[w.word_id].clone(),
220 weight: w.rank.into_inner(),
221 });
222 }
223 }
224
225 res.reverse();
226 res
227 }
228}
229
230#[derive(Debug, Clone, Eq, PartialEq)]
231struct HeapNode {
232 rank: OrderedFloat<f64>,
233 word_id: usize,
234}
235
236impl Ord for HeapNode {
237 fn cmp(&self, other: &HeapNode) -> Ordering {
238 other
239 .rank
240 .cmp(&self.rank)
241 .then_with(|| self.word_id.cmp(&other.word_id))
242 }
243}
244
245impl PartialOrd for HeapNode {
246 fn partial_cmp(&self, other: &HeapNode) -> Option<Ordering> {
247 Some(self.cmp(other))
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 #[test]
255 fn test_init_state_diagram() {
256 let diagram = StateDiagram::new(10);
257 assert_eq!(diagram.g.len(), 10);
258 }
259}