use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::vec::Vec;
use crate::freq::FreqMap;
use crate::segmenter::Tokenizer;
use crate::stopwords::StopwordSet;
use crate::token::TokenKind;
#[derive(Debug, Clone, PartialEq)]
pub struct Keyword {
pub word: String,
pub score: f32,
pub count: usize,
}
pub struct KeyExtractor {
tokenizer: Tokenizer,
freq: FreqMap,
stops: StopwordSet,
max_corpus_freq: u32,
}
impl KeyExtractor {
pub fn builtin() -> Self {
let freq = FreqMap::builtin();
let max_corpus_freq = freq.max_freq();
Self {
tokenizer: Tokenizer::new(),
freq,
stops: StopwordSet::builtin(),
max_corpus_freq,
}
}
pub fn extract(&self, text: &str, max_n: usize) -> Vec<Keyword> {
if text.is_empty() || max_n == 0 {
return Vec::new();
}
let tokens = self.tokenizer.segment(text);
let mut total_content: usize = 0;
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
for token in &tokens {
match token.kind {
TokenKind::Whitespace
| TokenKind::Punctuation
| TokenKind::Emoji
| TokenKind::Unknown => continue,
_ => {}
}
total_content += 1;
if token.text.chars().count() < 2 || self.stops.contains(token.text) {
continue;
}
*counts.entry(String::from(token.text)).or_insert(0) += 1;
}
if total_content == 0 || counts.is_empty() {
return Vec::new();
}
let total_f = total_content as f32;
let idf_num = self.max_corpus_freq as f32 + 1.0;
let mut results: Vec<Keyword> = counts
.into_iter()
.map(|(word, count)| {
let tf = count as f32 / total_f;
let corpus_freq = self.freq.get(&word);
let idf = idf_num / (corpus_freq as f32 + 1.0);
Keyword {
word,
score: tf * idf,
count,
}
})
.collect();
results.sort_unstable_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(core::cmp::Ordering::Equal)
.then(a.word.cmp(&b.word))
});
results.truncate(max_n);
results
}
}
#[cfg(test)]
mod tests {
use super::*;
fn kex() -> KeyExtractor {
KeyExtractor::builtin()
}
#[test]
fn empty_text_returns_empty() {
assert!(kex().extract("", 5).is_empty());
}
#[test]
fn zero_max_n_returns_empty() {
assert!(kex().extract("กินข้าวกับปลา", 0).is_empty());
}
#[test]
fn only_stopwords_returns_empty() {
assert!(kex().extract("และหรือของ", 5).is_empty());
}
#[test]
fn only_single_chars_returns_empty() {
assert!(kex().extract("ก ข ค ง", 5).is_empty());
}
#[test]
fn respects_max_n() {
let kws = kex().extract("การพัฒนาซอฟต์แวร์เป็นสิ่งสำคัญในยุคดิจิทัลสำหรับนักพัฒนา", 3);
assert!(kws.len() <= 3, "expected ≤ 3 results, got {}", kws.len());
}
#[test]
fn results_sorted_by_score_descending() {
let kws = kex().extract("การเรียนภาษาโปรแกรมมิ่งเป็นทักษะสำคัญสำหรับนักพัฒนาซอฟต์แวร์", 10);
for pair in kws.windows(2) {
assert!(
pair[0].score >= pair[1].score,
"sort order violated: {:?} before {:?}",
pair[0],
pair[1]
);
}
}
#[test]
fn count_reflects_occurrences() {
let kws = kex().extract("นักพัฒนาซอฟต์แวร์เขียนซอฟต์แวร์และทดสอบซอฟต์แวร์ทุกวัน", 10);
let sw = kws.iter().find(|k| k.word == "ซอฟต์แวร์");
assert!(sw.is_some(), "expected ซอฟต์แวร์ in keywords; got: {kws:?}");
assert_eq!(sw.unwrap().count, 3, "expected count=3 for ซอฟต์แวร์");
}
#[test]
fn stopwords_not_in_results() {
let kws = kex().extract("กินข้าวกับปลาและดื่มน้ำ", 20);
assert!(
kws.iter().all(|k| k.word != "กับ" && k.word != "และ"),
"stopword found in results: {kws:?}"
);
}
#[test]
fn all_scores_positive() {
let kws = kex().extract("การพัฒนาซอฟต์แวร์ต้องการทักษะและประสบการณ์", 10);
assert!(
kws.iter().all(|k| k.score > 0.0),
"expected all scores > 0; got: {kws:?}"
);
}
#[test]
fn rare_word_outranks_common_word_with_same_count() {
let kws = kex().extract("ไดโนเสาร์กินคน", 10);
let rare = kws.iter().find(|k| k.word == "ไดโนเสาร์");
let common = kws.iter().find(|k| k.word == "คน");
if let (Some(r), Some(c)) = (rare, common) {
assert!(
r.score > c.score,
"expected ไดโนเสาร์ ({}) to outscore คน ({})",
r.score,
c.score
);
}
}
#[test]
fn repeated_word_scores_higher_than_single_occurrence() {
let kws = kex().extract("นักพัฒนาซอฟต์แวร์เขียนซอฟต์แวร์และทดสอบซอฟต์แวร์", 10);
let sw = kws.iter().find(|k| k.word == "ซอฟต์แวร์");
let dev = kws.iter().find(|k| k.word == "นักพัฒนา");
if let (Some(s), Some(d)) = (sw, dev) {
assert!(
s.score > d.score,
"expected ซอฟต์แวร์ (×3, score {}) > นักพัฒนา (×1, score {})",
s.score,
d.score
);
}
}
#[test]
fn latin_tokens_included_as_candidates() {
let kws = kex().extract("เขียน Python และใช้ Python ทุกวัน", 10);
let py = kws.iter().find(|k| k.word == "Python");
assert!(py.is_some(), "expected Python in keywords; got: {kws:?}");
assert_eq!(py.unwrap().count, 2);
}
#[test]
fn punctuation_not_in_results() {
let kws = kex().extract("กินข้าว, ดื่มน้ำ. นอนหลับ!", 20);
assert!(
kws.iter()
.all(|k| !k.word.chars().all(|c| c.is_ascii_punctuation())),
"punctuation token found in results: {kws:?}"
);
}
}