llguidance/
output.rs

1use crate::HashSet;
2use serde::{Deserialize, Serialize};
3use toktrie::{bytes::to_hex_string, StepResult};
4
5use crate::{api::StopReason, earley, TokenParser};
6
7#[derive(Serialize, Deserialize)]
8pub struct BytesOutput {
9    pub str: String,
10    pub hex: String,
11}
12
13#[derive(Serialize, Deserialize)]
14#[serde(tag = "object", rename_all = "snake_case")]
15pub enum ParserOutput {
16    Capture {
17        name: String,
18        #[serde(flatten)]
19        bytes: BytesOutput,
20        log_prob: f64,
21    },
22    FinalText {
23        #[serde(flatten)]
24        bytes: BytesOutput,
25        stop_reason: StopReason,
26    },
27    Text {
28        #[serde(flatten)]
29        bytes: BytesOutput,
30        log_prob: f64,
31        num_tokens: usize,
32        is_generated: bool,
33        stats: ParserStats,
34    },
35}
36
37#[derive(Serialize, Deserialize)]
38pub struct ParserStats {
39    runtime_us: u64,
40    #[serde(flatten)]
41    stats: earley::ParserStats,
42}
43
44impl From<&[u8]> for BytesOutput {
45    fn from(bytes: &[u8]) -> Self {
46        BytesOutput::from_bytes(bytes)
47    }
48}
49
50impl BytesOutput {
51    pub fn from_bytes(bytes: &[u8]) -> Self {
52        BytesOutput {
53            str: String::from_utf8_lossy(bytes).to_string(),
54            hex: to_hex_string(bytes),
55        }
56    }
57}
58
59#[derive(Clone, Default)]
60pub struct Reporter {
61    reported_captures: usize,
62    text_ptr: usize,
63    token_ptr: usize,
64    prev_stats: earley::ParserStats,
65    is_generated: bool,
66}
67
68impl Reporter {
69    pub fn get_progress(
70        &mut self,
71        tok_parser: &TokenParser,
72        mid_res: &StepResult,
73    ) -> Vec<ParserOutput> {
74        let mut res = self.get_progress_core(tok_parser);
75        self.is_generated = !mid_res.is_stop() && mid_res.splices.is_empty();
76
77        if mid_res.is_stop() {
78            res.push(self.final_text(tok_parser));
79        }
80
81        res
82    }
83
84    pub fn final_text(&self, tok_parser: &TokenParser) -> ParserOutput {
85        ParserOutput::FinalText {
86            bytes: tok_parser.final_bytes().into(),
87            stop_reason: tok_parser.stop_reason(),
88        }
89    }
90
91    pub fn set_is_generated(&mut self, is_generated: bool) {
92        self.is_generated = is_generated;
93    }
94
95    pub fn get_progress_core(&mut self, tok_parser: &TokenParser) -> Vec<ParserOutput> {
96        let mut res = vec![];
97
98        // start with captures
99        let captures = &tok_parser.parser.captures()[self.reported_captures..];
100        self.reported_captures += captures.len();
101
102        // remove duplicate names
103        let mut seen = HashSet::default();
104        let captures = captures
105            .iter()
106            .rev()
107            .filter(|(name, _)| seen.insert(name))
108            .collect::<Vec<_>>();
109        for (name, val) in captures.iter().rev() {
110            res.push(ParserOutput::Capture {
111                name: name.clone(),
112                bytes: val.as_slice().into(),
113                log_prob: 0.0, // TODO
114            });
115        }
116
117        // compute stats
118        let delta = tok_parser.parser_stats().delta(&self.prev_stats);
119        self.prev_stats = tok_parser.parser_stats().clone();
120        let runtime_us = tok_parser.compute_mask_start_time.elapsed().as_micros() as u64;
121        let stats = ParserStats {
122            runtime_us,
123            stats: delta,
124        };
125
126        // report newly generated text
127        let num_tokens = tok_parser.num_tokens();
128        let new_text = tok_parser.bytes_since(self.text_ptr);
129        res.push(ParserOutput::Text {
130            bytes: new_text.into(),
131            log_prob: 0.0, // TODO
132            num_tokens: num_tokens.saturating_sub(self.token_ptr),
133            is_generated: self.is_generated,
134            stats,
135        });
136        self.text_ptr += new_text.len();
137        self.token_ptr = num_tokens;
138
139        res
140    }
141}