Skip to main content

llama_cpp_bindings/openai/
chat_parse_state_oaicompat.rs

1use std::ffi::{CStr, CString, c_char};
2use std::mem;
3use std::ptr::{self, NonNull};
4use std::slice;
5
6use crate::{ChatParseError, status_is_ok, status_to_i32};
7
8const fn check_ffi_status(
9    status: llama_cpp_bindings_sys::llama_rs_status,
10) -> Result<(), ChatParseError> {
11    if status_is_ok(status) {
12        Ok(())
13    } else {
14        Err(ChatParseError::FfiError(status_to_i32(status)))
15    }
16}
17
18const fn check_not_null_with_count(
19    pointer: *const llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat,
20    count: usize,
21) -> Result<(), ChatParseError> {
22    if count > 0 && pointer.is_null() {
23        Err(ChatParseError::NullResult)
24    } else {
25        Ok(())
26    }
27}
28
29/// # Safety
30///
31/// `diffs_ptr` must point to at least `count` valid `llama_rs_chat_msg_diff_oaicompat`
32/// values that remain valid for the lifetime `'diffs`.
33const unsafe fn diffs_as_slice<'diffs>(
34    diffs_ptr: *const llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat,
35    count: usize,
36) -> &'diffs [llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat] {
37    if count == 0 {
38        &[]
39    } else {
40        unsafe { slice::from_raw_parts(diffs_ptr, count) }
41    }
42}
43
44const fn check_json_not_null(json_ptr: *const c_char) -> Result<(), ChatParseError> {
45    if json_ptr.is_null() {
46        Err(ChatParseError::NullResult)
47    } else {
48        Ok(())
49    }
50}
51
52fn handle_diff_json_error(
53    status: llama_cpp_bindings_sys::llama_rs_status,
54    json_ptr: *mut c_char,
55) -> Result<(), ChatParseError> {
56    if !status_is_ok(status) {
57        if !json_ptr.is_null() {
58            unsafe { llama_cpp_bindings_sys::llama_rs_string_free(json_ptr) };
59        }
60
61        return Err(ChatParseError::FfiError(status_to_i32(status)));
62    }
63
64    Ok(())
65}
66
67/// Streaming OpenAI-compatible parser state.
68#[derive(Debug)]
69pub struct ChatParseStateOaicompat {
70    /// Raw pointer to the underlying FFI parser state.
71    pub state: NonNull<llama_cpp_bindings_sys::llama_rs_chat_parse_state_oaicompat>,
72}
73
74impl ChatParseStateOaicompat {
75    /// Update the parser with additional text and return OpenAI-compatible deltas as JSON strings.
76    ///
77    /// # Errors
78    /// Returns an error if the FFI call fails or the result is null.
79    pub fn update(
80        &mut self,
81        text_added: &str,
82        is_partial: bool,
83    ) -> Result<Vec<String>, ChatParseError> {
84        let text_cstr = CString::new(text_added)?;
85        let mut out_msg: llama_cpp_bindings_sys::llama_rs_chat_msg_oaicompat =
86            unsafe { mem::zeroed() };
87        let mut out_diffs: *mut llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat =
88            ptr::null_mut();
89        let mut out_diffs_count: usize = 0;
90        let rc = unsafe {
91            llama_cpp_bindings_sys::llama_rs_chat_parse_state_update_oaicompat(
92                self.state.as_ptr(),
93                text_cstr.as_ptr(),
94                is_partial,
95                &raw mut out_msg,
96                &raw mut out_diffs,
97                &raw mut out_diffs_count,
98            )
99        };
100
101        let result = {
102            check_ffi_status(rc)?;
103            check_not_null_with_count(out_diffs, out_diffs_count)?;
104
105            let diffs = unsafe { diffs_as_slice(out_diffs, out_diffs_count) };
106            let mut deltas = Vec::with_capacity(diffs.len());
107
108            for diff in diffs {
109                let mut out_json: *mut c_char = ptr::null_mut();
110                let rc = unsafe {
111                    llama_cpp_bindings_sys::llama_rs_chat_msg_diff_to_oaicompat_json(
112                        diff,
113                        &raw mut out_json,
114                    )
115                };
116                handle_diff_json_error(rc, out_json)?;
117                check_json_not_null(out_json)?;
118
119                let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
120                unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_json) };
121                deltas.push(String::from_utf8(bytes)?);
122            }
123
124            Ok(deltas)
125        };
126
127        unsafe { llama_cpp_bindings_sys::llama_rs_chat_msg_free_oaicompat(&raw mut out_msg) };
128        unsafe {
129            llama_cpp_bindings_sys::llama_rs_chat_msg_diff_free_oaicompat(
130                out_diffs,
131                out_diffs_count,
132            );
133        };
134
135        result
136    }
137}
138
139impl Drop for ChatParseStateOaicompat {
140    fn drop(&mut self) {
141        unsafe {
142            llama_cpp_bindings_sys::llama_rs_chat_parse_state_free_oaicompat(self.state.as_ptr());
143        };
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use crate::model::chat_template_result::ChatTemplateResult;
150
151    fn content_only_template() -> ChatTemplateResult {
152        ChatTemplateResult::default()
153    }
154
155    #[test]
156    fn update_with_simple_text() {
157        let mut state = content_only_template().streaming_state_oaicompat().unwrap();
158        let deltas = state.update("Hello", true);
159        assert!(deltas.is_ok());
160    }
161
162    #[test]
163    fn update_null_byte_returns_error() {
164        let mut state = content_only_template().streaming_state_oaicompat().unwrap();
165        let result = state.update("hello\0world", true);
166        assert!(result.unwrap_err().to_string().contains("nul byte"));
167    }
168
169    #[test]
170    fn update_finalized_produces_deltas() {
171        let mut state = content_only_template().streaming_state_oaicompat().unwrap();
172        let deltas = state.update("Hello world", false).unwrap();
173
174        assert!(!deltas.is_empty());
175    }
176
177    #[test]
178    fn check_ffi_status_returns_error_for_invalid() {
179        let result =
180            super::check_ffi_status(llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT);
181
182        assert!(result.unwrap_err().to_string().contains("ffi error"));
183    }
184
185    #[test]
186    fn check_not_null_with_count_returns_error() {
187        let result = super::check_not_null_with_count(std::ptr::null(), 1);
188
189        assert!(result.unwrap_err().to_string().contains("null result"));
190    }
191
192    #[test]
193    fn check_not_null_with_count_zero_is_ok() {
194        let result = super::check_not_null_with_count(std::ptr::null(), 0);
195
196        assert!(result.is_ok());
197    }
198
199    #[test]
200    fn check_json_not_null_returns_error() {
201        let result = super::check_json_not_null(std::ptr::null());
202
203        assert!(result.unwrap_err().to_string().contains("null result"));
204    }
205
206    #[test]
207    fn handle_diff_json_error_frees_and_returns_error() {
208        let result = super::handle_diff_json_error(
209            llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
210            std::ptr::null_mut(),
211        );
212
213        assert!(result.unwrap_err().to_string().contains("ffi error"));
214    }
215
216    #[test]
217    fn handle_diff_json_error_frees_non_null_pointer_on_error() {
218        let leaked_string = std::ffi::CString::new("test").unwrap().into_raw();
219        let result = super::handle_diff_json_error(
220            llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
221            leaked_string,
222        );
223
224        assert!(result.unwrap_err().to_string().contains("ffi error"));
225    }
226
227    #[test]
228    fn diffs_as_slice_returns_empty_for_zero_count() {
229        let result = unsafe { super::diffs_as_slice(std::ptr::null(), 0) };
230
231        assert!(result.is_empty());
232    }
233}