1use bincode::{Decode, Encode};
4use regex_automata::dfa::dense::DFA;
5use regex_automata::dfa::Automaton;
6use regex_automata::util::primitives::StateID as AutomataStateId;
7use regex_automata::Anchored;
8use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
9
10use crate::prelude::*;
11use crate::vocabulary::Vocabulary;
12use crate::{Error, Result};
13
14#[derive(Clone, Debug, PartialEq, Encode, Decode)]
16pub struct Index {
17 initial_state: StateId,
19 final_states: HashSet<StateId>,
21 transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
56 eos_token_id: TokenId,
58}
59impl Index {
100 pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
102 let eos_token_id = vocabulary.eos_token_id();
103 let dfa = DFA::new(regex).map_err(Box::new)?;
104 let start_state = match dfa.universal_start_state(Anchored::Yes) {
105 Some(s) => s,
106 None => return Err(Error::DfaHasNoStartState),
107 };
108
109 let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
110 let mut final_states: HashSet<StateId> = HashSet::default();
111
112 let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
113 let mut next_states: Vec<AutomataStateId> = vec![start_state];
114
115 while let Some(current_state) = next_states.pop() {
116 if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
117 final_states.insert(current_state.as_u32());
118 }
119
120 'token_loop: for (token, ids) in vocabulary.tokens().iter() {
121 if ids.contains(&eos_token_id) {
122 continue;
123 }
124
125 let mut next_state = current_state;
126 for transition_byte in token {
127 next_state = dfa.next_state(next_state, *transition_byte);
128 if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
129 continue 'token_loop;
130 }
131 }
132
133 let is_intermediate_state = !dfa.is_match_state(next_state);
134 let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
135 if is_intermediate_state || is_full_match_state {
136 for token_id in ids {
137 transitions
138 .entry(current_state.as_u32())
139 .or_default()
140 .insert(*token_id, next_state.as_u32());
141 }
142 }
143 if !seen.contains(&next_state) {
144 seen.insert(next_state);
145 next_states.push(next_state);
146 }
147 }
148 }
149
150 for &final_state in &final_states {
152 transitions
153 .entry(final_state)
154 .or_default()
155 .insert(eos_token_id, final_state);
156 }
157
158 Ok(Self {
159 initial_state: start_state.as_u32(),
160 final_states,
161 transitions,
162 eos_token_id,
163 })
164 }
165
166 pub fn initial_state(&self) -> StateId {
168 self.initial_state
169 }
170
171 pub fn final_states(&self) -> &HashSet<StateId> {
173 &self.final_states
174 }
175
176 pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
178 &self.transitions
179 }
180
181 pub fn is_final_state(&self, state: &StateId) -> bool {
183 self.final_states.contains(state)
184 }
185
186 pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
188 self.transitions
189 .get(state)
190 .map(|res| res.keys().cloned().collect())
191 }
192
193 pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
195 if token_id == &self.eos_token_id {
196 return None;
197 }
198 Some(*self.transitions.get(state)?.get(token_id)?)
199 }
200}
201
202impl std::fmt::Display for Index {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 writeln!(f, "Index object with transitions:")?;
205 for (state_id, token_ids) in self.transitions.iter() {
206 writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
207 }
208 Ok(())
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn index_from_regex() {
218 let regex = "0|[1-9][0-9]*";
219 let eos_token_id = 4;
220 let mut vocabulary = Vocabulary::new(eos_token_id);
221 for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
222 vocabulary
223 .try_insert(token, token_id as u32)
224 .expect("Insert failed");
225 }
226 let index = Index::new(regex, &vocabulary).expect("Index failed");
227 let initial_state = index.initial_state();
228 assert_eq!(initial_state, 40);
229 assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
230 assert!(!index.is_final_state(&initial_state));
231
232 let expected = HashMap::from_iter([
233 (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
234 (48, HashMap::from_iter([(4, 48)])),
235 (40, HashMap::from_iter([(3, 48), (2, 56)])),
236 (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
237 ]);
238 assert_eq!(index.transitions(), &expected);
239
240 let allowed_tokens = index
241 .allowed_tokens(&initial_state)
242 .expect("No allowed tokens");
243 let token_id = allowed_tokens.first().expect("No first tokens");
244
245 let state = 48;
246 assert_eq!(index.next_state(&initial_state, token_id), Some(state));
247 assert!(index.is_final_state(&state));
248
249 assert_eq!(index.next_state(&state, &eos_token_id), None);
250 assert_eq!(index.next_state(&state, token_id), None);
251 }
252
253 #[test]
254 fn index_from_regex_initital_in_allowed() {
255 let regex = "`\\n(\\.\\n)?`\\n";
256 let mut vocabulary = Vocabulary::new(104);
257 for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
258 vocabulary
259 .try_insert(token, token_id as u32)
260 .expect("Insert failed");
261 }
262
263 let index = Index::new(regex, &vocabulary).expect("Index failed");
264 let allowed = index
265 .allowed_tokens(&index.initial_state())
266 .expect("No allowed tokens");
267 assert!(allowed.contains(&101));
268 }
269
270 #[test]
271 fn index_from_regex_multibyte() {
272 let regex = "😇| [😈-😍][😇-😎]*";
273 let mut vocabulary = Vocabulary::new(8);
274 for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)]
275 {
276 vocabulary
277 .try_insert(token, token_id as u32)
278 .expect("Insert failed");
279 }
280 for (token, token_id) in [
281 (vec![32, 240, 159, 152], 7),
282 (vec![32, 240, 159, 152, 141], 6),
283 (vec![240, 159, 152, 141], 4),
284 ] {
285 vocabulary
286 .try_insert(token, token_id as u32)
287 .expect("Insert failed");
288 }
289
290 let index = Index::new(regex, &vocabulary).expect("Index failed");
291 assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));
292
293 let expected = HashMap::from_iter([
294 (
295 208,
296 HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
297 ),
298 (
299 80,
300 HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]),
301 ),
302 (128, HashMap::from_iter([(8, 128)])),
303 ]);
304 assert_eq!(index.transitions(), &expected);
305 }
306}