Skip to main content

llama_cpp_bindings/openai/
chat_template_result_grammar.rs

1use std::collections::HashSet;
2
3use crate::model::{AddBos, ChatTemplateResult, GrammarTriggerType, LlamaModel};
4use crate::sampling::LlamaSampler;
5use crate::token::LlamaToken;
6
7use super::grammar_sampler_error::GrammarSamplerError;
8
9fn regex_escape(value: &str) -> String {
10    let mut escaped = String::with_capacity(value.len());
11
12    for character in value.chars() {
13        match character {
14            '.' | '^' | '$' | '|' | '(' | ')' | '*' | '+' | '?' | '[' | ']' | '{' | '}' | '\\' => {
15                escaped.push('\\');
16                escaped.push(character);
17            }
18            _ => escaped.push(character),
19        }
20    }
21
22    escaped
23}
24
25fn anchor_pattern(pattern: &str) -> String {
26    if pattern.is_empty() {
27        return "^$".to_string();
28    }
29
30    let mut anchored = String::new();
31
32    if !pattern.starts_with('^') {
33        anchored.push('^');
34    }
35
36    anchored.push_str(pattern);
37
38    if !pattern.ends_with('$') {
39        anchored.push('$');
40    }
41
42    anchored
43}
44
45impl ChatTemplateResult {
46    /// Builds a grammar sampler from this template result's grammar and trigger configuration.
47    ///
48    /// Returns `None` if no grammar is present. The returned `HashSet` contains preserved
49    /// token IDs that should be decoded with special token handling.
50    ///
51    /// # Errors
52    /// Returns an error if trigger processing or grammar sampler initialization fails.
53    pub fn build_grammar_sampler(
54        &self,
55        model: &LlamaModel,
56    ) -> Result<(Option<LlamaSampler>, HashSet<LlamaToken>), GrammarSamplerError> {
57        let mut preserved = HashSet::new();
58
59        for token_str in &self.preserved_tokens {
60            let tokens = model
61                .str_to_token(token_str, AddBos::Never)
62                .map_err(|error| GrammarSamplerError::TokenizationFailed(error.to_string()))?;
63
64            if tokens.len() == 1 {
65                preserved.insert(tokens[0]);
66            }
67        }
68
69        let Some(grammar) = self.grammar.as_deref() else {
70            return Ok((None, preserved));
71        };
72
73        let grammar_sampler = if self.grammar_lazy {
74            if self.grammar_triggers.is_empty() {
75                return Err(GrammarSamplerError::MissingTriggers);
76            }
77
78            let mut trigger_patterns = Vec::new();
79            let mut trigger_tokens = Vec::new();
80
81            for trigger in &self.grammar_triggers {
82                match trigger.trigger_type {
83                    GrammarTriggerType::Token => {
84                        if let Some(token) = trigger.token {
85                            trigger_tokens.push(token);
86                        }
87                    }
88                    GrammarTriggerType::Word => {
89                        let tokens =
90                            model
91                                .str_to_token(&trigger.value, AddBos::Never)
92                                .map_err(|error| {
93                                    GrammarSamplerError::TokenizationFailed(error.to_string())
94                                })?;
95
96                        if tokens.len() == 1 {
97                            if !preserved.contains(&tokens[0]) {
98                                return Err(GrammarSamplerError::TriggerWordNotPreserved(
99                                    trigger.value.clone(),
100                                ));
101                            }
102                            trigger_tokens.push(tokens[0]);
103                        } else {
104                            trigger_patterns.push(regex_escape(&trigger.value));
105                        }
106                    }
107                    GrammarTriggerType::Pattern => {
108                        trigger_patterns.push(trigger.value.clone());
109                    }
110                    GrammarTriggerType::PatternFull => {
111                        trigger_patterns.push(anchor_pattern(&trigger.value));
112                    }
113                }
114            }
115
116            LlamaSampler::grammar_lazy_patterns(
117                model,
118                grammar,
119                "root",
120                &trigger_patterns,
121                &trigger_tokens,
122            )?
123        } else {
124            LlamaSampler::grammar(model, grammar, "root")?
125        };
126
127        Ok((Some(grammar_sampler), preserved))
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::{anchor_pattern, regex_escape};
134
135    #[test]
136    fn regex_escape_special_characters() {
137        assert_eq!(regex_escape("."), "\\.");
138        assert_eq!(regex_escape("^"), "\\^");
139        assert_eq!(regex_escape("$"), "\\$");
140        assert_eq!(regex_escape("|"), "\\|");
141        assert_eq!(regex_escape("("), "\\(");
142        assert_eq!(regex_escape(")"), "\\)");
143        assert_eq!(regex_escape("*"), "\\*");
144        assert_eq!(regex_escape("+"), "\\+");
145        assert_eq!(regex_escape("?"), "\\?");
146        assert_eq!(regex_escape("["), "\\[");
147        assert_eq!(regex_escape("]"), "\\]");
148        assert_eq!(regex_escape("{"), "\\{");
149        assert_eq!(regex_escape("}"), "\\}");
150        assert_eq!(regex_escape("\\"), "\\\\");
151    }
152
153    #[test]
154    fn regex_escape_normal_text() {
155        assert_eq!(regex_escape("hello world"), "hello world");
156    }
157
158    #[test]
159    fn regex_escape_empty_string() {
160        assert_eq!(regex_escape(""), "");
161    }
162
163    #[test]
164    fn regex_escape_mixed_text() {
165        assert_eq!(regex_escape("price: $5.00"), "price: \\$5\\.00");
166    }
167
168    #[test]
169    fn anchor_pattern_empty_string() {
170        assert_eq!(anchor_pattern(""), "^$");
171    }
172
173    #[test]
174    fn anchor_pattern_already_anchored() {
175        assert_eq!(anchor_pattern("^hello$"), "^hello$");
176    }
177
178    #[test]
179    fn anchor_pattern_needs_start_anchor() {
180        assert_eq!(anchor_pattern("hello$"), "^hello$");
181    }
182
183    #[test]
184    fn anchor_pattern_needs_end_anchor() {
185        assert_eq!(anchor_pattern("^hello"), "^hello$");
186    }
187
188    #[test]
189    fn anchor_pattern_needs_both_anchors() {
190        assert_eq!(anchor_pattern("hello"), "^hello$");
191    }
192
193    #[cfg(feature = "tests_that_use_llms")]
194    mod model_tests {
195        use serial_test::serial;
196
197        use crate::model::chat_template_result::ChatTemplateResult;
198        use crate::model::grammar_trigger::{GrammarTrigger, GrammarTriggerType};
199        use crate::test_model;
200        use crate::token::LlamaToken;
201
202        #[test]
203        #[serial]
204        fn build_grammar_sampler_returns_none_without_grammar() {
205            let (_backend, model) = test_model::load_default_model().unwrap();
206            let result = ChatTemplateResult::default();
207            let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
208
209            assert!(sampler.is_none());
210            assert!(preserved.is_empty());
211        }
212
213        #[test]
214        #[serial]
215        fn build_grammar_sampler_returns_sampler_with_non_lazy_grammar() {
216            let (_backend, model) = test_model::load_default_model().unwrap();
217            let result = ChatTemplateResult {
218                grammar: Some("root ::= \"hello\"".to_string()),
219                ..Default::default()
220            };
221            let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
222
223            assert!(sampler.is_some());
224        }
225
226        #[test]
227        #[serial]
228        fn build_grammar_sampler_lazy_without_triggers_returns_error() {
229            let (_backend, model) = test_model::load_default_model().unwrap();
230            let result = ChatTemplateResult {
231                grammar: Some("root ::= \"hello\"".to_string()),
232                grammar_lazy: true,
233                ..Default::default()
234            };
235            let build_result = result.build_grammar_sampler(&model);
236
237            assert!(build_result.is_err());
238        }
239
240        #[test]
241        #[serial]
242        fn build_grammar_sampler_lazy_with_word_trigger_multi_token() {
243            let (_backend, model) = test_model::load_default_model().unwrap();
244            let result = ChatTemplateResult {
245                grammar: Some("root ::= \"hello\"".to_string()),
246                grammar_lazy: true,
247                grammar_triggers: vec![GrammarTrigger {
248                    trigger_type: GrammarTriggerType::Word,
249                    value: "function_call".to_string(),
250                    token: None,
251                }],
252                ..Default::default()
253            };
254            let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
255
256            assert!(sampler.is_some());
257        }
258
259        #[test]
260        #[serial]
261        fn build_grammar_sampler_lazy_with_pattern_trigger() {
262            let (_backend, model) = test_model::load_default_model().unwrap();
263            let result = ChatTemplateResult {
264                grammar: Some("root ::= \"hello\"".to_string()),
265                grammar_lazy: true,
266                grammar_triggers: vec![GrammarTrigger {
267                    trigger_type: GrammarTriggerType::Pattern,
268                    value: "\\{.*".to_string(),
269                    token: None,
270                }],
271                ..Default::default()
272            };
273            let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
274
275            assert!(sampler.is_some());
276        }
277
278        #[test]
279        #[serial]
280        fn build_grammar_sampler_lazy_with_token_trigger() {
281            let (_backend, model) = test_model::load_default_model().unwrap();
282            let result = ChatTemplateResult {
283                grammar: Some("root ::= \"hello\"".to_string()),
284                grammar_lazy: true,
285                grammar_triggers: vec![GrammarTrigger {
286                    trigger_type: GrammarTriggerType::Token,
287                    value: "tool".to_string(),
288                    token: Some(LlamaToken::new(1)),
289                }],
290                ..Default::default()
291            };
292            let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
293
294            assert!(sampler.is_some());
295        }
296
297        #[test]
298        #[serial]
299        fn build_grammar_sampler_lazy_with_pattern_full_trigger() {
300            let (_backend, model) = test_model::load_default_model().unwrap();
301            let result = ChatTemplateResult {
302                grammar: Some("root ::= \"hello\"".to_string()),
303                grammar_lazy: true,
304                grammar_triggers: vec![GrammarTrigger {
305                    trigger_type: GrammarTriggerType::PatternFull,
306                    value: "^tool_call$".to_string(),
307                    token: None,
308                }],
309                ..Default::default()
310            };
311            let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
312
313            assert!(sampler.is_some());
314        }
315
316        #[test]
317        #[serial]
318        fn build_grammar_sampler_with_preserved_tokens() {
319            let (_backend, model) = test_model::load_default_model().unwrap();
320            let result = ChatTemplateResult {
321                preserved_tokens: vec!["hello".to_string()],
322                ..Default::default()
323            };
324            let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
325
326            assert!(sampler.is_none());
327            assert!(!preserved.is_empty());
328        }
329
330        #[test]
331        #[serial]
332        fn build_grammar_sampler_lazy_word_trigger_single_token_not_preserved_returns_error() {
333            let (_backend, model) = test_model::load_default_model().unwrap();
334            let result = ChatTemplateResult {
335                grammar: Some("root ::= \"hello\"".to_string()),
336                grammar_lazy: true,
337                grammar_triggers: vec![GrammarTrigger {
338                    trigger_type: GrammarTriggerType::Word,
339                    value: "\n".to_string(),
340                    token: None,
341                }],
342                ..Default::default()
343            };
344            let build_result = result.build_grammar_sampler(&model);
345
346            assert!(build_result.is_err());
347        }
348
349        #[test]
350        #[serial]
351        fn build_grammar_sampler_lazy_word_trigger_single_token_preserved() {
352            let (_backend, model) = test_model::load_default_model().unwrap();
353            let result = ChatTemplateResult {
354                grammar: Some("root ::= \"hello\"".to_string()),
355                grammar_lazy: true,
356                preserved_tokens: vec!["\n".to_string()],
357                grammar_triggers: vec![GrammarTrigger {
358                    trigger_type: GrammarTriggerType::Word,
359                    value: "\n".to_string(),
360                    token: None,
361                }],
362                ..Default::default()
363            };
364            let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
365
366            assert!(sampler.is_some());
367            assert!(!preserved.is_empty());
368        }
369
370        #[test]
371        #[serial]
372        fn build_grammar_sampler_lazy_word_trigger_with_null_byte_returns_error() {
373            let (_backend, model) = test_model::load_default_model().unwrap();
374            let result = ChatTemplateResult {
375                grammar: Some("root ::= \"hello\"".to_string()),
376                grammar_lazy: true,
377                grammar_triggers: vec![GrammarTrigger {
378                    trigger_type: GrammarTriggerType::Word,
379                    value: "null\0byte".to_string(),
380                    token: None,
381                }],
382                ..Default::default()
383            };
384            let build_result = result.build_grammar_sampler(&model);
385
386            assert!(build_result.is_err());
387        }
388
389        #[test]
390        #[serial]
391        fn build_grammar_sampler_lazy_invalid_grammar_returns_error() {
392            let (_backend, model) = test_model::load_default_model().unwrap();
393            let result = ChatTemplateResult {
394                grammar: Some("this is not a valid grammar at all!!!".to_string()),
395                grammar_lazy: true,
396                grammar_triggers: vec![GrammarTrigger {
397                    trigger_type: GrammarTriggerType::Pattern,
398                    value: ".*".to_string(),
399                    token: None,
400                }],
401                ..Default::default()
402            };
403            let build_result = result.build_grammar_sampler(&model);
404
405            assert!(build_result.is_err());
406        }
407
408        #[test]
409        #[serial]
410        fn build_grammar_sampler_lazy_with_token_trigger_without_token_value() {
411            let (_backend, model) = test_model::load_default_model().unwrap();
412            let result = ChatTemplateResult {
413                grammar: Some("root ::= \"hello\"".to_string()),
414                grammar_lazy: true,
415                grammar_triggers: vec![
416                    GrammarTrigger {
417                        trigger_type: GrammarTriggerType::Token,
418                        value: "tool".to_string(),
419                        token: None,
420                    },
421                    GrammarTrigger {
422                        trigger_type: GrammarTriggerType::Pattern,
423                        value: ".*".to_string(),
424                        token: None,
425                    },
426                ],
427                ..Default::default()
428            };
429            let (sampler, _preserved) = result.build_grammar_sampler(&model).unwrap();
430
431            assert!(sampler.is_some());
432        }
433
434        #[test]
435        #[serial]
436        fn build_grammar_sampler_with_multi_token_preserved_tokens() {
437            let (_backend, model) = test_model::load_default_model().unwrap();
438            let result = ChatTemplateResult {
439                preserved_tokens: vec!["hello world this is a long sentence".to_string()],
440                ..Default::default()
441            };
442            let (sampler, preserved) = result.build_grammar_sampler(&model).unwrap();
443
444            assert!(sampler.is_none());
445            // Multi-token strings are skipped, so preserved set may be empty
446            assert!(preserved.is_empty());
447        }
448
449        #[test]
450        #[serial]
451        fn build_grammar_sampler_preserved_token_with_null_byte_returns_error() {
452            let (_backend, model) = test_model::load_default_model().unwrap();
453            let result = ChatTemplateResult {
454                preserved_tokens: vec!["null\0byte".to_string()],
455                grammar: Some("root ::= \"hello\"".to_string()),
456                ..ChatTemplateResult::default()
457            };
458
459            let build_result = result.build_grammar_sampler(&model);
460
461            assert!(build_result.is_err());
462        }
463
464        #[test]
465        #[serial]
466        fn build_grammar_sampler_invalid_grammar_returns_error() {
467            let (_backend, model) = test_model::load_default_model().unwrap();
468            let result = ChatTemplateResult {
469                grammar: Some("this is not valid gbnf".to_string()),
470                ..ChatTemplateResult::default()
471            };
472
473            let build_result = result.build_grammar_sampler(&model);
474
475            assert!(build_result.is_err());
476        }
477    }
478}