use bincode::{Decode, Encode};
use regex_automata::dfa::dense::DFA;
use regex_automata::dfa::Automaton;
use regex_automata::util::primitives::StateID as AutomataStateId;
use regex_automata::Anchored;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use crate::prelude::*;
use crate::vocabulary::Vocabulary;
use crate::{Error, Result};
#[derive(Clone, Debug, PartialEq, Encode, Decode)]
pub struct Index {
initial_state: StateId,
final_states: HashSet<StateId>,
transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
eos_token_id: TokenId,
vocab_size: usize,
}
impl Index {
pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
let vocab_size = vocabulary.len();
let eos_token_id = vocabulary.eos_token_id();
let dfa = DFA::new(regex).map_err(Box::new)?;
let start_state = match dfa.universal_start_state(Anchored::Yes) {
Some(s) => s,
None => return Err(Error::DfaHasNoStartState),
};
let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
let mut final_states: HashSet<StateId> = HashSet::default();
let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
let mut next_states: Vec<AutomataStateId> = vec![start_state];
while let Some(current_state) = next_states.pop() {
let mut has_valid_transitions = false;
if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
final_states.insert(current_state.as_u32());
has_valid_transitions = true;
}
'token_loop: for (token, ids) in vocabulary.tokens().iter() {
if ids.contains(&eos_token_id) {
continue;
}
let mut next_state = current_state;
for transition_byte in token {
next_state = dfa.next_state(next_state, *transition_byte);
if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
continue 'token_loop;
}
}
let is_intermediate_state = !dfa.is_match_state(next_state);
let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
if is_intermediate_state || is_full_match_state {
has_valid_transitions = true;
for token_id in ids {
transitions
.entry(current_state.as_u32())
.or_default()
.insert(*token_id, next_state.as_u32());
}
}
if !seen.contains(&next_state) {
seen.insert(next_state);
next_states.push(next_state);
}
}
if !has_valid_transitions && !dfa.is_match_state(current_state) {
let mut valid_characters = Vec::new();
for byte in 0..=255u8 {
let test_state = dfa.next_state(current_state, byte);
if !dfa.is_dead_state(test_state) && !dfa.is_quit_state(test_state) {
if byte.is_ascii() {
valid_characters.push(char::from(byte).to_string());
} else {
valid_characters.push(format!("\\x{:02x}", byte));
}
}
}
return Err(Error::IncompatibleVocabulary {
regex: regex.to_string(),
error_state: current_state.as_u32(),
missing_tokens: valid_characters,
});
}
}
for &final_state in &final_states {
transitions
.entry(final_state)
.or_default()
.insert(eos_token_id, final_state);
}
Ok(Self {
initial_state: start_state.as_u32(),
final_states,
transitions,
eos_token_id,
vocab_size,
})
}
pub fn initial_state(&self) -> StateId {
self.initial_state
}
pub fn final_states(&self) -> &HashSet<StateId> {
&self.final_states
}
pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
&self.transitions
}
pub fn is_final_state(&self, state: &StateId) -> bool {
self.final_states.contains(state)
}
pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
self.transitions
.get(state)
.map(|res| res.keys().cloned().collect())
}
pub fn allowed_tokens_iter(&self, state: &StateId) -> Option<impl Iterator<Item = &TokenId>> {
self.transitions.get(state).map(|map| map.keys())
}
pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
if token_id == &self.eos_token_id {
return None;
}
Some(*self.transitions.get(state)?.get(token_id)?)
}
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
}
impl std::fmt::Display for Index {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Index object with transitions:")?;
for (state_id, token_ids) in self.transitions.iter() {
writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn index_from_regex() {
let regex = "0|[1-9][0-9]*";
let eos_token_id = 4;
let mut vocabulary = Vocabulary::new(eos_token_id);
for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
let index = Index::new(regex, &vocabulary).expect("Index failed");
let initial_state = index.initial_state();
assert_eq!(initial_state, 40);
assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
assert!(!index.is_final_state(&initial_state));
let expected = HashMap::from_iter([
(24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
(48, HashMap::from_iter([(4, 48)])),
(40, HashMap::from_iter([(3, 48), (2, 56)])),
(56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
]);
assert_eq!(index.transitions(), &expected);
let allowed_tokens = index
.allowed_tokens(&initial_state)
.expect("No allowed tokens");
let token_id = allowed_tokens.first().expect("No first tokens");
let state = 48;
assert_eq!(index.next_state(&initial_state, token_id), Some(state));
assert!(index.is_final_state(&state));
assert_eq!(index.next_state(&state, &eos_token_id), None);
assert_eq!(index.next_state(&state, token_id), None);
}
#[test]
fn index_from_regex_initital_in_allowed() {
let regex = "`\\n(\\.\\n)?`\\n";
let mut vocabulary = Vocabulary::new(104);
for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
let index = Index::new(regex, &vocabulary).expect("Index failed");
let allowed = index
.allowed_tokens(&index.initial_state())
.expect("No allowed tokens");
assert!(allowed.contains(&101));
}
#[test]
fn index_from_regex_multibyte() {
let regex = "😇| [😈-😍][😇-😎]*";
let mut vocabulary = Vocabulary::new(8);
for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)]
{
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
for (token, token_id) in [
(vec![32, 240, 159, 152, 136], 7),
(vec![32, 240, 159, 152, 141], 6),
(vec![240, 159, 152, 141], 4),
] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
let index = Index::new(regex, &vocabulary).expect("Index failed");
assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));
let expected = HashMap::from_iter([
(
208,
HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
),
(
80,
HashMap::from_iter([(2, 128), (7, 208), (5, 208), (6, 208)]),
),
(128, HashMap::from_iter([(8, 128)])),
]);
assert_eq!(index.transitions(), &expected);
}
#[test]
fn index_incompatible_vocabulary_error() {
let regex = "0 1";
let mut vocabulary = Vocabulary::new(3);
for (token, token_id) in [("0", 0), ("0 ", 1), ("1", 2)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
let result = Index::new(regex, &vocabulary);
assert!(result.is_err());
if let Err(Error::IncompatibleVocabulary {
regex: _,
missing_tokens,
..
}) = result
{
assert!(missing_tokens.contains(&" ".to_string()));
} else {
panic!("Expected IncompatibleVocabulary error");
}
}
#[test]
fn index_incompatible_vocabulary_error_non_ascii() {
let regex = "😈😍";
let mut vocabulary = Vocabulary::new(3);
for (token, token_id) in [("😈", 0), (" ", 1), ("b", 2)] {
vocabulary
.try_insert(token, token_id as u32)
.expect("Insert failed");
}
let result = Index::new(regex, &vocabulary);
assert!(result.is_err());
if let Err(Error::IncompatibleVocabulary {
regex: _,
missing_tokens,
..
}) = result
{
assert!(missing_tokens.contains(&"\\xf0".to_string()));
} else {
panic!("Expected IncompatibleVocabulary error");
}
}
}