openai_safe/
openai.rs

1use anyhow::{anyhow, Result};
2use log::{error, info, warn};
3use reqwest::{header, Client};
4use schemars::{schema_for, JsonSchema};
5use serde::{de::DeserializeOwned, Serialize};
6use serde_json::Value;
7
8use crate::{domain::OpenAIDataResponse, models::OpenAIModels, utils::get_tokenizer};
9
10/// [Chat Completions API](https://platform.openai.com/docs/guides/text-generation/chat-completions-api)
11///
12/// Chat models take a list of messages as input and return a model-generated message as output.
13/// Although the chat format is designed to make multi-turn conversations easy,
14/// it’s just as useful for single-turn tasks without any conversation.
15pub struct OpenAI {
16    model: OpenAIModels,
17    //For prompt & response
18    max_tokens: usize,
19    temperature: u32,
20    input_json: Option<String>,
21    debug: bool,
22    function_call: bool,
23    api_key: String,
24}
25
26impl OpenAI {
27    ///
28    pub fn new(
29        open_ai_key: &str,
30        model: OpenAIModels,
31        max_tokens: Option<usize>,
32        temperature: Option<u32>,
33    ) -> Self {
34        OpenAI {
35            //If no max tokens limit is provided we default to max allowed for the model
36            max_tokens: max_tokens.unwrap_or_else(|| model.default_max_tokens()),
37            function_call: model.function_call_default(),
38            model,
39            temperature: temperature.unwrap_or(0u32), //Low number makes the output less random and more deterministic
40            input_json: None,
41            debug: false,
42            api_key: open_ai_key.to_string(),
43        }
44    }
45
46    /*
47     * This function turns on debug mode which will info! the prompt to log when executing it.
48     */
49    pub fn debug(mut self) -> Self {
50        self.debug = true;
51        self
52    }
53
54    /*
55     * This function turns on/off function calling mode when interacting with OpenAI API.
56     */
57    pub fn function_calling(mut self, function_call: bool) -> Self {
58        self.function_call = function_call;
59        self
60    }
61
62    /*
63     * This method can be used to provide values that will be used as context for the prompt.
64     * Using this function you can provide multiple input values by calling it multiple times. New values will be appended with the category name
65     * It accepts any instance that implements the Serialize trait.
66     */
67    pub fn set_context<T: Serialize>(mut self, input_name: &str, input_data: &T) -> Result<Self> {
68        let input_json = if let Ok(json) = serde_json::to_string(&input_data) {
69            json
70        } else {
71            return Err(anyhow!("Unable serialize provided input data."));
72        };
73        let line_break = match self.input_json {
74            Some(_) => "\n\n".to_string(),
75            None => "".to_string(),
76        };
77        let new_json = format!(
78            "{}{}{}: {}",
79            self.input_json.unwrap_or_default(),
80            line_break,
81            input_name,
82            input_json,
83        );
84        self.input_json = Some(new_json);
85        Ok(self)
86    }
87
88    /*
89     * This method is used to check how many tokens would most likely remain for the response
90     * This is accomplished by estimating number of tokens needed for system/base instructions, user prompt, and function components including schema definition.
91     */
92    pub fn check_prompt_tokens<T: JsonSchema + DeserializeOwned>(
93        &self,
94        instructions: &str,
95    ) -> Result<usize> {
96        //Output schema is extracted from the type parameter
97        let schema = schema_for!(T);
98        let json_value: Value = serde_json::to_value(&schema)?;
99
100        let prompt = format!(
101            "Instructions:
102            {instructions}
103
104            Input data:
105            {input_json}
106            
107            Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.", 
108            instructions = instructions,
109            input_json = self.input_json.clone().unwrap_or_default(),
110        );
111
112        let full_prompt = format!(
113            "{}{}{}",
114            //Base (system) instructions
115            self.model.get_base_instructions(Some(self.function_call)),
116            //Instructions & context data
117            prompt,
118            //Output schema
119            serde_json::to_string(&json_value).unwrap_or_default()
120        );
121
122        //Check how many tokens are required for prompt
123        let bpe = get_tokenizer(&self.model)?;
124        let prompt_tokens = bpe.encode_with_special_tokens(&full_prompt).len();
125
126        //Assuming another 5% overhead for json formatting
127        Ok((prompt_tokens as f64 * 1.05) as usize)
128    }
129
130    /*
131     * This function leverages OpenAI API to perform any query as per the provided body.
132     *
133     * It returns a String the Response object that needs to be parsed based on the self.model.
134     */
135    async fn call_openai_api(&self, body: &serde_json::Value) -> Result<String> {
136        //Get the API url
137        let model_url = self.model.get_endpoint();
138
139        //Make the API call
140        let client = Client::new();
141
142        let response = client
143            .post(model_url)
144            .header(header::CONTENT_TYPE, "application/json")
145            .bearer_auth(&self.api_key)
146            .json(&body)
147            .send()
148            .await?;
149
150        let response_status = response.status();
151        let response_text = response.text().await?;
152
153        if self.debug {
154            info!(
155                "[debug] OpenAI API response: [{}] {:#?}",
156                &response_status, &response_text
157            );
158        }
159
160        Ok(response_text)
161    }
162
163    /*
164     * This method is used to submit a prompt to OpenAI and process the response.
165     * When calling the function you need to specify the type parameter as the response will match the schema of that type.
166     * The prompt in this function is written in a way to instruct OpenAI to behave like a computer function that calculates an output based on provided input and its language model.
167     */
168    pub async fn get_answer<T: JsonSchema + DeserializeOwned>(
169        self,
170        instructions: &str,
171    ) -> Result<T> {
172        //Output schema is extracted from the type parameter
173        let schema = schema_for!(T);
174        let json_value: Value = serde_json::to_value(&schema)?;
175
176        let prompt = format!(
177            "Instructions:
178            {instructions}
179
180            Input data:
181            {input_json}
182            
183            Respond ONLY with the data portion of a valid Json object. No schema definition required. No other words.", 
184            instructions = instructions,
185            input_json = self.input_json.clone().unwrap_or_default(),
186        );
187
188        //Validate how many tokens remain for the response (and how many are used for prompt)
189        let prompt_tokens = self
190            .check_prompt_tokens::<T>(instructions)
191            .unwrap_or_default();
192
193        if prompt_tokens >= self.max_tokens {
194            return Err(anyhow!(
195                "The provided prompt requires more tokens than allocated."
196            ));
197        }
198        let response_tokens = self.max_tokens - prompt_tokens;
199
200        //Throw a warning if after processing the prompt there might be not enough tokens for response
201        //This assumes response will be similar size as input. Because this is not always correct this is a warning and not an error
202        if prompt_tokens * 2 >= self.max_tokens {
203            warn!(
204                "{} tokens remaining for response: {} allocated, {} used for prompt",
205                response_tokens.to_string(),
206                self.max_tokens.to_string(),
207                prompt_tokens.to_string(),
208            );
209        };
210
211        //Build the API body depending on the used model
212        let model_body = self.model.get_body(
213            &prompt,
214            &json_value,
215            self.function_call,
216            &response_tokens,
217            &self.temperature,
218        );
219
220        //Display debug info if requested
221        if self.debug {
222            info!("[debug] Model body: {:#?}", model_body);
223            info!(
224                "[debug] Prompt accounts for approx {} tokens, leaving {} tokens for answer.",
225                prompt_tokens.to_string(),
226                response_tokens.to_string(),
227            );
228        }
229
230        let response_text = self.call_openai_api(&model_body).await?;
231
232        //Extract data from the returned response text based on the used model
233        let response_string = self.model.get_data(&response_text, self.function_call)?;
234
235        if self.debug {
236            info!("[debug] OpenAI response data: {}", response_string);
237        }
238        //Deserialize the string response into the expected output type
239        let response_deser: anyhow::Result<T, anyhow::Error> =
240            serde_json::from_str(&response_string).map_err(|error| {
241                error!("[OpenAI] Response serialization error: {}", &error);
242                anyhow!("Error: {}", error)
243            });
244        // Sometimes openai responds with a json object that has a data property. If that's the case, we need to extract the data property and deserialize that.
245        if let Err(_e) = response_deser {
246            let response_deser: OpenAIDataResponse<T> = serde_json::from_str(&response_text)
247                .map_err(|error| {
248                    error!("[OpenAI] Response serialization error: {}", &error);
249                    anyhow!("Error: {}", error)
250                })?;
251            Ok(response_deser.data)
252        } else {
253            Ok(response_deser.unwrap())
254        }
255    }
256}