Skip to main content

llama_cpp_2/
openai.rs

1//! OpenAI Specific Utility methods.
2use crate::{status_is_ok, status_to_i32, ChatParseError};
3use std::ffi::{c_char, CStr, CString};
4use std::mem;
5use std::ptr::{self, NonNull};
6use std::slice;
7
8/// Parameters for applying OpenAI-compatible chat templates.
9#[derive(Debug, Clone, PartialEq)]
10pub struct OpenAIChatTemplateParams<'a> {
11    /// OpenAI-compatible messages JSON array.
12    pub messages_json: &'a str,
13    /// Optional OpenAI-compatible tools JSON array.
14    pub tools_json: Option<&'a str>,
15    /// Optional tool choice string.
16    pub tool_choice: Option<&'a str>,
17    /// Optional JSON schema string for tool grammar generation.
18    pub json_schema: Option<&'a str>,
19    /// Optional custom grammar string.
20    pub grammar: Option<&'a str>,
21    /// Optional reasoning format string.
22    pub reasoning_format: Option<&'a str>,
23    /// Optional chat template kwargs JSON object.
24    pub chat_template_kwargs: Option<&'a str>,
25    /// Whether to add the assistant generation prompt.
26    pub add_generation_prompt: bool,
27    /// Whether to render templates with Jinja.
28    pub use_jinja: bool,
29    /// Whether to allow parallel tool calls.
30    pub parallel_tool_calls: bool,
31    /// Whether thinking blocks are enabled.
32    pub enable_thinking: bool,
33    /// Whether to add BOS.
34    pub add_bos: bool,
35    /// Whether to add EOS.
36    pub add_eos: bool,
37    /// Whether to parse tool calls in responses.
38    pub parse_tool_calls: bool,
39}
40
41/// Streaming OpenAI-compatible parser state.
42#[derive(Debug)]
43pub struct ChatParseStateOaicompat {
44    pub(crate) state: NonNull<llama_cpp_sys_2::llama_rs_chat_parse_state_oaicompat>,
45}
46
47impl ChatParseStateOaicompat {
48    /// Update the parser with additional text and return OpenAI-compatible deltas as JSON strings.
49    pub fn update(
50        &mut self,
51        text_added: &str,
52        is_partial: bool,
53    ) -> Result<Vec<String>, ChatParseError> {
54        let text_cstr = CString::new(text_added)?;
55        let mut out_msg: llama_cpp_sys_2::llama_rs_chat_msg_oaicompat = unsafe { mem::zeroed() };
56        let mut out_diffs: *mut llama_cpp_sys_2::llama_rs_chat_msg_diff_oaicompat = ptr::null_mut();
57        let mut out_diffs_count: usize = 0;
58        let rc = unsafe {
59            llama_cpp_sys_2::llama_rs_chat_parse_state_update_oaicompat(
60                self.state.as_ptr(),
61                text_cstr.as_ptr(),
62                is_partial,
63                &mut out_msg,
64                &mut out_diffs,
65                &mut out_diffs_count,
66            )
67        };
68
69        let result = {
70            if !status_is_ok(rc) {
71                return Err(ChatParseError::FfiError(status_to_i32(rc)));
72            }
73            if out_diffs_count > 0 && out_diffs.is_null() {
74                return Err(ChatParseError::NullResult);
75            }
76            let diffs = if out_diffs_count == 0 {
77                &[]
78            } else {
79                unsafe { slice::from_raw_parts(out_diffs, out_diffs_count) }
80            };
81            let mut deltas = Vec::with_capacity(diffs.len());
82            for diff in diffs {
83                let mut out_json: *mut c_char = ptr::null_mut();
84                let rc = unsafe {
85                    llama_cpp_sys_2::llama_rs_chat_msg_diff_to_oaicompat_json(diff, &mut out_json)
86                };
87                if !status_is_ok(rc) {
88                    if !out_json.is_null() {
89                        unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
90                    }
91                    return Err(ChatParseError::FfiError(status_to_i32(rc)));
92                }
93                if out_json.is_null() {
94                    return Err(ChatParseError::NullResult);
95                }
96                let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
97                unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
98                deltas.push(String::from_utf8(bytes)?);
99            }
100            Ok(deltas)
101        };
102
103        unsafe { llama_cpp_sys_2::llama_rs_chat_msg_free_oaicompat(&mut out_msg) };
104        unsafe {
105            llama_cpp_sys_2::llama_rs_chat_msg_diff_free_oaicompat(out_diffs, out_diffs_count)
106        };
107        result
108    }
109}
110
111impl Drop for ChatParseStateOaicompat {
112    fn drop(&mut self) {
113        unsafe { llama_cpp_sys_2::llama_rs_chat_parse_state_free_oaicompat(self.state.as_ptr()) };
114    }
115}