Skip to main content

megahal_gen/
lib.rs

1//! MegaHAL reply generation: babble, seeding, bidirectional generation,
2//! and surprise evaluation.
3//!
4//! This crate implements the core MegaHAL reply generation algorithm:
5//!
6//! 1. Seed a starting symbol from the keyword list.
7//! 2. Run a forward phase from the seed to generate the rest of the sentence.
8//! 3. Run a backward phase from the seed to generate the beginning.
9//! 4. Evaluate candidates by surprise scoring.
10//! 5. Select the highest-scoring candidate.
11//!
12//! It implements MegaHAL's keyword-biased random walk: keywords encountered during
13//! the walk are greedily selected, while non-keywords fall back to probability-weighted
14//! selection.
15
16use std::collections::HashSet;
17use std::time::{Duration, Instant};
18
19use megahal_markov::{BidirectionalModel, ContextWindow};
20use ngram_trie::Trie;
21use rand::{Rng, RngExt};
22use symbol_core::{ERROR_ID, FIN_ID, Symbol, SymbolId};
23use symbol_dict::SymbolDict;
24
25/// Controls how many candidate replies are generated before selecting the best.
26#[derive(Debug, Clone)]
27pub enum GenerationLimit {
28    /// Stop after the given duration.
29    Timeout(Duration),
30    /// Stop after the given number of iterations.
31    Iterations(usize),
32    /// Stop when either limit is reached.
33    Both {
34        timeout: Duration,
35        max_iterations: usize,
36    },
37}
38
39impl Default for GenerationLimit {
40    fn default() -> Self {
41        // Match original MegaHAL: 1-second timeout.
42        GenerationLimit::Timeout(Duration::from_secs(1))
43    }
44}
45
46/// Generate the best reply for given input tokens and keywords.
47///
48/// Runs the candidate generation loop from MEGAHAL_SPEC.md Section 7.1:
49/// 1. Generate a baseline reply with empty keywords.
50/// 2. Repeatedly generate candidates with keywords, scoring each by surprise.
51/// 3. Return the highest-scoring candidate that differs from the input.
52///
53/// `keywords` is the ordered keyword list from `extract_keywords`; input order
54/// is preserved so that `seed` scans keywords in the same order as C's
55/// `make_keywords` dictionary (`megahal.c:2273-2342`).  A `HashSet` is built
56/// internally for the O(1) membership checks in `babble` and `evaluate_reply`.
57///
58/// C's loop is a `do/while` (`megahal.c:2228-2240`): it always generates and
59/// evaluates at least one keyword-seeded candidate before checking the limit.
60pub fn generate_reply<S, R>(
61    model: &BidirectionalModel<S>,
62    input_tokens: &[S],
63    keywords: &[S],
64    aux_set: &HashSet<S>,
65    limit: &GenerationLimit,
66    rng: &mut R,
67) -> Vec<S>
68where
69    S: Symbol + AsRef<[u8]>,
70    R: Rng,
71{
72    // Build a HashSet from the ordered keywords for O(1) membership checks.
73    let keywords_set: HashSet<S> = keywords.iter().cloned().collect();
74
75    let empty_keywords: &[S] = &[];
76    let empty_keywords_set: HashSet<S> = HashSet::new();
77    let empty_aux = HashSet::new();
78
79    // Baseline reply (no keyword bias). Per C's `dissimilar()` check
80    // (megahal.c:2215-2218), drop the baseline when it equals the input.
81    let mut best = generate_one_reply(model, empty_keywords, &empty_keywords_set, &empty_aux, rng);
82    if tokens_equal(&best, input_tokens) {
83        best = Vec::new();
84    }
85
86    let mut max_surprise: f64 = -1.0;
87    let start = Instant::now();
88    let mut iterations: usize = 0;
89
90    // C's loop is do/while: generate first, check limit after.
91    loop {
92        let candidate = generate_one_reply(model, keywords, &keywords_set, aux_set, rng);
93        let surprise = evaluate_reply(model, &candidate, &keywords_set);
94
95        if surprise > max_surprise && !tokens_equal(&candidate, input_tokens) {
96            max_surprise = surprise;
97            best = candidate;
98        }
99
100        iterations += 1;
101
102        // Check limits after generating (matching C's do/while).
103        match limit {
104            GenerationLimit::Timeout(d) => {
105                if start.elapsed() >= *d {
106                    break;
107                }
108            }
109            GenerationLimit::Iterations(n) => {
110                if iterations >= *n {
111                    break;
112                }
113            }
114            GenerationLimit::Both {
115                timeout,
116                max_iterations,
117            } => {
118                if start.elapsed() >= *timeout || iterations >= *max_iterations {
119                    break;
120                }
121            }
122        }
123    }
124
125    best
126}
127
128/// Generate a single candidate reply (forward + backward phases).
129///
130/// MEGAHAL_SPEC.md Section 7.2.
131fn generate_one_reply<S, R>(
132    model: &BidirectionalModel<S>,
133    keywords: &[S],
134    keywords_set: &HashSet<S>,
135    aux_set: &HashSet<S>,
136    rng: &mut R,
137) -> Vec<S>
138where
139    S: Symbol + AsRef<[u8]>,
140    R: Rng,
141{
142    let mut reply: Vec<SymbolId> = Vec::new();
143    let mut used_key = false;
144
145    // Forward phase. Per C `reply()` (megahal.c:2420-2471), the backward phase
146    // always runs even when forward produces nothing.
147    let mut ctx = model.forward_context();
148
149    let seed_id = seed(model, keywords, aux_set, rng);
150    if seed_id != ERROR_ID && seed_id != FIN_ID {
151        reply.push(seed_id);
152        ctx.advance(&model.forward, seed_id);
153
154        loop {
155            let sym = babble(
156                &model.forward,
157                &ctx,
158                &model.dictionary,
159                keywords_set,
160                aux_set,
161                &reply,
162                &mut used_key,
163                rng,
164            );
165            if sym == ERROR_ID || sym == FIN_ID {
166                break;
167            }
168            reply.push(sym);
169            ctx.advance(&model.forward, sym);
170        }
171    }
172
173    // Backward phase.
174    let mut ctx = model.backward_context();
175
176    // Re-establish backward context from the beginning of the reply.
177    // Spec 7.2.3: walk from index min(reply_length-1, order) down to 0.
178    // This matches the C code: for(i=MIN(size-1,order); i>=0; i--)
179    if !reply.is_empty() {
180        let start = (reply.len() - 1).min(model.order as usize);
181        for i in (0..=start).rev() {
182            ctx.advance(&model.backward, reply[i]);
183        }
184    }
185
186    loop {
187        let sym = babble(
188            &model.backward,
189            &ctx,
190            &model.dictionary,
191            keywords_set,
192            aux_set,
193            &reply,
194            &mut used_key,
195            rng,
196        );
197        if sym == ERROR_ID || sym == FIN_ID {
198            break;
199        }
200        reply.insert(0, sym);
201        ctx.advance(&model.backward, sym);
202    }
203
204    resolve_ids(model, &reply)
205}
206
207/// Select a seed symbol for forward generation.
208///
209/// MEGAHAL_SPEC.md Section 7.2.1.
210///
211/// `keywords` is the ordered slice from `extract_keywords`, preserving the
212/// input order C's `make_keywords` builds (`megahal.c:2273-2342`).  The scan
213/// starts at a random index and wraps around, matching C's `seed()`
214/// (`megahal.c:2694-2706`).  The `.sort()` that was here before introduced a
215/// different distribution because sorted order differs from input order.
216fn seed<S, R>(
217    model: &BidirectionalModel<S>,
218    keywords: &[S],
219    aux_set: &HashSet<S>,
220    rng: &mut R,
221) -> SymbolId
222where
223    S: Symbol + AsRef<[u8]>,
224    R: Rng,
225{
226    let root = model.forward.root();
227    let children = model.forward.children(root);
228
229    // Keywords are scanned first: a keyword that is in the dictionary and not
230    // auxiliary seeds the reply even when the forward root has no children,
231    // matching C seed() (megahal.c:2697-2706).
232    if !keywords.is_empty() {
233        let start = rng.random_range(0..keywords.len());
234
235        for offset in 0..keywords.len() {
236            let idx = (start + offset) % keywords.len();
237            let kw = &keywords[idx];
238
239            if let Some(id) = model.dictionary.find(kw)
240                && !aux_set.contains(kw)
241            {
242                return id;
243            }
244        }
245    }
246
247    // Default: a random child of the forward root, or ERROR if it has none.
248    if children.is_empty() {
249        return ERROR_ID;
250    }
251    let idx = rng.random_range(0..children.len());
252    model.forward.node(children[idx]).symbol
253}
254
255/// Keyword-biased random symbol selection (the "babble" function).
256///
257/// MEGAHAL_SPEC.md Section 7.3.
258#[allow(clippy::too_many_arguments)]
259fn babble<S, R>(
260    trie: &Trie,
261    ctx: &ContextWindow,
262    dict: &SymbolDict<S>,
263    keywords: &HashSet<S>,
264    aux_set: &HashSet<S>,
265    reply: &[SymbolId],
266    used_key: &mut bool,
267    rng: &mut R,
268) -> SymbolId
269where
270    S: Symbol + AsRef<[u8]>,
271    R: Rng,
272{
273    // Find deepest available context.
274    let node_ref = match ctx.deepest() {
275        Some(r) => r,
276        None => return ERROR_ID,
277    };
278
279    let node = trie.node(node_ref);
280    let children = trie.children(node_ref);
281
282    if children.is_empty() {
283        return ERROR_ID;
284    }
285
286    let branch = children.len();
287    // C `babble()` calls `rnd(node->usage)` which returns 0 when usage is 0
288    // and the loop falls through; `rng.random_range(0..0)` panics, so guard
289    // explicitly and treat usage==0 as sentence-terminating.
290    if node.usage == 0 {
291        return ERROR_ID;
292    }
293    let mut i = rng.random_range(0..branch);
294    let mut count = rng.random_range(0..node.usage as i64);
295
296    loop {
297        let child_ref = children[i];
298        let child = trie.node(child_ref);
299        let sym = child.symbol;
300
301        // Check if this symbol is a keyword we should greedily select.
302        let word = dict.resolve(sym);
303        let is_keyword = keywords.contains(word);
304        let is_aux = aux_set.contains(word);
305        let already_in_reply = reply.contains(&sym);
306
307        if is_keyword && (*used_key || !is_aux) && !already_in_reply {
308            *used_key = true;
309            return sym;
310        }
311
312        // Otherwise, probability-weighted selection.
313        count -= child.count as i64;
314        if count < 0 {
315            return sym;
316        }
317
318        i = (i + 1) % branch;
319    }
320}
321
322/// Score a candidate reply by surprise (Shannon entropy of keywords in context).
323///
324/// MEGAHAL_SPEC.md Section 8.
325fn evaluate_reply<S>(model: &BidirectionalModel<S>, candidate: &[S], keywords: &HashSet<S>) -> f64
326where
327    S: Symbol + AsRef<[u8]>,
328{
329    if candidate.is_empty() {
330        return 0.0;
331    }
332
333    let mut entropy: f64 = 0.0;
334    let mut num: usize = 0;
335
336    // Forward evaluation.
337    let mut ctx = model.forward_context();
338    for token in candidate {
339        let sym_id = match model.dictionary.find(token) {
340            Some(id) => id,
341            None => continue,
342        };
343
344        if keywords.contains(token) {
345            let mut prob: f64 = 0.0;
346            let mut ctx_count: usize = 0;
347
348            for j in 0..model.order as usize {
349                if let Some(parent_ref) = ctx.at_depth(j)
350                    && let Some(child_ref) = model.forward.find_child(parent_ref, sym_id)
351                {
352                    let child = model.forward.node(child_ref);
353                    let parent = model.forward.node(parent_ref);
354                    if parent.usage > 0 {
355                        prob += child.count as f64 / parent.usage as f64;
356                        ctx_count += 1;
357                    }
358                }
359            }
360
361            if ctx_count > 0 {
362                entropy -= (prob / ctx_count as f64).ln();
363            }
364            num += 1;
365        }
366
367        ctx.advance(&model.forward, sym_id);
368    }
369
370    // Backward evaluation.
371    let mut ctx = model.backward_context();
372    for token in candidate.iter().rev() {
373        let sym_id = match model.dictionary.find(token) {
374            Some(id) => id,
375            None => continue,
376        };
377
378        if keywords.contains(token) {
379            let mut prob: f64 = 0.0;
380            let mut ctx_count: usize = 0;
381
382            for j in 0..model.order as usize {
383                if let Some(parent_ref) = ctx.at_depth(j)
384                    && let Some(child_ref) = model.backward.find_child(parent_ref, sym_id)
385                {
386                    let child = model.backward.node(child_ref);
387                    let parent = model.backward.node(parent_ref);
388                    if parent.usage > 0 {
389                        prob += child.count as f64 / parent.usage as f64;
390                        ctx_count += 1;
391                    }
392                }
393            }
394
395            if ctx_count > 0 {
396                entropy -= (prob / ctx_count as f64).ln();
397            }
398            num += 1;
399        }
400
401        ctx.advance(&model.backward, sym_id);
402    }
403
404    // Length penalty.
405    if num >= 8 {
406        entropy /= ((num - 1) as f64).sqrt();
407    }
408    if num >= 16 {
409        entropy /= num as f64;
410    }
411
412    entropy
413}
414
415/// Capitalize a token sequence per MegaHAL sentence-case rules.
416///
417/// MEGAHAL_SPEC.md Section 9.1. Mirrors C `capitalize()` in megahal.c:
418/// the sentence-start flag is set when a `!.?` is followed by whitespace,
419/// not by the punctuation itself.
420pub fn capitalize(tokens: &[String]) -> String {
421    let raw: String = tokens.concat();
422    let mut result = Vec::with_capacity(raw.len());
423    let bytes = raw.as_bytes();
424    let mut start = true;
425
426    for (i, &b) in bytes.iter().enumerate() {
427        if b.is_ascii_alphabetic() {
428            if start {
429                result.push(b.to_ascii_uppercase());
430            } else {
431                result.push(b.to_ascii_lowercase());
432            }
433            start = false;
434        } else {
435            result.push(b);
436        }
437        if i > 2 && b.is_ascii_whitespace() && matches!(bytes[i - 1], b'!' | b'.' | b'?') {
438            start = true;
439        }
440    }
441
442    String::from_utf8(result).unwrap_or(raw)
443}
444
445/// Check if two token sequences are equal (case-insensitive, for dissimilarity test).
446fn tokens_equal<S: Symbol>(a: &[S], b: &[S]) -> bool {
447    if a.len() != b.len() {
448        return false;
449    }
450    a.iter().zip(b.iter()).all(|(x, y)| x == y)
451}
452
453/// Resolve a sequence of SymbolIds back to Symbol values.
454fn resolve_ids<S: Symbol>(model: &BidirectionalModel<S>, ids: &[SymbolId]) -> Vec<S> {
455    ids.iter()
456        .map(|&id| model.dictionary.resolve(id).clone())
457        .collect()
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use rand::SeedableRng;
464    use rand::rngs::SmallRng;
465
466    // --- Test infrastructure ---
467
468    #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
469    struct TSym(String);
470
471    impl Symbol for TSym {
472        fn error() -> Self {
473            TSym("<ERROR>".into())
474        }
475        fn fin() -> Self {
476            TSym("<FIN>".into())
477        }
478    }
479
480    impl AsRef<[u8]> for TSym {
481        fn as_ref(&self) -> &[u8] {
482            self.0.as_bytes()
483        }
484    }
485
486    fn ts(s: &str) -> TSym {
487        TSym(s.to_string())
488    }
489
490    fn trained_model(order: u8, sentences: &[&[&str]]) -> BidirectionalModel<TSym> {
491        let mut model = BidirectionalModel::new(order);
492        for sentence in sentences {
493            let tokens: Vec<TSym> = sentence.iter().map(|&s| ts(s)).collect();
494            model.learn(&tokens);
495        }
496        model
497    }
498
499    fn make_rng(s: u64) -> SmallRng {
500        SmallRng::seed_from_u64(s)
501    }
502
503    // --- GenerationLimit tests ---
504
505    #[test]
506    fn generation_limit_default_is_timeout() {
507        let limit = GenerationLimit::default();
508        assert!(matches!(limit, GenerationLimit::Timeout(_)));
509    }
510
511    // --- capitalize tests ---
512
513    #[test]
514    fn capitalize_basic() {
515        let tokens = vec![
516            "hello".to_string(),
517            " ".to_string(),
518            "world".to_string(),
519            ".".to_string(),
520        ];
521        assert_eq!(capitalize(&tokens), "Hello world.");
522    }
523
524    #[test]
525    fn capitalize_after_period() {
526        let tokens = vec![
527            "hello".to_string(),
528            ". ".to_string(),
529            "world".to_string(),
530            ".".to_string(),
531        ];
532        assert_eq!(capitalize(&tokens), "Hello. World.");
533    }
534
535    #[test]
536    fn capitalize_empty() {
537        let tokens: Vec<String> = vec![];
538        assert_eq!(capitalize(&tokens), "");
539    }
540
541    #[test]
542    fn capitalize_after_exclamation() {
543        let tokens = vec![
544            "wow".to_string(),
545            "! ".to_string(),
546            "amazing".to_string(),
547            ".".to_string(),
548        ];
549        assert_eq!(capitalize(&tokens), "Wow! Amazing.");
550    }
551
552    #[test]
553    fn capitalize_after_question() {
554        let tokens = vec![
555            "really".to_string(),
556            "? ".to_string(),
557            "yes".to_string(),
558            ".".to_string(),
559        ];
560        assert_eq!(capitalize(&tokens), "Really? Yes.");
561    }
562
563    #[test]
564    fn capitalize_no_space_after_period_does_not_capitalize() {
565        let tokens = vec!["a.b.c".to_string()];
566        assert_eq!(capitalize(&tokens), "A.b.c");
567    }
568
569    #[test]
570    fn capitalize_glued_sentences_not_split() {
571        let tokens = vec!["hello.world".to_string()];
572        assert_eq!(capitalize(&tokens), "Hello.world");
573    }
574
575    #[test]
576    fn capitalize_leading_ellipsis() {
577        // First alpha after a leading run of dots still gets uppercased
578        // (start flag was never cleared).
579        let tokens = vec!["...hello".to_string()];
580        assert_eq!(capitalize(&tokens), "...Hello");
581    }
582
583    // --- seed tests ---
584
585    #[test]
586    fn seed_selects_keyword() {
587        let model = trained_model(
588            2,
589            &[
590                &["THE", " ", "CAT", " ", "SAT"],
591                &["THE", " ", "DOG", " ", "RAN"],
592            ],
593        );
594        let kws = vec![ts("CAT")];
595        let aux = HashSet::new();
596        let mut rng = make_rng(42);
597        let id = seed(&model, &kws, &aux, &mut rng);
598        let cat_id = model.dictionary.find(&ts("CAT")).unwrap();
599        assert_eq!(id, cat_id);
600    }
601
602    #[test]
603    fn seed_skips_aux_keyword() {
604        let model = trained_model(2, &[&["THE", " ", "MY", " ", "CAT"]]);
605        let kws = vec![ts("MY")];
606        let mut aux = HashSet::new();
607        aux.insert(ts("MY"));
608        let mut rng = make_rng(42);
609        let id = seed(&model, &kws, &aux, &mut rng);
610        // MY is aux-only → seed falls back to random child of forward root.
611        let my_id = model.dictionary.find(&ts("MY")).unwrap();
612        assert_ne!(id, my_id);
613    }
614
615    #[test]
616    fn seed_with_empty_keywords_picks_random() {
617        let model = trained_model(2, &[&["THE", " ", "CAT"]]);
618        let kws: Vec<TSym> = vec![];
619        let aux = HashSet::new();
620        let mut rng = make_rng(42);
621        let id = seed(&model, &kws, &aux, &mut rng);
622        assert_ne!(id, ERROR_ID);
623        assert_ne!(id, FIN_ID);
624    }
625
626    #[test]
627    fn seed_returns_error_on_empty_model() {
628        let model: BidirectionalModel<TSym> = BidirectionalModel::new(2);
629        let kws: Vec<TSym> = vec![];
630        let aux = HashSet::new();
631        let mut rng = make_rng(42);
632        let id = seed(&model, &kws, &aux, &mut rng);
633        assert_eq!(id, ERROR_ID);
634    }
635
636    // seed visits keywords in input order, not sorted order.
637    #[test]
638    fn seed_visits_keywords_in_input_order() {
639        // Train a model that knows ZEBRA and APPLE.
640        let model = trained_model(2, &[&["ZEBRA", " ", "SAT"], &["APPLE", " ", "RAN"]]);
641        // ZEBRA > APPLE alphabetically, so if seed sorted we would pick APPLE
642        // first on many RNG seeds.  With input order [ZEBRA, APPLE], a start
643        // index of 0 must land on ZEBRA first.
644        let kws = vec![ts("ZEBRA"), ts("APPLE")];
645        let aux: HashSet<TSym> = HashSet::new();
646
647        // Force start index 0 by using a seeded RNG that produces 0 for
648        // random_range(0..2).  Iterate a few seeds to find one that gives 0.
649        // Both ZEBRA and APPLE are valid seeds (neither is aux), so whichever
650        // index 0 points to must be returned.  We assert that at start=0 the
651        // result is ZEBRA (input-order index 0), not APPLE (sorted index 0).
652        let zebra_id = model.dictionary.find(&ts("ZEBRA")).unwrap();
653        let apple_id = model.dictionary.find(&ts("APPLE")).unwrap();
654
655        // With a fresh SmallRng(0), random_range(0..2) gives a deterministic
656        // value; we just need to verify that the result is consistent with
657        // input order (index 0 = ZEBRA), not sorted order (index 0 = APPLE).
658        // Try many RNG seeds: for any seed that yields start==0, result must
659        // be ZEBRA; for start==1, result must be APPLE.
660        let mut found_start_zero = false;
661        for seed_val in 0u64..200 {
662            let mut rng = make_rng(seed_val);
663            // Peek what start index would be chosen (same call as seed()).
664            let mut rng_peek = make_rng(seed_val);
665            let start_idx = rng_peek.random_range(0usize..2);
666            let result = seed(&model, &kws, &aux, &mut rng);
667            if start_idx == 0 {
668                assert_eq!(
669                    result, zebra_id,
670                    "seed={seed_val}: start==0 should pick kws[0]=ZEBRA (input order), not APPLE (sorted order)"
671                );
672                found_start_zero = true;
673            } else {
674                assert_eq!(
675                    result, apple_id,
676                    "seed={seed_val}: start==1 should pick kws[1]=APPLE"
677                );
678            }
679        }
680        assert!(
681            found_start_zero,
682            "no RNG seed produced start index 0 in 200 tries"
683        );
684    }
685
686    #[test]
687    fn seed_returns_keyword_when_forward_root_has_no_children() {
688        // A keyword in the dictionary seeds the reply even when the forward
689        // root is childless, matching C seed() (megahal.c:2697-2706). This
690        // state is reachable only via a degenerate loaded brain, since learning
691        // couples dictionary and trie population.
692        let mut model: BidirectionalModel<TSym> = BidirectionalModel::new(2);
693        let hello_id = model.dictionary.intern(ts("HELLO"));
694        assert!(model.forward.children(model.forward.root()).is_empty());
695
696        let kws = vec![ts("HELLO")];
697        let aux: HashSet<TSym> = HashSet::new();
698        let mut rng = make_rng(0);
699        assert_eq!(seed(&model, &kws, &aux, &mut rng), hello_id);
700    }
701
702    // --- evaluate_reply tests ---
703
704    #[test]
705    fn evaluate_empty_candidate_returns_zero() {
706        let model = trained_model(2, &[&["A", "B", "C"]]);
707        let kws = HashSet::new();
708        let score = evaluate_reply(&model, &[], &kws);
709        assert_eq!(score, 0.0);
710    }
711
712    #[test]
713    fn evaluate_no_keywords_returns_zero() {
714        let model = trained_model(2, &[&["A", "B", "C"]]);
715        let kws = HashSet::new();
716        let candidate = vec![ts("A"), ts("B"), ts("C")];
717        let score = evaluate_reply(&model, &candidate, &kws);
718        assert_eq!(score, 0.0);
719    }
720
721    #[test]
722    fn evaluate_with_keywords_returns_positive() {
723        let model = trained_model(2, &[&["A", "B", "C"], &["A", "B", "C"], &["A", "B", "C"]]);
724        let mut kws = HashSet::new();
725        kws.insert(ts("B"));
726        let candidate = vec![ts("A"), ts("B"), ts("C")];
727        let score = evaluate_reply(&model, &candidate, &kws);
728        assert!(score > 0.0, "Expected positive surprise, got {score}");
729    }
730
731    #[test]
732    fn evaluate_unknown_token_skipped() {
733        let model = trained_model(2, &[&["A", "B", "C"]]);
734        let mut kws = HashSet::new();
735        kws.insert(ts("UNKNOWN"));
736        // UNKNOWN is not in the dict → find returns None → skipped.
737        let candidate = vec![ts("UNKNOWN")];
738        let score = evaluate_reply(&model, &candidate, &kws);
739        assert_eq!(score, 0.0);
740    }
741
742    // --- tokens_equal tests ---
743
744    #[test]
745    fn tokens_equal_same() {
746        let a = vec![ts("A"), ts("B")];
747        let b = vec![ts("A"), ts("B")];
748        assert!(tokens_equal(&a, &b));
749    }
750
751    #[test]
752    fn tokens_equal_different_length() {
753        let a = vec![ts("A"), ts("B")];
754        let b = vec![ts("A")];
755        assert!(!tokens_equal(&a, &b));
756    }
757
758    #[test]
759    fn tokens_equal_different_content() {
760        let a = vec![ts("A"), ts("B")];
761        let b = vec![ts("A"), ts("C")];
762        assert!(!tokens_equal(&a, &b));
763    }
764
765    #[test]
766    fn tokens_equal_both_empty() {
767        let a: Vec<TSym> = vec![];
768        let b: Vec<TSym> = vec![];
769        assert!(tokens_equal(&a, &b));
770    }
771
772    // --- generate_reply integration tests ---
773
774    #[test]
775    fn generate_reply_empty_model() {
776        let model: BidirectionalModel<TSym> = BidirectionalModel::new(2);
777        let kws: Vec<TSym> = vec![];
778        let aux = HashSet::new();
779        let limit = GenerationLimit::Iterations(10);
780        let mut rng = make_rng(42);
781        let reply = generate_reply(&model, &[], &kws, &aux, &limit, &mut rng);
782        assert!(reply.is_empty());
783    }
784
785    #[test]
786    fn generate_reply_produces_output() {
787        let model = trained_model(
788            2,
789            &[
790                &["THE", " ", "CAT", " ", "SAT"],
791                &["THE", " ", "DOG", " ", "RAN"],
792                &["A", " ", "BIG", " ", "CAT"],
793            ],
794        );
795        let kws = vec![ts("CAT")];
796        let aux = HashSet::new();
797        let limit = GenerationLimit::Iterations(10);
798        let mut rng = make_rng(42);
799        let reply = generate_reply(&model, &[], &kws, &aux, &limit, &mut rng);
800        assert!(!reply.is_empty());
801    }
802
803    #[test]
804    fn generate_reply_deterministic() {
805        let build = || {
806            let model = trained_model(
807                2,
808                &[
809                    &["THE", " ", "CAT", " ", "SAT"],
810                    &["THE", " ", "DOG", " ", "RAN"],
811                ],
812            );
813            let kws = vec![ts("CAT")];
814            let aux = HashSet::new();
815            let limit = GenerationLimit::Iterations(50);
816            let mut rng = make_rng(42);
817            generate_reply(&model, &[], &kws, &aux, &limit, &mut rng)
818        };
819        assert_eq!(build(), build());
820    }
821
822    // Iterations(0) still generates one keyword-seeded candidate (C do/while).
823    #[test]
824    fn generate_reply_iterations_zero_still_evaluates_one_candidate() {
825        let model = trained_model(
826            2,
827            &[
828                &["THE", " ", "CAT", " ", "SAT"],
829                &["THE", " ", "DOG", " ", "RAN"],
830            ],
831        );
832        let kws = vec![ts("CAT")];
833        let aux = HashSet::new();
834        let limit = GenerationLimit::Iterations(0);
835        let mut rng = make_rng(42);
836        // C's do/while always runs the body once before checking the bound.
837        // Iterations(0) means the limit check fires after iteration 1, so
838        // exactly one keyword-seeded candidate is generated and evaluated.
839        // The previous code checked the limit first and skipped generation.
840        let reply = generate_reply(&model, &[], &kws, &aux, &limit, &mut rng);
841        assert!(
842            !reply.is_empty(),
843            "Iterations(0) must still generate one keyword-seeded candidate per C do/while"
844        );
845    }
846
847    #[test]
848    fn generate_reply_with_timeout() {
849        let model = trained_model(
850            2,
851            &[
852                &["THE", " ", "CAT", " ", "SAT"],
853                &["THE", " ", "DOG", " ", "RAN"],
854                &["A", " ", "BIG", " ", "CAT"],
855            ],
856        );
857        let kws = vec![ts("CAT")];
858        let aux = HashSet::new();
859        let limit = GenerationLimit::Timeout(Duration::from_millis(50));
860        let mut rng = make_rng(42);
861        let reply = generate_reply(&model, &[], &kws, &aux, &limit, &mut rng);
862        assert!(!reply.is_empty());
863    }
864
865    #[test]
866    fn generate_reply_with_both_limit() {
867        let model = trained_model(
868            2,
869            &[
870                &["THE", " ", "CAT", " ", "SAT"],
871                &["THE", " ", "DOG", " ", "RAN"],
872            ],
873        );
874        let kws = vec![ts("CAT")];
875        let aux = HashSet::new();
876        let limit = GenerationLimit::Both {
877            timeout: Duration::from_millis(50),
878            max_iterations: 10,
879        };
880        let mut rng = make_rng(42);
881        let reply = generate_reply(&model, &[], &kws, &aux, &limit, &mut rng);
882        assert!(!reply.is_empty());
883    }
884
885    #[test]
886    fn generate_reply_with_aux_keywords() {
887        let model = trained_model(
888            2,
889            &[
890                &["MY", " ", "CAT", " ", "SAT"],
891                &["YOUR", " ", "DOG", " ", "RAN"],
892            ],
893        );
894        let kws = vec![ts("CAT"), ts("MY")];
895        let mut aux = HashSet::new();
896        aux.insert(ts("MY"));
897        let limit = GenerationLimit::Iterations(20);
898        let mut rng = make_rng(42);
899        let reply = generate_reply(&model, &[], &kws, &aux, &limit, &mut rng);
900        assert!(!reply.is_empty());
901    }
902
903    #[test]
904    fn generate_reply_dissimilarity_test() {
905        // If the candidate is identical to input, it should be rejected in favor
906        // of a different candidate (when possible).
907        let model = trained_model(
908            2,
909            &[
910                &["A", " ", "B", " ", "C"],
911                &["D", " ", "E", " ", "F"],
912                &["A", " ", "B", " ", "C"],
913            ],
914        );
915        let input = vec![ts("A"), ts(" "), ts("B"), ts(" "), ts("C")];
916        let kws: Vec<TSym> = vec![];
917        let aux = HashSet::new();
918        let limit = GenerationLimit::Iterations(50);
919        let mut rng = make_rng(42);
920        let reply = generate_reply(&model, &input, &kws, &aux, &limit, &mut rng);
921        // Verify that candidates identical to the input are filtered.
922        let _ = reply;
923    }
924}