kbnf/
engine.rs

1//! The main module that contains the [`Engine`] struct and its related types.
2use std::sync::Arc;
3
4use kbnf_syntax::simplified_grammar::SimplifiedGrammar;
5#[cfg(feature = "python")]
6use pyo3::pyclass;
7use serde::{Deserialize, Serialize};
8#[cfg(feature = "wasm")]
9use wasm_bindgen::prelude::*;
10
11use crate::{
12    config::Config, engine_base::EngineBase, engine_like::EngineLike, grammar::Grammar, utils,
13    vocabulary::Vocabulary,
14};
15
16/// The specific config of the [`Engine`].
17#[cfg_attr(feature = "python", pyclass)]
18#[cfg_attr(feature = "python", pyo3(get_all, set_all))]
19#[cfg_attr(feature = "wasm", wasm_bindgen)]
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Copy)]
21pub struct EngineConfig {
22    /// Whether the cache is enabled. Caching speeds up the engine eventually if any of the following conditions are met:
23    /// 1. The grammar is "simple". What exactly constitutes a simple grammar is not well defined at the moment but
24    ///    all regular grammars should be simple.
25    /// 2. The grammar is reused multiple times for inputs of similar lengths.
26    ///    It is enabled by default.
27    pub cache_enabled: bool,
28    /// Whether the compaction is enabled. Compaction reduces the memory usage of the engine and
29    /// speeds up the engine in most cases. In particular, cache usually requires compaction to be effective.
30    /// It is enabled by default.
31    pub compaction_enabled: bool,
32    /// Whether rejected token prefixes are cached.
33    /// It is enabled by default.
34    pub rejected_token_prefix_cache_enabled: bool,
35}
36#[derive(Debug, Clone)]
37/// An enum that represents the common type combinations of [`EngineBase`].
38pub(crate) enum EngineUnion {
39    /// Typical simple grammar with complex dfa without any repetition
40    U8U8U8U8U32(EngineBase<u8, u8, u8, u8, u32>),
41    /// Typical simple grammar with simple dfa without any repetition
42    U8U8U16U16U16(EngineBase<u8, u8, u16, u16, u16>),
43    /// Complex grammar with complex dfa without any repetition
44    U16U16U32U32U32(EngineBase<u16, u16, u32, u32, u32>),
45}
46#[cfg_attr(feature = "python", pyclass(subclass))]
47#[cfg_attr(feature = "python", pyo3(name = "InternalEngine"))]
48#[cfg_attr(feature = "wasm", wasm_bindgen)]
49#[derive(Debug, Clone)]
50/// The main struct that wraps the [`EngineBase`] so the user do not have to specify the generic type every time for common cases.
51pub struct Engine {
52    union: EngineUnion,
53}
54#[derive(Debug, thiserror::Error)]
55/// Represents the error type for the [`Engine`] creation.
56pub enum CreateEngineError {
57    #[error("{0}")] // inherits the error message from the wrapped EngineBaseError
58    /// A wrapper for the [`CreateEngineBaseError`](crate::engine_base::CreateEngineBaseError) error type.
59    EngineBaseError(#[from] crate::engine_base::CreateEngineBaseError),
60    #[error("{0}")] // inherits the error message from the wrapped GrammarError
61    /// A wrapper for the [`CreateGrammarError`](crate::grammar::CreateGrammarError) error type.
62    GrammarError(#[from] crate::grammar::CreateGrammarError),
63    #[error("The grammar after simplification is empty.
64    This usually means that the grammar only contains empty terminals and/or self recursions like A::=A;")]
65    /// The grammar is empty.
66    EmptyGrammarError,
67    #[error("The grammar and/or config's value range is not supported by the Engine.\n
68    This usually means that the grammar has more than 65536 nonterminals,
69    at least one nonterminal has more than 65536 alternations or repetitions, and/or the expected output length is more than 2^32.")]
70    /// The grammar and/or config's value range is not supported by the Engine.
71    InvalidInputError,
72}
73
74impl Engine {
75    /// Create a new [`Engine`] from an KBNF grammar string and a [`Vocabulary`].
76    ///
77    /// # Arguments
78    ///
79    /// * `kbnf_syntax_grammar_str` - The KBNF grammar string.
80    ///
81    /// * `vocabulary` - The [`Vocabulary`] object.
82    ///
83    /// # Returns
84    ///
85    /// * [`Engine`] - The new [`Engine`] object.
86    ///
87    /// # Errors
88    ///
89    /// Returns an [`CreateEngineError`] when the grammar is empty or the grammar and/or config's value range is not supported by the Engine.
90    pub fn new(
91        kbnf_syntax_grammar_str: &str,
92        vocabulary: Vocabulary,
93    ) -> Result<Engine, CreateEngineError> {
94        let config = Config::default();
95        Self::with_config(kbnf_syntax_grammar_str, vocabulary, config)
96    }
97
98    fn check_id_length(grammar: &SimplifiedGrammar, value: usize) -> bool {
99        grammar.interned_strings.terminals.len() <= value
100            && grammar.interned_strings.nonterminals.len() <= value
101    }
102    /// Create a new [`Engine`] from an KBNF grammar string, a [`Vocabulary`], and a [`Config`].
103    ///
104    /// # Arguments
105    ///
106    /// * `kbnf_syntax_grammar_str` - The KBNF grammar string.
107    /// * `vocabulary` - The [`Vocabulary`] object.
108    /// * `config` - The [`Config`] object.
109    ///
110    /// # Returns
111    ///
112    /// * [`Engine`] - The new [`Engine`] object.
113    ///
114    /// # Errors
115    ///
116    /// Returns an [`CreateEngineError`] when the grammar is empty or the grammar and/or config's value range is not supported by the Engine.
117    pub fn with_config(
118        kbnf_syntax_grammar_str: &str,
119        vocabulary: Vocabulary,
120        config: Config,
121    ) -> Result<Engine, CreateEngineError> {
122        let tsp = config.expected_output_length;
123        let regex_config = config.regex_config;
124        let internal_config = config.internal_config();
125        let grammar =
126            utils::construct_kbnf_syntax_grammar(kbnf_syntax_grammar_str, internal_config.clone())?;
127        if grammar.is_empty() {
128            return Err(CreateEngineError::EmptyGrammarError);
129        }
130        let td = utils::find_max_dotted_position_from_kbnf_syntax_grammar(&grammar);
131        let tp = utils::find_max_production_id_from_kbnf_syntax_grammar(&grammar);
132        let ts = utils::find_max_state_id_from_kbnf_syntax_grammar(&grammar);
133        let engine = if Self::check_id_length(&grammar, u8::MAX.into())
134            && td <= u8::MAX.into()
135            && tp <= u8::MAX.into()
136            && tsp <= u8::MAX.into()
137            && ts <= u32::MAX as usize
138        {
139            let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
140            let grammar = Arc::new(grammar);
141            let vocabulary = Arc::new(vocabulary);
142            EngineUnion::U8U8U8U8U32(EngineBase::new(
143                vocabulary,
144                grammar,
145                internal_config.engine_config,
146            )?)
147        } else if Self::check_id_length(&grammar, u8::MAX.into())
148            && td <= u8::MAX.into()
149            && tp <= u16::MAX.into()
150            && tsp <= u16::MAX.into()
151            && ts <= u16::MAX as usize
152        {
153            let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
154            let grammar = Arc::new(grammar);
155            let vocabulary = Arc::new(vocabulary);
156            EngineUnion::U8U8U16U16U16(EngineBase::new(
157                vocabulary,
158                grammar,
159                internal_config.engine_config,
160            )?)
161        } else if Self::check_id_length(&grammar, u16::MAX.into())
162            && td <= u16::MAX.into()
163            && tp <= u32::MAX as usize
164            && tsp <= u32::MAX as usize
165            && ts <= u32::MAX as usize
166        {
167            let grammar: Grammar<u16> = Grammar::new(grammar, &vocabulary, regex_config)?;
168            let grammar = Arc::new(grammar);
169            let vocabulary = Arc::new(vocabulary);
170            EngineUnion::U16U16U32U32U32(EngineBase::new(
171                vocabulary,
172                grammar,
173                internal_config.engine_config,
174            )?)
175        } else {
176            return Err(CreateEngineError::InvalidInputError);
177        };
178        Ok(Self { union: engine })
179    }
180}
181
182macro_rules! match_engine_union {
183    ($e:path[$s:expr$(,$p:ident)*]) => {
184        match $s {
185            EngineUnion::U8U8U8U8U32(engine) => $e(engine, $($p,)*),
186            EngineUnion::U8U8U16U16U16(engine) => $e(engine, $($p,)*),
187            EngineUnion::U16U16U32U32U32(engine) => $e(engine, $($p,)*),
188        }
189    }
190}
191
192impl crate::engine_like::sealed::Sealed for Engine {}
193
194impl Engine {
195    pub fn shrink_to_fit(&mut self) {
196        match &mut self.union {
197            EngineUnion::U8U8U8U8U32(engine) => {
198                engine.shrink_to_fit();
199            },
200            EngineUnion::U8U8U16U16U16(engine) => {
201                engine.shrink_to_fit();
202            },
203            EngineUnion::U16U16U32U32U32(engine) => {
204                engine.shrink_to_fit();
205            },
206        }
207    }
208}
209
210impl EngineLike for Engine {
211    fn try_accept_new_token(
212        &mut self,
213        token_id: u32,
214    ) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
215        match_engine_union!(EngineLike::try_accept_new_token[&mut self.union, token_id])
216    }
217
218    fn try_accept_new_bytes(
219        &mut self,
220        bytes: &[u8],
221    ) -> Result<crate::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
222        match_engine_union!(EngineLike::try_accept_new_bytes[&mut self.union, bytes])
223    }
224
225    fn compute_allowed_token_ids(&mut self) {
226        match_engine_union!(EngineLike::compute_allowed_token_ids[&mut self.union])
227    }
228
229    fn mask_logits(&self, logits: &mut [f32]) -> Result<(), crate::engine_like::MaskLogitsError> {
230        match_engine_union!(EngineLike::mask_logits[&self.union, logits])
231    }
232
233    fn update_logits(
234        &mut self,
235        token_id: u32,
236        logits: &mut [f32],
237    ) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::UpdateLogitsError> {
238        match_engine_union!(EngineLike::update_logits[&mut self.union, token_id, logits])
239    }
240
241    fn allowed_token_ids_from_last_computation(&self) -> &fixedbitset_stack::FixedBitSet {
242        match_engine_union!(EngineLike::allowed_token_ids_from_last_computation[&self.union])
243    }
244
245    fn write_disallowed_token_ids_to_buffer(
246        &self,
247        buffer: &mut [usize],
248    ) -> Result<(), crate::engine_like::WriteBufferError> {
249        match_engine_union!(EngineLike::write_disallowed_token_ids_to_buffer[&self.union, buffer])
250    }
251
252    fn write_allowed_token_ids_to_buffer(
253        &self,
254        buffer: &mut [usize],
255    ) -> Result<(), crate::engine_like::WriteBufferError> {
256        match_engine_union!(EngineLike::write_allowed_token_ids_to_buffer[&self.union, buffer])
257    }
258
259    fn is_finished(&self) -> bool {
260        match_engine_union!(EngineLike::is_finished[&self.union])
261    }
262
263    fn reset(&mut self) {
264        match_engine_union!(EngineLike::reset[&mut self.union])
265    }
266
267    fn into_boxed_engine(self) -> Box<dyn EngineLike> {
268        match_engine_union!(EngineLike::into_boxed_engine[self.union])
269    }
270    fn vocab(&self) -> Arc<Vocabulary> {
271        match_engine_union!(EngineLike::vocab[&self.union])
272    }
273}