use serde::{Deserialize, Serialize};
pub mod bytes;
pub mod recognizer;
pub mod rng;
mod svob;
mod tokenv;
mod toktree;
pub use svob::{SimpleVob, SimpleVobIter};
pub use tokenv::{parse_numeric_token, ApproximateTokEnv, TokEnv, TokEnvWithTrie, TokenizerEnv};
pub use toktree::{AnythingGoes, Recognizer, TokRxInfo, TokTrie, TokenId, TrieNode, INVALID_TOKEN};
#[derive(Serialize, Deserialize, Clone, Debug, Default)]
pub struct InferenceCapabilities {
#[serde(default)]
pub ff_tokens: bool,
#[serde(default)]
pub conditional_ff_tokens: bool,
#[serde(default)]
pub backtrack: bool,
#[serde(default)]
pub fork: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StepArg {
pub backtrack: u32,
pub tokens: Vec<TokenId>,
pub sampled: Option<TokenId>,
}
impl StepArg {
pub fn empty() -> Self {
StepArg {
backtrack: 0,
tokens: vec![],
sampled: None,
}
}
pub fn save_tokens(&self, acc_tokens: &mut Vec<TokenId>) {
let bt = self.backtrack as usize;
assert!(
bt <= acc_tokens.len(),
"attempting to backtrack past beginning"
);
acc_tokens.truncate(acc_tokens.len() - bt);
acc_tokens.extend_from_slice(&self.tokens);
}
pub fn from_splice(s: &Splice, sampled: Option<TokenId>) -> Self {
StepArg {
backtrack: s.backtrack,
tokens: s.ff_tokens.clone(),
sampled,
}
}
pub fn from_sampled_token(tok: TokenId) -> Self {
StepArg {
backtrack: 0,
tokens: vec![tok],
sampled: Some(tok),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Splice {
pub when_sampled: Vec<TokenId>,
pub backtrack: u32,
pub ff_tokens: Vec<TokenId>,
}
impl Splice {
pub fn noop() -> Self {
Splice {
when_sampled: vec![],
backtrack: 0,
ff_tokens: vec![],
}
}
pub fn tokens(ff_tokens: Vec<TokenId>) -> Self {
Splice {
when_sampled: vec![],
backtrack: 0,
ff_tokens,
}
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Branch<S> {
pub sample_mask: Option<S>,
pub temperature: Option<f32>,
pub splices: Vec<Splice>,
}
impl<S: Clone> Clone for Branch<S> {
fn clone(&self) -> Self {
Branch {
sample_mask: self.sample_mask.clone(),
temperature: self.temperature,
splices: self.splices.clone(),
}
}
}
impl<S> Branch<S> {
pub fn map_mask<F, T>(&self, f: F) -> Branch<T>
where
F: FnOnce(&S) -> T,
{
Branch {
sample_mask: self.sample_mask.as_ref().map(f),
temperature: self.temperature,
splices: self.splices.clone(),
}
}
pub fn find_splice(&self, sampled: TokenId) -> Option<&Splice> {
self.splices
.iter()
.find(|s| s.when_sampled.is_empty() || s.when_sampled.contains(&sampled))
}
pub fn spliced(&self, sampled: TokenId) -> Splice {
self.find_splice(sampled)
.cloned()
.unwrap_or_else(|| Splice {
when_sampled: vec![],
backtrack: 0,
ff_tokens: vec![sampled],
})
}
pub fn unconditional_splice(&self) -> Option<&Splice> {
if self.splices.len() == 1 && self.splices[0].when_sampled.is_empty() {
Some(&self.splices[0])
} else {
None
}
}
pub fn has_backtrack(&self) -> bool {
let max_bt = if self.sample_mask.is_none() { 0 } else { 1 };
self.splices.iter().any(|s| s.backtrack > max_bt)
}
pub fn has_ff_tokens(&self) -> bool {
!self.splices.is_empty()
}
pub fn stop() -> Self {
Branch {
sample_mask: None,
temperature: None,
splices: vec![],
}
}
pub fn is_stop(&self) -> bool {
self.sample_mask.is_none() && self.splices.is_empty()
}
pub fn splice(backtrack: u32, ff_tokens: Vec<TokenId>) -> Self {
Branch {
sample_mask: None,
temperature: None,
splices: vec![Splice {
when_sampled: vec![],
backtrack,
ff_tokens,
}],
}
}
pub fn noop() -> Self {
Self::splice(0, vec![])
}
pub fn sample(set: S, temperature: Option<f32>) -> Self {
Branch {
sample_mask: Some(set),
temperature,
splices: vec![],
}
}
}
pub type StepResult = Branch<SimpleVob>;