Skip to main content

llama_cpp_bindings/model/
chat_template_result.rs

1use std::ffi::{CStr, CString, c_char};
2use std::ptr::{self, NonNull};
3use std::slice;
4
5use crate::model::grammar_trigger::{GrammarTrigger, GrammarTriggerType};
6use crate::openai::ChatParseStateOaicompat;
7use crate::token::LlamaToken;
8use crate::{ApplyChatTemplateError, ChatParseError, status_is_ok, status_to_i32};
9
10/// Result of applying a chat template with tool grammar support.
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub struct ChatTemplateResult {
13    /// Rendered chat prompt.
14    pub prompt: String,
15    /// Optional grammar generated from tool definitions.
16    pub grammar: Option<String>,
17    /// Whether to use lazy grammar sampling.
18    pub grammar_lazy: bool,
19    /// Lazy grammar triggers derived from the template.
20    pub grammar_triggers: Vec<GrammarTrigger>,
21    /// Tokens that should be preserved for sampling.
22    pub preserved_tokens: Vec<String>,
23    /// Additional stop sequences added by the template.
24    pub additional_stops: Vec<String>,
25    /// Chat format used for parsing responses.
26    pub chat_format: i32,
27    /// Optional serialized PEG parser for tool-call parsing.
28    pub parser: Option<String>,
29    /// Whether the parser expects a forced-open thinking block.
30    pub thinking_forced_open: bool,
31    /// Whether tool calls should be parsed from the response.
32    pub parse_tool_calls: bool,
33}
34
35pub fn new_empty_chat_template_raw_result() -> llama_cpp_bindings_sys::llama_rs_chat_template_result
36{
37    llama_cpp_bindings_sys::llama_rs_chat_template_result {
38        prompt: ptr::null_mut(),
39        grammar: ptr::null_mut(),
40        parser: ptr::null_mut(),
41        chat_format: 0,
42        thinking_forced_open: false,
43        grammar_lazy: false,
44        grammar_triggers: ptr::null_mut(),
45        grammar_triggers_count: 0,
46        preserved_tokens: ptr::null_mut(),
47        preserved_tokens_count: 0,
48        additional_stops: ptr::null_mut(),
49        additional_stops_count: 0,
50    }
51}
52
53/// # Safety
54///
55/// `raw_cstr_array` must point to `count` valid, null-terminated C strings.
56unsafe fn parse_raw_cstr_array(
57    raw_cstr_array: *const *mut c_char,
58    count: usize,
59) -> Result<Vec<String>, ApplyChatTemplateError> {
60    if count == 0 {
61        return Ok(Vec::new());
62    }
63
64    if raw_cstr_array.is_null() {
65        return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
66    }
67
68    let raw_entries = unsafe { slice::from_raw_parts(raw_cstr_array, count) };
69    let mut parsed = Vec::with_capacity(raw_entries.len());
70
71    for entry in raw_entries {
72        if entry.is_null() {
73            return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
74        }
75        let bytes = unsafe { CStr::from_ptr(*entry) }.to_bytes().to_vec();
76        parsed.push(String::from_utf8(bytes)?);
77    }
78
79    Ok(parsed)
80}
81
82/// # Safety
83///
84/// `raw_triggers` must point to `count` valid `llama_rs_grammar_trigger` structs.
85unsafe fn parse_raw_grammar_triggers(
86    raw_triggers: *const llama_cpp_bindings_sys::llama_rs_grammar_trigger,
87    count: usize,
88) -> Result<Vec<GrammarTrigger>, ApplyChatTemplateError> {
89    if count == 0 {
90        return Ok(Vec::new());
91    }
92
93    if raw_triggers.is_null() {
94        return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
95    }
96
97    let triggers = unsafe { slice::from_raw_parts(raw_triggers, count) };
98    let mut parsed = Vec::with_capacity(triggers.len());
99
100    for trigger in triggers {
101        let trigger_type = match trigger.type_ {
102            0 => GrammarTriggerType::Token,
103            1 => GrammarTriggerType::Word,
104            2 => GrammarTriggerType::Pattern,
105            3 => GrammarTriggerType::PatternFull,
106            _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
107        };
108        let value = if trigger.value.is_null() {
109            return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
110        } else {
111            let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
112            String::from_utf8(bytes)?
113        };
114        let token = if trigger_type == GrammarTriggerType::Token {
115            Some(LlamaToken(trigger.token))
116        } else {
117            None
118        };
119        parsed.push(GrammarTrigger {
120            trigger_type,
121            value,
122            token,
123        });
124    }
125
126    Ok(parsed)
127}
128
129/// # Safety
130///
131/// `raw_result` must point to a valid, initialized `llama_rs_chat_template_result`.
132pub unsafe fn parse_chat_template_raw_result(
133    ffi_return_code: llama_cpp_bindings_sys::llama_rs_status,
134    raw_result: *mut llama_cpp_bindings_sys::llama_rs_chat_template_result,
135    parse_tool_calls: bool,
136) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
137    let result = (|| {
138        if !status_is_ok(ffi_return_code) {
139            return Err(ApplyChatTemplateError::FfiError(status_to_i32(
140                ffi_return_code,
141            )));
142        }
143
144        let raw = unsafe { &*raw_result };
145
146        if raw.prompt.is_null() {
147            return Err(ApplyChatTemplateError::NullResult);
148        }
149
150        let prompt_bytes = unsafe { CStr::from_ptr(raw.prompt) }.to_bytes().to_vec();
151        let prompt = String::from_utf8(prompt_bytes)?;
152
153        let grammar = if raw.grammar.is_null() {
154            None
155        } else {
156            let grammar_bytes = unsafe { CStr::from_ptr(raw.grammar) }.to_bytes().to_vec();
157            Some(String::from_utf8(grammar_bytes)?)
158        };
159
160        let parser = if raw.parser.is_null() {
161            None
162        } else {
163            let parser_bytes = unsafe { CStr::from_ptr(raw.parser) }.to_bytes().to_vec();
164            Some(String::from_utf8(parser_bytes)?)
165        };
166
167        let grammar_triggers = unsafe {
168            parse_raw_grammar_triggers(raw.grammar_triggers, raw.grammar_triggers_count)
169        }?;
170
171        let preserved_tokens =
172            unsafe { parse_raw_cstr_array(raw.preserved_tokens, raw.preserved_tokens_count) }?;
173
174        let additional_stops =
175            unsafe { parse_raw_cstr_array(raw.additional_stops, raw.additional_stops_count) }?;
176
177        Ok(ChatTemplateResult {
178            prompt,
179            grammar,
180            grammar_lazy: raw.grammar_lazy,
181            grammar_triggers,
182            preserved_tokens,
183            additional_stops,
184            chat_format: raw.chat_format,
185            parser,
186            thinking_forced_open: raw.thinking_forced_open,
187            parse_tool_calls,
188        })
189    })();
190
191    unsafe { llama_cpp_bindings_sys::llama_rs_chat_template_result_free(raw_result) };
192
193    result
194}
195
196impl ChatTemplateResult {
197    /// Parse a generated response into an OpenAI-compatible message JSON string.
198    ///
199    /// # Errors
200    /// Returns an error if the FFI call fails or the result is null.
201    pub fn parse_response_oaicompat(
202        &self,
203        text: &str,
204        is_partial: bool,
205    ) -> Result<String, ChatParseError> {
206        let text_cstr = CString::new(text)?;
207        let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
208        let mut out_json: *mut c_char = ptr::null_mut();
209        let rc = unsafe {
210            llama_cpp_bindings_sys::llama_rs_chat_parse_to_oaicompat(
211                text_cstr.as_ptr(),
212                is_partial,
213                self.chat_format,
214                self.parse_tool_calls,
215                parser_cstr
216                    .as_ref()
217                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
218                self.thinking_forced_open,
219                &raw mut out_json,
220            )
221        };
222
223        let result = (|| {
224            if !status_is_ok(rc) {
225                return Err(ChatParseError::FfiError(status_to_i32(rc)));
226            }
227            if out_json.is_null() {
228                return Err(ChatParseError::NullResult);
229            }
230            let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
231            Ok(String::from_utf8(bytes)?)
232        })();
233
234        unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_json) };
235
236        result
237    }
238
239    /// Initialize a streaming parser for OpenAI-compatible chat deltas.
240    ///
241    /// # Errors
242    /// Returns an error if the parser state cannot be initialized.
243    pub fn streaming_state_oaicompat(&self) -> Result<ChatParseStateOaicompat, ChatParseError> {
244        let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
245        let state = unsafe {
246            llama_cpp_bindings_sys::llama_rs_chat_parse_state_init_oaicompat(
247                self.chat_format,
248                self.parse_tool_calls,
249                parser_cstr
250                    .as_ref()
251                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
252                self.thinking_forced_open,
253            )
254        };
255        let state = NonNull::new(state).ok_or(ChatParseError::NullResult)?;
256
257        Ok(ChatParseStateOaicompat { state })
258    }
259}