llama_cpp_bindings/openai/
chat_parse_state_oaicompat.rs1use std::ffi::{CStr, CString, c_char};
2use std::mem;
3use std::ptr::{self, NonNull};
4use std::slice;
5
6use crate::{ChatParseError, status_is_ok, status_to_i32};
7
8const fn check_ffi_status(
9 status: llama_cpp_bindings_sys::llama_rs_status,
10) -> Result<(), ChatParseError> {
11 if status_is_ok(status) {
12 Ok(())
13 } else {
14 Err(ChatParseError::FfiError(status_to_i32(status)))
15 }
16}
17
18const fn check_not_null_with_count(
19 pointer: *const llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat,
20 count: usize,
21) -> Result<(), ChatParseError> {
22 if count > 0 && pointer.is_null() {
23 Err(ChatParseError::NullResult)
24 } else {
25 Ok(())
26 }
27}
28
29const unsafe fn diffs_as_slice<'diffs>(
34 diffs_ptr: *const llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat,
35 count: usize,
36) -> &'diffs [llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat] {
37 if count == 0 {
38 &[]
39 } else {
40 unsafe { slice::from_raw_parts(diffs_ptr, count) }
41 }
42}
43
44const fn check_json_not_null(json_ptr: *const c_char) -> Result<(), ChatParseError> {
45 if json_ptr.is_null() {
46 Err(ChatParseError::NullResult)
47 } else {
48 Ok(())
49 }
50}
51
52fn handle_diff_json_error(
53 status: llama_cpp_bindings_sys::llama_rs_status,
54 json_ptr: *mut c_char,
55) -> Result<(), ChatParseError> {
56 if !status_is_ok(status) {
57 if !json_ptr.is_null() {
58 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(json_ptr) };
59 }
60
61 return Err(ChatParseError::FfiError(status_to_i32(status)));
62 }
63
64 Ok(())
65}
66
67#[derive(Debug)]
69pub struct ChatParseStateOaicompat {
70 pub state: NonNull<llama_cpp_bindings_sys::llama_rs_chat_parse_state_oaicompat>,
72}
73
74impl ChatParseStateOaicompat {
75 pub fn update(
80 &mut self,
81 text_added: &str,
82 is_partial: bool,
83 ) -> Result<Vec<String>, ChatParseError> {
84 let text_cstr = CString::new(text_added)?;
85 let mut out_msg: llama_cpp_bindings_sys::llama_rs_chat_msg_oaicompat =
86 unsafe { mem::zeroed() };
87 let mut out_diffs: *mut llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat =
88 ptr::null_mut();
89 let mut out_diffs_count: usize = 0;
90 let rc = unsafe {
91 llama_cpp_bindings_sys::llama_rs_chat_parse_state_update_oaicompat(
92 self.state.as_ptr(),
93 text_cstr.as_ptr(),
94 is_partial,
95 &raw mut out_msg,
96 &raw mut out_diffs,
97 &raw mut out_diffs_count,
98 )
99 };
100
101 let result = {
102 check_ffi_status(rc)?;
103 check_not_null_with_count(out_diffs, out_diffs_count)?;
104
105 let diffs = unsafe { diffs_as_slice(out_diffs, out_diffs_count) };
106 let mut deltas = Vec::with_capacity(diffs.len());
107
108 for diff in diffs {
109 let mut out_json: *mut c_char = ptr::null_mut();
110 let rc = unsafe {
111 llama_cpp_bindings_sys::llama_rs_chat_msg_diff_to_oaicompat_json(
112 diff,
113 &raw mut out_json,
114 )
115 };
116 handle_diff_json_error(rc, out_json)?;
117 check_json_not_null(out_json)?;
118
119 let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
120 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_json) };
121 deltas.push(String::from_utf8(bytes)?);
122 }
123
124 Ok(deltas)
125 };
126
127 unsafe { llama_cpp_bindings_sys::llama_rs_chat_msg_free_oaicompat(&raw mut out_msg) };
128 unsafe {
129 llama_cpp_bindings_sys::llama_rs_chat_msg_diff_free_oaicompat(
130 out_diffs,
131 out_diffs_count,
132 );
133 };
134
135 result
136 }
137}
138
139impl Drop for ChatParseStateOaicompat {
140 fn drop(&mut self) {
141 unsafe {
142 llama_cpp_bindings_sys::llama_rs_chat_parse_state_free_oaicompat(self.state.as_ptr());
143 };
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use crate::model::chat_template_result::ChatTemplateResult;
150
151 fn content_only_template() -> ChatTemplateResult {
152 ChatTemplateResult::default()
153 }
154
155 #[test]
156 fn update_with_simple_text() {
157 let mut state = content_only_template().streaming_state_oaicompat().unwrap();
158 let deltas = state.update("Hello", true);
159 assert!(deltas.is_ok());
160 }
161
162 #[test]
163 fn update_null_byte_returns_error() {
164 let mut state = content_only_template().streaming_state_oaicompat().unwrap();
165 let result = state.update("hello\0world", true);
166 assert!(result.unwrap_err().to_string().contains("nul byte"));
167 }
168
169 #[test]
170 fn update_finalized_produces_deltas() {
171 let mut state = content_only_template().streaming_state_oaicompat().unwrap();
172 let deltas = state.update("Hello world", false).unwrap();
173
174 assert!(!deltas.is_empty());
175 }
176
177 #[test]
178 fn check_ffi_status_returns_error_for_invalid() {
179 let result =
180 super::check_ffi_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT);
181
182 assert!(result.unwrap_err().to_string().contains("ffi error"));
183 }
184
185 #[test]
186 fn check_not_null_with_count_returns_error() {
187 let result = super::check_not_null_with_count(std::ptr::null(), 1);
188
189 assert!(result.unwrap_err().to_string().contains("null result"));
190 }
191
192 #[test]
193 fn check_not_null_with_count_zero_is_ok() {
194 let result = super::check_not_null_with_count(std::ptr::null(), 0);
195
196 assert!(result.is_ok());
197 }
198
199 #[test]
200 fn check_json_not_null_returns_error() {
201 let result = super::check_json_not_null(std::ptr::null());
202
203 assert!(result.unwrap_err().to_string().contains("null result"));
204 }
205
206 #[test]
207 fn handle_diff_json_error_frees_and_returns_error() {
208 let result = super::handle_diff_json_error(
209 llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
210 std::ptr::null_mut(),
211 );
212
213 assert!(result.unwrap_err().to_string().contains("ffi error"));
214 }
215
216 #[test]
217 fn handle_diff_json_error_frees_non_null_pointer_on_error() {
218 let leaked_string = std::ffi::CString::new("test").unwrap().into_raw();
219 let result = super::handle_diff_json_error(
220 llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
221 leaked_string,
222 );
223
224 assert!(result.unwrap_err().to_string().contains("ffi error"));
225 }
226
227 #[test]
228 fn diffs_as_slice_returns_empty_for_zero_count() {
229 let result = unsafe { super::diffs_as_slice(std::ptr::null(), 0) };
230
231 assert!(result.is_empty());
232 }
233}