1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
//! The main module that contains the [`Engine`] struct and its related types.
use std::sync::Arc;

use kbnf_syntax::simplified_grammar::SimplifiedGrammar;
#[cfg(feature = "python")]
use pyo3::pyclass;
use serde::{Deserialize, Serialize};
#[cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;

use crate::{
    config::Config, engine_base::EngineBase, engine_like::EngineLike, grammar::Grammar, utils,
    vocabulary::Vocabulary,
};

/// The specific config of the [`Engine`].
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyo3(get_all, set_all))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Copy)]
pub struct EngineConfig {
    /// Whether the cache is enabled. Caching speeds up the engine eventually if any of the following conditions are met:
    /// 1. The grammar is "simple". What exactly constitutes a simple grammar is not well defined at the moment but
    ///    all regular grammars should be simple.
    /// 2. The grammar is reused multiple times for inputs of similar lengths.
    ///    It is enabled by default.
    pub cache_enabled: bool,
    /// Whether the compaction is enabled. Compaction reduces the memory usage of the engine and
    /// speeds up the engine in most cases. In particular, cache usually requires compaction to be effective.
    /// It is enabled by default.
    pub compaction_enabled: bool,
}
#[derive(Debug, Clone)]
/// An enum that represents the common type combinations of [`EngineBase`].
pub(crate) enum EngineUnion {
    /// Typical simple grammar with complex dfa without any repetition
    U8U8U8U8U32(EngineBase<u8, u8, u8, u8, u32>),
    /// Typical simple grammar with simple dfa without any repetition
    U8U8U16U16U16(EngineBase<u8, u8, u16, u16, u16>),
    /// Complex grammar with complex dfa without any repetition
    U16U16U32U32U32(EngineBase<u16, u16, u32, u32, u32>),
}
#[cfg_attr(feature = "python", pyclass(subclass))]
#[cfg_attr(feature = "python", pyo3(name = "InternalEngine"))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Clone)]
/// The main struct that wraps the [`EngineBase`] so the user do not have to specify the generic type every time for common cases.
pub struct Engine {
    union: EngineUnion,
}
#[derive(Debug, thiserror::Error)]
/// Represents the error type for the [`Engine`] creation.
pub enum CreateEngineError {
    #[error("{0}")] // inherits the error message from the wrapped EngineBaseError
    /// A wrapper for the [`CreateEngineBaseError`](crate::engine_base::CreateEngineBaseError) error type.
    EngineBaseError(#[from] crate::engine_base::CreateEngineBaseError),
    #[error("{0}")] // inherits the error message from the wrapped GrammarError
    /// A wrapper for the [`CreateGrammarError`](crate::grammar::CreateGrammarError) error type.
    GrammarError(#[from] crate::grammar::CreateGrammarError),
    #[error("The grammar after simplification is empty.
    This usually means that the grammar only contains empty terminals and/or self recursions like A::=A;")]
    /// The grammar is empty.
    EmptyGrammarError,
    #[error("The grammar and/or config's value range is not supported by the Engine.\n
    This usually means that the grammar has more than 65536 nonterminals,
    at least one nonterminal has more than 65536 alternations or repetitions, and/or the expected output length is more than 2^32.")]
    /// The grammar and/or config's value range is not supported by the Engine.
    InvalidInputError,
}

impl Engine {
    /// Create a new [`Engine`] from an KBNF grammar string and a [`Vocabulary`].
    ///
    /// # Arguments
    ///
    /// * `kbnf_syntax_grammar_str` - The KBNF grammar string.
    ///
    /// * `vocabulary` - The [`Vocabulary`] object.
    ///
    /// # Returns
    ///
    /// * [`Engine`] - The new [`Engine`] object.
    ///
    /// # Errors
    ///
    /// Returns an [`CreateEngineError`] when the grammar is empty or the grammar and/or config's value range is not supported by the Engine.
    pub fn new(
        kbnf_syntax_grammar_str: &str,
        vocabulary: Vocabulary,
    ) -> Result<Engine, CreateEngineError> {
        let config = Config::default();
        Self::with_config(kbnf_syntax_grammar_str, vocabulary, config)
    }

    fn check_id_length(grammar: &SimplifiedGrammar, value: usize) -> bool {
        grammar.interned_strings.terminals.len() <= value
            && grammar.interned_strings.nonterminals.len() <= value
    }
    /// Create a new [`Engine`] from an KBNF grammar string, a [`Vocabulary`], and a [`Config`].
    ///
    /// # Arguments
    ///
    /// * `kbnf_syntax_grammar_str` - The KBNF grammar string.
    /// * `vocabulary` - The [`Vocabulary`] object.
    /// * `config` - The [`Config`] object.
    ///
    /// # Returns
    ///
    /// * [`Engine`] - The new [`Engine`] object.
    ///
    /// # Errors
    ///
    /// Returns an [`CreateEngineError`] when the grammar is empty or the grammar and/or config's value range is not supported by the Engine.
    pub fn with_config(
        kbnf_syntax_grammar_str: &str,
        vocabulary: Vocabulary,
        config: Config,
    ) -> Result<Engine, CreateEngineError> {
        let tsp = config.expected_output_length;
        let regex_config = config.regex_config;
        let internal_config = config.internal_config();
        let grammar =
            utils::construct_kbnf_syntax_grammar(kbnf_syntax_grammar_str, internal_config.clone())?;
        if grammar.is_empty() {
            return Err(CreateEngineError::EmptyGrammarError);
        }
        let td = utils::find_max_dotted_position_from_kbnf_syntax_grammar(&grammar);
        let tp = utils::find_max_production_id_from_kbnf_syntax_grammar(&grammar);
        let ts = utils::find_max_state_id_from_kbnf_syntax_grammar(&grammar);
        let engine = if Self::check_id_length(&grammar, u8::MAX.into())
            && td <= u8::MAX.into()
            && tp <= u8::MAX.into()
            && tsp <= u8::MAX.into()
            && ts <= u32::MAX as usize
        {
            let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
            let grammar = Arc::new(grammar);
            let vocabulary = Arc::new(vocabulary);
            EngineUnion::U8U8U8U8U32(EngineBase::new(
                vocabulary,
                grammar,
                internal_config.engine_config,
            )?)
        } else if Self::check_id_length(&grammar, u8::MAX.into())
            && td <= u8::MAX.into()
            && tp <= u16::MAX.into()
            && tsp <= u16::MAX.into()
            && ts <= u16::MAX as usize
        {
            let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
            let grammar = Arc::new(grammar);
            let vocabulary = Arc::new(vocabulary);
            EngineUnion::U8U8U16U16U16(EngineBase::new(
                vocabulary,
                grammar,
                internal_config.engine_config,
            )?)
        } else if Self::check_id_length(&grammar, u16::MAX.into())
            && td <= u16::MAX.into()
            && tp <= u32::MAX as usize
            && tsp <= u32::MAX as usize
            && ts <= u32::MAX as usize
        {
            let grammar: Grammar<u16> = Grammar::new(grammar, &vocabulary, regex_config)?;
            let grammar = Arc::new(grammar);
            let vocabulary = Arc::new(vocabulary);
            EngineUnion::U16U16U32U32U32(EngineBase::new(
                vocabulary,
                grammar,
                internal_config.engine_config,
            )?)
        } else {
            return Err(CreateEngineError::InvalidInputError);
        };
        Ok(Self { union: engine })
    }
}

macro_rules! match_engine_union {
    ($e:path[$s:expr$(,$p:ident)*]) => {
        match $s {
            EngineUnion::U8U8U8U8U32(engine) => $e(engine, $($p,)*),
            EngineUnion::U8U8U16U16U16(engine) => $e(engine, $($p,)*),
            EngineUnion::U16U16U32U32U32(engine) => $e(engine, $($p,)*),
        }
    }
}

impl crate::engine_like::sealed::Sealed for Engine {}

impl EngineLike for Engine {
    fn try_accept_new_token(
        &mut self,
        token_id: u32,
    ) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
        match_engine_union!(EngineLike::try_accept_new_token[&mut self.union, token_id])
    }

    fn try_accept_new_bytes(
        &mut self,
        bytes: &[u8],
    ) -> Result<crate::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
        match_engine_union!(EngineLike::try_accept_new_bytes[&mut self.union, bytes])
    }

