use crate::HashSet;
use serde::{Deserialize, Serialize};
use toktrie::{bytes::to_hex_string, StepResult};
use crate::{api::StopReason, earley, TokenParser};
#[derive(Serialize, Deserialize)]
pub struct BytesOutput {
pub str: String,
pub hex: String,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "object", rename_all = "snake_case")]
pub enum ParserOutput {
Capture {
name: String,
#[serde(flatten)]
bytes: BytesOutput,
log_prob: f64,
},
FinalText {
#[serde(flatten)]
bytes: BytesOutput,
stop_reason: StopReason,
},
Text {
#[serde(flatten)]
bytes: BytesOutput,
log_prob: f64,
num_tokens: usize,
is_generated: bool,
stats: ParserStats,
},
}
#[derive(Serialize, Deserialize)]
pub struct ParserStats {
runtime_us: u64,
#[serde(flatten)]
stats: earley::ParserStats,
}
impl From<&[u8]> for BytesOutput {
fn from(bytes: &[u8]) -> Self {
BytesOutput::from_bytes(bytes)
}
}
impl BytesOutput {
pub fn from_bytes(bytes: &[u8]) -> Self {
BytesOutput {
str: String::from_utf8_lossy(bytes).to_string(),
hex: to_hex_string(bytes),
}
}
}
#[derive(Clone, Default)]
pub struct Reporter {
reported_captures: usize,
text_ptr: usize,
token_ptr: usize,
prev_stats: earley::ParserStats,
is_generated: bool,
}
impl Reporter {
pub fn get_progress(
&mut self,
tok_parser: &TokenParser,
mid_res: &StepResult,
) -> Vec<ParserOutput> {
let mut res = self.get_progress_core(tok_parser);
self.is_generated = !mid_res.is_stop() && mid_res.splices.len() == 0;
if mid_res.is_stop() {
res.push(self.final_text(tok_parser));
}
res
}
pub fn final_text(&self, tok_parser: &TokenParser) -> ParserOutput {
ParserOutput::FinalText {
bytes: tok_parser.final_bytes().into(),
stop_reason: tok_parser.stop_reason(),
}
}
pub fn set_is_generated(&mut self, is_generated: bool) {
self.is_generated = is_generated;
}
pub fn get_progress_core(&mut self, tok_parser: &TokenParser) -> Vec<ParserOutput> {
let mut res = vec![];
let captures = &tok_parser.parser.captures()[self.reported_captures..];
self.reported_captures += captures.len();
let mut seen = HashSet::default();
let captures = captures
.iter()
.rev()
.filter(|(name, _)| seen.insert(name))
.collect::<Vec<_>>();
for (name, val) in captures.iter().rev() {
res.push(ParserOutput::Capture {
name: name.clone(),
bytes: val.as_slice().into(),
log_prob: 0.0, });
}
let delta = tok_parser.parser_stats().delta(&self.prev_stats);
self.prev_stats = tok_parser.parser_stats().clone();
let runtime_us = tok_parser.compute_mask_start_time.elapsed().as_micros() as u64;
let stats = ParserStats {
runtime_us,
stats: delta,
};
let num_tokens = tok_parser.num_tokens();
let new_text = tok_parser.bytes_since(self.text_ptr);
res.push(ParserOutput::Text {
bytes: new_text.into(),
log_prob: 0.0, num_tokens: num_tokens.saturating_sub(self.token_ptr),
is_generated: self.is_generated,
stats,
});
self.text_ptr += new_text.len();
self.token_ptr = num_tokens;
res
}
}