jieba_rs/keywords/
textrank.rs

1use std::cmp::Ordering;
2use std::collections::{BTreeSet, BinaryHeap};
3
4use ordered_float::OrderedFloat;
5
6use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder};
7use crate::FxHashMap as HashMap;
8use crate::Jieba;
9
10type Weight = f64;
11
12#[derive(Clone)]
13struct Edge {
14    dst: usize,
15    weight: Weight,
16}
17
18impl Edge {
19    fn new(dst: usize, weight: Weight) -> Edge {
20        Edge { dst, weight }
21    }
22}
23
24type Edges = Vec<Edge>;
25type Graph = Vec<Edges>;
26
27struct StateDiagram {
28    damping_factor: Weight,
29    g: Graph,
30}
31
32impl StateDiagram {
33    fn new(size: usize) -> Self {
34        StateDiagram {
35            damping_factor: 0.85,
36            g: vec![Vec::new(); size],
37        }
38    }
39
40    fn add_undirected_edge(&mut self, src: usize, dst: usize, weight: Weight) {
41        self.g[src].push(Edge::new(dst, weight));
42        self.g[dst].push(Edge::new(src, weight));
43    }
44
45    fn rank(&mut self) -> Vec<Weight> {
46        let n = self.g.len();
47        let default_weight = 1.0 / (n as f64);
48
49        let mut ranking_vector = vec![default_weight; n];
50
51        let mut outflow_weights = vec![0.0; n];
52        for (i, v) in self.g.iter().enumerate() {
53            outflow_weights[i] = v.iter().map(|e| e.weight).sum();
54        }
55
56        for _ in 0..20 {
57            for (i, v) in self.g.iter().enumerate() {
58                let s: f64 = v
59                    .iter()
60                    .map(|e| e.weight / outflow_weights[e.dst] * ranking_vector[e.dst])
61                    .sum();
62
63                ranking_vector[i] = (1.0 - self.damping_factor) + self.damping_factor * s;
64            }
65        }
66
67        ranking_vector
68    }
69}
70
71/// Text rank keywords extraction.
72///
73/// Requires `textrank` feature to be enabled.
74#[derive(Debug)]
75pub struct TextRank {
76    span: usize,
77    config: KeywordExtractConfig,
78}
79
80impl TextRank {
81    /// Creates an TextRank.
82    ///
83    /// # Examples
84    ///
85    /// New instance with custom stop words. Also uses hmm for unknown words
86    /// during segmentation.
87    /// ```
88    ///    use std::collections::BTreeSet;
89    ///    use jieba_rs::{TextRank, KeywordExtractConfig};
90    ///
91    ///    let stop_words : BTreeSet<String> =
92    ///        BTreeSet::from(["a", "the", "of"].map(|s| s.to_string()));
93    ///    TextRank::new(
94    ///        5,
95    ///        KeywordExtractConfig::default());
96    /// ```
97    pub fn new(span: usize, config: KeywordExtractConfig) -> Self {
98        TextRank { span, config }
99    }
100}
101
102impl Default for TextRank {
103    /// Creates TextRank with 5 Unicode Scalar Value spans
104    fn default() -> Self {
105        TextRank::new(5, KeywordExtractConfigBuilder::default().build().unwrap())
106    }
107}
108
109impl KeywordExtract for TextRank {
110    /// Uses TextRank algorithm to extract the `top_k` keywords from `sentence`.
111    ///
112    /// If `allowed_pos` is not empty, then only terms matching those parts if
113    /// speech are considered.
114    ///
115    /// # Examples
116    ///
117    /// ```
118    ///    use jieba_rs::{Jieba, KeywordExtract, TextRank};
119    ///
120    ///    let jieba = Jieba::new();
121    ///    let keyword_extractor = TextRank::default();
122    ///    let mut top_k = keyword_extractor.extract_keywords(
123    ///        &jieba,
124    ///        "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。",
125    ///        6,
126    ///        vec![String::from("ns"), String::from("n"), String::from("vn"), String::from("v")],
127    ///    );
128    ///    assert_eq!(
129    ///        top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
130    ///        vec!["吉林", "欧亚", "置业", "实现", "收入", "子公司"]
131    ///    );
132    ///
133    ///    top_k = keyword_extractor.extract_keywords(
134    ///        &jieba,
135    ///        "It is nice weather in New York City. and今天纽约的天气真好啊,and京华大酒店的张尧经理吃了一只北京烤鸭。and后天纽约的天气不好,and昨天纽约的天气也不好,and北京烤鸭真好吃",
136    ///        3,
137    ///        vec![],
138    ///    );
139    ///    assert_eq!(
140    ///        top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
141    ///        vec!["纽约", "天气", "不好"]
142    ///    );
143    /// ```
144    fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
145        let tags = jieba.tag(sentence, self.config.use_hmm());
146        let mut allowed_pos_set = BTreeSet::new();
147
148        for s in allowed_pos {
149            allowed_pos_set.insert(s);
150        }
151
152        let mut word2id: HashMap<String, usize> =
153            HashMap::with_capacity_and_hasher(tags.len() / 2, rustc_hash::FxBuildHasher);
154        let mut unique_words = Vec::with_capacity(tags.len() / 2);
155        for t in &tags {
156            if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) {
157                continue;
158            }
159
160            if !word2id.contains_key(t.word) {
161                unique_words.push(String::from(t.word));
162                word2id.insert(String::from(t.word), unique_words.len() - 1);
163            }
164        }
165
166        let mut cooccurence: HashMap<(usize, usize), usize> = HashMap::default();
167        for (i, t) in tags.iter().enumerate() {
168            if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) {
169                continue;
170            }
171
172            if !self.config.filter(t.word) {
173                continue;
174            }
175
176            for j in (i + 1)..(i + self.span) {
177                if j >= tags.len() {
178                    break;
179                }
180
181                if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(tags[j].tag) {
182                    continue;
183                }
184
185                if !self.config.filter(tags[j].word) {
186                    continue;
187                }
188
189                let u = word2id.get(t.word).unwrap().to_owned();
190                let v = word2id.get(tags[j].word).unwrap().to_owned();
191                let entry = cooccurence.entry((u, v)).or_insert(0);
192                *entry += 1;
193            }
194        }
195
196        let mut diagram = StateDiagram::new(unique_words.len());
197        for (k, &v) in cooccurence.iter() {
198            diagram.add_undirected_edge(k.0, k.1, v as f64);
199        }
200
201        let ranking_vector = diagram.rank();
202
203        let mut heap = BinaryHeap::new();
204        for (k, v) in ranking_vector.iter().enumerate() {
205            heap.push(HeapNode {
206                rank: OrderedFloat(v * 1e10),
207                word_id: k,
208            });
209
210            if k >= top_k {
211                heap.pop();
212            }
213        }
214
215        let mut res = Vec::with_capacity(top_k);
216        for _ in 0..top_k {
217            if let Some(w) = heap.pop() {
218                res.push(Keyword {
219                    keyword: unique_words[w.word_id].clone(),
220                    weight: w.rank.into_inner(),
221                });
222            }
223        }
224
225        res.reverse();
226        res
227    }
228}
229
230#[derive(Debug, Clone, Eq, PartialEq)]
231struct HeapNode {
232    rank: OrderedFloat<f64>,
233    word_id: usize,
234}
235
236impl Ord for HeapNode {
237    fn cmp(&self, other: &HeapNode) -> Ordering {
238        other
239            .rank
240            .cmp(&self.rank)
241            .then_with(|| self.word_id.cmp(&other.word_id))
242    }
243}
244
245impl PartialOrd for HeapNode {
246    fn partial_cmp(&self, other: &HeapNode) -> Option<Ordering> {
247        Some(self.cmp(other))
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    #[test]
255    fn test_init_state_diagram() {
256        let diagram = StateDiagram::new(10);
257        assert_eq!(diagram.g.len(), 10);
258    }
259}