1use crate::HashSet;
2use serde::{Deserialize, Serialize};
3use toktrie::{bytes::to_hex_string, StepResult};
4
5use crate::{api::StopReason, earley, TokenParser};
6
7#[derive(Serialize, Deserialize)]
8pub struct BytesOutput {
9 pub str: String,
10 pub hex: String,
11}
12
13#[derive(Serialize, Deserialize)]
14#[serde(tag = "object", rename_all = "snake_case")]
15pub enum ParserOutput {
16 Capture {
17 name: String,
18 #[serde(flatten)]
19 bytes: BytesOutput,
20 log_prob: f64,
21 },
22 FinalText {
23 #[serde(flatten)]
24 bytes: BytesOutput,
25 stop_reason: StopReason,
26 },
27 Text {
28 #[serde(flatten)]
29 bytes: BytesOutput,
30 log_prob: f64,
31 num_tokens: usize,
32 is_generated: bool,
33 stats: ParserStats,
34 },
35}
36
37#[derive(Serialize, Deserialize)]
38pub struct ParserStats {
39 runtime_us: u64,
40 #[serde(flatten)]
41 stats: earley::ParserStats,
42}
43
44impl From<&[u8]> for BytesOutput {
45 fn from(bytes: &[u8]) -> Self {
46 BytesOutput::from_bytes(bytes)
47 }
48}
49
50impl BytesOutput {
51 pub fn from_bytes(bytes: &[u8]) -> Self {
52 BytesOutput {
53 str: String::from_utf8_lossy(bytes).to_string(),
54 hex: to_hex_string(bytes),
55 }
56 }
57}
58
59#[derive(Clone, Default)]
60pub struct Reporter {
61 reported_captures: usize,
62 text_ptr: usize,
63 token_ptr: usize,
64 prev_stats: earley::ParserStats,
65 is_generated: bool,
66}
67
68impl Reporter {
69 pub fn get_progress(
70 &mut self,
71 tok_parser: &TokenParser,
72 mid_res: &StepResult,
73 ) -> Vec<ParserOutput> {
74 let mut res = self.get_progress_core(tok_parser);
75 self.is_generated = !mid_res.is_stop() && mid_res.splices.is_empty();
76
77 if mid_res.is_stop() {
78 res.push(self.final_text(tok_parser));
79 }
80
81 res
82 }
83
84 pub fn final_text(&self, tok_parser: &TokenParser) -> ParserOutput {
85 ParserOutput::FinalText {
86 bytes: tok_parser.final_bytes().into(),
87 stop_reason: tok_parser.stop_reason(),
88 }
89 }
90
91 pub fn set_is_generated(&mut self, is_generated: bool) {
92 self.is_generated = is_generated;
93 }
94
95 pub fn get_progress_core(&mut self, tok_parser: &TokenParser) -> Vec<ParserOutput> {
96 let mut res = vec![];
97
98 let captures = &tok_parser.parser.captures()[self.reported_captures..];
100 self.reported_captures += captures.len();
101
102 let mut seen = HashSet::default();
104 let captures = captures
105 .iter()
106 .rev()
107 .filter(|(name, _)| seen.insert(name))
108 .collect::<Vec<_>>();
109 for (name, val) in captures.iter().rev() {
110 res.push(ParserOutput::Capture {
111 name: name.clone(),
112 bytes: val.as_slice().into(),
113 log_prob: 0.0, });
115 }
116
117 let delta = tok_parser.parser_stats().delta(&self.prev_stats);
119 self.prev_stats = tok_parser.parser_stats().clone();
120 let runtime_us = tok_parser.compute_mask_start_time.elapsed().as_micros() as u64;
121 let stats = ParserStats {
122 runtime_us,
123 stats: delta,
124 };
125
126 let num_tokens = tok_parser.num_tokens();
128 let new_text = tok_parser.bytes_since(self.text_ptr);
129 res.push(ParserOutput::Text {
130 bytes: new_text.into(),
131 log_prob: 0.0, num_tokens: num_tokens.saturating_sub(self.token_ptr),
133 is_generated: self.is_generated,
134 stats,
135 });
136 self.text_ptr += new_text.len();
137 self.token_ptr = num_tokens;
138
139 res
140 }
141}