1use bincode::{Decode, Encode};
4use regex_automata::dfa::dense::DFA;
5use regex_automata::dfa::Automaton;
6use regex_automata::util::primitives::StateID as AutomataStateId;
7use regex_automata::Anchored;
8use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
9
10use crate::prelude::*;
11use crate::vocabulary::Vocabulary;
12use crate::{Error, Result};
13
14#[derive(Clone, Debug, PartialEq, Encode, Decode)]
16pub struct Index {
17 initial_state: StateId,
19 final_states: HashSet<StateId>,
21 transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
56 eos_token_id: TokenId,
58 vocab_size: usize,
60}
61impl Index {
102 pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
104 let vocab_size = vocabulary.len();
105 let eos_token_id = vocabulary.eos_token_id();
106 let dfa = DFA::new(regex).map_err(Box::new)?;
107 let start_state = match dfa.universal_start_state(Anchored::Yes) {
108 Some(s) => s,
109 None => return Err(Error::DfaHasNoStartState),
110 };
111
112 let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
113 let mut final_states: HashSet<StateId> = HashSet::default();
114
115 let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
116 let mut next_states: Vec<AutomataStateId> = vec![start_state];
117
118 while let Some(current_state) = next_states.pop() {
119 if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
120 final_states.insert(current_state.as_u32());
121 }
122
123 'token_loop: for (token, ids) in vocabulary.tokens().iter() {
124 if ids.contains(&eos_token_id) {
125 continue;
126 }
127
128 let mut next_state = current_state;
129 for transition_byte in token {
130 next_state = dfa.next_state(next_state, *transition_byte);
131 if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
132 continue 'token_loop;
133 }
134 }
135
136 let is_intermediate_state = !dfa.is_match_state(next_state);
137 let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
138 if is_intermediate_state || is_full_match_state {
139 for token_id in ids {
140 transitions
141 .entry(current_state.as_u32())
142 .or_default()
143 .insert(*token_id, next_state.as_u32());
144 }
145 }
146 if !seen.contains(&next_state) {
147 seen.insert(next_state);
148 next_states.push(next_state);
149 }
150 }
151 }
152
153 for &final_state in &final_states {
155 transitions
156 .entry(final_state)
157 .or_default()
158 .insert(eos_token_id, final_state);
159 }
160
161 Ok(Self {
162 initial_state: start_state.as_u32(),
163 final_states,
164 transitions,
165 eos_token_id,
166 vocab_size,
167 })
168 }
169
170 pub fn initial_state(&self) -> StateId {
172 self.initial_state
173 }
174
175 pub fn final_states(&self) -> &HashSet<StateId> {
177 &self.final_states
178 }
179
180 pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
182 &self.transitions
183 }
184
185 pub fn is_final_state(&self, state: &StateId) -> bool {
187 self.final_states.contains(state)
188 }
189
190 pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
192 self.transitions
193 .get(state)
194 .map(|res| res.keys().cloned().collect())
195 }
196
197 pub fn allowed_tokens_iter(&self, state: &StateId) -> Option<impl Iterator<Item = &TokenId>> {
198 self.transitions.get(state).map(|map| map.keys())
199 }
200
201 pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
203 if token_id == &self.eos_token_id {
204 return None;
205 }
206 Some(*self.transitions.get(state)?.get(token_id)?)
207 }
208
209 pub fn vocab_size(&self) -> usize {
210 self.vocab_size
211 }
212}
213
214impl std::fmt::Display for Index {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 writeln!(f, "Index object with transitions:")?;
217 for (state_id, token_ids) in self.transitions.iter() {
218 writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
219 }
220 Ok(())
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227
228 #[test]
229 fn index_from_regex() {
230 let regex = "0|[1-9][0-9]*";
231 let eos_token_id = 4;
232 let mut vocabulary = Vocabulary::new(eos_token_id);
233 for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
234 vocabulary
235 .try_insert(token, token_id as u32)
236 .expect("Insert failed");
237 }
238 let index = Index::new(regex, &vocabulary).expect("Index failed");
239 let initial_state = index.initial_state();
240 assert_eq!(initial_state, 40);
241 assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
242 assert!(!index.is_final_state(&initial_state));
243
244 let expected = HashMap::from_iter([
245 (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
246 (48, HashMap::from_iter([(4, 48)])),
247 (40, HashMap::from_iter([(3, 48), (2, 56)])),
248 (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
249 ]);
250 assert_eq!(index.transitions(), &expected);
251
252 let allowed_tokens = index
253 .allowed_tokens(&initial_state)
254 .expect("No allowed tokens");
255 let token_id = allowed_tokens.first().expect("No first tokens");
256
257 let state = 48;
258 assert_eq!(index.next_state(&initial_state, token_id), Some(state));
259 assert!(index.is_final_state(&state));
260
261 assert_eq!(index.next_state(&state, &eos_token_id), None);
262 assert_eq!(index.next_state(&state, token_id), None);
263 }
264
265 #[test]
266 fn index_from_regex_initital_in_allowed() {
267 let regex = "`\\n(\\.\\n)?`\\n";
268 let mut vocabulary = Vocabulary::new(104);
269 for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
270 vocabulary
271 .try_insert(token, token_id as u32)
272 .expect("Insert failed");
273 }
274
275 let index = Index::new(regex, &vocabulary).expect("Index failed");
276 let allowed = index
277 .allowed_tokens(&index.initial_state())
278 .expect("No allowed tokens");
279 assert!(allowed.contains(&101));
280 }
281
282 #[test]
283 fn index_from_regex_multibyte() {
284 let regex = "😇| [😈-😍][😇-😎]*";
285 let mut vocabulary = Vocabulary::new(8);
286 for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)]
287 {
288 vocabulary
289 .try_insert(token, token_id as u32)
290 .expect("Insert failed");
291 }
292 for (token, token_id) in [
293 (vec![32, 240, 159, 152], 7),
294 (vec![32, 240, 159, 152, 141], 6),
295 (vec![240, 159, 152, 141], 4),
296 ] {
297 vocabulary
298 .try_insert(token, token_id as u32)
299 .expect("Insert failed");
300 }
301
302 let index = Index::new(regex, &vocabulary).expect("Index failed");
303 assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));
304
305 let expected = HashMap::from_iter([
306 (
307 208,
308 HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
309 ),
310 (
311 80,
312 HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]),
313 ),
314 (128, HashMap::from_iter([(8, 128)])),
315 ]);
316 assert_eq!(index.transitions(), &expected);
317 }
318}