llguidance/
matcher.rs

1use anyhow::{anyhow, ensure, Result};
2use toktrie::{SimpleVob, TokEnv, TokenId};
3
4use crate::{api::StopReason, panic_utils, TokenParser};
5
6#[derive(Clone)]
7struct MatcherInner {
8    parser: TokenParser,
9}
10
11#[derive(Clone)]
12#[allow(clippy::large_enum_variant)]
13enum MatcherState {
14    Normal(MatcherInner),
15    Error(String),
16}
17
18/// This is meant to be used in server-side scenarios.
19/// The Constraint interface is more for usage in Python Guidance.
20#[derive(Clone)]
21pub struct Matcher(MatcherState);
22
23impl Matcher {
24    pub fn new(parser: Result<TokenParser>) -> Self {
25        match parser {
26            Ok(mut parser) => {
27                let caps = &parser.inference_caps;
28                if caps.backtrack {
29                    Self::new(Err(anyhow!("backtracking not supported")))
30                } else {
31                    // rest of caps is ignored
32                    if parser.is_fresh() {
33                        parser.start_without_prompt();
34                    }
35                    Matcher(MatcherState::Normal(MatcherInner { parser }))
36                }
37            }
38            Err(e) => Matcher(MatcherState::Error(e.to_string())),
39        }
40    }
41
42    fn with_inner<T>(&mut self, f: impl FnOnce(&mut MatcherInner) -> Result<T>) -> Result<T> {
43        match &mut self.0 {
44            MatcherState::Normal(ref mut inner) => {
45                // We catch any panics here and transform them into regular errors.
46                // They shouldn't happen, but if they do, we don't want to crash the whole program.
47                let r = panic_utils::catch_unwind(std::panic::AssertUnwindSafe(|| f(inner)));
48                match r {
49                    Ok(r) => Ok(r),
50                    Err(e) => {
51                        self.0 = MatcherState::Error(e.to_string());
52                        Err(e)
53                    }
54                }
55            }
56            MatcherState::Error(e) => Err(anyhow!("{}", e)),
57        }
58    }
59
60    /// Advance the parser by one token.
61    /// Also checks if the parser should stop after consuming the tokens
62    /// and puts the parser in stop state if necessary.
63    pub fn consume_tokens(&mut self, tokens: &[TokenId]) -> Result<()> {
64        self.with_inner(|inner| {
65            for &t in tokens {
66                let bt = inner.parser.consume_token(t)?;
67                ensure!(bt == 0, "unexpected backtracking");
68            }
69            let _ = inner.parser.check_stop()?;
70            Ok(())
71        })
72    }
73
74    pub fn rollback(&mut self, num_tokens: usize) -> Result<()> {
75        self.with_inner(|inner| inner.parser.rollback(num_tokens))
76    }
77
78    pub fn reset(&mut self) -> Result<()> {
79        self.with_inner(|inner| inner.parser.reset())
80    }
81
82    /// Compute which tokens can be consumed in the current state.
83    pub fn compute_mask(&mut self) -> Result<SimpleVob> {
84        self.with_inner(|inner| inner.parser.compute_mask())
85    }
86
87    /// Can the grammar be finished in the current state?
88    /// In other words, would the current token mask allow EOS token?
89    pub fn is_accepting(&mut self) -> Result<bool> {
90        self.with_inner(|inner| Ok(inner.parser.is_accepting()))
91    }
92
93    pub fn is_stopped(&self) -> bool {
94        match &self.0 {
95            MatcherState::Normal(inner) => inner.parser.stop_reason() != StopReason::NotStopped,
96            MatcherState::Error(_) => true,
97        }
98    }
99
100    pub fn stop_reason(&self) -> StopReason {
101        match &self.0 {
102            MatcherState::Normal(inner) => inner.parser.stop_reason(),
103            MatcherState::Error(_) => StopReason::InternalError,
104        }
105    }
106
107    /// This will always return [] for non-canonical tokenizers.
108    pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
109        self.with_inner(|inner| Ok(inner.parser.compute_ff_tokens()))
110            .unwrap_or_else(|_| vec![])
111    }
112
113    /// Return any bytes that are forced by the current parser state.
114    /// This also works for non-canonical tokenizers.
115    pub fn compute_ff_bytes(&mut self) -> Vec<u8> {
116        self.with_inner(|inner| Ok(inner.parser.force_bytes()))
117            .unwrap_or_else(|_| vec![])
118    }
119
120    /// Tries to advance the parser by consuming the given tokens.
121    /// Returns the number of tokens consumed.
122    /// Also checks if the parser should stop after consuming the tokens
123    /// and puts the parser in stop state if necessary.
124    pub fn try_consume_tokens(&mut self, tokens: &[TokenId]) -> Result<usize> {
125        self.with_inner(|inner| {
126            for (idx, &t) in tokens.iter().enumerate() {
127                if !inner.parser.validate_token(t)? {
128                    return Ok(idx);
129                }
130                let bt = inner.parser.consume_token(t)?;
131                ensure!(bt == 0, "unexpected backtracking");
132            }
133            let _ = inner.parser.check_stop()?;
134            Ok(tokens.len())
135        })
136    }
137
138    pub fn validate_tokens(&mut self, tokens: &[TokenId]) -> Result<usize> {
139        self.with_inner(|inner| inner.parser.validate_tokens_raw(tokens))
140    }
141
142    pub fn is_error(&self) -> bool {
143        matches!(self.0, MatcherState::Error(_))
144    }
145
146    pub fn get_error(&self) -> Option<String> {
147        match &self.0 {
148            MatcherState::Normal(_) => None,
149            MatcherState::Error(e) => Some(e.clone()),
150        }
151    }
152
153    pub fn tok_env(&self) -> Result<TokEnv> {
154        match &self.0 {
155            MatcherState::Normal(inner) => Ok(inner.parser.token_env.clone()),
156            MatcherState::Error(e) => Err(anyhow!("{}", e)),
157        }
158    }
159}