jieba_rs/keywords/
tfidf.rs

1use std::cmp::Ordering;
2use std::collections::{BTreeSet, BinaryHeap};
3use std::io::{self, BufRead, BufReader};
4
5use include_flate::flate;
6use ordered_float::OrderedFloat;
7
8use super::{Keyword, KeywordExtract, KeywordExtractConfig, KeywordExtractConfigBuilder};
9use crate::FxHashMap as HashMap;
10use crate::Jieba;
11
12flate!(static DEFAULT_IDF: str from "src/data/idf.txt");
13
14#[derive(Debug, Clone, Eq, PartialEq)]
15struct HeapNode<'a> {
16    tfidf: OrderedFloat<f64>,
17    word: &'a str,
18}
19
20impl Ord for HeapNode<'_> {
21    fn cmp(&self, other: &HeapNode) -> Ordering {
22        other.tfidf.cmp(&self.tfidf).then_with(|| self.word.cmp(other.word))
23    }
24}
25
26impl PartialOrd for HeapNode<'_> {
27    fn partial_cmp(&self, other: &HeapNode) -> Option<Ordering> {
28        Some(self.cmp(other))
29    }
30}
31
32/// TF-IDF keywords extraction
33///
34/// Require `tfidf` feature to be enabled
35#[derive(Debug)]
36pub struct TfIdf {
37    idf_dict: HashMap<String, f64>,
38    median_idf: f64,
39    config: KeywordExtractConfig,
40}
41
42/// Implementation of JiebaKeywordExtract using a TF-IDF dictionary.
43///
44/// This takes the segments produced by Jieba and attempts to extract keywords.
45/// Segments are filtered for stopwords and short terms. They are then matched
46/// against a loaded dictionary to calculate TF-IDF scores.
47impl TfIdf {
48    /// Creates an TfIdf.
49    ///
50    /// # Examples
51    ///
52    /// New instance with custom idf dictionary.
53    /// ```
54    ///    use jieba_rs::{TfIdf, KeywordExtractConfig};
55    ///
56    ///    let mut sample_idf = "劳动防护 13.900677652\n\
57    ///        生化学 13.900677652\n";
58    ///    TfIdf::new(
59    ///        Some(&mut sample_idf.as_bytes()),
60    ///        KeywordExtractConfig::default());
61    /// ```
62    ///
63    /// New instance with module default stop words and no initial IDF
64    /// dictionary. Dictionary should be loaded later with `load_dict()` calls.
65    /// ```
66    ///    use jieba_rs::{TfIdf, KeywordExtractConfig};
67    ///
68    ///    TfIdf::new(
69    ///        None::<&mut std::io::Empty>,
70    ///        KeywordExtractConfig::default());
71    /// ```
72    pub fn new(opt_dict: Option<&mut impl BufRead>, config: KeywordExtractConfig) -> Self {
73        let mut instance = TfIdf {
74            idf_dict: HashMap::default(),
75            median_idf: 0.0,
76            config,
77        };
78        if let Some(dict) = opt_dict {
79            instance.load_dict(dict).unwrap();
80        }
81        instance
82    }
83
84    /// Merges entires from `dict` into the `idf_dict`.
85    ///
86    /// ```
87    ///    use jieba_rs::{Jieba, KeywordExtract, Keyword, KeywordExtractConfig,
88    ///        TfIdf};
89    ///
90    ///    let jieba = Jieba::default();
91    ///    let mut init_idf = "生化学 13.900677652\n";
92    ///
93    ///    let mut tfidf = TfIdf::new(
94    ///        Some(&mut init_idf.as_bytes()),
95    ///        KeywordExtractConfig::default());
96    ///    let top_k = tfidf.extract_keywords(&jieba, "生化学不是光化学的,", 3, vec![]);
97    ///    assert_eq!(
98    ///        top_k,
99    ///        vec![
100    ///            Keyword { keyword: "不是".to_string(), weight: 4.6335592173333335 },
101    ///            Keyword { keyword: "光化学".to_string(), weight: 4.6335592173333335 },
102    ///            Keyword { keyword: "生化学".to_string(), weight: 4.6335592173333335 }
103    ///        ]
104    ///    );
105    ///
106    ///    let mut init_idf = "光化学 99.123456789\n";
107    ///    tfidf.load_dict(&mut init_idf.as_bytes());
108    ///    let new_top_k = tfidf.extract_keywords(&jieba, "生化学不是光化学的,", 3, vec![]);
109    ///    assert_eq!(
110    ///        new_top_k,
111    ///        vec![
112    ///            Keyword { keyword: "不是".to_string(), weight: 33.041152263 },
113    ///            Keyword { keyword: "光化学".to_string(), weight: 33.041152263 },
114    ///            Keyword { keyword: "生化学".to_string(), weight: 4.6335592173333335 }
115    ///        ]
116    ///    );
117    /// ```
118    pub fn load_dict(&mut self, dict: &mut impl BufRead) -> io::Result<()> {
119        let mut buf = String::new();
120        let mut idf_heap = BinaryHeap::new();
121        while dict.read_line(&mut buf)? > 0 {
122            let parts: Vec<&str> = buf.split_whitespace().collect();
123            if parts.is_empty() {
124                continue;
125            }
126
127            let word = parts[0];
128            if let Some(idf) = parts.get(1).and_then(|x| x.parse::<f64>().ok()) {
129                self.idf_dict.insert(word.to_string(), idf);
130                idf_heap.push(OrderedFloat(idf));
131            }
132
133            buf.clear();
134        }
135
136        let m = idf_heap.len() / 2;
137        for _ in 0..m {
138            idf_heap.pop();
139        }
140
141        self.median_idf = idf_heap.pop().unwrap().into_inner();
142
143        Ok(())
144    }
145
146    pub fn config(&self) -> &KeywordExtractConfig {
147        &self.config
148    }
149
150    pub fn config_mut(&mut self) -> &mut KeywordExtractConfig {
151        &mut self.config
152    }
153}
154
155/// TF-IDF keywords extraction.
156///
157/// Require `tfidf` feature to be enabled.
158impl Default for TfIdf {
159    /// Creates TfIdf with DEFAULT_STOP_WORDS, the default TfIdf dictionary,
160    /// 2 Unicode Scalar Value minimum for keywords, and no hmm in segmentation.
161    fn default() -> Self {
162        let mut default_dict = BufReader::new(DEFAULT_IDF.as_bytes());
163        TfIdf::new(
164            Some(&mut default_dict),
165            KeywordExtractConfigBuilder::default().build().unwrap(),
166        )
167    }
168}
169
170impl KeywordExtract for TfIdf {
171    /// Uses TF-IDF algorithm to extract the `top_k` keywords from `sentence`.
172    ///
173    /// If `allowed_pos` is not empty, then only terms matching those parts if
174    /// speech are considered.
175    ///
176    /// # Examples
177    /// ```
178    ///    use jieba_rs::{Jieba, KeywordExtract, TfIdf};
179    ///
180    ///    let jieba = Jieba::new();
181    ///    let keyword_extractor = TfIdf::default();
182    ///    let mut top_k = keyword_extractor.extract_keywords(
183    ///        &jieba,
184    ///        "今天纽约的天气真好啊,京华大酒店的张尧经理吃了一只北京烤鸭。后天纽约的天气不好,昨天纽约的天气也不好,北京烤鸭真好吃",
185    ///        3,
186    ///        vec![],
187    ///    );
188    ///    assert_eq!(
189    ///        top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
190    ///        vec!["北京烤鸭", "纽约", "天气"]
191    ///    );
192    ///
193    ///    top_k = keyword_extractor.extract_keywords(
194    ///        &jieba,
195    ///        "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。",
196    ///        5,
197    ///        vec![],
198    ///    );
199    ///    assert_eq!(
200    ///        top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
201    ///        vec!["欧亚", "吉林", "置业", "万元", "增资"]
202    ///    );
203    ///
204    ///    top_k = keyword_extractor.extract_keywords(
205    ///        &jieba,
206    ///        "此外,公司拟对全资子公司吉林欧亚置业有限公司增资4.3亿元,增资后,吉林欧亚置业注册资本由7000万元增加到5亿元。吉林欧亚置业主要经营范围为房地产开发及百货零售等业务。目前在建吉林欧亚城市商业综合体项目。2013年,实现营业收入0万元,实现净利润-139.13万元。",
207    ///        5,
208    ///        vec![String::from("ns"), String::from("n"), String::from("vn"), String::from("v")],
209    ///    );
210    ///    assert_eq!(
211    ///        top_k.iter().map(|x| &x.keyword).collect::<Vec<&String>>(),
212    ///        vec!["欧亚", "吉林", "置业", "增资", "实现"]
213    ///    );
214    /// ```
215    fn extract_keywords(&self, jieba: &Jieba, sentence: &str, top_k: usize, allowed_pos: Vec<String>) -> Vec<Keyword> {
216        let tags = jieba.tag(sentence, self.config.use_hmm());
217        let mut allowed_pos_set = BTreeSet::new();
218
219        for s in allowed_pos {
220            allowed_pos_set.insert(s);
221        }
222
223        let mut term_freq: HashMap<String, u64> = HashMap::default();
224        for t in &tags {
225            if !allowed_pos_set.is_empty() && !allowed_pos_set.contains(t.tag) {
226                continue;
227            }
228
229            if !self.config.filter(t.word) {
230                continue;
231            }
232
233            let entry = term_freq.entry(String::from(t.word)).or_insert(0);
234            *entry += 1;
235        }
236
237        let total: u64 = term_freq.values().sum();
238        let mut heap = BinaryHeap::new();
239        for (cnt, (k, tf)) in term_freq.iter().enumerate() {
240            let idf = self.idf_dict.get(k).unwrap_or(&self.median_idf);
241            let node = HeapNode {
242                tfidf: OrderedFloat(*tf as f64 * idf / total as f64),
243                word: k,
244            };
245            heap.push(node);
246            if cnt >= top_k {
247                heap.pop();
248            }
249        }
250
251        let mut res = Vec::with_capacity(top_k);
252        for _ in 0..top_k {
253            if let Some(w) = heap.pop() {
254                res.push(Keyword {
255                    keyword: String::from(w.word),
256                    weight: w.tfidf.into_inner(),
257                });
258            }
259        }
260
261        res.reverse();
262        res
263    }
264}