outlines_core/
index.rs

1//! Building an `Index` to efficiently map vocabulary tokens to state transitions.
2
3use bincode::{Decode, Encode};
4use regex_automata::dfa::dense::DFA;
5use regex_automata::dfa::Automaton;
6use regex_automata::util::primitives::StateID as AutomataStateId;
7use regex_automata::Anchored;
8use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
9
10use crate::prelude::*;
11use crate::vocabulary::Vocabulary;
12use crate::{Error, Result};
13
14/// `Index` efficiently maps vocabulary tokens to state transitions.
15#[derive(Clone, Debug, PartialEq, Encode, Decode)]
16pub struct Index {
17    /// The ID of the initial state in the automaton, processing begins from this state.
18    initial_state: StateId,
19    /// A collection of states considered as terminal states.
20    final_states: HashSet<StateId>,
21    /// A mapping of state transitions, defined by tokens ids and their corresponding state changes.
22    ///
23    /// ### Example
24    /// ```ignore
25    /// transitions = {
26    ///    1: {10: 2, 15: 3},
27    ///    2: {20: 4, 25: 3},
28    ///    3: {30: 4},
29    ///    4: {40: 4},
30    /// }
31    ///  +--------------------------------------+
32    ///  |               State 1                |
33    ///  |            Initial State             |
34    ///  +--------------------------------------+
35    ///              |                     |
36    ///              +                     |
37    ///         Token ID 10                |
38    ///  +-----------------------+         |
39    ///  |        State 2        |         |
40    ///  +-----------------------+         |
41    ///       |             |              |
42    ///       |             +              +
43    ///       |        Token ID 25    Token ID 15
44    ///       |        +------------------------+
45    ///       |        |        State 3         |
46    ///       |        +------------------------+
47    ///       |                            |
48    ///       +                            +
49    ///  Token ID 20                  Token ID 30
50    ///  +--------------------------------------+
51    ///  |               State 4                |
52    ///  |             Final state              |
53    ///  +--------------------------------------+
54    /// ```
55    transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
56    /// The token ID reserved for the "end-of-sequence" token.
57    eos_token_id: TokenId,
58    /// The size of the vocabulary used to build the index.
59    vocab_size: usize,
60}
61/// The `Index` structure is designed to efficiently map tokens from a given vocabulary
62/// to state transitions within a finite-state automaton.
63///
64/// ## Usage:
65/// The `Index` is typically constructed by combining a vocabulary and regular expressions.
66/// Once built, it can be used to efficiently evaluate token sequences or to validate input data.
67///
68/// ## Example:
69/// ```rust
70/// use outlines_core::prelude::*;
71///
72/// # fn run() -> Result<(), outlines_core::Error> {
73/// let regex = "0|[1-9][0-9]*";
74/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None)?;
75/// let index = Index::new(regex, &vocabulary)?;
76///
77/// let initial_state = index.initial_state();
78/// println!("Initial state is {}", initial_state);
79/// println!("Is initial state a final state? {}", index.is_final_state(&initial_state));
80///
81/// let allowed_tokens = index.allowed_tokens(&initial_state).expect("Some allowed tokens");
82/// println!("Allowed tokens at initial state are {:?}", allowed_tokens);
83///
84/// let token_id = allowed_tokens.first().expect("First token");
85/// println!("Next state for the token_id {} is {:?}", token_id, index.next_state(&initial_state, token_id));
86///
87/// println!("Final states are {:?}", index.final_states());
88/// println!("Index has exactly {} transitions", index.transitions().len());
89/// # Ok(())
90/// # }
91///
92/// ```
93///
94/// ## Performance:
95/// - **Complexity**:
96///   The `Index` can accommodate large vocabularies and complex regular expressions.
97///   However, its size may grow significantly with the complexity of the input.
98/// - **Construction Cost**:
99///   Building the `Index` involves processing the vocabulary and regular expressions,
100///   which may require a considerable amount of time and computational resources.
101impl Index {
102    /// Builds an `Index` from regular expression and vocabulary tokens.
103    pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
104        let vocab_size = vocabulary.len();
105        let eos_token_id = vocabulary.eos_token_id();
106        let dfa = DFA::new(regex).map_err(Box::new)?;
107        let start_state = match dfa.universal_start_state(Anchored::Yes) {
108            Some(s) => s,
109            None => return Err(Error::DfaHasNoStartState),
110        };
111
112        let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
113        let mut final_states: HashSet<StateId> = HashSet::default();
114
115        let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
116        let mut next_states: Vec<AutomataStateId> = vec![start_state];
117
118        while let Some(current_state) = next_states.pop() {
119            let mut has_valid_transitions = false;
120
121            if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
122                final_states.insert(current_state.as_u32());
123                has_valid_transitions = true;
124            }
125
126            'token_loop: for (token, ids) in vocabulary.tokens().iter() {
127                if ids.contains(&eos_token_id) {
128                    continue;
129                }
130
131                let mut next_state = current_state;
132                for transition_byte in token {
133                    next_state = dfa.next_state(next_state, *transition_byte);
134                    if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
135                        continue 'token_loop;
136                    }
137                }
138
139                let is_intermediate_state = !dfa.is_match_state(next_state);
140                let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
141                if is_intermediate_state || is_full_match_state {
142                    has_valid_transitions = true;
143                    for token_id in ids {
144                        transitions
145                            .entry(current_state.as_u32())
146                            .or_default()
147                            .insert(*token_id, next_state.as_u32());
148                    }
149                }
150                if !seen.contains(&next_state) {
151                    seen.insert(next_state);
152                    next_states.push(next_state);
153                }
154            }
155
156            // If the current state has no valid transitions and is not a match state,
157            // it means the vocabulary is incompatible with the regex.
158            if !has_valid_transitions && !dfa.is_match_state(current_state) {
159                let mut valid_characters = Vec::new();
160                for byte in 0..=255u8 {
161                    let test_state = dfa.next_state(current_state, byte);
162                    if !dfa.is_dead_state(test_state) && !dfa.is_quit_state(test_state) {
163                        if byte.is_ascii() {
164                            valid_characters.push(char::from(byte).to_string());
165                        } else {
166                            valid_characters.push(format!("\\x{:02x}", byte));
167                        }
168                    }
169                }
170
171                return Err(Error::IncompatibleVocabulary {
172                    regex: regex.to_string(),
173                    error_state: current_state.as_u32(),
174                    missing_tokens: valid_characters,
175                });
176            }
177        }
178
179        // Populate `transitions` with mappings from `final_states` to `eos_token_id`
180        for &final_state in &final_states {
181            transitions
182                .entry(final_state)
183                .or_default()
184                .insert(eos_token_id, final_state);
185        }
186
187        Ok(Self {
188            initial_state: start_state.as_u32(),
189            final_states,
190            transitions,
191            eos_token_id,
192            vocab_size,
193        })
194    }
195
196    /// Returns the ID of the initial state in the automaton.
197    pub fn initial_state(&self) -> StateId {
198        self.initial_state
199    }
200
201    /// Returns set of final states.
202    pub fn final_states(&self) -> &HashSet<StateId> {
203        &self.final_states
204    }
205
206    /// Returns state transitions map of tokens ids and their corresponding transition states.
207    pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
208        &self.transitions
209    }
210
211    /// Checks if state is in final states set or not.
212    pub fn is_final_state(&self, state: &StateId) -> bool {
213        self.final_states.contains(state)
214    }
215
216    /// Lists allowed tokens for a give state ID or `None` if it is not found in `Index`.
217    pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
218        self.transitions
219            .get(state)
220            .map(|res| res.keys().cloned().collect())
221    }
222
223    pub fn allowed_tokens_iter(&self, state: &StateId) -> Option<impl Iterator<Item = &TokenId>> {
224        self.transitions.get(state).map(|map| map.keys())
225    }
226
227    /// Returns transition state for a given state and token id or `None` otherwise.
228    pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
229        if token_id == &self.eos_token_id {
230            return None;
231        }
232        Some(*self.transitions.get(state)?.get(token_id)?)
233    }
234
235    pub fn vocab_size(&self) -> usize {
236        self.vocab_size
237    }
238}
239
240impl std::fmt::Display for Index {
241    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242        writeln!(f, "Index object with transitions:")?;
243        for (state_id, token_ids) in self.transitions.iter() {
244            writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
245        }
246        Ok(())
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    #[test]
255    fn index_from_regex() {
256        let regex = "0|[1-9][0-9]*";
257        let eos_token_id = 4;
258        let mut vocabulary = Vocabulary::new(eos_token_id);
259        for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
260            vocabulary
261                .try_insert(token, token_id as u32)
262                .expect("Insert failed");
263        }
264        let index = Index::new(regex, &vocabulary).expect("Index failed");
265        let initial_state = index.initial_state();
266        assert_eq!(initial_state, 40);
267        assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
268        assert!(!index.is_final_state(&initial_state));
269
270        let expected = HashMap::from_iter([
271            (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
272            (48, HashMap::from_iter([(4, 48)])),
273            (40, HashMap::from_iter([(3, 48), (2, 56)])),
274            (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
275        ]);
276        assert_eq!(index.transitions(), &expected);
277
278        let allowed_tokens = index
279            .allowed_tokens(&initial_state)
280            .expect("No allowed tokens");
281        let token_id = allowed_tokens.first().expect("No first tokens");
282
283        let state = 48;
284        assert_eq!(index.next_state(&initial_state, token_id), Some(state));
285        assert!(index.is_final_state(&state));
286
287        assert_eq!(index.next_state(&state, &eos_token_id), None);
288        assert_eq!(index.next_state(&state, token_id), None);
289    }
290
291    #[test]
292    fn index_from_regex_initital_in_allowed() {
293        let regex = "`\\n(\\.\\n)?`\\n";
294        let mut vocabulary = Vocabulary::new(104);
295        for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
296            vocabulary
297                .try_insert(token, token_id as u32)
298                .expect("Insert failed");
299        }
300
301        let index = Index::new(regex, &vocabulary).expect("Index failed");
302        let allowed = index
303            .allowed_tokens(&index.initial_state())
304            .expect("No allowed tokens");
305        assert!(allowed.contains(&101));
306    }
307
308    #[test]
309    fn index_from_regex_multibyte() {
310        let regex = "😇| [😈-😍][😇-😎]*";
311        let mut vocabulary = Vocabulary::new(8);
312        for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)]
313        {
314            vocabulary
315                .try_insert(token, token_id as u32)
316                .expect("Insert failed");
317        }
318        for (token, token_id) in [
319            (vec![32, 240, 159, 152, 136], 7),
320            (vec![32, 240, 159, 152, 141], 6),
321            (vec![240, 159, 152, 141], 4),
322        ] {
323            vocabulary
324                .try_insert(token, token_id as u32)
325                .expect("Insert failed");
326        }
327
328        let index = Index::new(regex, &vocabulary).expect("Index failed");
329        assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));
330
331        let expected = HashMap::from_iter([
332            (
333                208,
334                HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
335            ),
336            (
337                80,
338                HashMap::from_iter([(2, 128), (7, 208), (5, 208), (6, 208)]),
339            ),
340            (128, HashMap::from_iter([(8, 128)])),
341        ]);
342        assert_eq!(index.transitions(), &expected);
343    }
344
345    #[test]
346    fn index_incompatible_vocabulary_error() {
347        let regex = "0 1";
348        let mut vocabulary = Vocabulary::new(3);
349        for (token, token_id) in [("0", 0), ("0 ", 1), ("1", 2)] {
350            vocabulary
351                .try_insert(token, token_id as u32)
352                .expect("Insert failed");
353        }
354
355        let result = Index::new(regex, &vocabulary);
356        assert!(result.is_err());
357
358        if let Err(Error::IncompatibleVocabulary {
359            regex: _,
360            missing_tokens,
361            ..
362        }) = result
363        {
364            assert!(missing_tokens.contains(&" ".to_string()));
365        } else {
366            panic!("Expected IncompatibleVocabulary error");
367        }
368    }
369
370    #[test]
371    fn index_incompatible_vocabulary_error_non_ascii() {
372        let regex = "😈😍";
373        let mut vocabulary = Vocabulary::new(3);
374        for (token, token_id) in [("😈", 0), (" ", 1), ("b", 2)] {
375            vocabulary
376                .try_insert(token, token_id as u32)
377                .expect("Insert failed");
378        }
379
380        let result = Index::new(regex, &vocabulary);
381        assert!(result.is_err());
382
383        if let Err(Error::IncompatibleVocabulary {
384            regex: _,
385            missing_tokens,
386            ..
387        }) = result
388        {
389            assert!(missing_tokens.contains(&"\\xf0".to_string()));
390        } else {
391            panic!("Expected IncompatibleVocabulary error");
392        }
393    }
394}