llama_cpp_bindings/openai/
chat_template_result_grammar.rs1use std::collections::HashSet;
2
3use crate::model::{AddBos, ChatTemplateResult, GrammarTriggerType, LlamaModel};
4use crate::sampling::LlamaSampler;
5use crate::token::LlamaToken;
6
7use super::grammar_sampler_error::GrammarSamplerError;
8
9fn regex_escape(value: &str) -> String {
10 let mut escaped = String::with_capacity(value.len());
11
12 for character in value.chars() {
13 match character {
14 '.' | '^' | '$' | '|' | '(' | ')' | '*' | '+' | '?' | '[' | ']' | '{' | '}' | '\\' => {
15 escaped.push('\\');
16 escaped.push(character);
17 }
18 _ => escaped.push(character),
19 }
20 }
21
22 escaped
23}
24
25fn anchor_pattern(pattern: &str) -> String {
26 if pattern.is_empty() {
27 return "^$".to_string();
28 }
29
30 let mut anchored = String::new();
31
32 if !pattern.starts_with('^') {
33 anchored.push('^');
34 }
35
36 anchored.push_str(pattern);
37
38 if !pattern.ends_with('$') {
39 anchored.push('$');
40 }
41
42 anchored
43}
44
45impl ChatTemplateResult {
46 pub fn build_grammar_sampler(
54 &self,
55 model: &LlamaModel,
56 ) -> Result<(Option<LlamaSampler>, HashSet<LlamaToken>), GrammarSamplerError> {
57 let mut preserved = HashSet::new();
58
59 for token_str in &self.preserved_tokens {
60 let tokens = model
61 .str_to_token(token_str, AddBos::Never)
62 .map_err(|error| GrammarSamplerError::TokenizationFailed(error.to_string()))?;
63
64 if tokens.len() == 1 {
65 preserved.insert(tokens[0]);
66 }
67 }
68
69 let Some(grammar) = self.grammar.as_deref() else {
70 return Ok((None, preserved));
71 };
72
73 let grammar_sampler = if self.grammar_lazy {
74 if self.grammar_triggers.is_empty() {
75 return Err(GrammarSamplerError::MissingTriggers);
76 }
77
78 let mut trigger_patterns = Vec::new();
79 let mut trigger_tokens = Vec::new();
80
81 for trigger in &self.grammar_triggers {
82 match trigger.trigger_type {
83 GrammarTriggerType::Token => {
84 if let Some(token) = trigger.token {
85 trigger_tokens.push(token);
86 }
87 }
88 GrammarTriggerType::Word => {
89 let tokens =
90 model
91 .str_to_token(&trigger.value, AddBos::Never)
92 .map_err(|error| {
93 GrammarSamplerError::TokenizationFailed(error.to_string())
94 })?;
95
96 if tokens.len() == 1 {
97 if !preserved.contains(&tokens[0]) {
98 return Err(GrammarSamplerError::TriggerWordNotPreserved(
99 trigger.value.clone(),
100 ));
101 }
102 trigger_tokens.push(tokens[0]);
103 } else {
104 trigger_patterns.push(regex_escape(&trigger.value));
105 }
106 }
107 GrammarTriggerType::Pattern => {
108 trigger_patterns.push(trigger.value.clone());
109 }
110 GrammarTriggerType::PatternFull => {
111 trigger_patterns.push(anchor_pattern(&trigger.value));
112 }
113 }
114 }
115
116 LlamaSampler::grammar_lazy_patterns(
117 model,
118 grammar,
119 "root",
120 &trigger_patterns,
121 &trigger_tokens,
122 )?
123 } else {
124 LlamaSampler::grammar(model, grammar, "root")?
125 };
126
127 Ok((Some(grammar_sampler), preserved))
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::{anchor_pattern, regex_escape};
134
135 #[test]
136 fn regex_escape_special_characters() {
137 assert_eq!(regex_escape("."), "\\.");
138 assert_eq!(regex_escape("^"), "\\^");
139 assert_eq!(regex_escape("$"), "\\$");
140 assert_eq!(regex_escape("|"), "\\|");
141 assert_eq!(regex_escape("("), "\\(");
142 assert_eq!(regex_escape(")"), "\\)");
143 assert_eq!(regex_escape("*"), "\\*");
144 assert_eq!(regex_escape("+"), "\\+");
145 assert_eq!(regex_escape("?"), "\\?");
146 assert_eq!(regex_escape("["), "\\[");
147 assert_eq!(regex_escape("]"), "\\]");
148 assert_eq!(regex_escape("{"), "\\{");
149 assert_eq!(regex_escape("}"), "\\}");
150 assert_eq!(regex_escape("\\"), "\\\\");
151 }
152
153 #[test]
154 fn regex_escape_normal_text() {
155 assert_eq!(regex_escape("hello world"), "hello world");
156 }
157
158 #[test]
159 fn regex_escape_empty_string() {
160 assert_eq!(regex_escape(""), "");
161 }
162
163 #[test]
164 fn regex_escape_mixed_text() {
165 assert_eq!(regex_escape("price: $5.00"), "price: \\$5\\.00");
166 }
167
168 #[test]
169 fn anchor_pattern_empty_string() {
170 assert_eq!(anchor_pattern(""), "^$");
171 }
172
173 #[test]
174 fn anchor_pattern_already_anchored() {
175 assert_eq!(anchor_pattern("^hello$"), "^hello$");
176 }
177
178 #[test]
179 fn anchor_pattern_needs_start_anchor() {
180 assert_eq!(anchor_pattern("hello$"), "^hello$");
181 }
182
183 #[test]
184 fn anchor_pattern_needs_end_anchor() {
185 assert_eq!(anchor_pattern("^hello"), "^hello$");
186 }
187
188 #[test]
189 fn anchor_pattern_needs_both_anchors() {
190 assert_eq!(anchor_pattern("hello"), "^hello$");
191 }
192}