outlines_core/
index.rs

1//! Building an `Index` to efficiently map vocabulary tokens to state transitions.
2
3use bincode::{Decode, Encode};
4use regex_automata::dfa::dense::DFA;
5use regex_automata::dfa::Automaton;
6use regex_automata::util::primitives::StateID as AutomataStateId;
7use regex_automata::Anchored;
8use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
9
10use crate::prelude::*;
11use crate::vocabulary::Vocabulary;
12use crate::{Error, Result};
13
14/// `Index` efficiently maps vocabulary tokens to state transitions.
15#[derive(Clone, Debug, PartialEq, Encode, Decode)]
16pub struct Index {
17    /// The ID of the initial state in the automaton, processing begins from this state.
18    initial_state: StateId,
19    /// A collection of states considered as terminal states.
20    final_states: HashSet<StateId>,
21    /// A mapping of state transitions, defined by tokens ids and their corresponding state changes.
22    ///
23    /// ### Example
24    /// ```ignore
25    /// transitions = {
26    ///    1: {10: 2, 15: 3},
27    ///    2: {20: 4, 25: 3},
28    ///    3: {30: 4},
29    ///    4: {40: 4},
30    /// }
31    ///  +--------------------------------------+
32    ///  |               State 1                |
33    ///  |            Initial State             |
34    ///  +--------------------------------------+
35    ///              |                     |
36    ///              +                     |
37    ///         Token ID 10                |
38    ///  +-----------------------+         |
39    ///  |        State 2        |         |
40    ///  +-----------------------+         |
41    ///       |             |              |
42    ///       |             +              +
43    ///       |        Token ID 25    Token ID 15
44    ///       |        +------------------------+
45    ///       |        |        State 3         |
46    ///       |        +------------------------+
47    ///       |                            |
48    ///       +                            +
49    ///  Token ID 20                  Token ID 30
50    ///  +--------------------------------------+
51    ///  |               State 4                |
52    ///  |             Final state              |
53    ///  +--------------------------------------+
54    /// ```
55    transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
56    /// The token ID reserved for the "end-of-sequence" token.
57    eos_token_id: TokenId,
58    /// The size of the vocabulary used to build the index.
59    vocab_size: usize,
60}
61/// The `Index` structure is designed to efficiently map tokens from a given vocabulary
62/// to state transitions within a finite-state automaton.
63///
64/// ## Usage:
65/// The `Index` is typically constructed by combining a vocabulary and regular expressions.
66/// Once built, it can be used to efficiently evaluate token sequences or to validate input data.
67///
68/// ## Example:
69/// ```rust
70/// use outlines_core::prelude::*;
71///
72/// # fn run() -> Result<(), outlines_core::Error> {
73/// let regex = "0|[1-9][0-9]*";
74/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None)?;
75/// let index = Index::new(regex, &vocabulary)?;
76///
77/// let initial_state = index.initial_state();
78/// println!("Initial state is {}", initial_state);
79/// println!("Is initial state a final state? {}", index.is_final_state(&initial_state));
80///
81/// let allowed_tokens = index.allowed_tokens(&initial_state).expect("Some allowed tokens");
82/// println!("Allowed tokens at initial state are {:?}", allowed_tokens);
83///
84/// let token_id = allowed_tokens.first().expect("First token");
85/// println!("Next state for the token_id {} is {:?}", token_id, index.next_state(&initial_state, token_id));
86///
87/// println!("Final states are {:?}", index.final_states());
88/// println!("Index has exactly {} transitions", index.transitions().len());
89/// # Ok(())
90/// # }
91///
92/// ```
93///
94/// ## Performance:
95/// - **Complexity**:
96///   The `Index` can accommodate large vocabularies and complex regular expressions.
97///   However, its size may grow significantly with the complexity of the input.
98/// - **Construction Cost**:
99///   Building the `Index` involves processing the vocabulary and regular expressions,
100///   which may require a considerable amount of time and computational resources.
101impl Index {
102    /// Builds an `Index` from regular expression and vocabulary tokens.
103    pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
104        let vocab_size = vocabulary.len();
105        let eos_token_id = vocabulary.eos_token_id();
106        let dfa = DFA::new(regex).map_err(Box::new)?;
107        let start_state = match dfa.universal_start_state(Anchored::Yes) {
108            Some(s) => s,
109            None => return Err(Error::DfaHasNoStartState),
110        };
111
112        let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
113        let mut final_states: HashSet<StateId> = HashSet::default();
114
115        let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
116        let mut next_states: Vec<AutomataStateId> = vec![start_state];
117
118        while let Some(current_state) = next_states.pop() {
119            if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
120                final_states.insert(current_state.as_u32());
121            }
122
123            'token_loop: for (token, ids) in vocabulary.tokens().iter() {
124                if ids.contains(&eos_token_id) {
125                    continue;
126                }
127
128                let mut next_state = current_state;
129                for transition_byte in token {
130                    next_state = dfa.next_state(next_state, *transition_byte);
131                    if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
132                        continue 'token_loop;
133                    }
134                }
135
136                let is_intermediate_state = !dfa.is_match_state(next_state);
137                let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
138                if is_intermediate_state || is_full_match_state {
139                    for token_id in ids {
140                        transitions
141                            .entry(current_state.as_u32())
142                            .or_default()
143                            .insert(*token_id, next_state.as_u32());
144                    }
145                }
146                if !seen.contains(&next_state) {
147                    seen.insert(next_state);
148                    next_states.push(next_state);
149                }
150            }
151        }
152
153        // Populate `transitions` with mappings from `final_states` to `eos_token_id`
154        for &final_state in &final_states {
155            transitions
156                .entry(final_state)
157                .or_default()
158                .insert(eos_token_id, final_state);
159        }
160
161        Ok(Self {
162            initial_state: start_state.as_u32(),
163            final_states,
164            transitions,
165            eos_token_id,
166            vocab_size,
167        })
168    }
169
170    /// Returns the ID of the initial state in the automaton.
171    pub fn initial_state(&self) -> StateId {
172        self.initial_state
173    }
174
175    /// Returns set of final states.
176    pub fn final_states(&self) -> &HashSet<StateId> {
177        &self.final_states
178    }
179
180    /// Returns state transitions map of tokens ids and their corresponding transition states.
181    pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
182        &self.transitions
183    }
184
185    /// Checks if state is in final states set or not.
186    pub fn is_final_state(&self, state: &StateId) -> bool {
187        self.final_states.contains(state)
188    }
189
190    /// Lists allowed tokens for a give state ID or `None` if it is not found in `Index`.
191    pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
192        self.transitions
193            .get(state)
194            .map(|res| res.keys().cloned().collect())
195    }
196
197    pub fn allowed_tokens_iter(&self, state: &StateId) -> Option<impl Iterator<Item = &TokenId>> {
198        self.transitions.get(state).map(|map| map.keys())
199    }
200
201    /// Returns transition state for a given state and token id or `None` otherwise.
202    pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
203        if token_id == &self.eos_token_id {
204            return None;
205        }
206        Some(*self.transitions.get(state)?.get(token_id)?)
207    }
208
209    pub fn vocab_size(&self) -> usize {
210        self.vocab_size
211    }
212}
213
214impl std::fmt::Display for Index {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        writeln!(f, "Index object with transitions:")?;
217        for (state_id, token_ids) in self.transitions.iter() {
218            writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
219        }
220        Ok(())
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn index_from_regex() {
230        let regex = "0|[1-9][0-9]*";
231        let eos_token_id = 4;
232        let mut vocabulary = Vocabulary::new(eos_token_id);
233        for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
234            vocabulary
235                .try_insert(token, token_id as u32)
236                .expect("Insert failed");
237        }
238        let index = Index::new(regex, &vocabulary).expect("Index failed");
239        let initial_state = index.initial_state();
240        assert_eq!(initial_state, 40);
241        assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
242        assert!(!index.is_final_state(&initial_state));
243
244        let expected = HashMap::from_iter([
245            (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
246            (48, HashMap::from_iter([(4, 48)])),
247            (40, HashMap::from_iter([(3, 48), (2, 56)])),
248            (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
249        ]);
250        assert_eq!(index.transitions(), &expected);
251
252        let allowed_tokens = index
253            .allowed_tokens(&initial_state)
254            .expect("No allowed tokens");
255        let token_id = allowed_tokens.first().expect("No first tokens");
256
257        let state = 48;
258        assert_eq!(index.next_state(&initial_state, token_id), Some(state));
259        assert!(index.is_final_state(&state));
260
261        assert_eq!(index.next_state(&state, &eos_token_id), None);
262        assert_eq!(index.next_state(&state, token_id), None);
263    }
264
265    #[test]
266    fn index_from_regex_initital_in_allowed() {
267        let regex = "`\\n(\\.\\n)?`\\n";
268        let mut vocabulary = Vocabulary::new(104);
269        for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
270            vocabulary
271                .try_insert(token, token_id as u32)
272                .expect("Insert failed");
273        }
274
275        let index = Index::new(regex, &vocabulary).expect("Index failed");
276        let allowed = index
277            .allowed_tokens(&index.initial_state())
278            .expect("No allowed tokens");
279        assert!(allowed.contains(&101));
280    }
281
282    #[test]
283    fn index_from_regex_multibyte() {
284        let regex = "😇| [😈-😍][😇-😎]*";
285        let mut vocabulary = Vocabulary::new(8);
286        for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)]
287        {
288            vocabulary
289                .try_insert(token, token_id as u32)
290                .expect("Insert failed");
291        }
292        for (token, token_id) in [
293            (vec![32, 240, 159, 152], 7),
294            (vec![32, 240, 159, 152, 141], 6),
295            (vec![240, 159, 152, 141], 4),
296        ] {
297            vocabulary
298                .try_insert(token, token_id as u32)
299                .expect("Insert failed");
300        }
301
302        let index = Index::new(regex, &vocabulary).expect("Index failed");
303        assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));
304
305        let expected = HashMap::from_iter([
306            (
307                208,
308                HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
309            ),
310            (
311                80,
312                HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]),
313            ),
314            (128, HashMap::from_iter([(8, 128)])),
315        ]);
316        assert_eq!(index.transitions(), &expected);
317    }
318}