1use std::ffi::{CStr, CString, c_char};
2use std::ptr::{self, NonNull};
3use std::slice;
4
5use crate::model::grammar_trigger::{GrammarTrigger, GrammarTriggerType};
6use crate::openai::ChatParseStateOaicompat;
7use crate::token::LlamaToken;
8use crate::{ApplyChatTemplateError, ChatParseError, status_is_ok, status_to_i32};
9
10const fn check_chat_parse_status(
11 rc: llama_cpp_bindings_sys::llama_rs_status,
12) -> Result<(), ChatParseError> {
13 if !status_is_ok(rc) {
14 return Err(ChatParseError::FfiError(status_to_i32(rc)));
15 }
16
17 Ok(())
18}
19
20const fn check_chat_parse_not_null(json_ptr: *const c_char) -> Result<(), ChatParseError> {
21 if json_ptr.is_null() {
22 return Err(ChatParseError::NullResult);
23 }
24
25 Ok(())
26}
27
28#[derive(Debug, Clone, Default, PartialEq, Eq)]
30pub struct ChatTemplateResult {
31 pub prompt: String,
33 pub grammar: Option<String>,
35 pub grammar_lazy: bool,
37 pub grammar_triggers: Vec<GrammarTrigger>,
39 pub preserved_tokens: Vec<String>,
41 pub additional_stops: Vec<String>,
43 pub chat_format: i32,
45 pub parser: Option<String>,
47 pub supports_thinking: bool,
49 pub parse_tool_calls: bool,
51}
52
53#[must_use]
54pub const fn new_empty_chat_template_raw_result()
55-> llama_cpp_bindings_sys::llama_rs_chat_template_result {
56 llama_cpp_bindings_sys::llama_rs_chat_template_result {
57 prompt: ptr::null_mut(),
58 grammar: ptr::null_mut(),
59 parser: ptr::null_mut(),
60 chat_format: 0,
61 supports_thinking: false,
62 grammar_lazy: false,
63 grammar_triggers: ptr::null_mut(),
64 grammar_triggers_count: 0,
65 preserved_tokens: ptr::null_mut(),
66 preserved_tokens_count: 0,
67 additional_stops: ptr::null_mut(),
68 additional_stops_count: 0,
69 }
70}
71
72unsafe fn parse_raw_cstr_array(
76 raw_cstr_array: *const *mut c_char,
77 count: usize,
78) -> Result<Vec<String>, ApplyChatTemplateError> {
79 if count == 0 {
80 return Ok(Vec::new());
81 }
82
83 if raw_cstr_array.is_null() {
84 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
85 }
86
87 let raw_entries = unsafe { slice::from_raw_parts(raw_cstr_array, count) };
88 let mut parsed = Vec::with_capacity(raw_entries.len());
89
90 for entry in raw_entries {
91 if entry.is_null() {
92 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
93 }
94 let bytes = unsafe { CStr::from_ptr(*entry) }.to_bytes().to_vec();
95 parsed.push(String::from_utf8(bytes)?);
96 }
97
98 Ok(parsed)
99}
100
101unsafe fn parse_raw_grammar_triggers(
105 raw_triggers: *const llama_cpp_bindings_sys::llama_rs_grammar_trigger,
106 count: usize,
107) -> Result<Vec<GrammarTrigger>, ApplyChatTemplateError> {
108 if count == 0 {
109 return Ok(Vec::new());
110 }
111
112 if raw_triggers.is_null() {
113 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
114 }
115
116 let triggers = unsafe { slice::from_raw_parts(raw_triggers, count) };
117 let mut parsed = Vec::with_capacity(triggers.len());
118
119 for trigger in triggers {
120 let trigger_type = match trigger.type_ {
121 0 => GrammarTriggerType::Token,
122 1 => GrammarTriggerType::Word,
123 2 => GrammarTriggerType::Pattern,
124 3 => GrammarTriggerType::PatternFull,
125 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
126 };
127 let value = if trigger.value.is_null() {
128 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
129 } else {
130 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
131 String::from_utf8(bytes)?
132 };
133 let token = if trigger_type == GrammarTriggerType::Token {
134 Some(LlamaToken(trigger.token))
135 } else {
136 None
137 };
138 parsed.push(GrammarTrigger {
139 trigger_type,
140 value,
141 token,
142 });
143 }
144
145 Ok(parsed)
146}
147
148pub unsafe fn parse_chat_template_raw_result(
155 ffi_return_code: llama_cpp_bindings_sys::llama_rs_status,
156 raw_result: *mut llama_cpp_bindings_sys::llama_rs_chat_template_result,
157 parse_tool_calls: bool,
158) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
159 let result = (|| {
160 if !status_is_ok(ffi_return_code) {
161 return Err(ApplyChatTemplateError::FfiError(status_to_i32(
162 ffi_return_code,
163 )));
164 }
165
166 let raw = unsafe { &*raw_result };
167
168 if raw.prompt.is_null() {
169 return Err(ApplyChatTemplateError::NullResult);
170 }
171
172 let prompt_bytes = unsafe { CStr::from_ptr(raw.prompt) }.to_bytes().to_vec();
173 let prompt = String::from_utf8(prompt_bytes)?;
174
175 let grammar = if raw.grammar.is_null() {
176 None
177 } else {
178 let grammar_bytes = unsafe { CStr::from_ptr(raw.grammar) }.to_bytes().to_vec();
179 Some(String::from_utf8(grammar_bytes)?)
180 };
181
182 let parser = if raw.parser.is_null() {
183 None
184 } else {
185 let parser_bytes = unsafe { CStr::from_ptr(raw.parser) }.to_bytes().to_vec();
186 Some(String::from_utf8(parser_bytes)?)
187 };
188
189 let grammar_triggers = unsafe {
190 parse_raw_grammar_triggers(raw.grammar_triggers, raw.grammar_triggers_count)
191 }?;
192
193 let preserved_tokens =
194 unsafe { parse_raw_cstr_array(raw.preserved_tokens, raw.preserved_tokens_count) }?;
195
196 let additional_stops =
197 unsafe { parse_raw_cstr_array(raw.additional_stops, raw.additional_stops_count) }?;
198
199 Ok(ChatTemplateResult {
200 prompt,
201 grammar,
202 grammar_lazy: raw.grammar_lazy,
203 grammar_triggers,
204 preserved_tokens,
205 additional_stops,
206 chat_format: raw.chat_format,
207 parser,
208 supports_thinking: raw.supports_thinking,
209 parse_tool_calls,
210 })
211 })();
212
213 unsafe { llama_cpp_bindings_sys::llama_rs_chat_template_result_free(raw_result) };
214
215 result
216}
217
218impl ChatTemplateResult {
219 pub fn parse_response_oaicompat(
224 &self,
225 text: &str,
226 is_partial: bool,
227 ) -> Result<String, ChatParseError> {
228 let text_cstr = CString::new(text)?;
229 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
230 let mut out_json: *mut c_char = ptr::null_mut();
231 let rc = unsafe {
232 llama_cpp_bindings_sys::llama_rs_chat_parse_to_oaicompat(
233 text_cstr.as_ptr(),
234 is_partial,
235 self.chat_format,
236 self.parse_tool_calls,
237 parser_cstr
238 .as_ref()
239 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
240 &raw mut out_json,
241 )
242 };
243
244 let result = (|| {
245 check_chat_parse_status(rc)?;
246 check_chat_parse_not_null(out_json)?;
247 let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
248 Ok(String::from_utf8(bytes)?)
249 })();
250
251 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_json) };
252
253 result
254 }
255
256 pub fn streaming_state_oaicompat(&self) -> Result<ChatParseStateOaicompat, ChatParseError> {
261 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
262 let state = unsafe {
263 llama_cpp_bindings_sys::llama_rs_chat_parse_state_init_oaicompat(
264 self.chat_format,
265 self.parse_tool_calls,
266 parser_cstr
267 .as_ref()
268 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
269 )
270 };
271 let state = NonNull::new(state).ok_or(ChatParseError::NullResult)?;
272
273 Ok(ChatParseStateOaicompat { state })
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use std::ffi::{CString, c_char};
280 use std::ptr;
281
282 use super::{
283 ChatTemplateResult, new_empty_chat_template_raw_result, parse_chat_template_raw_result,
284 parse_raw_cstr_array, parse_raw_grammar_triggers,
285 };
286 use crate::model::grammar_trigger::GrammarTriggerType;
287 use crate::token::LlamaToken;
288
289 fn heap_cstring(value: &str) -> *mut c_char {
290 CString::new(value).unwrap().into_raw()
291 }
292
293 #[test]
296 fn parse_cstr_array_zero_count_returns_empty() {
297 let result = unsafe { parse_raw_cstr_array(ptr::null(), 0) };
298 assert_eq!(result.unwrap(), Vec::<String>::new());
299 }
300
301 #[test]
302 fn parse_cstr_array_null_with_nonzero_count_returns_error() {
303 let result = unsafe { parse_raw_cstr_array(ptr::null(), 1) };
304 assert!(
305 result
306 .unwrap_err()
307 .to_string()
308 .contains("invalid grammar trigger data")
309 );
310 }
311
312 #[test]
313 fn parse_cstr_array_valid_single_string() {
314 let raw_string = heap_cstring("hello");
315 let array = [raw_string];
316 let result = unsafe { parse_raw_cstr_array(array.as_ptr(), 1) };
317 assert_eq!(result.unwrap(), vec!["hello".to_string()]);
318 unsafe { drop(CString::from_raw(array[0])) };
319 }
320
321 #[test]
322 fn parse_cstr_array_null_entry_returns_error() {
323 let raw_string = heap_cstring("valid");
324 let array: [*mut c_char; 2] = [raw_string, ptr::null_mut()];
325 let result = unsafe { parse_raw_cstr_array(array.as_ptr(), 2) };
326 assert!(
327 result
328 .unwrap_err()
329 .to_string()
330 .contains("invalid grammar trigger data")
331 );
332 unsafe { drop(CString::from_raw(array[0])) };
333 }
334
335 #[test]
338 fn parse_triggers_zero_count_returns_empty() {
339 let result = unsafe { parse_raw_grammar_triggers(ptr::null(), 0) };
340 assert_eq!(result.unwrap(), Vec::new());
341 }
342
343 #[test]
344 fn parse_triggers_null_with_nonzero_count_returns_error() {
345 let result = unsafe { parse_raw_grammar_triggers(ptr::null(), 1) };
346 assert!(
347 result
348 .unwrap_err()
349 .to_string()
350 .contains("invalid grammar trigger data")
351 );
352 }
353
354 #[test]
355 fn parse_triggers_token_type_has_token() {
356 let value_ptr = heap_cstring("<tool>");
357 let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
358 type_: 0,
359 value: value_ptr,
360 token: 42,
361 };
362 let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
363 let parsed = result.unwrap();
364 assert_eq!(parsed.len(), 1);
365 assert_eq!(parsed[0].trigger_type, GrammarTriggerType::Token);
366 assert_eq!(parsed[0].value, "<tool>");
367 assert_eq!(parsed[0].token, Some(LlamaToken(42)));
368 unsafe { drop(CString::from_raw(value_ptr)) };
369 }
370
371 #[test]
372 fn parse_triggers_word_type_has_no_token() {
373 let value_ptr = heap_cstring("function");
374 let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
375 type_: 1,
376 value: value_ptr,
377 token: 99,
378 };
379 let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
380 let parsed = result.unwrap();
381 assert_eq!(parsed[0].trigger_type, GrammarTriggerType::Word);
382 assert_eq!(parsed[0].token, None);
383 unsafe { drop(CString::from_raw(value_ptr)) };
384 }
385
386 #[test]
387 fn parse_triggers_pattern_type() {
388 let value_ptr = heap_cstring("\\{.*\\}");
389 let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
390 type_: 2,
391 value: value_ptr,
392 token: 0,
393 };
394 let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
395 assert_eq!(result.unwrap()[0].trigger_type, GrammarTriggerType::Pattern);
396 unsafe { drop(CString::from_raw(value_ptr)) };
397 }
398
399 #[test]
400 fn parse_triggers_pattern_full_type() {
401 let value_ptr = heap_cstring("^tool$");
402 let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
403 type_: 3,
404 value: value_ptr,
405 token: 0,
406 };
407 let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
408 assert_eq!(
409 result.unwrap()[0].trigger_type,
410 GrammarTriggerType::PatternFull
411 );
412 unsafe { drop(CString::from_raw(value_ptr)) };
413 }
414
415 #[test]
416 fn parse_triggers_invalid_type_returns_error() {
417 let value_ptr = heap_cstring("x");
418 let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
419 type_: 4,
420 value: value_ptr,
421 token: 0,
422 };
423 let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
424 assert!(
425 result
426 .unwrap_err()
427 .to_string()
428 .contains("invalid grammar trigger data")
429 );
430 unsafe { drop(CString::from_raw(value_ptr)) };
431 }
432
433 #[test]
434 fn parse_triggers_null_value_returns_error() {
435 let trigger = llama_cpp_bindings_sys::llama_rs_grammar_trigger {
436 type_: 1,
437 value: ptr::null_mut(),
438 token: 0,
439 };
440 let result = unsafe { parse_raw_grammar_triggers(&raw const trigger, 1) };
441 assert!(
442 result
443 .unwrap_err()
444 .to_string()
445 .contains("invalid grammar trigger data")
446 );
447 }
448
449 #[test]
452 fn parse_raw_result_error_status_returns_ffi_error() {
453 let mut raw = new_empty_chat_template_raw_result();
454 let result = unsafe {
455 parse_chat_template_raw_result(
456 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
457 &raw mut raw,
458 false,
459 )
460 };
461 assert!(result.unwrap_err().to_string().contains("ffi error -1"));
462 }
463
464 #[test]
465 fn parse_raw_result_null_prompt_returns_null_result() {
466 let mut raw = new_empty_chat_template_raw_result();
467 let result = unsafe {
468 parse_chat_template_raw_result(
469 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
470 &raw mut raw,
471 false,
472 )
473 };
474 assert!(result.unwrap_err().to_string().contains("null result"));
475 }
476
477 #[test]
478 fn parse_raw_result_minimal_prompt() {
479 let mut raw = new_empty_chat_template_raw_result();
480 raw.prompt = heap_cstring("Hello");
481 let result = unsafe {
482 parse_chat_template_raw_result(
483 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
484 &raw mut raw,
485 false,
486 )
487 };
488 let parsed = result.unwrap();
489 assert_eq!(parsed.prompt, "Hello");
490 assert_eq!(parsed.grammar, None);
491 assert_eq!(parsed.parser, None);
492 assert!(!parsed.supports_thinking);
493 assert!(!parsed.grammar_lazy);
494 assert!(!parsed.parse_tool_calls);
495 }
496
497 #[test]
498 fn parse_raw_result_supports_thinking_true() {
499 let mut raw = new_empty_chat_template_raw_result();
500 raw.prompt = heap_cstring("test");
501 raw.supports_thinking = true;
502 let result = unsafe {
503 parse_chat_template_raw_result(
504 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
505 &raw mut raw,
506 false,
507 )
508 };
509 assert!(result.unwrap().supports_thinking);
510 }
511
512 #[test]
513 fn parse_raw_result_with_grammar_and_parser() {
514 let mut raw = new_empty_chat_template_raw_result();
515 raw.prompt = heap_cstring("prompt");
516 raw.grammar = heap_cstring("root ::= .*");
517 raw.parser = heap_cstring("peg_data");
518 raw.grammar_lazy = true;
519 raw.chat_format = 2;
520 let result = unsafe {
521 parse_chat_template_raw_result(
522 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
523 &raw mut raw,
524 true,
525 )
526 };
527 let parsed = result.unwrap();
528 assert_eq!(parsed.grammar.as_deref(), Some("root ::= .*"));
529 assert_eq!(parsed.parser.as_deref(), Some("peg_data"));
530 assert!(parsed.grammar_lazy);
531 assert_eq!(parsed.chat_format, 2);
532 assert!(parsed.parse_tool_calls);
533 }
534
535 #[test]
538 fn parse_response_content_only_format() {
539 let json_string = ChatTemplateResult::default()
540 .parse_response_oaicompat("Hello, world!", false)
541 .unwrap();
542 let json_value: serde_json::Value = serde_json::from_str(&json_string).unwrap();
543 assert_eq!(json_value["role"], "assistant");
544 assert_eq!(json_value["content"], "Hello, world!");
545 }
546
547 #[test]
548 fn parse_response_null_byte_returns_error() {
549 let result = ChatTemplateResult::default().parse_response_oaicompat("hello\0world", false);
550 assert!(result.is_err());
551 }
552
553 #[test]
556 fn parse_raw_result_invalid_triggers_propagates_error() {
557 let mut raw = new_empty_chat_template_raw_result();
558 raw.prompt = heap_cstring("prompt");
559 raw.grammar_triggers = ptr::null_mut();
560 raw.grammar_triggers_count = 1;
561 let result = unsafe {
562 parse_chat_template_raw_result(
563 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
564 &raw mut raw,
565 false,
566 )
567 };
568
569 assert!(
570 result
571 .unwrap_err()
572 .to_string()
573 .contains("invalid grammar trigger data")
574 );
575 }
576
577 #[test]
580 fn check_chat_parse_status_ok() {
581 let result = super::check_chat_parse_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK);
582
583 assert!(result.is_ok());
584 }
585
586 #[test]
587 fn check_chat_parse_status_error() {
588 let result = super::check_chat_parse_status(
589 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
590 );
591
592 assert!(result.unwrap_err().to_string().contains("ffi error"));
593 }
594
595 #[test]
596 fn check_chat_parse_not_null_ok() {
597 let cstr = CString::new("test").unwrap();
598 let result = super::check_chat_parse_not_null(cstr.as_ptr());
599
600 assert!(result.is_ok());
601 }
602
603 #[test]
604 fn check_chat_parse_not_null_error() {
605 let result = super::check_chat_parse_not_null(ptr::null());
606
607 assert!(result.unwrap_err().to_string().contains("null result"));
608 }
609
610 #[test]
613 fn streaming_state_returns_valid_state() {
614 let template_result = ChatTemplateResult::default();
615 let state = template_result.streaming_state_oaicompat();
616 assert!(state.is_ok());
617 }
618
619 #[test]
620 fn parse_raw_result_null_preserved_token_propagates_error() {
621 let mut raw = new_empty_chat_template_raw_result();
622 raw.prompt = heap_cstring("test");
623 raw.preserved_tokens_count = 1;
624 let result = unsafe {
626 parse_chat_template_raw_result(
627 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
628 &raw mut raw,
629 false,
630 )
631 };
632
633 assert!(result.is_err());
634 }
635
636 #[test]
637 fn parse_raw_result_null_additional_stop_propagates_error() {
638 let mut raw = new_empty_chat_template_raw_result();
639 raw.prompt = heap_cstring("test");
640 raw.additional_stops_count = 1;
642 let result = unsafe {
644 parse_chat_template_raw_result(
645 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
646 &raw mut raw,
647 false,
648 )
649 };
650
651 assert!(result.is_err());
652 }
653
654 #[test]
655 fn parse_response_with_null_byte_parser_returns_error() {
656 let template_result = ChatTemplateResult {
657 parser: Some("null\0byte".to_string()),
658 ..ChatTemplateResult::default()
659 };
660
661 let result = template_result.parse_response_oaicompat("hello", false);
662
663 assert!(result.is_err());
664 }
665
666 #[test]
667 fn streaming_state_with_null_byte_parser_returns_error() {
668 let template_result = ChatTemplateResult {
669 parser: Some("null\0byte".to_string()),
670 ..ChatTemplateResult::default()
671 };
672
673 let result = template_result.streaming_state_oaicompat();
674
675 assert!(result.is_err());
676 }
677
678 #[test]
679 fn parse_response_with_valid_parser() {
680 let template_result = ChatTemplateResult {
681 parser: Some(String::new()),
682 ..ChatTemplateResult::default()
683 };
684
685 let result = template_result.parse_response_oaicompat("hello", false);
686
687 assert!(result.is_ok());
688 }
689
690 #[test]
691 fn streaming_state_with_valid_parser() {
692 let template_result = ChatTemplateResult {
693 parser: Some(String::new()),
694 ..ChatTemplateResult::default()
695 };
696
697 let result = template_result.streaming_state_oaicompat();
698
699 assert!(result.is_ok());
700 }
701}