jieba_rs/keywords/tfidf.rs
1use std::cmp::Ordering;
2use std::collections::{BTreeSet, BinaryHeap};
3use std::io::{self, BufRead, BufReader};
4
5use include_flate::flate;
6use ordered_float::OrderedFloat;
7
8use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder};
9use crate::FxHashMap as HashMap;
10use crate::Jieba;
11
12flate!(static DEFAULT_IDF: str from "src/data/idf.txt");
13
14#[derive(Debug, Clone, Eq, PartialEq)]
15struct HeapNode<'a> {
16 tfidf: OrderedFloat<f64>,
17 word: &'a str,
18}
19
20impl Ord for HeapNode<'_> {
21 fn cmp(&self, other: &HeapNode) -> Ordering {
22 other.tfidf.cmp(&self.tfidf).then_with(|| self.word.cmp(other.word))
23 }
24}
25
26impl PartialOrd for HeapNode<'_> {
27 fn partial_cmp(&self, other: &HeapNode) -> Option<Ordering> {
28 Some(self.cmp(other))
29 }
30}
31
32/// TF-IDF keywords extraction
33///
34/// Require `tfidf` feature to be enabled
35#[derive(Debug)]
36pub struct TfIdf {
37 idf_dict: HashMap<String, f64>,
38 median_idf: f64,
39 config: KeywordExtractConfig,
40}
41
42/// Implementation of JiebaKeywordExtract using a TF-IDF dictionary.
43///
44/// This takes the segments produced by Jieba and attempts to extract keywords.
45/// Segments are filtered for stopwords and short terms. They are then matched
46/// against a loaded dictionary to calculate TF-IDF scores.
47impl TfIdf {
48 /// Creates an TfIdf.
49 ///
50 /// # Examples
51 ///
52 /// New instance with custom idf dictionary.
53 /// ```
54 /// use jieba_rs::{TfIdf, KeywordExtractConfig};
55 ///
56 /// let mut sample_idf = "劳动防护 13.900677652\n\
57 /// 生化学 13.900677652\n";
58 /// TfIdf::new(
59 /// Some(&mut sample_idf.as_bytes()),
60 /// KeywordExtractConfig::default());
61 /// ```
62 ///
63 /// New instance with module default stop words and no initial IDF
64 /// dictionary. Dictionary should be loaded later with `load_dict()` calls.
65 /// ```
66 /// use jieba_rs::{TfIdf, KeywordExtractConfig};
67 ///
68 /// TfIdf::new(
69 /// None::<&mut std::io::Empty>,
70 /// KeywordExtractConfig::default());
71 /// ```
72 pub fn new(opt_dict: Option<&mut impl BufRead>, config: KeywordExtractConfig) -> Self {
73 let mut instance = TfIdf {
74 idf_dict: HashMap::default(),
75 median_idf: 0.0,
76 config,
77 };
78 if let Some(dict) = opt_dict {
79 instance.load_dict(dict).unwrap();
80 }
81 instance
82 }
83
84 /// Merges entires from `dict` into the `idf_dict`.
85 ///
86 /// ```
87 /// use jieba_rs::{Jieba, KeywordExtract, Keyword, KeywordExtractConfig,
88 /// TfIdf};
89 ///
90 /// let jieba = Jieba::default();
91 /// let mut init_idf = "生化学 13.900677652\n";
92 ///
93 /// let mut tfidf = TfIdf::new(
94 /// Some(&mut init_idf.as_bytes()),
95 /// KeywordExtractConfig::default());
96 /// let top_k = tfidf.extract_keywords(&jieba, "生化学不是光化学的,", 3, vec![]);
97 /// assert_eq!(
98 /// top_k,
99 /// vec![
100 /// Keyword { keyword: "不是".to_string(), weight: 4.6335592173333335 },
101 /// Keyword { keyword: "光化学".to_string(), weight: 4.6335592173333335 },
102 /// Keyword { keyword: "生化学".to_string(), weight: 4.6335592173333335 }
103 /// ]
104 /// );
105 ///
106 /// let mut init_idf = "光化学 99.123456789\n";
107 /// tfidf.load_dict(&mut init_idf.as_bytes());
108 /// let new_top_k = tfidf.extract_keywords(&jieba, "生化学不是光化学的,", 3, vec![]);
109 /// assert_eq!(
110 /// new_top_k,
111 /// vec![
112 /// Keyword { keyword: "不是".to_string(), weight: 33.041152263 },
113 /// Keyword { keyword: "光化学".to_string(), weight: 33.041152263 },
114 /// Keyword { keyword: "生化学".to_string(), weight: 4.6335592173333335 }
115 /// ]
116 /// );
117 /// ```
118 pub fn load_dict(&mut self, dict: &mut impl BufRead) -> io::Result<()> {
119 let mut buf = String::new();
120 let mut idf_heap = BinaryHeap::new();
121 while dict.read_line(&mut buf)? > 0 {
122 let parts: Vec<&str> = buf.split_whitespace().collect();
123 if parts.is_empty() {
124 continue;
125 }
126
127 let word = parts[0];
128 if let Some(idf) = parts.get(1).and_then(|x| x.parse::<f64>().ok()) {
129 self.idf_dict.insert(word.to_string(), idf);
130 idf_heap.push(OrderedFloat(idf));
131 }
132
133 buf.clear();
134 }
135
136 let m = idf_heap.len() / 2;
137 for _ in 0..m {
138 idf_heap.pop();
139 }
140
141 self.median_idf = idf_heap.pop().unwrap().into_inner();
142
143 Ok(())
144 }
145
146 pub fn config(&self) -> &KeywordExtractConfig {
147 &self.config
148 }
149
150 pub fn config_mut(&mut self) -> &mut KeywordExtractConfig {
151 &mut self.config
152 }
153}
154
155/// TF-IDF keywords extraction.
156///
157/// Require `tfidf` feature to be enabled.
158impl Default for TfIdf {
159 /// Creates TfIdf with DEFAULT_STOP_WORDS, the default TfIdf dictionary,
160 /// 2 Unicode Scalar Value minimum for keywords, and no hmm in segmentation.
161 fn default() -> Self {
162 let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes());
163 TfIdf::new(
164 Some(&mut default_dict),
165 KeywordExtractConfigBuilder::default().build().unwrap(),
166 )
167 }
168}
169
170impl KeywordExtract for TfIdf {
171 /// Uses TF-IDF algorithm to extract the `top_k` keywords from `sentence`.
172 ///
173 /// If `allowed_pos` is not empty, then only terms matching those parts if
174 /// speech are considered.
175 ///
176 /// # Examples
177 /// ```
178 /// use jieba_rs::{Jieba, KeywordExtract, TfIdf};
179 ///
180 /// let jieba = Jieba::new();
181 /// let keyword_extractor = TfIdf::default();
182 /// let mut top_k = keyword_extractor.extract_keywords(
183 /// &jieba,
184 /// "今天纽约的天气真好啊,京华大酒店的张尧经理吃了一只北京烤鸭。后天纽约的天气不好,昨天纽约的天气也不好,北京烤鸭真好吃",
185 /// 3,
186 /// vec![],
187 /// );
188 /// assert_eq!(
189 /// top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
190 /// vec!["北京烤鸭", "纽约", "天气"]
191 /// );
192 ///
193 /// top_k = keyword_extractor.extract_keywords(
194 /// &jieba,
195 /// "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。",
196 /// 5,
197 /// vec![],
198 /// );
199 /// assert_eq!(
200 /// top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
201 /// vec!["欧亚", "吉林", "置业", "万元", "增资"]
202 /// );
203 ///
204 /// top_k = keyword_extractor.extract_keywords(
205 /// &jieba,
206 /// "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。",
207 /// 5,
208 /// vec![String::from("ns"), String::from("n"), String::from("vn"), String::from("v")],
209 /// );
210 /// assert_eq!(
211 /// top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
212 /// vec!["欧亚", "吉林", "置业", "增资", "实现"]
213 /// );
214 /// ```
215 fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
216 let tags = jieba.tag(sentence, self.config.use_hmm());
217 let mut allowed_pos_set = BTreeSet::new();
218
219 for s in allowed_pos {
220 allowed_pos_set.insert(s);
221 }
222
223 let mut term_freq: HashMap<String, u64> = HashMap::default();
224 for t in &tags {
225 if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) {
226 continue;
227 }
228
229 if !self.config.filter(t.word) {
230 continue;
231 }
232
233 let entry = term_freq.entry(String::from(t.word)).or_insert(0);
234 *entry += 1;
235 }
236
237 let total: u64 = term_freq.values().sum();
238 let mut heap = BinaryHeap::new();
239 for (cnt, (k, tf)) in term_freq.iter().enumerate() {
240 let idf = self.idf_dict.get(k).unwrap_or(&self.median_idf);
241 let node = HeapNode {
242 tfidf: OrderedFloat(*tf as f64 * idf / total as f64),
243 word: k,
244 };
245 heap.push(node);
246 if cnt >= top_k {
247 heap.pop();
248 }
249 }
250
251 let mut res = Vec::with_capacity(top_k);
252 for _ in 0..top_k {
253 if let Some(w) = heap.pop() {
254 res.push(Keyword {
255 keyword: String::from(w.word),
256 weight: w.tfidf.into_inner(),
257 });
258 }
259 }
260
261 res.reverse();
262 res
263 }
264}