outlines_core/
index.rs

1//! Building an `Index` to efficiently map vocabulary tokens to state transitions.
2
3use bincode::{Decode, Encode};
4use regex_automata::dfa::dense::DFA;
5use regex_automata::dfa::Automaton;
6use regex_automata::util::primitives::StateID as AutomataStateId;
7use regex_automata::Anchored;
8use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
9
10use crate::prelude::*;
11use crate::vocabulary::Vocabulary;
12use crate::{Error, Result};
13
14/// `Index` efficiently maps vocabulary tokens to state transitions.
15#[derive(Clone, Debug, PartialEq, Encode, Decode)]
16pub struct Index {
17    /// The ID of the initial state in the automaton, processing begins from this state.
18    initial_state: StateId,
19    /// A collection of states considered as terminal states.
20    final_states: HashSet<StateId>,
21    /// A mapping of state transitions, defined by tokens ids and their corresponding state changes.
22    ///
23    /// ### Example
24    /// ```ignore
25    /// transitions = {
26    ///    1: {10: 2, 15: 3},
27    ///    2: {20: 4, 25: 3},
28    ///    3: {30: 4},
29    ///    4: {40: 4},
30    /// }
31    ///  +--------------------------------------+
32    ///  |               State 1                |
33    ///  |            Initial State             |
34    ///  +--------------------------------------+
35    ///              |                     |
36    ///              +                     |
37    ///         Token ID 10                |
38    ///  +-----------------------+         |
39    ///  |        State 2        |         |
40    ///  +-----------------------+         |
41    ///       |             |              |
42    ///       |             +              +
43    ///       |        Token ID 25    Token ID 15
44    ///       |        +------------------------+
45    ///       |        |        State 3         |
46    ///       |        +------------------------+
47    ///       |                            |
48    ///       +                            +
49    ///  Token ID 20                  Token ID 30
50    ///  +--------------------------------------+
51    ///  |               State 4                |
52    ///  |             Final state              |
53    ///  +--------------------------------------+
54    /// ```
55    transitions: HashMap<StateId, HashMap<TokenId, StateId>>,
56    /// The token ID reserved for the "end-of-sequence" token.
57    eos_token_id: TokenId,
58}
59/// The `Index` structure is designed to efficiently map tokens from a given vocabulary
60/// to state transitions within a finite-state automaton.
61///
62/// ## Usage:
63/// The `Index` is typically constructed by combining a vocabulary and regular expressions.
64/// Once built, it can be used to efficiently evaluate token sequences or to validate input data.
65///
66/// ## Example:
67/// ```rust
68/// use outlines_core::prelude::*;
69///
70/// # fn run() -> Result<(), outlines_core::Error> {
71/// let regex = "0|[1-9][0-9]*";
72/// let vocabulary = Vocabulary::from_pretrained("openai-community/gpt2", None)?;
73/// let index = Index::new(regex, &vocabulary)?;
74///
75/// let initial_state = index.initial_state();
76/// println!("Initial state is {}", initial_state);
77/// println!("Is initial state a final state? {}", index.is_final_state(&initial_state));
78///
79/// let allowed_tokens = index.allowed_tokens(&initial_state).expect("Some allowed tokens");
80/// println!("Allowed tokens at initial state are {:?}", allowed_tokens);
81///
82/// let token_id = allowed_tokens.first().expect("First token");
83/// println!("Next state for the token_id {} is {:?}", token_id, index.next_state(&initial_state, token_id));
84///
85/// println!("Final states are {:?}", index.final_states());
86/// println!("Index has exactly {} transitions", index.transitions().len());
87/// # Ok(())
88/// # }
89///
90/// ```
91///
92/// ## Performance:
93/// - **Complexity**:
94///   The `Index` can accommodate large vocabularies and complex regular expressions.
95///   However, its size may grow significantly with the complexity of the input.
96/// - **Construction Cost**:
97///   Building the `Index` involves processing the vocabulary and regular expressions,
98///   which may require a considerable amount of time and computational resources.
99impl Index {
100    /// Builds an `Index` from regular expression and vocabulary tokens.
101    pub fn new(regex: &str, vocabulary: &Vocabulary) -> Result<Self> {
102        let eos_token_id = vocabulary.eos_token_id();
103        let dfa = DFA::new(regex).map_err(Box::new)?;
104        let start_state = match dfa.universal_start_state(Anchored::Yes) {
105            Some(s) => s,
106            None => return Err(Error::DfaHasNoStartState),
107        };
108
109        let mut transitions: HashMap<StateId, HashMap<TokenId, StateId>> = HashMap::default();
110        let mut final_states: HashSet<StateId> = HashSet::default();
111
112        let mut seen: HashSet<AutomataStateId> = HashSet::from_iter([start_state]);
113        let mut next_states: Vec<AutomataStateId> = vec![start_state];
114
115        while let Some(current_state) = next_states.pop() {
116            if dfa.is_match_state(dfa.next_eoi_state(current_state)) {
117                final_states.insert(current_state.as_u32());
118            }
119
120            'token_loop: for (token, ids) in vocabulary.tokens().iter() {
121                if ids.contains(&eos_token_id) {
122                    continue;
123                }
124
125                let mut next_state = current_state;
126                for transition_byte in token {
127                    next_state = dfa.next_state(next_state, *transition_byte);
128                    if dfa.is_dead_state(next_state) || dfa.is_quit_state(next_state) {
129                        continue 'token_loop;
130                    }
131                }
132
133                let is_intermediate_state = !dfa.is_match_state(next_state);
134                let is_full_match_state = dfa.is_match_state(dfa.next_eoi_state(next_state));
135                if is_intermediate_state || is_full_match_state {
136                    for token_id in ids {
137                        transitions
138                            .entry(current_state.as_u32())
139                            .or_default()
140                            .insert(*token_id, next_state.as_u32());
141                    }
142                }
143                if !seen.contains(&next_state) {
144                    seen.insert(next_state);
145                    next_states.push(next_state);
146                }
147            }
148        }
149
150        // Populate `transitions` with mappings from `final_states` to `eos_token_id`
151        for &final_state in &final_states {
152            transitions
153                .entry(final_state)
154                .or_default()
155                .insert(eos_token_id, final_state);
156        }
157
158        Ok(Self {
159            initial_state: start_state.as_u32(),
160            final_states,
161            transitions,
162            eos_token_id,
163        })
164    }
165
166    /// Returns the ID of the initial state in the automaton.
167    pub fn initial_state(&self) -> StateId {
168        self.initial_state
169    }
170
171    /// Returns set of final states.
172    pub fn final_states(&self) -> &HashSet<StateId> {
173        &self.final_states
174    }
175
176    /// Returns state transitions map of tokens ids and their corresponding transition states.
177    pub fn transitions(&self) -> &HashMap<StateId, HashMap<TokenId, StateId>> {
178        &self.transitions
179    }
180
181    /// Checks if state is in final states set or not.
182    pub fn is_final_state(&self, state: &StateId) -> bool {
183        self.final_states.contains(state)
184    }
185
186    /// Lists allowed tokens for a give state ID or `None` if it is not found in `Index`.
187    pub fn allowed_tokens(&self, state: &StateId) -> Option<Vec<TokenId>> {
188        self.transitions
189            .get(state)
190            .map(|res| res.keys().cloned().collect())
191    }
192
193    /// Returns transition state for a given state and token id or `None` otherwise.
194    pub fn next_state(&self, state: &StateId, token_id: &TokenId) -> Option<StateId> {
195        if token_id == &self.eos_token_id {
196            return None;
197        }
198        Some(*self.transitions.get(state)?.get(token_id)?)
199    }
200}
201
202impl std::fmt::Display for Index {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        writeln!(f, "Index object with transitions:")?;
205        for (state_id, token_ids) in self.transitions.iter() {
206            writeln!(f, "{:?} -> {:#?}", state_id, token_ids)?;
207        }
208        Ok(())
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215
216    #[test]
217    fn index_from_regex() {
218        let regex = "0|[1-9][0-9]*";
219        let eos_token_id = 4;
220        let mut vocabulary = Vocabulary::new(eos_token_id);
221        for (token, token_id) in [("blah", 0), ("1a", 1), ("2", 2), ("0", 3)] {
222            vocabulary
223                .try_insert(token, token_id as u32)
224                .expect("Insert failed");
225        }
226        let index = Index::new(regex, &vocabulary).expect("Index failed");
227        let initial_state = index.initial_state();
228        assert_eq!(initial_state, 40);
229        assert_eq!(index.final_states(), &HashSet::from_iter([24, 48, 56]));
230        assert!(!index.is_final_state(&initial_state));
231
232        let expected = HashMap::from_iter([
233            (24, HashMap::from_iter([(3, 24), (4, 24), (2, 24)])),
234            (48, HashMap::from_iter([(4, 48)])),
235            (40, HashMap::from_iter([(3, 48), (2, 56)])),
236            (56, HashMap::from_iter([(3, 24), (4, 56), (2, 24)])),
237        ]);
238        assert_eq!(index.transitions(), &expected);
239
240        let allowed_tokens = index
241            .allowed_tokens(&initial_state)
242            .expect("No allowed tokens");
243        let token_id = allowed_tokens.first().expect("No first tokens");
244
245        let state = 48;
246        assert_eq!(index.next_state(&initial_state, token_id), Some(state));
247        assert!(index.is_final_state(&state));
248
249        assert_eq!(index.next_state(&state, &eos_token_id), None);
250        assert_eq!(index.next_state(&state, token_id), None);
251    }
252
253    #[test]
254    fn index_from_regex_initital_in_allowed() {
255        let regex = "`\\n(\\.\\n)?`\\n";
256        let mut vocabulary = Vocabulary::new(104);
257        for (token, token_id) in [("\n", 103), (".", 102), ("`", 101)] {
258            vocabulary
259                .try_insert(token, token_id as u32)
260                .expect("Insert failed");
261        }
262
263        let index = Index::new(regex, &vocabulary).expect("Index failed");
264        let allowed = index
265            .allowed_tokens(&index.initial_state())
266            .expect("No allowed tokens");
267        assert!(allowed.contains(&101));
268    }
269
270    #[test]
271    fn index_from_regex_multibyte() {
272        let regex = "😇| [😈-😍][😇-😎]*";
273        let mut vocabulary = Vocabulary::new(8);
274        for (token, token_id) in [(" 😍", 5), ("blah", 0), ("😇", 2), ("😈a", 1), ("😍", 3)]
275        {
276            vocabulary
277                .try_insert(token, token_id as u32)
278                .expect("Insert failed");
279        }
280        for (token, token_id) in [
281            (vec![32, 240, 159, 152], 7),
282            (vec![32, 240, 159, 152, 141], 6),
283            (vec![240, 159, 152, 141], 4),
284        ] {
285            vocabulary
286                .try_insert(token, token_id as u32)
287                .expect("Insert failed");
288        }
289
290        let index = Index::new(regex, &vocabulary).expect("Index failed");
291        assert_eq!(index.final_states(), &HashSet::from_iter([208, 128]));
292
293        let expected = HashMap::from_iter([
294            (
295                208,
296                HashMap::from_iter([(3, 208), (8, 208), (4, 208), (2, 208)]),
297            ),
298            (
299                80,
300                HashMap::from_iter([(2, 128), (7, 192), (5, 208), (6, 208)]),
301            ),
302            (128, HashMap::from_iter([(8, 128)])),
303        ]);
304        assert_eq!(index.transitions(), &expected);
305    }
306}