mod common;
use common::*;
use toktrie::recognizer::{FunctionalRecognizer, StackRecognizer};
use toktrie::{Recognizer, TokenId};
#[test]
fn sample_stateless_functional_recognizer() {
let trie = build_test_trie();
let mut set = trie.alloc_token_set();
let mut rec = StackRecognizer::from(AlphaOnly);
trie.add_bias(&mut rec, &mut set, b"");
let allowed = allowed_set(&set);
let expected: Vec<TokenId> = (1..=22).chain(24..=25).collect();
assert_eq!(allowed, expected);
assert!(!set.is_allowed(23), "space should be rejected");
assert!(set.is_allowed(15), "\"cat\" should be allowed");
assert!(set.is_allowed(5), "\"apple\" should be allowed");
}
#[test]
fn sample_stateful_functional_recognizer() {
let trie = build_test_trie();
let mut set = trie.alloc_token_set();
let mut rec = StackRecognizer::from(CaPrefix);
trie.add_bias(&mut rec, &mut set, b"");
let allowed = allowed_set(&set);
assert_eq!(allowed, vec![13, 14, 15, 16]);
}
#[test]
fn sample_stack_recognizer_api() {
let mut rec = StackRecognizer::from(CaPrefix);
let inner: &CaPrefix = rec.recognizer();
assert_eq!(inner.initial(), 0);
let trie = build_test_trie();
let mut set = trie.alloc_token_set();
trie.add_bias(&mut rec, &mut set, b"");
assert_eq!(allowed_set(&set).len(), 4);
rec.reset();
let mut set2 = trie.alloc_token_set();
trie.add_bias(&mut rec, &mut set2, b"");
assert_eq!(allowed_set(&set2), allowed_set(&set));
}
#[test]
fn sample_anything_goes() {
let trie = build_test_trie();
let mut set1 = trie.alloc_token_set();
let mut goes = toktrie::AnythingGoes;
trie.add_bias(&mut goes, &mut set1, b"");
let all_tokens: Vec<TokenId> = (1..=25).collect();
assert_eq!(allowed_set(&set1), all_tokens);
let mut set2 = trie.alloc_token_set();
let mut rec = StackRecognizer::from(toktrie::recognizer::AnythingGoes {});
trie.add_bias(&mut rec, &mut set2, b"");
assert_eq!(allowed_set(&set2), all_tokens);
}
#[test]
fn sample_add_bias_with_start_prefix() {
let trie = build_test_trie();
let mut set = trie.alloc_token_set();
let mut goes = toktrie::AnythingGoes;
trie.add_bias(&mut goes, &mut set, b"app");
assert_eq!(allowed_set(&set), vec![1, 4, 5, 6, 7]);
let mut set2 = trie.alloc_token_set();
let mut rec = StackRecognizer::from(AlphaOnly);
trie.add_bias(&mut rec, &mut set2, b"ba");
assert_eq!(allowed_set(&set2), vec![8, 9, 10, 11, 12]);
}
#[test]
fn sample_has_valid_extensions() {
let trie = build_test_trie();
let mut rec = StackRecognizer::from(AlphaOnly);
assert!(trie.has_valid_extensions(&mut rec, b"app"));
let mut rec = StackRecognizer::from(AlphaOnly);
assert!(!trie.has_valid_extensions(&mut rec, b"apple"));
let mut goes = toktrie::AnythingGoes;
assert!(!trie.has_valid_extensions(&mut goes, b"xyz"));
let mut rec = StackRecognizer::from(CaPrefix);
assert!(!trie.has_valid_extensions(&mut rec, b"c"));
let mut rec = StackRecognizer::from(AlphaOnly);
assert!(trie.has_valid_extensions(&mut rec, b"ba"));
}
struct MaxLenRecognizer {
max_len: usize,
stack: Vec<usize>,
}
impl MaxLenRecognizer {
fn new(max_len: usize) -> Self {
MaxLenRecognizer {
max_len,
stack: vec![0], }
}
}
impl Recognizer for MaxLenRecognizer {
fn pop_bytes(&mut self, num: usize) {
self.stack.truncate(self.stack.len() - num);
}
fn collapse(&mut self) {
let top = *self.stack.last().unwrap();
self.stack.clear();
self.stack.push(top);
}
fn trie_finished(&mut self) {
self.stack.truncate(1);
}
fn try_push_byte(&mut self, _byte: u8) -> bool {
let depth = *self.stack.last().unwrap();
if depth < self.max_len {
self.stack.push(depth + 1);
true
} else {
false
}
}
fn get_error(&mut self) -> Option<String> {
let depth = *self.stack.last().unwrap();
if depth >= self.max_len {
Some(format!(
"MaxLenRecognizer: reached maximum length of {} bytes",
self.max_len
))
} else {
None
}
}
}
#[test]
fn sample_direct_recognizer_impl() {
let trie = build_test_trie();
let mut set = trie.alloc_token_set();
let mut rec = MaxLenRecognizer::new(2);
trie.add_bias(&mut rec, &mut set, b"");
let allowed = allowed_set(&set);
let expected: Vec<TokenId> = vec![1, 2, 8, 9, 13, 14, 18, 19, 22, 23, 24];
assert_eq!(allowed, expected);
}
#[test]
fn sample_combining_token_sets() {
let trie = build_test_trie();
let mut set_alpha = trie.alloc_token_set();
let mut rec_alpha = StackRecognizer::from(AlphaOnly);
trie.add_bias(&mut rec_alpha, &mut set_alpha, b"");
let mut set_ca = trie.alloc_token_set();
let mut rec_ca = StackRecognizer::from(CaPrefix);
trie.add_bias(&mut rec_ca, &mut set_ca, b"");
let mut intersection = set_alpha.clone();
intersection.and(&set_ca);
assert_eq!(allowed_set(&intersection), vec![13, 14, 15, 16]);
let mut diff = set_alpha.clone();
diff.sub(&set_ca);
let diff_set = allowed_set(&diff);
assert!(!diff_set.contains(&13)); assert!(!diff_set.contains(&15)); assert!(diff_set.contains(&1)); assert!(diff_set.contains(&5)); }
#[test]
fn sample_get_error_alpha_only() {
let mut rec = StackRecognizer::from(AlphaOnly);
let err = rec.get_error();
assert_eq!(
err.as_deref(),
Some("AlphaOnly: expected lowercase ASCII letter (a-z)")
);
assert!(rec.try_push_byte(b'h'));
assert_eq!(
rec.get_error().as_deref(),
Some("AlphaOnly: expected lowercase ASCII letter (a-z)")
);
}
#[test]
fn sample_get_error_ca_prefix() {
let mut rec = StackRecognizer::from(CaPrefix);
assert_eq!(rec.get_error().as_deref(), Some("CaPrefix: expected 'c'"));
assert!(rec.try_push_byte(b'c'));
assert_eq!(
rec.get_error().as_deref(),
Some("CaPrefix: expected 'a' after 'c'")
);
assert!(rec.try_push_byte(b'a'));
assert_eq!(
rec.get_error().as_deref(),
Some("CaPrefix: expected lowercase letter after \"ca\"")
);
assert!(rec.try_push_byte(b't'));
assert_eq!(
rec.get_error().as_deref(),
Some("CaPrefix: pattern complete, no further bytes accepted")
);
assert!(!rec.try_push_byte(b'a'));
assert!(!rec.try_push_byte(b'z'));
}
#[test]
fn sample_get_error_max_len() {
let mut rec = MaxLenRecognizer::new(2);
assert_eq!(rec.get_error(), None);
assert!(rec.try_push_byte(b'x'));
assert_eq!(rec.get_error(), None);
assert!(rec.try_push_byte(b'y'));
assert_eq!(
rec.get_error().as_deref(),
Some("MaxLenRecognizer: reached maximum length of 2 bytes")
);
assert!(!rec.try_push_byte(b'z'));
rec.pop_bytes(1);
assert_eq!(rec.get_error(), None);
}