Skip to main content

llama_cpp_bindings/openai/
chat_template_result_grammar.rs

1use std::collections::HashSet;
2
3use crate::model::{AddBos, ChatTemplateResult, GrammarTriggerType, LlamaModel};
4use crate::sampling::LlamaSampler;
5use crate::token::LlamaToken;
6
7use super::grammar_sampler_error::GrammarSamplerError;
8
9fn regex_escape(value: &str) -> String {
10    let mut escaped = String::with_capacity(value.len());
11
12    for character in value.chars() {
13        match character {
14            '.' | '^' | '$' | '|' | '(' | ')' | '*' | '+' | '?' | '[' | ']' | '{' | '}' | '\\' => {
15                escaped.push('\\');
16                escaped.push(character);
17            }
18            _ => escaped.push(character),
19        }
20    }
21
22    escaped
23}
24
25fn anchor_pattern(pattern: &str) -> String {
26    if pattern.is_empty() {
27        return "^$".to_string();
28    }
29
30    let mut anchored = String::new();
31
32    if !pattern.starts_with('^') {
33        anchored.push('^');
34    }
35
36    anchored.push_str(pattern);
37
38    if !pattern.ends_with('$') {
39        anchored.push('$');
40    }
41
42    anchored
43}
44
45impl ChatTemplateResult {
46    /// Builds a grammar sampler from this template result's grammar and trigger configuration.
47    ///
48    /// Returns `None` if no grammar is present. The returned `HashSet` contains preserved
49    /// token IDs that should be decoded with special token handling.
50    ///
51    /// # Errors
52    /// Returns an error if trigger processing or grammar sampler initialization fails.
53    pub fn build_grammar_sampler(
54        &self,
55        model: &LlamaModel,
56    ) -> Result<(Option<LlamaSampler>, HashSet<LlamaToken>), GrammarSamplerError> {
57        let mut preserved = HashSet::new();
58
59        for token_str in &self.preserved_tokens {
60            let tokens = model
61                .str_to_token(token_str, AddBos::Never)
62                .map_err(|error| GrammarSamplerError::TokenizationFailed(error.to_string()))?;
63
64            if tokens.len() == 1 {
65                preserved.insert(tokens[0]);
66            }
67        }
68
69        let Some(grammar) = self.grammar.as_deref() else {
70            return Ok((None, preserved));
71        };
72
73        let grammar_sampler = if self.grammar_lazy {
74            if self.grammar_triggers.is_empty() {
75                return Err(GrammarSamplerError::MissingTriggers);
76            }
77
78            let mut trigger_patterns = Vec::new();
79            let mut trigger_tokens = Vec::new();
80
81            for trigger in &self.grammar_triggers {
82                match trigger.trigger_type {
83                    GrammarTriggerType::Token => {
84                        if let Some(token) = trigger.token {
85                            trigger_tokens.push(token);
86                        }
87                    }
88                    GrammarTriggerType::Word => {
89                        let tokens =
90                            model
91                                .str_to_token(&trigger.value, AddBos::Never)
92                                .map_err(|error| {
93                                    GrammarSamplerError::TokenizationFailed(error.to_string())
94                                })?;
95
96                        if tokens.len() == 1 {
97                            if !preserved.contains(&tokens[0]) {
98                                return Err(GrammarSamplerError::TriggerWordNotPreserved(
99                                    trigger.value.clone(),
100                                ));
101                            }
102                            trigger_tokens.push(tokens[0]);
103                        } else {
104                            trigger_patterns.push(regex_escape(&trigger.value));
105                        }
106                    }
107                    GrammarTriggerType::Pattern => {
108                        trigger_patterns.push(trigger.value.clone());
109                    }
110                    GrammarTriggerType::PatternFull => {
111                        trigger_patterns.push(anchor_pattern(&trigger.value));
112                    }
113                }
114            }
115
116            LlamaSampler::grammar_lazy_patterns(
117                model,
118                grammar,
119                "root",
120                &trigger_patterns,
121                &trigger_tokens,
122            )?
123        } else {
124            LlamaSampler::grammar(model, grammar, "root")?
125        };
126
127        Ok((Some(grammar_sampler), preserved))
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::{anchor_pattern, regex_escape};
134
135    #[test]
136    fn regex_escape_special_characters() {
137        assert_eq!(regex_escape("."), "\\.");
138        assert_eq!(regex_escape("^"), "\\^");
139        assert_eq!(regex_escape("$"), "\\$");
140        assert_eq!(regex_escape("|"), "\\|");
141        assert_eq!(regex_escape("("), "\\(");
142        assert_eq!(regex_escape(")"), "\\)");
143        assert_eq!(regex_escape("*"), "\\*");
144        assert_eq!(regex_escape("+"), "\\+");
145        assert_eq!(regex_escape("?"), "\\?");
146        assert_eq!(regex_escape("["), "\\[");
147        assert_eq!(regex_escape("]"), "\\]");
148        assert_eq!(regex_escape("{"), "\\{");
149        assert_eq!(regex_escape("}"), "\\}");
150        assert_eq!(regex_escape("\\"), "\\\\");
151    }
152
153    #[test]
154    fn regex_escape_normal_text() {
155        assert_eq!(regex_escape("hello world"), "hello world");
156    }
157
158    #[test]
159    fn regex_escape_empty_string() {
160        assert_eq!(regex_escape(""), "");
161    }
162
163    #[test]
164    fn regex_escape_mixed_text() {
165        assert_eq!(regex_escape("price: $5.00"), "price: \\$5\\.00");
166    }
167
168    #[test]
169    fn anchor_pattern_empty_string() {
170        assert_eq!(anchor_pattern(""), "^$");
171    }
172
173    #[test]
174    fn anchor_pattern_already_anchored() {
175        assert_eq!(anchor_pattern("^hello$"), "^hello$");
176    }
177
178    #[test]
179    fn anchor_pattern_needs_start_anchor() {
180        assert_eq!(anchor_pattern("hello$"), "^hello$");
181    }
182
183    #[test]
184    fn anchor_pattern_needs_end_anchor() {
185        assert_eq!(anchor_pattern("^hello"), "^hello$");
186    }
187
188    #[test]
189    fn anchor_pattern_needs_both_anchors() {
190        assert_eq!(anchor_pattern("hello"), "^hello$");
191    }
192}