    fn compute_allowed_token_ids(&mut self) {
        match_engine_union!(EngineLike::compute_allowed_token_ids[&mut self.union])
    }

    fn mask_logits(&self, logits: &mut [f32]) -> Result<(), crate::engine_like::MaskLogitsError> {
        match_engine_union!(EngineLike::mask_logits[&self.union, logits])
    }

    fn update_logits(
        &mut self,
        token_id: u32,
        logits: &mut [f32],
    ) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::UpdateLogitsError> {
        match_engine_union!(EngineLike::update_logits[&mut self.union, token_id, logits])
    }

    fn allowed_token_ids_from_last_computation(&self) -> &fixedbitset_stack::FixedBitSet {
        match_engine_union!(EngineLike::allowed_token_ids_from_last_computation[&self.union])
    }

    fn write_disallowed_token_ids_to_buffer(
        &self,
        buffer: &mut [usize],
    ) -> Result<(), crate::engine_like::WriteBufferError> {
        match_engine_union!(EngineLike::write_disallowed_token_ids_to_buffer[&self.union, buffer])
    }

    fn is_finished(&self) -> bool {
        match_engine_union!(EngineLike::is_finished[&self.union])
    }

    fn reset(&mut self) {
        match_engine_union!(EngineLike::reset[&mut self.union])
    }

    fn into_boxed_engine(self) -> Box<dyn EngineLike> {
        match_engine_union!(EngineLike::into_boxed_engine[self.union])
    }
    fn vocab(&self) -> Arc<Vocabulary> {
        match_engine_union!(EngineLike::vocab[&self.union])
    }
}