Skip to main content

megahal_keywords/
lib.rs

1//! MegaHAL keyword extraction: two-pass algorithm with swap table,
2//! banned/auxiliary word lists.
3//!
4//! Keywords drive MegaHAL's reply generation by biasing the Markov walk toward
5//! topically relevant symbols. Extraction works in two passes:
6//!
7//! 1. **Primary**: select words from input (after swap substitution) that are
8//!    in the model dictionary, start with an alphanumeric character, and are
9//!    neither banned nor auxiliary.
10//! 2. **Auxiliary**: if at least one primary keyword was found, also add words
11//!    from the auxiliary list (pronouns, possessives) that appear in input.
12
13use std::collections::HashSet;
14
15use symbol_core::Symbol;
16use symbol_dict::SymbolDict;
17
18/// Perspective-swapping substitution table.
19///
20/// When extracting keywords, input tokens are matched against `from` entries.
21/// If a match is found, the corresponding `to` entry is used as the keyword
22/// candidate instead. Multiple `from` entries can match the same token,
23/// producing multiple keyword candidates (e.g., "YOU" → ["I", "ME"]).
24#[derive(Debug, Clone, Default)]
25pub struct SwapTable {
26    /// (from, to) pairs. Scanned linearly for each input token.
27    pub pairs: Vec<(String, String)>,
28}
29
30impl SwapTable {
31    /// Apply swap substitutions to a token. Returns all matching `to` values.
32    /// If no match, returns the original token.
33    pub fn apply(&self, token: &str) -> Vec<String> {
34        let mut results = Vec::new();
35        for (from, to) in &self.pairs {
36            if from.eq_ignore_ascii_case(token) {
37                results.push(to.clone());
38            }
39        }
40        if results.is_empty() {
41            results.push(token.to_string());
42        }
43        results
44    }
45}
46
47/// Configuration for keyword extraction.
48#[derive(Debug, Clone, Default)]
49pub struct KeywordConfig {
50    /// Words that are never used as keywords (common function words).
51    pub banned: HashSet<String>,
52    /// Words used as keywords only to supplement existing primary keywords.
53    pub auxiliary: HashSet<String>,
54    /// Perspective-swapping substitutions.
55    pub swap: SwapTable,
56}
57
58/// Extract keywords from tokenized input per the MegaHAL two-pass algorithm.
59///
60/// `S` must implement `AsRef<[u8]>` so we can check if the first character is
61/// alphanumeric (a requirement of the extraction rules).
62///
63/// `make_symbol` constructs a `Symbol` from a string.
64///
65/// Returns keywords in first-occurrence input order, deduplicated.  The C
66/// reference builds the keyword dictionary with `add_word` in input order
67/// (`make_keywords`, megahal.c:2273-2342), and `seed()` scans that same
68/// dictionary from a random index.  Preserving input order here ensures the
69/// seed scan visits keywords in the same distribution as C.
70pub fn extract_keywords<S: Symbol + AsRef<[u8]>>(
71    tokens: &[S],
72    dict: &SymbolDict<S>,
73    config: &KeywordConfig,
74    make_symbol: impl Fn(&str) -> S,
75) -> Vec<String> {
76    let mut keywords: Vec<String> = Vec::new();
77    let mut seen: HashSet<String> = HashSet::new();
78
79    // Collect all swap-applied candidates for the two-pass algorithm.
80    let candidates: Vec<Vec<String>> = tokens
81        .iter()
82        .map(|tok| {
83            let tok_str = std::str::from_utf8(tok.as_ref()).unwrap_or("");
84            config.swap.apply(tok_str)
85        })
86        .collect();
87
88    // Pass 1: Primary keywords.
89    for candidate_group in &candidates {
90        for candidate in candidate_group {
91            if !is_keyword_eligible(candidate, dict, config, false, &make_symbol) {
92                continue;
93            }
94            if seen.insert(candidate.clone()) {
95                keywords.push(candidate.clone());
96            }
97        }
98    }
99
100    // Pass 2: Auxiliary keywords (only if primary pass found at least one).
101    if !keywords.is_empty() {
102        for candidate_group in &candidates {
103            for candidate in candidate_group {
104                if !is_keyword_eligible(candidate, dict, config, true, &make_symbol) {
105                    continue;
106                }
107                if seen.insert(candidate.clone()) {
108                    keywords.push(candidate.clone());
109                }
110            }
111        }
112    }
113
114    keywords
115}
116
117/// Check if a candidate word is eligible as a keyword.
118///
119/// In primary mode (`aux_pass = false`): must be in dict, start alphanumeric,
120/// not banned, not auxiliary.
121/// In auxiliary mode (`aux_pass = true`): must be in dict, start alphanumeric,
122/// and IS in auxiliary list.
123fn is_keyword_eligible<S: Symbol + AsRef<[u8]>>(
124    candidate: &str,
125    dict: &SymbolDict<S>,
126    config: &KeywordConfig,
127    aux_pass: bool,
128    make_symbol: &impl Fn(&str) -> S,
129) -> bool {
130    // Must start with an alphanumeric character.
131    let first_byte = candidate.as_bytes().first().copied();
132    if !first_byte.is_some_and(|b| b.is_ascii_alphanumeric()) {
133        return false;
134    }
135
136    // Must exist in the model dictionary (the model has seen this word).
137    let sym = make_symbol(candidate);
138    if dict.find(&sym).is_none() {
139        return false;
140    }
141
142    let upper = candidate.to_uppercase();
143
144    if aux_pass {
145        // Auxiliary pass: only add words that ARE in the auxiliary list.
146        config.auxiliary.contains(&upper)
147    } else {
148        // Primary pass: skip banned and auxiliary words.
149        if config.banned.contains(&upper) {
150            return false;
151        }
152        if config.auxiliary.contains(&upper) {
153            return false;
154        }
155        true
156    }
157}
158
159/// Check if a given string exists in the model dictionary.
160///
161/// This requires constructing a temporary Symbol, which is the responsibility
162/// of the caller (the facade crate knows how to create `MegaHalSymbol` from strings).
163pub fn word_in_dict<S: Symbol>(dict: &SymbolDict<S>, symbol: &S) -> bool {
164    dict.find(symbol)
165        .is_some_and(|id| id != symbol_core::ERROR_ID)
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171
172    // --- Test infrastructure ---
173
174    #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
175    struct TestSym(String);
176
177    impl Symbol for TestSym {
178        fn error() -> Self {
179            TestSym("<ERROR>".into())
180        }
181        fn fin() -> Self {
182            TestSym("<FIN>".into())
183        }
184    }
185
186    impl AsRef<[u8]> for TestSym {
187        fn as_ref(&self) -> &[u8] {
188            self.0.as_bytes()
189        }
190    }
191
192    fn sym(s: &str) -> TestSym {
193        TestSym(s.to_uppercase())
194    }
195
196    fn dict_with(words: &[&str]) -> SymbolDict<TestSym> {
197        let mut dict = SymbolDict::new();
198        for w in words {
199            dict.intern(sym(w));
200        }
201        dict
202    }
203
204    // --- SwapTable tests ---
205
206    #[test]
207    fn swap_table_basic() {
208        let swap = SwapTable {
209            pairs: vec![
210                ("I".into(), "YOU".into()),
211                ("YOU".into(), "I".into()),
212                ("YOU".into(), "ME".into()),
213            ],
214        };
215
216        assert_eq!(swap.apply("I"), vec!["YOU"]);
217
218        let you_swaps = swap.apply("YOU");
219        assert_eq!(you_swaps, vec!["I", "ME"]); // multiple matches
220
221        assert_eq!(swap.apply("HELLO"), vec!["HELLO"]); // no match → original
222    }
223
224    #[test]
225    fn swap_case_insensitive() {
226        let swap = SwapTable {
227            pairs: vec![("MY".into(), "YOUR".into())],
228        };
229        assert_eq!(swap.apply("my"), vec!["YOUR"]);
230        assert_eq!(swap.apply("My"), vec!["YOUR"]);
231    }
232
233    #[test]
234    fn swap_empty_table() {
235        let swap = SwapTable::default();
236        assert_eq!(swap.apply("HELLO"), vec!["HELLO"]);
237    }
238
239    // --- KeywordConfig tests ---
240
241    #[test]
242    fn keyword_config_default() {
243        let config = KeywordConfig::default();
244        assert!(config.banned.is_empty());
245        assert!(config.auxiliary.is_empty());
246        assert!(config.swap.pairs.is_empty());
247    }
248
249    fn has(kws: &[String], word: &str) -> bool {
250        kws.iter().any(|s| s == word)
251    }
252
253    // --- extract_keywords tests ---
254
255    #[test]
256    fn extract_skips_words_not_in_dict() {
257        let dict = dict_with(&["HELLO", "WORLD"]);
258        let config = KeywordConfig::default();
259        let tokens = vec![sym("HELLO"), sym(" "), sym("UNKNOWN")];
260        let kws = extract_keywords(&tokens, &dict, &config, sym);
261        assert!(has(&kws, "HELLO"));
262        assert!(!has(&kws, "UNKNOWN"));
263    }
264
265    #[test]
266    fn extract_skips_non_alphanumeric_start() {
267        let dict = dict_with(&["HELLO", " ", "."]);
268        let config = KeywordConfig::default();
269        let tokens = vec![sym("HELLO"), sym(" "), sym(".")];
270        let kws = extract_keywords(&tokens, &dict, &config, sym);
271        assert!(has(&kws, "HELLO"));
272        assert!(!has(&kws, " "));
273        assert!(!has(&kws, "."));
274    }
275
276    #[test]
277    fn extract_skips_banned() {
278        let dict = dict_with(&["THE", "CAT"]);
279        let mut config = KeywordConfig::default();
280        config.banned.insert("THE".into());
281        let tokens = vec![sym("THE"), sym("CAT")];
282        let kws = extract_keywords(&tokens, &dict, &config, sym);
283        assert!(!has(&kws, "THE"));
284        assert!(has(&kws, "CAT"));
285    }
286
287    #[test]
288    fn extract_aux_added_when_primary_exists() {
289        let dict = dict_with(&["MY", "CAT"]);
290        let mut config = KeywordConfig::default();
291        config.auxiliary.insert("MY".into());
292        let tokens = vec![sym("MY"), sym("CAT")];
293        let kws = extract_keywords(&tokens, &dict, &config, sym);
294        // Primary pass gets CAT (not banned, not aux).
295        // Auxiliary pass then adds MY (because primary found at least one).
296        assert!(has(&kws, "CAT"));
297        assert!(has(&kws, "MY"));
298    }
299
300    #[test]
301    fn extract_no_aux_without_primary() {
302        let dict = dict_with(&["MY"]);
303        let mut config = KeywordConfig::default();
304        config.auxiliary.insert("MY".into());
305        let tokens = vec![sym("MY")];
306        let kws = extract_keywords(&tokens, &dict, &config, sym);
307        // No primary keywords found → aux pass doesn't run.
308        assert!(kws.is_empty());
309    }
310
311    #[test]
312    fn extract_with_swap_substitution() {
313        let dict = dict_with(&["YOU", "CAT"]);
314        let config = KeywordConfig {
315            swap: SwapTable {
316                pairs: vec![("I".into(), "YOU".into())],
317            },
318            ..Default::default()
319        };
320        // Token "I" swaps to "YOU" (which IS in dict). "CAT" is unchanged.
321        let tokens = vec![sym("I"), sym(" "), sym("CAT")];
322        let kws = extract_keywords(&tokens, &dict, &config, sym);
323        assert!(has(&kws, "YOU"));
324        assert!(has(&kws, "CAT"));
325        assert!(!has(&kws, "I"));
326    }
327
328    #[test]
329    fn extract_swap_target_must_be_in_dict() {
330        let dict = dict_with(&["CAT"]); // "YOU" is NOT in dict
331        let config = KeywordConfig {
332            swap: SwapTable {
333                pairs: vec![("I".into(), "YOU".into())],
334            },
335            ..Default::default()
336        };
337        let tokens = vec![sym("I"), sym(" "), sym("CAT")];
338        let kws = extract_keywords(&tokens, &dict, &config, sym);
339        // "I" swaps to "YOU", but "YOU" is not in dict → skipped.
340        assert!(!has(&kws, "YOU"));
341        assert!(has(&kws, "CAT"));
342    }
343
344    #[test]
345    fn extract_empty_input() {
346        let dict = dict_with(&["HELLO"]);
347        let config = KeywordConfig::default();
348        let tokens: Vec<TestSym> = vec![];
349        let kws = extract_keywords(&tokens, &dict, &config, sym);
350        assert!(kws.is_empty());
351    }
352
353    #[test]
354    fn extract_all_banned_yields_empty() {
355        let dict = dict_with(&["THE", "A", "IS"]);
356        let mut config = KeywordConfig::default();
357        config.banned.insert("THE".into());
358        config.banned.insert("A".into());
359        config.banned.insert("IS".into());
360        let tokens = vec![sym("THE"), sym("A"), sym("IS")];
361        let kws = extract_keywords(&tokens, &dict, &config, sym);
362        assert!(kws.is_empty());
363    }
364
365    // Input order is preserved; duplicates are dropped (first occurrence wins).
366    #[test]
367    fn extract_preserves_input_order_not_sorted() {
368        // ZEBRA < APPLE alphabetically (Z > A) but appears first in input.
369        let dict = dict_with(&["ZEBRA", "APPLE", "MANGO"]);
370        let config = KeywordConfig::default();
371        let tokens = vec![sym("ZEBRA"), sym("APPLE"), sym("MANGO")];
372        let kws = extract_keywords(&tokens, &dict, &config, sym);
373        // Must preserve input order, not sort.
374        assert_eq!(kws, vec!["ZEBRA", "APPLE", "MANGO"]);
375        // Sorted order would be APPLE, MANGO, ZEBRA — verify we differ from that.
376        let mut sorted = kws.clone();
377        sorted.sort();
378        assert_ne!(kws, sorted);
379    }
380
381    #[test]
382    fn extract_deduplicates_keeping_first_occurrence() {
383        let dict = dict_with(&["CAT", "DOG"]);
384        let config = KeywordConfig::default();
385        // CAT appears twice; only the first occurrence should be kept.
386        let tokens = vec![sym("CAT"), sym("DOG"), sym("CAT")];
387        let kws = extract_keywords(&tokens, &dict, &config, sym);
388        assert_eq!(kws, vec!["CAT", "DOG"]);
389    }
390
391    // --- word_in_dict tests ---
392
393    #[test]
394    fn word_in_dict_found() {
395        let dict = dict_with(&["HELLO"]);
396        assert!(word_in_dict(&dict, &sym("HELLO")));
397    }
398
399    #[test]
400    fn word_in_dict_missing() {
401        let dict = dict_with(&["HELLO"]);
402        assert!(!word_in_dict(&dict, &sym("NOPE")));
403    }
404
405    #[test]
406    fn word_in_dict_rejects_error_sentinel() {
407        let dict: SymbolDict<TestSym> = SymbolDict::new();
408        // ERROR sentinel is at ID 0, but word_in_dict should reject it.
409        assert!(!word_in_dict(&dict, &TestSym::error()));
410    }
411}