llama_cpp_bindings/openai/
chat_template_result_grammar.rs1use std::collections::HashSet;
2
3use crate::model::{AddBos, ChatTemplateResult, GrammarTriggerType, LlamaModel};
4use crate::sampling::LlamaSampler;
5use crate::token::LlamaToken;
6
7use super::grammar_sampler_error::GrammarSamplerError;
8
9fn regex_escape(value: &str) -> String {
10 let mut escaped = String::with_capacity(value.len());
11
12 for character in value.chars() {
13 match character {
14 '.' | '^' | '$' | '|' | '(' | ')' | '*' | '+' | '?' | '[' | ']' | '{' | '}' | '\\' => {
15 escaped.push('\\');
16 escaped.push(character);
17 }
18 _ => escaped.push(character),
19 }
20 }
21
22 escaped
23}
24
25fn anchor_pattern(pattern: &str) -> String {
26 if pattern.is_empty() {
27 return "^$".to_string();
28 }
29
30 let mut anchored = String::new();
31
32 if !pattern.starts_with('^') {
33 anchored.push('^');
34 }
35
36 anchored.push_str(pattern);
37
38 if !pattern.ends_with('$') {
39 anchored.push('$');
40 }
41
42 anchored
43}
44
45impl ChatTemplateResult {
46 pub fn build_grammar_sampler(
54 &self,
55 model: &LlamaModel,
56 ) -> Result<(Option<LlamaSampler>, HashSet<LlamaToken>), GrammarSamplerError> {
57 let mut preserved = HashSet::new();
58
59 for token_str in &self.preserved_tokens {
60 let tokens = model
61 .str_to_token(token_str, AddBos::Never)
62 .map_err(|error| GrammarSamplerError::TokenizationFailed(error.to_string()))?;
63
64 if tokens.len() == 1 {
65 preserved.insert(tokens[0]);
66 }
67 }
68
69 let Some(grammar) = self.grammar.as_deref() else {
70 return Ok((None, preserved));
71 };
72
73 let grammar_sampler = if self.grammar_lazy {
74 if self.grammar_triggers.is_empty() {
75 return Err(GrammarSamplerError::MissingTriggers);
76 }
77
78 let mut trigger_patterns = Vec::new();
79 let mut trigger_tokens = Vec::new();
80
81 for trigger in &self.grammar_triggers {
82 match trigger.trigger_type {
83 GrammarTriggerType::Token => {
84 if let Some(token) = trigger.token {
85 trigger_tokens.push(token);
86 }
87 }
88 GrammarTriggerType::Word => {
89 let tokens =
90 model
91 .str_to_token(&trigger.value, AddBos::Never)
92 .map_err(|error| {
93 GrammarSamplerError::TokenizationFailed(error.to_string())
94 })?;
95
96 if tokens.len() == 1 {
97 if !preserved.contains(&tokens[0]) {
98 return Err(GrammarSamplerError::TriggerWordNotPreserved(
99 trigger.value.clone(),
100 ));
101 }
102 trigger_tokens.push(tokens[0]);
103 } else {
104 trigger_patterns.push(regex_escape(&trigger.value));
105 }
106 }
107 GrammarTriggerType::Pattern => {
108 trigger_patterns.push(trigger.value.clone());
109 }
110 GrammarTriggerType::PatternFull => {
111 trigger_patterns.push(anchor_pattern(&trigger.value));
112 }
113 }
114 }
115
116 LlamaSampler::grammar_lazy_patterns(
117 model,
118 grammar,
119 "root",
120 &trigger_patterns,
121 &trigger_tokens,
122 )?
123 } else {
124 LlamaSampler::grammar(model, grammar, "root")?
125 };
126
127 Ok((Some(grammar_sampler), preserved))
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::{anchor_pattern, regex_escape};
134
135 #[test]
136 fn regex_escape_special_characters() {
137 assert_eq!(regex_escape("."), "\\.");
138 assert_eq!(regex_escape("^"), "\\^");
139 assert_eq!(regex_escape("$"), "\\$");
140 assert_eq!(regex_escape("|"), "\\|");
141 assert_eq!(regex_escape("("), "\\(");
142 assert_eq!(regex_escape(")"), "\\)");
143 assert_eq!(regex_escape("*"), "\\*");
144 assert_eq!(regex_escape("+"), "\\+");
145 assert_eq!(regex_escape("?"), "\\?");
146 assert_eq!(regex_escape("["), "\\[");
147 assert_eq!(regex_escape("]"), "\\]");
148 assert_eq!(regex_escape("{"), "\\{");
149 assert_eq!(regex_escape("}"), "\\}");
150 assert_eq!(regex_escape("\\"), "\\\\");
151 }
152
153 #[test]
154 fn regex_escape_normal_text() {
155 assert_eq!(regex_escape("hello world"), "hello world");
156 }
157
158 #[test]
159 fn regex_escape_empty_string() {
160 assert_eq!(regex_escape(""), "");
161 }
162
163 #[test]
164 fn regex_escape_mixed_text() {
165 assert_eq!(regex_escape("price: $5.00"), "price: \\$5\\.00");
166 }
167
168 #[test]
169 fn anchor_pattern_empty_string() {
170 assert_eq!(anchor_pattern(""), "^$");
171 }
172
173 #[test]
174 fn anchor_pattern_already_anchored() {
175 assert_eq!(anchor_pattern("^hello$"), "^hello$");
176 }
177
178 #[test]
179 fn anchor_pattern_needs_start_anchor() {
180 assert_eq!(anchor_pattern("hello$"), "^hello$");
181 }
182
183 #[test]
184 fn anchor_pattern_needs_end_anchor() {
185 assert_eq!(anchor_pattern("^hello"), "^hello$");
186 }
187
188 #[test]
189 fn anchor_pattern_needs_both_anchors() {
190 assert_eq!(anchor_pattern("hello"), "^hello$");
191 }
192
193 #[cfg(feature = "tests_that_use_llms")]
194 mod model_tests {
195 use serial_test::serial;
196
197 use crate::model::chat_template_result::ChatTemplateResult;
198 use crate::model::grammar_trigger::{GrammarTrigger, GrammarTriggerType};
199 use crate::test_model;
200 use crate::token::LlamaToken;
201
202 #[test]
203 #[serial]
204 fn build_grammar_sampler_returns_none_without_grammar() {
205 let (_backend, model) = test_model::load_default_model().unwrap();
206 let result = ChatTemplateResult::default();
207 let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
208
209 assert!(sampler.is_none());
210 assert!(preserved.is_empty());
211 }
212
213 #[test]
214 #[serial]
215 fn build_grammar_sampler_returns_sampler_with_non_lazy_grammar() {
216 let (_backend, model) = test_model::load_default_model().unwrap();
217 let result = ChatTemplateResult {
218 grammar: Some("root ::= \"hello\"".to_string()),
219 ..Default::default()
220 };
221 let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
222
223 assert!(sampler.is_some());
224 }
225
226 #[test]
227 #[serial]
228 fn build_grammar_sampler_lazy_without_triggers_returns_error() {
229 let (_backend, model) = test_model::load_default_model().unwrap();
230 let result = ChatTemplateResult {
231 grammar: Some("root ::= \"hello\"".to_string()),
232 grammar_lazy: true,
233 ..Default::default()
234 };
235 let build_result = result.build_grammar_sampler(&model);
236
237 assert!(build_result.is_err());
238 }
239
240 #[test]
241 #[serial]
242 fn build_grammar_sampler_lazy_with_word_trigger_multi_token() {
243 let (_backend, model) = test_model::load_default_model().unwrap();
244 let result = ChatTemplateResult {
245 grammar: Some("root ::= \"hello\"".to_string()),
246 grammar_lazy: true,
247 grammar_triggers: vec![GrammarTrigger {
248 trigger_type: GrammarTriggerType::Word,
249 value: "function_call".to_string(),
250 token: None,
251 }],
252 ..Default::default()
253 };
254 let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
255
256 assert!(sampler.is_some());
257 }
258
259 #[test]
260 #[serial]
261 fn build_grammar_sampler_lazy_with_pattern_trigger() {
262 let (_backend, model) = test_model::load_default_model().unwrap();
263 let result = ChatTemplateResult {
264 grammar: Some("root ::= \"hello\"".to_string()),
265 grammar_lazy: true,
266 grammar_triggers: vec![GrammarTrigger {
267 trigger_type: GrammarTriggerType::Pattern,
268 value: "\\{.*".to_string(),
269 token: None,
270 }],
271 ..Default::default()
272 };
273 let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
274
275 assert!(sampler.is_some());
276 }
277
278 #[test]
279 #[serial]
280 fn build_grammar_sampler_lazy_with_token_trigger() {
281 let (_backend, model) = test_model::load_default_model().unwrap();
282 let result = ChatTemplateResult {
283 grammar: Some("root ::= \"hello\"".to_string()),
284 grammar_lazy: true,
285 grammar_triggers: vec![GrammarTrigger {
286 trigger_type: GrammarTriggerType::Token,
287 value: "tool".to_string(),
288 token: Some(LlamaToken::new(1)),
289 }],
290 ..Default::default()
291 };
292 let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
293
294 assert!(sampler.is_some());
295 }
296
297 #[test]
298 #[serial]
299 fn build_grammar_sampler_lazy_with_pattern_full_trigger() {
300 let (_backend, model) = test_model::load_default_model().unwrap();
301 let result = ChatTemplateResult {
302 grammar: Some("root ::= \"hello\"".to_string()),
303 grammar_lazy: true,
304 grammar_triggers: vec![GrammarTrigger {
305 trigger_type: GrammarTriggerType::PatternFull,
306 value: "^tool_call$".to_string(),
307 token: None,
308 }],
309 ..Default::default()
310 };
311 let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
312
313 assert!(sampler.is_some());
314 }
315
316 #[test]
317 #[serial]
318 fn build_grammar_sampler_with_preserved_tokens() {
319 let (_backend, model) = test_model::load_default_model().unwrap();
320 let result = ChatTemplateResult {
321 preserved_tokens: vec!["hello".to_string()],
322 ..Default::default()
323 };
324 let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
325
326 assert!(sampler.is_none());
327 assert!(!preserved.is_empty());
328 }
329
330 #[test]
331 #[serial]
332 fn build_grammar_sampler_lazy_word_trigger_single_token_not_preserved_returns_error() {
333 let (_backend, model) = test_model::load_default_model().unwrap();
334 let result = ChatTemplateResult {
335 grammar: Some("root ::= \"hello\"".to_string()),
336 grammar_lazy: true,
337 grammar_triggers: vec![GrammarTrigger {
338 trigger_type: GrammarTriggerType::Word,
339 value: "\n".to_string(),
340 token: None,
341 }],
342 ..Default::default()
343 };
344 let build_result = result.build_grammar_sampler(&model);
345
346 assert!(build_result.is_err());
347 }
348
349 #[test]
350 #[serial]
351 fn build_grammar_sampler_lazy_word_trigger_single_token_preserved() {
352 let (_backend, model) = test_model::load_default_model().unwrap();
353 let result = ChatTemplateResult {
354 grammar: Some("root ::= \"hello\"".to_string()),
355 grammar_lazy: true,
356 preserved_tokens: vec!["\n".to_string()],
357 grammar_triggers: vec![GrammarTrigger {
358 trigger_type: GrammarTriggerType::Word,
359 value: "\n".to_string(),
360 token: None,
361 }],
362 ..Default::default()
363 };
364 let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
365
366 assert!(sampler.is_some());
367 assert!(!preserved.is_empty());
368 }
369
370 #[test]
371 #[serial]
372 fn build_grammar_sampler_lazy_word_trigger_with_null_byte_returns_error() {
373 let (_backend, model) = test_model::load_default_model().unwrap();
374 let result = ChatTemplateResult {
375 grammar: Some("root ::= \"hello\"".to_string()),
376 grammar_lazy: true,
377 grammar_triggers: vec![GrammarTrigger {
378 trigger_type: GrammarTriggerType::Word,
379 value: "null\0byte".to_string(),
380 token: None,
381 }],
382 ..Default::default()
383 };
384 let build_result = result.build_grammar_sampler(&model);
385
386 assert!(build_result.is_err());
387 }
388
389 #[test]
390 #[serial]
391 fn build_grammar_sampler_lazy_invalid_grammar_returns_error() {
392 let (_backend, model) = test_model::load_default_model().unwrap();
393 let result = ChatTemplateResult {
394 grammar: Some("this is not a valid grammar at all!!!".to_string()),
395 grammar_lazy: true,
396 grammar_triggers: vec![GrammarTrigger {
397 trigger_type: GrammarTriggerType::Pattern,
398 value: ".*".to_string(),
399 token: None,
400 }],
401 ..Default::default()
402 };
403 let build_result = result.build_grammar_sampler(&model);
404
405 assert!(build_result.is_err());
406 }
407
408 #[test]
409 #[serial]
410 fn build_grammar_sampler_lazy_with_token_trigger_without_token_value() {
411 let (_backend, model) = test_model::load_default_model().unwrap();
412 let result = ChatTemplateResult {
413 grammar: Some("root ::= \"hello\"".to_string()),
414 grammar_lazy: true,
415 grammar_triggers: vec![
416 GrammarTrigger {
417 trigger_type: GrammarTriggerType::Token,
418 value: "tool".to_string(),
419 token: None,
420 },
421 GrammarTrigger {
422 trigger_type: GrammarTriggerType::Pattern,
423 value: ".*".to_string(),
424 token: None,
425 },
426 ],
427 ..Default::default()
428 };
429 let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
430
431 assert!(sampler.is_some());
432 }
433
434 #[test]
435 #[serial]
436 fn build_grammar_sampler_with_multi_token_preserved_tokens() {
437 let (_backend, model) = test_model::load_default_model().unwrap();
438 let result = ChatTemplateResult {
439 preserved_tokens: vec!["hello world this is a long sentence".to_string()],
440 ..Default::default()
441 };
442 let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
443
444 assert!(sampler.is_none());
445 assert!(preserved.is_empty());
447 }
448
449 #[test]
450 #[serial]
451 fn build_grammar_sampler_preserved_token_with_null_byte_returns_error() {
452 let (_backend, model) = test_model::load_default_model().unwrap();
453 let result = ChatTemplateResult {
454 preserved_tokens: vec!["null\0byte".to_string()],
455 grammar: Some("root ::= \"hello\"".to_string()),
456 ..ChatTemplateResult::default()
457 };
458
459 let build_result = result.build_grammar_sampler(&model);
460
461 assert!(build_result.is_err());
462 }
463
464 #[test]
465 #[serial]
466 fn build_grammar_sampler_invalid_grammar_returns_error() {
467 let (_backend, model) = test_model::load_default_model().unwrap();
468 let result = ChatTemplateResult {
469 grammar: Some("this is not valid gbnf".to_string()),
470 ..ChatTemplateResult::default()
471 };
472
473 let build_result = result.build_grammar_sampler(&model);
474
475 assert!(build_result.is_err());
476 }
477 }
478}