use std::sync::Arc;
use displaydoc::Display;
use fixedbitset_stack::FixedBitSet;
#[cfg(feature = "python")]
use pyo3::pyclass;
#[cfg(feature = "wasm")]
use wasm_bindgen::prelude::*;
use crate::vocabulary::Vocabulary;
#[cfg_attr(feature = "python", pyclass(eq, eq_int))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Display, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AcceptTokenError {
UnknownTokenID,
Rejected,
Finished,
}
#[cfg_attr(feature = "python", pyclass(eq, eq_int))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Display, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AcceptTokenResult {
Ongoing,
Finished,
}
#[cfg_attr(feature = "python", pyclass(eq, eq_int))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Display, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MaskLogitsError {
InvalidLogitsLength,
}
#[cfg_attr(feature = "python", pyclass(eq, eq_int))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Display, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WriteBufferError {
BufferTooSmall,
}
#[cfg_attr(feature = "python", pyclass(eq, eq_int))]
#[cfg_attr(feature = "wasm", wasm_bindgen)]
#[derive(Debug, Display, Clone, Copy, PartialEq, Eq, Hash)]
pub enum UpdateLogitsError {
UnknownTokenID,
Rejected,
Finished,
InvalidLogitsLength,
}
pub(crate) mod sealed {
pub trait Sealed {}
}
pub trait EngineLike: sealed::Sealed {
fn try_accept_new_token(
&mut self,
token_id: u32,
) -> Result<AcceptTokenResult, AcceptTokenError>;
fn try_accept_new_bytes(&mut self, bytes: &[u8])
-> Result<AcceptTokenResult, AcceptTokenError>;
fn compute_allowed_token_ids(&mut self);
fn mask_logits(&self, logits: &mut [f32]) -> Result<(), MaskLogitsError>;
fn update_logits(
&mut self,
token_id: u32,
logits: &mut [f32],
) -> Result<AcceptTokenResult, UpdateLogitsError>;
fn allowed_token_ids_from_last_computation(&self) -> &FixedBitSet;
fn write_disallowed_token_ids_to_buffer(
&self,
buffer: &mut [usize],
) -> Result<(), WriteBufferError>;
fn write_allowed_token_ids_to_buffer(
&self,
buffer: &mut [usize],
) -> Result<(), WriteBufferError>;
fn is_finished(&self) -> bool;
fn reset(&mut self);
fn into_boxed_engine(self) -> Box<dyn EngineLike>;
fn vocab(&self) -> Arc<Vocabulary>;
}