Skip to main content

llama_cpp_bindings/model/
chat_template_result.rs

1use std::ffi::{CStr, CString, c_char};
2use std::ptr::{self, NonNull};
3use std::slice;
4
5use crate::model::grammar_trigger::{GrammarTrigger, GrammarTriggerType};
6use crate::openai::ChatParseStateOaicompat;
7use crate::token::LlamaToken;
8use crate::{ApplyChatTemplateError, ChatParseError, status_is_ok, status_to_i32};
9
10const fn check_chat_parse_status(
11    rc: llama_cpp_bindings_sys::llama_rs_status,
12) -> Result<(), ChatParseError> {
13    if !status_is_ok(rc) {
14        return Err(ChatParseError::FfiError(status_to_i32(rc)));
15    }
16
17    Ok(())
18}
19
20const fn check_chat_parse_not_null(json_ptr: *const c_char) -> Result<(), ChatParseError> {
21    if json_ptr.is_null() {
22        return Err(ChatParseError::NullResult);
23    }
24
25    Ok(())
26}
27
28/// Result of applying a chat template with tool grammar support.
29#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct ChatTemplateResult {
31    /// Rendered chat prompt.
32    pub prompt: String,
33    /// Optional grammar generated from tool definitions.
34    pub grammar: Option<String>,
35    /// Whether to use lazy grammar sampling.
36    pub grammar_lazy: bool,
37    /// Lazy grammar triggers derived from the template.
38    pub grammar_triggers: Vec<GrammarTrigger>,
39    /// Tokens that should be preserved for sampling.
40    pub preserved_tokens: Vec<String>,
41    /// Additional stop sequences added by the template.
42    pub additional_stops: Vec<String>,
43    /// Chat format used for parsing responses.
44    pub chat_format: i32,
45    /// Optional serialized PEG parser for tool-call parsing.
46    pub parser: Option<String>,
47    /// Whether the model supports thinking/reasoning blocks.
48    pub supports_thinking: bool,
49    /// Whether tool calls should be parsed from the response.
50    pub parse_tool_calls: bool,
51}
52
53#[must_use]
54pub const fn new_empty_chat_template_raw_result()
55-> llama_cpp_bindings_sys::llama_rs_chat_template_result {
56    llama_cpp_bindings_sys::llama_rs_chat_template_result {
57        prompt: ptr::null_mut(),
58        grammar: ptr::null_mut(),
59        parser: ptr::null_mut(),
60        chat_format: 0,
61        supports_thinking: false,
62        grammar_lazy: false,
63        grammar_triggers: ptr::null_mut(),
64        grammar_triggers_count: 0,
65        preserved_tokens: ptr::null_mut(),
66        preserved_tokens_count: 0,
67        additional_stops: ptr::null_mut(),
68        additional_stops_count: 0,
69    }
70}
71
72/// # Safety
73///
74/// `raw_cstr_array` must point to `count` valid, null-terminated C strings.
75unsafe fn parse_raw_cstr_array(
76    raw_cstr_array: *const *mut c_char,
77    count: usize,
78) -> Result<Vec<String>, ApplyChatTemplateError> {
79    if count == 0 {
80        return Ok(Vec::new());
81    }
82
83    if raw_cstr_array.is_null() {
84        return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
85    }
86
87    let raw_entries = unsafe { slice::from_raw_parts(raw_cstr_array, count) };
88    let mut parsed = Vec::with_capacity(raw_entries.len());
89
90    for entry in raw_entries {
91        if entry.is_null() {
92            return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
93        }
94        let bytes = unsafe { CStr::from_ptr(*entry) }.to_bytes().to_vec();
95        parsed.push(String::from_utf8(bytes)?);
96    }
97
98    Ok(parsed)
99}
100
101/// # Safety
102///
103/// `raw_triggers` must point to `count` valid `llama_rs_grammar_trigger` structs.
104unsafe fn parse_raw_grammar_triggers(
105    raw_triggers: *const llama_cpp_bindings_sys::llama_rs_grammar_trigger,
106    count: usize,
107) -> Result<Vec<GrammarTrigger>, ApplyChatTemplateError> {
108    if count == 0 {
109        return Ok(Vec::new());
110    }
111
112    if raw_triggers.is_null() {
113        return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
114    }
115
116    let triggers = unsafe { slice::from_raw_parts(raw_triggers, count) };
117    let mut parsed = Vec::with_capacity(triggers.len());
118
119    for trigger in triggers {
120        let trigger_type = match trigger.type_ {
121            0 => GrammarTriggerType::Token,
122            1 => GrammarTriggerType::Word,
123            2 => GrammarTriggerType::Pattern,
124            3 => GrammarTriggerType::PatternFull,
125            _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
126        };
127        let value = if trigger.value.is_null() {
128            return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
129        } else {
130            let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
131            String::from_utf8(bytes)?
132        };
133        let token = if trigger_type == GrammarTriggerType::Token {
134            Some(LlamaToken(trigger.token))
135        } else {
136            None
137        };
138        parsed.push(GrammarTrigger {
139            trigger_type,
140            value,
141            token,
142        });
143    }
144
145    Ok(parsed)
146}
147
148/// # Safety
149///
150/// `raw_result` must point to a valid, initialized `llama_rs_chat_template_result`.
151///
152/// # Errors
153/// Returns `ApplyChatTemplateError` if the FFI call failed or the result could not be parsed.
154pub unsafe fn parse_chat_template_raw_result(
155    ffi_return_code: llama_cpp_bindings_sys::llama_rs_status,
156    raw_result: *mut llama_cpp_bindings_sys::llama_rs_chat_template_result,
157    parse_tool_calls: bool,
158) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
159    let result = (|| {
160        if !status_is_ok(ffi_return_code) {
161            return Err(ApplyChatTemplateError::FfiError(status_to_i32(
162                ffi_return_code,
163            )));
164        }
165
166        let raw = unsafe { &*raw_result };
167
168        if raw.prompt.is_null() {
169            return Err(ApplyChatTemplateError::NullResult);
170        }
171
172        let prompt_bytes = unsafe { CStr::from_ptr(raw.prompt) }.to_bytes().to_vec();
173        let prompt = String::from_utf8(prompt_bytes)?;
174
175        let grammar = if raw.grammar.is_null() {
176            None
177        } else {
178            let grammar_bytes = unsafe { CStr::from_ptr(raw.grammar) }.to_bytes().to_vec();
179            Some(String::from_utf8(grammar_bytes)?)
180        };
181
182        let parser = if raw.parser.is_null() {
183            None
184        } else {
185            let parser_bytes = unsafe { CStr::from_ptr(raw.parser) }.to_bytes().to_vec();
186            Some(String::from_utf8(parser_bytes)?)
187        };
188
189        let grammar_triggers = unsafe {
190            parse_raw_grammar_triggers(raw.grammar_triggers, raw.grammar_triggers_count)
191        }?;
192
193        let preserved_tokens =
194            unsafe { parse_raw_cstr_array(raw.preserved_tokens, raw.preserved_tokens_count) }?;
195
196        let additional_stops =
197            unsafe { parse_raw_cstr_array(raw.additional_stops, raw.additional_stops_count) }?;
198
199        Ok(ChatTemplateResult {
200            prompt,
201            grammar,
202            grammar_lazy: raw.grammar_lazy,
203            grammar_triggers,
204            preserved_tokens,
205            additional_stops,
206            chat_format: raw.chat_format,
207            parser,
208            supports_thinking: raw.supports_thinking,
209            parse_tool_calls,
210        })
211    })();
212
213    unsafe { llama_cpp_bindings_sys::llama_rs_chat_template_result_free(raw_result) };
214
215    result
216}
217
218impl ChatTemplateResult {
219    /// Parse a generated response into an OpenAI-compatible message JSON string.
220    ///
221    /// # Errors
222    /// Returns an error if the FFI call fails or the result is null.
223    pub fn parse_response_oaicompat(
224        &self,
225        text: &str,
226        is_partial: bool,
227    ) -> Result<String, ChatParseError> {
228        let text_cstr = CString::new(text)?;
229        let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
230        let mut out_json: *mut c_char = ptr::null_mut();
231        let rc = unsafe {
232            llama_cpp_bindings_sys::llama_rs_chat_parse_to_oaicompat(
233                text_cstr.as_ptr(),
234                is_partial,
235                self.chat_format,
236                self.parse_tool_calls,
237                parser_cstr
238                    .as_ref()
239                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
240                &raw mut out_json,
241            )
242        };
243
244        let result = (|| {
245            check_chat_parse_status(rc)?;
246            check_chat_parse_not_null(out_json)?;
247            let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
248            Ok(String::from_utf8(bytes)?)
249        })();
250
251        unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_json) };
252
253        result
254    }
255
256    /// Initialize a streaming parser for OpenAI-compatible chat deltas.
257    ///
258    /// # Errors
259    /// Returns an error if the parser state cannot be initialized.
260    pub fn streaming_state_oaicompat(&self) -> Result<ChatParseStateOaicompat, ChatParseError> {
261        let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
262        let state = unsafe {
263            llama_cpp_bindings_sys::llama_rs_chat_parse_state_init_oaicompat(
264                self.chat_format,
265                self.parse_tool_calls,
266                parser_cstr
267                    .as_ref()
268                    .map_or(ptr::null(), |cstr| cstr.as_ptr()),
269            )
270        };
271        let state = NonNull::new(state).ok_or(ChatParseError::NullResult)?;
272
273        Ok(ChatParseStateOaicompat { state })
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use std::ffi::{CString, c_char};
280    use std::ptr;
281
282    use super::{
283        ChatTemplateResult, new_empty_chat_template_raw_result, parse_chat_template_raw_result,
284        parse_raw_cstr_array, parse_raw_grammar_triggers,
285    };
286    use crate::model::grammar_trigger::GrammarTriggerType;
287    use crate::token::LlamaToken;
288
289    fn heap_cstring(value: &str) -> *mut c_char {
290        CString::new(value).unwrap().into_raw()
291    }
292
293    // --- parse_raw_cstr_array ---
294
295    #[test]
296    fn parse_cstr_array_zero_count_returns_empty() {
297        let result = unsafe { parse_raw_cstr_array(ptr::null(), 0) };
298        assert_eq!(result.unwrap(), Vec::<String>::new());
299    }
300
301    #[test]
302    fn parse_cstr_array_null_with_nonzero_count_returns_error() {
303        let result = unsafe { parse_raw_cstr_array(ptr::null(), 1) };
304        assert!(
305            result
306                .unwrap_err()
307                .to_string()
308                .contains("invalid grammar trigger data")
309        );
310    }
311
312    #[test]
313    fn parse_cstr_array_valid_single_string() {
314        let raw_string = heap_cstring("hello");
315        let array = [raw_string];
316        let result = unsafe { parse_raw_cstr_array(array.as_ptr(), 1) };
317        assert_eq!(result.unwrap(), vec!["hello".to_string()]);
318        unsafe { drop(CString::from_raw(array[0])) };
319    }
320
321    #[test]
322    fn parse_cstr_array_null_entry_returns_error() {
323        let raw_string = heap_cstring("valid");
324        let array: [*mut c_char; 2] = [raw_string, ptr::null_mut()];
325        let result = unsafe { parse_raw_cstr_array(array.as_ptr(), 2) };
326        assert!(
327            result
328                .unwrap_err()
329                .to_string()
330                .contains("invalid grammar trigger data")
331        );
332        unsafe { drop(CString::from_raw(array[0])) };
333    }
334
335    // --- parse_raw_grammar_triggers ---
336
337    #[test]
338    fn parse_triggers_zero_count_returns_empty() {
339        let result = unsafe { parse_raw_grammar_triggers(ptr::null(), 0) };
340        assert_eq!(result.unwrap(), Vec::new());
341    }
342
343    #[test]
344    fn parse_triggers_null_with_nonzero_count_returns_error() {
345        let result = unsafe { parse_raw_grammar_triggers(ptr::null(), 1) };
346        assert!(
347            result
348                .unwrap_err()
349                .to_string()
350                .contains("invalid grammar trigger data")
351        );
352    }
353
354    #[test]
355    fn parse_triggers_token_type_has_token() {
356        let value_ptr = heap_cstring("<tool>");
357        let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
358            type_: 0,
359            value: value_ptr,
360            token: 42,
361        };
362        let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
363        let parsed = result.unwrap();
364        assert_eq!(parsed.len(), 1);
365        assert_eq!(parsed[0].trigger_type, GrammarTriggerType::Token);
366        assert_eq!(parsed[0].value, "<tool>");
367        assert_eq!(parsed[0].token, Some(LlamaToken(42)));
368        unsafe { drop(CString::from_raw(value_ptr)) };
369    }
370
371    #[test]
372    fn parse_triggers_word_type_has_no_token() {
373        let value_ptr = heap_cstring("function");
374        let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
375            type_: 1,
376            value: value_ptr,
377            token: 99,
378        };
379        let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
380        let parsed = result.unwrap();
381        assert_eq!(parsed[0].trigger_type, GrammarTriggerType::Word);
382        assert_eq!(parsed[0].token, None);
383        unsafe { drop(CString::from_raw(value_ptr)) };
384    }
385
386    #[test]
387    fn parse_triggers_pattern_type() {
388        let value_ptr = heap_cstring("\\{.*\\}");
389        let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
390            type_: 2,
391            value: value_ptr,
392            token: 0,
393        };
394        let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
395        assert_eq!(result.unwrap()[0].trigger_type, GrammarTriggerType::Pattern);
396        unsafe { drop(CString::from_raw(value_ptr)) };
397    }
398
399    #[test]
400    fn parse_triggers_pattern_full_type() {
401        let value_ptr = heap_cstring("^tool$");
402        let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
403            type_: 3,
404            value: value_ptr,
405            token: 0,
406        };
407        let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
408        assert_eq!(
409            result.unwrap()[0].trigger_type,
410            GrammarTriggerType::PatternFull
411        );
412        unsafe { drop(CString::from_raw(value_ptr)) };
413    }
414
415    #[test]
416    fn parse_triggers_invalid_type_returns_error() {
417        let value_ptr = heap_cstring("x");
418        let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
419            type_: 4,
420            value: value_ptr,
421            token: 0,
422        };
423        let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
424        assert!(
425            result
426                .unwrap_err()
427                .to_string()
428                .contains("invalid grammar trigger data")
429        );
430        unsafe { drop(CString::from_raw(value_ptr)) };
431    }
432
433    #[test]
434    fn parse_triggers_null_value_returns_error() {
435        let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
436            type_: 1,
437            value: ptr::null_mut(),
438            token: 0,
439        };
440        let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
441        assert!(
442            result
443                .unwrap_err()
444                .to_string()
445                .contains("invalid grammar trigger data")
446        );
447    }
448
449    // --- parse_chat_template_raw_result ---
450
451    #[test]
452    fn parse_raw_result_error_status_returns_ffi_error() {
453        let mut raw = new_empty_chat_template_raw_result();
454        let result = unsafe {
455            parse_chat_template_raw_result(
456                llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
457                &raw mut raw,
458                false,
459            )
460        };
461        assert!(result.unwrap_err().to_string().contains("ffi error -1"));
462    }
463
464    #[test]
465    fn parse_raw_result_null_prompt_returns_null_result() {
466        let mut raw = new_empty_chat_template_raw_result();
467        let result = unsafe {
468            parse_chat_template_raw_result(
469                llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
470                &raw mut raw,
471                false,
472            )
473        };
474        assert!(result.unwrap_err().to_string().contains("null result"));
475    }
476
477    #[test]
478    fn parse_raw_result_minimal_prompt() {
479        let mut raw = new_empty_chat_template_raw_result();
480        raw.prompt = heap_cstring("Hello");
481        let result = unsafe {
482            parse_chat_template_raw_result(
483                llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
484                &raw mut raw,
485                false,
486            )
487        };
488        let parsed = result.unwrap();
489        assert_eq!(parsed.prompt, "Hello");
490        assert_eq!(parsed.grammar, None);
491        assert_eq!(parsed.parser, None);
492        assert!(!parsed.supports_thinking);
493        assert!(!parsed.grammar_lazy);
494        assert!(!parsed.parse_tool_calls);
495    }
496
497    #[test]
498    fn parse_raw_result_supports_thinking_true() {
499        let mut raw = new_empty_chat_template_raw_result();
500        raw.prompt = heap_cstring("test");
501        raw.supports_thinking = true;
502        let result = unsafe {
503            parse_chat_template_raw_result(
504                llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
505                &raw mut raw,
506                false,
507            )
508        };
509        assert!(result.unwrap().supports_thinking);
510    }
511
512    #[test]
513    fn parse_raw_result_with_grammar_and_parser() {
514        let mut raw = new_empty_chat_template_raw_result();
515        raw.prompt = heap_cstring("prompt");
516        raw.grammar = heap_cstring("root ::= .*");
517        raw.parser = heap_cstring("peg_data");
518        raw.grammar_lazy = true;
519        raw.chat_format = 2;
520        let result = unsafe {
521            parse_chat_template_raw_result(
522                llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
523                &raw mut raw,
524                true,
525            )
526        };
527        let parsed = result.unwrap();
528        assert_eq!(parsed.grammar.as_deref(), Some("root ::= .*"));
529        assert_eq!(parsed.parser.as_deref(), Some("peg_data"));
530        assert!(parsed.grammar_lazy);
531        assert_eq!(parsed.chat_format, 2);
532        assert!(parsed.parse_tool_calls);
533    }
534
535    // --- parse_response_oaicompat ---
536
537    #[test]
538    fn parse_response_content_only_format() {
539        let json_string = ChatTemplateResult::default()
540            .parse_response_oaicompat("Hello, world!", false)
541            .unwrap();
542        let json_value: serde_json::Value = serde_json::from_str(&json_string).unwrap();
543        assert_eq!(json_value["role"], "assistant");
544        assert_eq!(json_value["content"], "Hello, world!");
545    }
546
547    #[test]
548    fn parse_response_null_byte_returns_error() {
549        let result = ChatTemplateResult::default().parse_response_oaicompat("hello\0world", false);
550        assert!(result.is_err());
551    }
552
553    // --- parse_chat_template_raw_result with invalid grammar triggers ---
554
555    #[test]
556    fn parse_raw_result_invalid_triggers_propagates_error() {
557        let mut raw = new_empty_chat_template_raw_result();
558        raw.prompt = heap_cstring("prompt");
559        raw.grammar_triggers = ptr::null_mut();
560        raw.grammar_triggers_count = 1;
561        let result = unsafe {
562            parse_chat_template_raw_result(
563                llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
564                &raw mut raw,
565                false,
566            )
567        };
568
569        assert!(
570            result
571                .unwrap_err()
572                .to_string()
573                .contains("invalid grammar trigger data")
574        );
575    }
576
577    // --- check_chat_parse_status / check_chat_parse_not_null ---
578
579    #[test]
580    fn check_chat_parse_status_ok() {
581        let result = super::check_chat_parse_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK);
582
583        assert!(result.is_ok());
584    }
585
586    #[test]
587    fn check_chat_parse_status_error() {
588        let result = super::check_chat_parse_status(
589            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
590        );
591
592        assert!(result.unwrap_err().to_string().contains("ffi error"));
593    }
594
595    #[test]
596    fn check_chat_parse_not_null_ok() {
597        let cstr = CString::new("test").unwrap();
598        let result = super::check_chat_parse_not_null(cstr.as_ptr());
599
600        assert!(result.is_ok());
601    }
602
603    #[test]
604    fn check_chat_parse_not_null_error() {
605        let result = super::check_chat_parse_not_null(ptr::null());
606
607        assert!(result.unwrap_err().to_string().contains("null result"));
608    }
609
610    // --- streaming_state_oaicompat ---
611
612    #[test]
613    fn streaming_state_returns_valid_state() {
614        let template_result = ChatTemplateResult::default();
615        let state = template_result.streaming_state_oaicompat();
616        assert!(state.is_ok());
617    }
618
619    #[test]
620    fn parse_raw_result_null_preserved_token_propagates_error() {
621        let mut raw = new_empty_chat_template_raw_result();
622        raw.prompt = heap_cstring("test");
623        raw.preserved_tokens_count = 1;
624        // preserved_tokens pointer is null but count is 1
625        let result = unsafe {
626            parse_chat_template_raw_result(
627                llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
628                &raw mut raw,
629                false,
630            )
631        };
632
633        assert!(result.is_err());
634    }
635
636    #[test]
637    fn parse_raw_result_null_additional_stop_propagates_error() {
638        let mut raw = new_empty_chat_template_raw_result();
639        raw.prompt = heap_cstring("test");
640        // valid preserved_tokens (empty)
641        raw.additional_stops_count = 1;
642        // additional_stops pointer is null but count is 1
643        let result = unsafe {
644            parse_chat_template_raw_result(
645                llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
646                &raw mut raw,
647                false,
648            )
649        };
650
651        assert!(result.is_err());
652    }
653
654    #[test]
655    fn parse_response_with_null_byte_parser_returns_error() {
656        let template_result = ChatTemplateResult {
657            parser: Some("null\0byte".to_string()),
658            ..ChatTemplateResult::default()
659        };
660
661        let result = template_result.parse_response_oaicompat("hello", false);
662
663        assert!(result.is_err());
664    }
665
666    #[test]
667    fn streaming_state_with_null_byte_parser_returns_error() {
668        let template_result = ChatTemplateResult {
669            parser: Some("null\0byte".to_string()),
670            ..ChatTemplateResult::default()
671        };
672
673        let result = template_result.streaming_state_oaicompat();
674
675        assert!(result.is_err());
676    }
677
678    #[test]
679    fn parse_response_with_valid_parser() {
680        let template_result = ChatTemplateResult {
681            parser: Some(String::new()),
682            ..ChatTemplateResult::default()
683        };
684
685        let result = template_result.parse_response_oaicompat("hello", false);
686
687        assert!(result.is_ok());
688    }
689
690    #[test]
691    fn streaming_state_with_valid_parser() {
692        let template_result = ChatTemplateResult {
693            parser: Some(String::new()),
694            ..ChatTemplateResult::default()
695        };
696
697        let result = template_result.streaming_state_oaicompat();
698
699        assert!(result.is_ok());
700    }
701}