use anyhow::{bail, ensure, Result};
use toktrie::{StepResult, TokenId};
use crate::{
api::StopReason,
loginfo,
output::{ParserOutput, Reporter},
panic_utils, TokenParser,
};
#[derive(Clone)]
pub struct Constraint {
pub parser: TokenParser,
pub log_json_progress: bool,
pub temperature: f32,
reporter: Reporter,
last_res: StepResult,
started: bool,
pending_stop: bool,
}
#[derive(Debug, Clone, Default)]
pub struct CommitResult {
pub stop: bool,
pub backtrack: u32,
pub ff_tokens: Vec<TokenId>,
}
impl CommitResult {
pub fn stop() -> Self {
Self {
stop: true,
backtrack: 0,
ff_tokens: vec![],
}
}
pub fn from_step_result(res: &StepResult) -> Self {
let mut r = CommitResult {
stop: res.is_stop(),
backtrack: 0,
ff_tokens: vec![],
};
if let Some(s) = res.unconditional_splice() {
r.backtrack = s.backtrack;
r.ff_tokens = s.ff_tokens.clone();
}
r
}
}
impl Constraint {
pub fn new(parser: TokenParser) -> Self {
assert!(parser.is_fresh(), "Parser was already used");
Self {
parser,
reporter: Reporter::default(),
last_res: StepResult::noop(),
started: false,
log_json_progress: false,
temperature: 0.0,
pending_stop: false,
}
}
pub fn deep_clone(&self) -> Self {
let mut copy = self.clone();
copy.parser = self.parser.deep_clone();
copy
}
fn save_progress_and_result(&mut self, res: StepResult) {
self.last_res = res;
if self.log_json_progress {
for p in self.reporter.get_progress(&mut self.parser, &self.last_res) {
self.parser.logger.write_buffer("JSON-OUT: ");
self.parser
.logger
.write_buffer(&serde_json::to_string(&p).unwrap());
self.parser.logger.write_buffer("\n");
}
}
self.save_temperature();
}
fn save_temperature(&mut self) {
if let Some(temp) = self.parser.parser.temperature() {
self.temperature = temp;
}
}
pub fn process_prompt(&mut self, prompt: Vec<TokenId>) -> Vec<TokenId> {
assert!(!self.started);
self.started = true;
let r = if self.parser.token_env.tokenize_is_canonical() {
self.parser.process_prompt(prompt)
} else {
self.parser.start_without_prompt();
prompt
};
self.save_temperature();
r
}
pub fn start_without_prompt(&mut self) {
assert!(!self.started);
self.started = true;
self.parser.start_without_prompt();
self.save_temperature();
}
pub fn force_tokens(&mut self, tokens: &[TokenId]) -> Result<()> {
for &t in tokens {
self.parser.consume_token(t)?;
}
Ok(())
}
pub fn has_pending_stop(&self) -> bool {
self.pending_stop
}
pub fn compute_mask(&mut self) -> Result<&StepResult> {
panic_utils::catch_unwind(std::panic::AssertUnwindSafe(|| self.compute_mask_inner()))
.map(|_| &self.last_res)
}
fn compute_mask_inner(&mut self) -> Result<()> {
loginfo!(self.parser.logger, "\ncompute_mask()");
if !self.started {
self.started = true;
self.parser.start_without_prompt();
self.save_temperature();
}
ensure!(!self.last_res.is_stop(), "compute_mask() called after stop");
if self.parser.check_stop()? {
self.pending_stop = true;
self.save_progress_and_result(StepResult::stop());
} else {
let mask = self.parser.compute_mask();
if mask.is_err() && self.parser.stop_reason() == StopReason::NoExtensionBias {
self.save_progress_and_result(StepResult::stop());
} else {
self.save_progress_and_result(StepResult::sample(mask?, self.parser.temperature()));
}
}
Ok(())
}
pub fn step_result(&self) -> &StepResult {
&self.last_res
}
fn res_commit_result(&mut self) -> Result<CommitResult> {
Ok(CommitResult::from_step_result(&self.last_res))
}
pub fn validate_tokens_raw(&mut self, tokens: &[TokenId]) -> Result<usize> {
if self.last_res.unconditional_splice().is_some() {
self.save_progress_and_result(StepResult::sample(
self.tok_trie().alloc_token_set(),
self.parser.temperature(),
));
}
self.parser.validate_tokens_raw(tokens)
}
pub fn commit_token(&mut self, sampled_token: Option<TokenId>) -> Result<CommitResult> {
panic_utils::catch_unwind(std::panic::AssertUnwindSafe(|| {
self.commit_token_inner(sampled_token)
}))
}
fn commit_token_inner(&mut self, sampled_token: Option<TokenId>) -> Result<CommitResult> {
let n_tokens = self.parser.num_tokens();
loginfo!(
self.parser.logger,
"\ncommit_token({}) at #{}",
sampled_token
.map(|t| self.parser.token_env.tok_trie().token_dbg(t))
.unwrap_or("None".to_string()),
n_tokens
);
if self.last_res.is_stop() {
return self.res_commit_result();
}
if self.last_res.unconditional_splice().is_some() {
assert!(self.parser.inference_caps.ff_tokens);
return self.res_commit_result();
}
if self.last_res.sample_mask.is_some() {
let t = sampled_token.ok_or_else(|| {
anyhow::anyhow!("sampled_token is required when mask was present")
})?;
let mut bt = self.parser.consume_token(t)?;
let mut tokens = vec![t];
if bt > 0 {
loginfo!(self.parser.logger, "backtrack sampled");
tokens.clear();
bt -= 1;
}
if self.parser.inference_caps.ff_tokens {
tokens.extend(self.parser.consume_ff_tokens()?);
}
if self.parser.check_stop()? {
loginfo!(self.parser.logger, "set pending stop");
self.pending_stop = true;
}
self.save_progress_and_result(StepResult::splice(bt as u32, tokens));
return self.res_commit_result();
} else {
bail!("internal error: invalid compute_mask() result");
}
}
pub fn flush_progress(&mut self) -> Vec<ParserOutput> {
self.reporter.get_progress(&mut self.parser, &self.last_res)
}
pub fn flush_logs(&mut self) -> String {
self.parser.logger.get_and_clear_logs()
}
pub fn tok_trie(&self) -> &toktrie::TokTrie {
self.parser.token_env.tok_trie()
}
}