1use std::sync::Arc;
3
4use kbnf_syntax::simplified_grammar::SimplifiedGrammar;
5#[cfg(feature = "python")]
6use pyo3::pyclass;
7use serde::{Deserialize, Serialize};
8#[cfg(feature = "wasm")]
9use wasm_bindgen::prelude::*;
10
11use crate::{
12 config::Config, engine_base::EngineBase, engine_like::EngineLike, grammar::Grammar, utils,
13 vocabulary::Vocabulary,
14};
15
16#[cfg_attr(feature = "python", pyclass)]
18#[cfg_attr(feature = "python", pyo3(get_all, set_all))]
19#[cfg_attr(feature = "wasm", wasm_bindgen)]
20#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Copy)]
21pub struct EngineConfig {
22 pub cache_enabled: bool,
28 pub compaction_enabled: bool,
32 pub rejected_token_prefix_cache_enabled: bool,
35}
36#[derive(Debug, Clone)]
37pub(crate) enum EngineUnion {
39 U8U8U8U8U32(EngineBase<u8, u8, u8, u8, u32>),
41 U8U8U16U16U16(EngineBase<u8, u8, u16, u16, u16>),
43 U16U16U32U32U32(EngineBase<u16, u16, u32, u32, u32>),
45}
46#[cfg_attr(feature = "python", pyclass(subclass))]
47#[cfg_attr(feature = "python", pyo3(name = "InternalEngine"))]
48#[cfg_attr(feature = "wasm", wasm_bindgen)]
49#[derive(Debug, Clone)]
50pub struct Engine {
52 union: EngineUnion,
53}
54#[derive(Debug, thiserror::Error)]
55pub enum CreateEngineError {
57 #[error("{0}")] EngineBaseError(#[from] crate::engine_base::CreateEngineBaseError),
60 #[error("{0}")] GrammarError(#[from] crate::grammar::CreateGrammarError),
63 #[error("The grammar after simplification is empty.
64 This usually means that the grammar only contains empty terminals and/or self recursions like A::=A;")]
65 EmptyGrammarError,
67 #[error("The grammar and/or config's value range is not supported by the Engine.\n
68 This usually means that the grammar has more than 65536 nonterminals,
69 at least one nonterminal has more than 65536 alternations or repetitions, and/or the expected output length is more than 2^32.")]
70 InvalidInputError,
72}
73
74impl Engine {
75 pub fn new(
91 kbnf_syntax_grammar_str: &str,
92 vocabulary: Vocabulary,
93 ) -> Result<Engine, CreateEngineError> {
94 let config = Config::default();
95 Self::with_config(kbnf_syntax_grammar_str, vocabulary, config)
96 }
97
98 fn check_id_length(grammar: &SimplifiedGrammar, value: usize) -> bool {
99 grammar.interned_strings.terminals.len() <= value
100 && grammar.interned_strings.nonterminals.len() <= value
101 }
102 pub fn with_config(
118 kbnf_syntax_grammar_str: &str,
119 vocabulary: Vocabulary,
120 config: Config,
121 ) -> Result<Engine, CreateEngineError> {
122 let tsp = config.expected_output_length;
123 let regex_config = config.regex_config;
124 let internal_config = config.internal_config();
125 let grammar =
126 utils::construct_kbnf_syntax_grammar(kbnf_syntax_grammar_str, internal_config.clone())?;
127 if grammar.is_empty() {
128 return Err(CreateEngineError::EmptyGrammarError);
129 }
130 let td = utils::find_max_dotted_position_from_kbnf_syntax_grammar(&grammar);
131 let tp = utils::find_max_production_id_from_kbnf_syntax_grammar(&grammar);
132 let ts = utils::find_max_state_id_from_kbnf_syntax_grammar(&grammar);
133 let engine = if Self::check_id_length(&grammar, u8::MAX.into())
134 && td <= u8::MAX.into()
135 && tp <= u8::MAX.into()
136 && tsp <= u8::MAX.into()
137 && ts <= u32::MAX as usize
138 {
139 let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
140 let grammar = Arc::new(grammar);
141 let vocabulary = Arc::new(vocabulary);
142 EngineUnion::U8U8U8U8U32(EngineBase::new(
143 vocabulary,
144 grammar,
145 internal_config.engine_config,
146 )?)
147 } else if Self::check_id_length(&grammar, u8::MAX.into())
148 && td <= u8::MAX.into()
149 && tp <= u16::MAX.into()
150 && tsp <= u16::MAX.into()
151 && ts <= u16::MAX as usize
152 {
153 let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
154 let grammar = Arc::new(grammar);
155 let vocabulary = Arc::new(vocabulary);
156 EngineUnion::U8U8U16U16U16(EngineBase::new(
157 vocabulary,
158 grammar,
159 internal_config.engine_config,
160 )?)
161 } else if Self::check_id_length(&grammar, u16::MAX.into())
162 && td <= u16::MAX.into()
163 && tp <= u32::MAX as usize
164 && tsp <= u32::MAX as usize
165 && ts <= u32::MAX as usize
166 {
167 let grammar: Grammar<u16> = Grammar::new(grammar, &vocabulary, regex_config)?;
168 let grammar = Arc::new(grammar);
169 let vocabulary = Arc::new(vocabulary);
170 EngineUnion::U16U16U32U32U32(EngineBase::new(
171 vocabulary,
172 grammar,
173 internal_config.engine_config,
174 )?)
175 } else {
176 return Err(CreateEngineError::InvalidInputError);
177 };
178 Ok(Self { union: engine })
179 }
180}
181
182macro_rules! match_engine_union {
183 ($e:path[$s:expr$(,$p:ident)*]) => {
184 match $s {
185 EngineUnion::U8U8U8U8U32(engine) => $e(engine, $($p,)*),
186 EngineUnion::U8U8U16U16U16(engine) => $e(engine, $($p,)*),
187 EngineUnion::U16U16U32U32U32(engine) => $e(engine, $($p,)*),
188 }
189 }
190}
191
192impl crate::engine_like::sealed::Sealed for Engine {}
193
194impl Engine {
195 pub fn shrink_to_fit(&mut self) {
196 match &mut self.union {
197 EngineUnion::U8U8U8U8U32(engine) => {
198 engine.shrink_to_fit();
199 },
200 EngineUnion::U8U8U16U16U16(engine) => {
201 engine.shrink_to_fit();
202 },
203 EngineUnion::U16U16U32U32U32(engine) => {
204 engine.shrink_to_fit();
205 },
206 }
207 }
208}
209
210impl EngineLike for Engine {
211 fn try_accept_new_token(
212 &mut self,
213 token_id: u32,
214 ) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
215 match_engine_union!(EngineLike::try_accept_new_token[&mut self.union, token_id])
216 }
217
218 fn try_accept_new_bytes(
219 &mut self,
220 bytes: &[u8],
221 ) -> Result<crate::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
222 match_engine_union!(EngineLike::try_accept_new_bytes[&mut self.union, bytes])
223 }
224
225 fn compute_allowed_token_ids(&mut self) {
226 match_engine_union!(EngineLike::compute_allowed_token_ids[&mut self.union])
227 }
228
229 fn mask_logits(&self, logits: &mut [f32]) -> Result<(), crate::engine_like::MaskLogitsError> {
230 match_engine_union!(EngineLike::mask_logits[&self.union, logits])
231 }
232
233 fn update_logits(
234 &mut self,
235 token_id: u32,
236 logits: &mut [f32],
237 ) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::UpdateLogitsError> {
238 match_engine_union!(EngineLike::update_logits[&mut self.union, token_id, logits])
239 }
240
241 fn allowed_token_ids_from_last_computation(&self) -> &fixedbitset_stack::FixedBitSet {
242 match_engine_union!(EngineLike::allowed_token_ids_from_last_computation[&self.union])
243 }
244
245 fn write_disallowed_token_ids_to_buffer(
246 &self,
247 buffer: &mut [usize],
248 ) -> Result<(), crate::engine_like::WriteBufferError> {
249 match_engine_union!(EngineLike::write_disallowed_token_ids_to_buffer[&self.union, buffer])
250 }
251
252 fn write_allowed_token_ids_to_buffer(
253 &self,
254 buffer: &mut [usize],
255 ) -> Result<(), crate::engine_like::WriteBufferError> {
256 match_engine_union!(EngineLike::write_allowed_token_ids_to_buffer[&self.union, buffer])
257 }
258
259 fn is_finished(&self) -> bool {
260 match_engine_union!(EngineLike::is_finished[&self.union])
261 }
262
263 fn reset(&mut self) {
264 match_engine_union!(EngineLike::reset[&mut self.union])
265 }
266
267 fn into_boxed_engine(self) -> Box<dyn EngineLike> {
268 match_engine_union!(EngineLike::into_boxed_engine[self.union])
269 }
270 fn vocab(&self) -> Arc<Vocabulary> {
271 match_engine_union!(EngineLike::vocab[&self.union])
272 }
273}