use std::sync::Arc;
use kbnf_syntax::simplified_grammar::SimplifiedGrammar;
#[cfg(feature = "python")]
use pyo3::pyclass;
use serde::{Deserialize, Serialize};
#[cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;
use crate::{
config::Config, engine_base::EngineBase, engine_like::EngineLike, grammar::Grammar, utils,
vocabulary::Vocabulary,
};
#[cfg_attr(feature = "python", pyclass)]
#[cfg_attr(feature = "python", pyo3(get_all, set_all))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Copy)]
pub struct EngineConfig {
pub cache_enabled: bool,
pub compaction_enabled: bool,
}
#[derive(Debug, Clone)]
pub(crate) enum EngineUnion {
U8U8U8U8U32(EngineBase<u8, u8, u8, u8, u32>),
U8U8U16U16U16(EngineBase<u8, u8, u16, u16, u16>),
U16U16U32U32U32(EngineBase<u16, u16, u32, u32, u32>),
}
#[cfg_attr(feature = "python", pyclass(subclass))]
#[cfg_attr(feature = "python", pyo3(name = "InternalEngine"))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Clone)]
pub struct Engine {
union: EngineUnion,
}
#[derive(Debug, thiserror::Error)]
pub enum CreateEngineError {
#[error("{0}")] EngineBaseError(#[from] crate::engine_base::CreateEngineBaseError),
#[error("{0}")] GrammarError(#[from] crate::grammar::CreateGrammarError),
#[error("The grammar after simplification is empty.
This usually means that the grammar only contains empty terminals and/or self recursions like A::=A;")]
EmptyGrammarError,
#[error("The grammar and/or config's value range is not supported by the Engine.\n
This usually means that the grammar has more than 65536 nonterminals,
at least one nonterminal has more than 65536 alternations or repetitions, and/or the expected output length is more than 2^32.")]
InvalidInputError,
}
impl Engine {
pub fn new(
kbnf_syntax_grammar_str: &str,
vocabulary: Vocabulary,
) -> Result<Engine, CreateEngineError> {
let config = Config::default();
Self::with_config(kbnf_syntax_grammar_str, vocabulary, config)
}
fn check_id_length(grammar: &SimplifiedGrammar, value: usize) -> bool {
grammar.interned_strings.terminals.len() <= value
&& grammar.interned_strings.nonterminals.len() <= value
}
pub fn with_config(
kbnf_syntax_grammar_str: &str,
vocabulary: Vocabulary,
config: Config,
) -> Result<Engine, CreateEngineError> {
let tsp = config.expected_output_length;
let regex_config = config.regex_config;
let internal_config = config.internal_config();
let grammar =
utils::construct_kbnf_syntax_grammar(kbnf_syntax_grammar_str, internal_config.clone())?;
if grammar.is_empty() {
return Err(CreateEngineError::EmptyGrammarError);
}
let td = utils::find_max_dotted_position_from_kbnf_syntax_grammar(&grammar);
let tp = utils::find_max_production_id_from_kbnf_syntax_grammar(&grammar);
let ts = utils::find_max_state_id_from_kbnf_syntax_grammar(&grammar);
let engine = if Self::check_id_length(&grammar, u8::MAX.into())
&& td <= u8::MAX.into()
&& tp <= u8::MAX.into()
&& tsp <= u8::MAX.into()
&& ts <= u32::MAX as usize
{
let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
let grammar = Arc::new(grammar);
let vocabulary = Arc::new(vocabulary);
EngineUnion::U8U8U8U8U32(EngineBase::new(
vocabulary,
grammar,
internal_config.engine_config,
)?)
} else if Self::check_id_length(&grammar, u8::MAX.into())
&& td <= u8::MAX.into()
&& tp <= u16::MAX.into()
&& tsp <= u16::MAX.into()
&& ts <= u16::MAX as usize
{
let grammar: Grammar<u8> = Grammar::new(grammar, &vocabulary, regex_config)?;
let grammar = Arc::new(grammar);
let vocabulary = Arc::new(vocabulary);
EngineUnion::U8U8U16U16U16(EngineBase::new(
vocabulary,
grammar,
internal_config.engine_config,
)?)
} else if Self::check_id_length(&grammar, u16::MAX.into())
&& td <= u16::MAX.into()
&& tp <= u32::MAX as usize
&& tsp <= u32::MAX as usize
&& ts <= u32::MAX as usize
{
let grammar: Grammar<u16> = Grammar::new(grammar, &vocabulary, regex_config)?;
let grammar = Arc::new(grammar);
let vocabulary = Arc::new(vocabulary);
EngineUnion::U16U16U32U32U32(EngineBase::new(
vocabulary,
grammar,
internal_config.engine_config,
)?)
} else {
return Err(CreateEngineError::InvalidInputError);
};
Ok(Self { union: engine })
}
}
macro_rules! match_engine_union {
($e:path[$s:expr$(,$p:ident)*]) => {
match $s {
EngineUnion::U8U8U8U8U32(engine) => $e(engine, $($p,)*),
EngineUnion::U8U8U16U16U16(engine) => $e(engine, $($p,)*),
EngineUnion::U16U16U32U32U32(engine) => $e(engine, $($p,)*),
}
}
}
impl crate::engine_like::sealed::Sealed for Engine {}
impl EngineLike for Engine {
fn try_accept_new_token(
&mut self,
token_id: u32,
) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
match_engine_union!(EngineLike::try_accept_new_token[&mut self.union, token_id])
}
fn try_accept_new_bytes(
&mut self,
bytes: &[u8],
) -> Result<crate::AcceptTokenResult, crate::engine_like::AcceptTokenError> {
match_engine_union!(EngineLike::try_accept_new_bytes[&mut self.union, bytes])
}
fn compute_allowed_token_ids(&mut self) {
match_engine_union!(EngineLike::compute_allowed_token_ids[&mut self.union])
}
fn mask_logits(&self, logits: &mut [f32]) -> Result<(), crate::engine_like::MaskLogitsError> {
match_engine_union!(EngineLike::mask_logits[&self.union, logits])
}
fn update_logits(
&mut self,
token_id: u32,
logits: &mut [f32],
) -> Result<crate::engine_like::AcceptTokenResult, crate::engine_like::UpdateLogitsError> {
match_engine_union!(EngineLike::update_logits[&mut self.union, token_id, logits])
}
fn allowed_token_ids_from_last_computation(&self) -> &fixedbitset_stack::FixedBitSet {
match_engine_union!(EngineLike::allowed_token_ids_from_last_computation[&self.union])
}
fn write_disallowed_token_ids_to_buffer(
&self,
buffer: &mut [usize],
) -> Result<(), crate::engine_like::WriteBufferError> {
match_engine_union!(EngineLike::write_disallowed_token_ids_to_buffer[&self.union, buffer])
}
fn is_finished(&self) -> bool {
match_engine_union!(EngineLike::is_finished[&self.union])
}
fn reset(&mut self) {
match_engine_union!(EngineLike::reset[&mut self.union])
}
fn into_boxed_engine(self) -> Box<dyn EngineLike> {
match_engine_union!(EngineLike::into_boxed_engine[self.union])
}
fn vocab(&self) -> Arc<Vocabulary> {
match_engine_union!(EngineLike::vocab[&self.union])
}
}