1use crate::{status_is_ok, status_to_i32, ChatParseError};
3use std::ffi::{c_char, CStr, CString};
4use std::mem;
5use std::ptr::{self, NonNull};
6use std::slice;
7
8#[derive(Debug, Clone, PartialEq)]
10pub struct OpenAIChatTemplateParams<'a> {
11 pub messages_json: &'a str,
13 pub tools_json: Option<&'a str>,
15 pub tool_choice: Option<&'a str>,
17 pub json_schema: Option<&'a str>,
19 pub grammar: Option<&'a str>,
21 pub reasoning_format: Option<&'a str>,
23 pub chat_template_kwargs: Option<&'a str>,
25 pub add_generation_prompt: bool,
27 pub use_jinja: bool,
29 pub parallel_tool_calls: bool,
31 pub enable_thinking: bool,
33 pub add_bos: bool,
35 pub add_eos: bool,
37 pub parse_tool_calls: bool,
39}
40
41#[derive(Debug)]
43pub struct ChatParseStateOaicompat {
44 pub(crate) state: NonNull<llama_cpp_sys_2::llama_rs_chat_parse_state_oaicompat>,
45}
46
47impl ChatParseStateOaicompat {
48 pub fn update(
50 &mut self,
51 text_added: &str,
52 is_partial: bool,
53 ) -> Result<Vec<String>, ChatParseError> {
54 let text_cstr = CString::new(text_added)?;
55 let mut out_msg: llama_cpp_sys_2::llama_rs_chat_msg_oaicompat = unsafe { mem::zeroed() };
56 let mut out_diffs: *mut llama_cpp_sys_2::llama_rs_chat_msg_diff_oaicompat = ptr::null_mut();
57 let mut out_diffs_count: usize = 0;
58 let rc = unsafe {
59 llama_cpp_sys_2::llama_rs_chat_parse_state_update_oaicompat(
60 self.state.as_ptr(),
61 text_cstr.as_ptr(),
62 is_partial,
63 &mut out_msg,
64 &mut out_diffs,
65 &mut out_diffs_count,
66 )
67 };
68
69 let result = {
70 if !status_is_ok(rc) {
71 return Err(ChatParseError::FfiError(status_to_i32(rc)));
72 }
73 if out_diffs_count > 0 && out_diffs.is_null() {
74 return Err(ChatParseError::NullResult);
75 }
76 let diffs = if out_diffs_count == 0 {
77 &[]
78 } else {
79 unsafe { slice::from_raw_parts(out_diffs, out_diffs_count) }
80 };
81 let mut deltas = Vec::with_capacity(diffs.len());
82 for diff in diffs {
83 let mut out_json: *mut c_char = ptr::null_mut();
84 let rc = unsafe {
85 llama_cpp_sys_2::llama_rs_chat_msg_diff_to_oaicompat_json(diff, &mut out_json)
86 };
87 if !status_is_ok(rc) {
88 if !out_json.is_null() {
89 unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
90 }
91 return Err(ChatParseError::FfiError(status_to_i32(rc)));
92 }
93 if out_json.is_null() {
94 return Err(ChatParseError::NullResult);
95 }
96 let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
97 unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
98 deltas.push(String::from_utf8(bytes)?);
99 }
100 Ok(deltas)
101 };
102
103 unsafe { llama_cpp_sys_2::llama_rs_chat_msg_free_oaicompat(&mut out_msg) };
104 unsafe {
105 llama_cpp_sys_2::llama_rs_chat_msg_diff_free_oaicompat(out_diffs, out_diffs_count)
106 };
107 result
108 }
109}
110
111impl Drop for ChatParseStateOaicompat {
112 fn drop(&mut self) {
113 unsafe { llama_cpp_sys_2::llama_rs_chat_parse_state_free_oaicompat(self.state.as_ptr()) };
114 }
115}