1use anyhow::{anyhow, ensure, Result};
2use toktrie::{SimpleVob, TokEnv, TokenId};
3
4use crate::{api::StopReason, panic_utils, TokenParser};
5
6#[derive(Clone)]
7struct MatcherInner {
8 parser: TokenParser,
9}
10
11#[derive(Clone)]
12#[allow(clippy::large_enum_variant)]
13enum MatcherState {
14 Normal(MatcherInner),
15 Error(String),
16}
17
18#[derive(Clone)]
21pub struct Matcher(MatcherState);
22
23impl Matcher {
24 pub fn new(parser: Result<TokenParser>) -> Self {
25 match parser {
26 Ok(mut parser) => {
27 let caps = &parser.inference_caps;
28 if caps.backtrack {
29 Self::new(Err(anyhow!("backtracking not supported")))
30 } else {
31 if parser.is_fresh() {
33 parser.start_without_prompt();
34 }
35 Matcher(MatcherState::Normal(MatcherInner { parser }))
36 }
37 }
38 Err(e) => Matcher(MatcherState::Error(e.to_string())),
39 }
40 }
41
42 fn with_inner<T>(&mut self, f: impl FnOnce(&mut MatcherInner) -> Result<T>) -> Result<T> {
43 match &mut self.0 {
44 MatcherState::Normal(ref mut inner) => {
45 let r = panic_utils::catch_unwind(std::panic::AssertUnwindSafe(|| f(inner)));
48 match r {
49 Ok(r) => Ok(r),
50 Err(e) => {
51 self.0 = MatcherState::Error(e.to_string());
52 Err(e)
53 }
54 }
55 }
56 MatcherState::Error(e) => Err(anyhow!("{}", e)),
57 }
58 }
59
60 pub fn consume_tokens(&mut self, tokens: &[TokenId]) -> Result<()> {
64 self.with_inner(|inner| {
65 for &t in tokens {
66 let bt = inner.parser.consume_token(t)?;
67 ensure!(bt == 0, "unexpected backtracking");
68 }
69 let _ = inner.parser.check_stop()?;
70 Ok(())
71 })
72 }
73
74 pub fn rollback(&mut self, num_tokens: usize) -> Result<()> {
75 self.with_inner(|inner| inner.parser.rollback(num_tokens))
76 }
77
78 pub fn reset(&mut self) -> Result<()> {
79 self.with_inner(|inner| inner.parser.reset())
80 }
81
82 pub fn compute_mask(&mut self) -> Result<SimpleVob> {
84 self.with_inner(|inner| inner.parser.compute_mask())
85 }
86
87 pub fn is_accepting(&mut self) -> Result<bool> {
90 self.with_inner(|inner| Ok(inner.parser.is_accepting()))
91 }
92
93 pub fn is_stopped(&self) -> bool {
94 match &self.0 {
95 MatcherState::Normal(inner) => inner.parser.stop_reason() != StopReason::NotStopped,
96 MatcherState::Error(_) => true,
97 }
98 }
99
100 pub fn stop_reason(&self) -> StopReason {
101 match &self.0 {
102 MatcherState::Normal(inner) => inner.parser.stop_reason(),
103 MatcherState::Error(_) => StopReason::InternalError,
104 }
105 }
106
107 pub fn compute_ff_tokens(&mut self) -> Vec<TokenId> {
109 self.with_inner(|inner| Ok(inner.parser.compute_ff_tokens()))
110 .unwrap_or_else(|_| vec![])
111 }
112
113 pub fn compute_ff_bytes(&mut self) -> Vec<u8> {
116 self.with_inner(|inner| Ok(inner.parser.force_bytes()))
117 .unwrap_or_else(|_| vec![])
118 }
119
120 pub fn try_consume_tokens(&mut self, tokens: &[TokenId]) -> Result<usize> {
125 self.with_inner(|inner| {
126 for (idx, &t) in tokens.iter().enumerate() {
127 if !inner.parser.validate_token(t)? {
128 return Ok(idx);
129 }
130 let bt = inner.parser.consume_token(t)?;
131 ensure!(bt == 0, "unexpected backtracking");
132 }
133 let _ = inner.parser.check_stop()?;
134 Ok(tokens.len())
135 })
136 }
137
138 pub fn validate_tokens(&mut self, tokens: &[TokenId]) -> Result<usize> {
139 self.with_inner(|inner| inner.parser.validate_tokens_raw(tokens))
140 }
141
142 pub fn is_error(&self) -> bool {
143 matches!(self.0, MatcherState::Error(_))
144 }
145
146 pub fn get_error(&self) -> Option<String> {
147 match &self.0 {
148 MatcherState::Normal(_) => None,
149 MatcherState::Error(e) => Some(e.clone()),
150 }
151 }
152
153 pub fn tok_env(&self) -> Result<TokEnv> {
154 match &self.0 {
155 MatcherState::Normal(inner) => Ok(inner.parser.token_env.clone()),
156 MatcherState::Error(e) => Err(anyhow!("{}", e)),
157 }
158 }
159}