1use bincode::{Decode, Encode};
4use regex_automata::dfa::dense::DFA;
5use regex_automata::dfa::Automaton;
6use regex_automata::util::primitives::StateID as AutomataStateId;
7use regex_automata::Anchored;
8use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
9
10use crate::prelude::*;
11use crate::vocabulary::Vocabulary;
12use crate::{Error, Result};
13
14#[derive(Clone, Debug, PartialEq, Encode, Decode)]
16pub struct Index {
17 initial_state: StateId,
19 final_states: HashSet<StateId>,
21 transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
56 eos_token_id: TokenId,
58 vocab_size: usize,
60}
61impl Index {
102 pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
104 let vocab_size = vocabulary.len();
105 let eos_token_id = vocabulary.eos_token_id();
106 let dfa = DFA::new(regex).map_err(Box::new)?;
107 let start_state = match dfa.universal_start_state(Anchored::Yes) {
108 Some(s) => s,
109 None => return Err(Error::DfaHasNoStartState),
110 };
111
112 let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
113 let mut final_states: HashSet<StateId> = HashSet::default();
114
115 let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
116 let mut next_states: Vec<AutomataStateId> = vec![start_state];
117
118 while let Some(current_state) = next_states.pop() {
119 let mut has_valid_transitions = false;
120
121 if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
122 final_states.insert(current_state.as_u32());
123 has_valid_transitions = true;
124 }
125
126 'token_loop: for (token, ids) in vocabulary.tokens().iter() {
127 if ids.contains(&eos_token_id) {
128 continue;
129 }
130
131 let mut next_state = current_state;
132 for transition_byte in token {
133 next_state = dfa.next_state(next_state, *transition_byte);
134 if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
135 continue 'token_loop;
136 }
137 }
138
139 let is_intermediate_state = !dfa.is_match_state(next_state);
140 let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
141 if is_intermediate_state || is_full_match_state {
142 has_valid_transitions = true;
143 for token_id in ids {
144 transitions
145 .entry(current_state.as_u32())
146 .or_default()
147 .insert(*token_id, next_state.as_u32());
148 }
149 }
150 if !seen.contains(&next_state) {
151 seen.insert(next_state);
152 next_states.push(next_state);
153 }
154 }
155
156 if !has_valid_transitions && !dfa.is_match_state(current_state) {
159 let mut valid_characters = Vec::new();
160 for byte in 0..=255u8 {
161 let test_state = dfa.next_state(current_state, byte);
162 if !dfa.is_dead_state(test_state) && !dfa.is_quit_state(test_state) {
163 if byte.is_ascii() {
164 valid_characters.push(char::from(byte).to_string());
165 } else {
166 valid_characters.push(format!("\\x{:02x}", byte));
167 }
168 }
169 }
170
171 return Err(Error::IncompatibleVocabulary {
172 regex: regex.to_string(),
173 error_state: current_state.as_u32(),
174 missing_tokens: valid_characters,
175 });
176 }
177 }
178
179 for &final_state in &final_states {
181 transitions
182 .entry(final_state)
183 .or_default()
184 .insert(eos_token_id, final_state);
185 }
186
187 Ok(Self {
188 initial_state: start_state.as_u32(),
189 final_states,
190 transitions,
191 eos_token_id,
192 vocab_size,
193 })
194 }
195
196 pub fn initial_state(&self) -> StateId {
198 self.initial_state
199 }
200
201 pub fn final_states(&self) -> &HashSet<StateId> {
203 &self.final_states
204 }
205
206 pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
208 &self.transitions
209 }
210
211 pub fn is_final_state(&self, state: &StateId) -> bool {
213 self.final_states.contains(state)
214 }
215
216 pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
218 self.transitions
219 .get(state)
220 .map(|res| res.keys().cloned().collect())
221 }
222
223 pub fn allowed_tokens_iter(&self, state: &StateId) -> Option<impl Iterator<Item = &TokenId>> {
224 self.transitions.get(state).map(|map| map.keys())
225 }
226
227 pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
229 if token_id == &self.eos_token_id {
230 return None;
231 }
232 Some(*self.transitions.get(state)?.get(token_id)?)
233 }
234
235 pub fn vocab_size(&self) -> usize {
236 self.vocab_size
237 }
238}
239
240impl std::fmt::Display for Index {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 writeln!(f, "Index object with transitions:")?;
243 for (state_id, token_ids) in self.transitions.iter() {
244 writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
245 }
246 Ok(())
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn index_from_regex() {
256 let regex = "0|[1-9][0-9]*";
257 let eos_token_id = 4;
258 let mut vocabulary = Vocabulary::new(eos_token_id);
259 for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
260 vocabulary
261 .try_insert(token, token_id as u32)
262 .expect("Insert failed");
263 }
264 let index = Index::new(regex, &vocabulary).expect("Index failed");
265 let initial_state = index.initial_state();
266 assert_eq!(initial_state, 40);
267 assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
268 assert!(!index.is_final_state(&initial_state));
269
270 let expected = HashMap::from_iter([
271 (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
272 (48, HashMap::from_iter([(4, 48)])),
273 (40, HashMap::from_iter([(3, 48), (2, 56)])),
274 (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
275 ]);
276 assert_eq!(index.transitions(), &expected);
277
278 let allowed_tokens = index
279 .allowed_tokens(&initial_state)
280 .expect("No allowed tokens");
281 let token_id = allowed_tokens.first().expect("No first tokens");
282
283 let state = 48;
284 assert_eq!(index.next_state(&initial_state, token_id), Some(state));
285 assert!(index.is_final_state(&state));
286
287 assert_eq!(index.next_state(&state, &eos_token_id), None);
288 assert_eq!(index.next_state(&state, token_id), None);
289 }
290
291 #[test]
292 fn index_from_regex_initital_in_allowed() {
293 let regex = "`\\n(\\.\\n)?`\\n";
294 let mut vocabulary = Vocabulary::new(104);
295 for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
296 vocabulary
297 .try_insert(token, token_id as u32)
298 .expect("Insert failed");
299 }
300
301 let index = Index::new(regex, &vocabulary).expect("Index failed");
302 let allowed = index
303 .allowed_tokens(&index.initial_state())
304 .expect("No allowed tokens");
305 assert!(allowed.contains(&101));
306 }
307
308 #[test]
309 fn index_from_regex_multibyte() {
310 let regex = "😇| [😈-😍][😇-😎]*";
311 let mut vocabulary = Vocabulary::new(8);
312 for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)]
313 {
314 vocabulary
315 .try_insert(token, token_id as u32)
316 .expect("Insert failed");
317 }
318 for (token, token_id) in [
319 (vec![32, 240, 159, 152, 136], 7),
320 (vec![32, 240, 159, 152, 141], 6),
321 (vec![240, 159, 152, 141], 4),
322 ] {
323 vocabulary
324 .try_insert(token, token_id as u32)
325 .expect("Insert failed");
326 }
327
328 let index = Index::new(regex, &vocabulary).expect("Index failed");
329 assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));
330
331 let expected = HashMap::from_iter([
332 (
333 208,
334 HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
335 ),
336 (
337 80,
338 HashMap::from_iter([(2, 128), (7, 208), (5, 208), (6, 208)]),
339 ),
340 (128, HashMap::from_iter([(8, 128)])),
341 ]);
342 assert_eq!(index.transitions(), &expected);
343 }
344
345 #[test]
346 fn index_incompatible_vocabulary_error() {
347 let regex = "0 1";
348 let mut vocabulary = Vocabulary::new(3);
349 for (token, token_id) in [("0", 0), ("0 ", 1), ("1", 2)] {
350 vocabulary
351 .try_insert(token, token_id as u32)
352 .expect("Insert failed");
353 }
354
355 let result = Index::new(regex, &vocabulary);
356 assert!(result.is_err());
357
358 if let Err(Error::IncompatibleVocabulary {
359 regex: _,
360 missing_tokens,
361 ..
362 }) = result
363 {
364 assert!(missing_tokens.contains(&" ".to_string()));
365 } else {
366 panic!("Expected IncompatibleVocabulary error");
367 }
368 }
369
370 #[test]
371 fn index_incompatible_vocabulary_error_non_ascii() {
372 let regex = "😈😍";
373 let mut vocabulary = Vocabulary::new(3);
374 for (token, token_id) in [("😈", 0), (" ", 1), ("b", 2)] {
375 vocabulary
376 .try_insert(token, token_id as u32)
377 .expect("Insert failed");
378 }
379
380 let result = Index::new(regex, &vocabulary);
381 assert!(result.is_err());
382
383 if let Err(Error::IncompatibleVocabulary {
384 regex: _,
385 missing_tokens,
386 ..
387 }) = result
388 {
389 assert!(missing_tokens.contains(&"\\xf0".to_string()));
390 } else {
391 panic!("Expected IncompatibleVocabulary error");
392 }
393 }
394}