Skip to main content

llama_cpp_bindings/openai/
chat_parse_state_oaicompat.rs

1use std::ffi::{CStr, CString, c_char};
2use std::mem;
3use std::ptr::{self, NonNull};
4use std::slice;
5
6use crate::{ChatParseError, status_is_ok, status_to_i32};
7
8/// Streaming OpenAI-compatible parser state.
9#[derive(Debug)]
10pub struct ChatParseStateOaicompat {
11    /// Raw pointer to the underlying FFI parser state.
12    pub state: NonNull<llama_cpp_bindings_sys::llama_rs_chat_parse_state_oaicompat>,
13}
14
15impl ChatParseStateOaicompat {
16    /// Update the parser with additional text and return OpenAI-compatible deltas as JSON strings.
17    ///
18    /// # Errors
19    /// Returns an error if the FFI call fails or the result is null.
20    pub fn update(
21        &mut self,
22        text_added: &str,
23        is_partial: bool,
24    ) -> Result<Vec<String>, ChatParseError> {
25        let text_cstr = CString::new(text_added)?;
26        let mut out_msg: llama_cpp_bindings_sys::llama_rs_chat_msg_oaicompat =
27            unsafe { mem::zeroed() };
28        let mut out_diffs: *mut llama_cpp_bindings_sys::llama_rs_chat_msg_diff_oaicompat =
29            ptr::null_mut();
30        let mut out_diffs_count: usize = 0;
31        let rc = unsafe {
32            llama_cpp_bindings_sys::llama_rs_chat_parse_state_update_oaicompat(
33                self.state.as_ptr(),
34                text_cstr.as_ptr(),
35                is_partial,
36                &raw mut out_msg,
37                &raw mut out_diffs,
38                &raw mut out_diffs_count,
39            )
40        };
41
42        let result = {
43            if !status_is_ok(rc) {
44                return Err(ChatParseError::FfiError(status_to_i32(rc)));
45            }
46            if out_diffs_count > 0 && out_diffs.is_null() {
47                return Err(ChatParseError::NullResult);
48            }
49            let diffs = if out_diffs_count == 0 {
50                &[]
51            } else {
52                unsafe { slice::from_raw_parts(out_diffs, out_diffs_count) }
53            };
54            let mut deltas = Vec::with_capacity(diffs.len());
55
56            for diff in diffs {
57                let mut out_json: *mut c_char = ptr::null_mut();
58                let rc = unsafe {
59                    llama_cpp_bindings_sys::llama_rs_chat_msg_diff_to_oaicompat_json(
60                        diff,
61                        &raw mut out_json,
62                    )
63                };
64                if !status_is_ok(rc) {
65                    if !out_json.is_null() {
66                        unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_json) };
67                    }
68
69                    return Err(ChatParseError::FfiError(status_to_i32(rc)));
70                }
71                if out_json.is_null() {
72                    return Err(ChatParseError::NullResult);
73                }
74                let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
75                unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_json) };
76                deltas.push(String::from_utf8(bytes)?);
77            }
78
79            Ok(deltas)
80        };
81
82        unsafe { llama_cpp_bindings_sys::llama_rs_chat_msg_free_oaicompat(&raw mut out_msg) };
83        unsafe {
84            llama_cpp_bindings_sys::llama_rs_chat_msg_diff_free_oaicompat(
85                out_diffs,
86                out_diffs_count,
87            );
88        };
89
90        result
91    }
92}
93
94impl Drop for ChatParseStateOaicompat {
95    fn drop(&mut self) {
96        unsafe {
97            llama_cpp_bindings_sys::llama_rs_chat_parse_state_free_oaicompat(self.state.as_ptr())
98        };
99    }
100}